From: Mark Holt Date: Sun, 21 Apr 2024 10:49:49 +0000 (+0100) Subject: close body in same go routine as request X-Git-Tag: v1.56.0~18^2 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=77dd51f1ed8a01606ab65c5abe397baa7745237f;p=btrtrc.git close body in same go routine as request --- diff --git a/webseed/client.go b/webseed/client.go index 4614a3e4..5e3c28b8 100644 --- a/webseed/client.go +++ b/webseed/client.go @@ -19,16 +19,10 @@ import ( type RequestSpec = segments.Extent -type requestPartResult struct { - resp *http.Response - err error -} - type requestPart struct { - req *http.Request - e segments.Extent - result chan requestPartResult - start func() + req *http.Request + e segments.Extent + do func() (*http.Response, error) // Wrap http response bodies for such things as download rate limiting. responseBodyWrapper ResponseBodyWrapper } @@ -88,18 +82,11 @@ func (ws *Client) NewRequest(r RequestSpec) Request { } part := requestPart{ req: req, - result: make(chan requestPartResult, 1), e: e, responseBodyWrapper: ws.ResponseBodyWrapper, } - part.start = func() { - go func() { - resp, err := ws.HttpClient.Do(req) - part.result <- requestPartResult{ - resp: resp, - err: err, - } - }() + part.do = func() (*http.Response, error) { + return ws.HttpClient.Do(req) } requestParts = append(requestParts, part) return true @@ -129,24 +116,18 @@ func (me ErrBadResponse) Error() string { return me.Msg } -func recvPartResult(ctx context.Context, buf io.Writer, part requestPart) error { - result := <-part.result - // Make sure there's no further results coming, it should be a one-shot channel. - close(part.result) - if result.err != nil { - return result.err - } - defer result.resp.Body.Close() - var body io.Reader = result.resp.Body +func recvPartResult(ctx context.Context, buf io.Writer, part requestPart, resp *http.Response) error { + defer resp.Body.Close() + var body io.Reader = resp.Body if part.responseBodyWrapper != nil { body = part.responseBodyWrapper(body) } // Prevent further accidental use - result.resp.Body = nil + resp.Body = nil if ctx.Err() != nil { return ctx.Err() } - switch result.resp.StatusCode { + switch resp.StatusCode { case http.StatusPartialContent: copied, err := io.Copy(buf, body) if err != nil { @@ -178,14 +159,14 @@ func recvPartResult(ctx context.Context, buf io.Writer, part requestPart) error _, err := io.CopyN(buf, body, part.e.Length) return err } else { - return ErrBadResponse{"resp status ok but requested range", result.resp} + return ErrBadResponse{"resp status ok but requested range", resp} } case http.StatusServiceUnavailable: return ErrTooFast default: return ErrBadResponse{ - fmt.Sprintf("unhandled response status code (%v)", result.resp.StatusCode), - result.resp, + fmt.Sprintf("unhandled response status code (%v)", resp.StatusCode), + resp, } } } @@ -195,12 +176,16 @@ var ErrTooFast = errors.New("making requests too fast") func readRequestPartResponses(ctx context.Context, parts []requestPart) (_ []byte, err error) { var buf bytes.Buffer for _, part := range parts { - part.start() - err = recvPartResult(ctx, &buf, part) + result, err := part.do() + + if err == nil { + err = recvPartResult(ctx, &buf, part, result) + } + if err != nil { err = fmt.Errorf("reading %q at %q: %w", part.req.URL, part.req.Header.Get("Range"), err) break } } return buf.Bytes(), err -} +} \ No newline at end of file