]> Sergey Matveev's repositories - btrtrc.git/commitdiff
close body in same go routine as request
authorMark Holt <mark@distributed.vision>
Sun, 21 Apr 2024 10:49:49 +0000 (11:49 +0100)
committerMark Holt <mark@distributed.vision>
Sun, 21 Apr 2024 10:49:49 +0000 (11:49 +0100)
webseed/client.go

index 4614a3e407a9bc1adb5c55e2a1cd902eccbd7e64..5e3c28b88522c22068e91e10234d552122f62cb9 100644 (file)
@@ -19,16 +19,10 @@ import (
 
 type RequestSpec = segments.Extent
 
-type requestPartResult struct {
-       resp *http.Response
-       err  error
-}
-
 type requestPart struct {
-       req    *http.Request
-       e      segments.Extent
-       result chan requestPartResult
-       start  func()
+       req *http.Request
+       e   segments.Extent
+       do  func() (*http.Response, error)
        // Wrap http response bodies for such things as download rate limiting.
        responseBodyWrapper ResponseBodyWrapper
 }
@@ -88,18 +82,11 @@ func (ws *Client) NewRequest(r RequestSpec) Request {
                }
                part := requestPart{
                        req:                 req,
-                       result:              make(chan requestPartResult, 1),
                        e:                   e,
                        responseBodyWrapper: ws.ResponseBodyWrapper,
                }
-               part.start = func() {
-                       go func() {
-                               resp, err := ws.HttpClient.Do(req)
-                               part.result <- requestPartResult{
-                                       resp: resp,
-                                       err:  err,
-                               }
-                       }()
+               part.do = func() (*http.Response, error) {
+                       return ws.HttpClient.Do(req)
                }
                requestParts = append(requestParts, part)
                return true
@@ -129,24 +116,18 @@ func (me ErrBadResponse) Error() string {
        return me.Msg
 }
 
-func recvPartResult(ctx context.Context, buf io.Writer, part requestPart) error {
-       result := <-part.result
-       // Make sure there's no further results coming, it should be a one-shot channel.
-       close(part.result)
-       if result.err != nil {
-               return result.err
-       }
-       defer result.resp.Body.Close()
-       var body io.Reader = result.resp.Body
+func recvPartResult(ctx context.Context, buf io.Writer, part requestPart, resp *http.Response) error {
+       defer resp.Body.Close()
+       var body io.Reader = resp.Body
        if part.responseBodyWrapper != nil {
                body = part.responseBodyWrapper(body)
        }
        // Prevent further accidental use
-       result.resp.Body = nil
+       resp.Body = nil
        if ctx.Err() != nil {
                return ctx.Err()
        }
-       switch result.resp.StatusCode {
+       switch resp.StatusCode {
        case http.StatusPartialContent:
                copied, err := io.Copy(buf, body)
                if err != nil {
@@ -178,14 +159,14 @@ func recvPartResult(ctx context.Context, buf io.Writer, part requestPart) error
                        _, err := io.CopyN(buf, body, part.e.Length)
                        return err
                } else {
-                       return ErrBadResponse{"resp status ok but requested range", result.resp}
+                       return ErrBadResponse{"resp status ok but requested range", resp}
                }
        case http.StatusServiceUnavailable:
                return ErrTooFast
        default:
                return ErrBadResponse{
-                       fmt.Sprintf("unhandled response status code (%v)", result.resp.StatusCode),
-                       result.resp,
+                       fmt.Sprintf("unhandled response status code (%v)", resp.StatusCode),
+                       resp,
                }
        }
 }
@@ -195,12 +176,16 @@ var ErrTooFast = errors.New("making requests too fast")
 func readRequestPartResponses(ctx context.Context, parts []requestPart) (_ []byte, err error) {
        var buf bytes.Buffer
        for _, part := range parts {
-               part.start()
-               err = recvPartResult(ctx, &buf, part)
+               result, err := part.do()
+
+               if err == nil {
+                       err = recvPartResult(ctx, &buf, part, result)
+               }
+
                if err != nil {
                        err = fmt.Errorf("reading %q at %q: %w", part.req.URL, part.req.Header.Get("Range"), err)
                        break
                }
        }
        return buf.Bytes(), err
-}
+}
\ No newline at end of file