]> Sergey Matveev's repositories - btrtrc.git/commitdiff
webseed: Close unused part responses after error
authorMatt Joiner <anacrolix@gmail.com>
Fri, 12 Nov 2021 01:40:15 +0000 (12:40 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 12 Nov 2021 01:40:15 +0000 (12:40 +1100)
Also don't bother to read their response bodies.

webseed/client.go

index 3a03fb1b05956ff81b8f3e09488411a9615e5d76..6acf90d88098f54c3055e2dd93b250f418dded41 100644 (file)
@@ -94,7 +94,7 @@ func (ws *Client) NewRequest(r RequestSpec) Request {
                Result: make(chan RequestResult, 1),
        }
        go func() {
-               b, err := readRequestPartResponses(requestParts)
+               b, err := readRequestPartResponses(ctx, requestParts)
                req.Result <- RequestResult{
                        Bytes: b,
                        Err:   err,
@@ -112,12 +112,15 @@ func (me ErrBadResponse) Error() string {
        return me.Msg
 }
 
-func recvPartResult(buf io.Writer, part requestPart) error {
+func recvPartResult(ctx context.Context, buf io.Writer, part requestPart) error {
        result := <-part.result
        if result.err != nil {
                return result.err
        }
        defer result.resp.Body.Close()
+       if ctx.Err() != nil {
+               return ctx.Err()
+       }
        switch result.resp.StatusCode {
        case http.StatusPartialContent:
        case http.StatusOK:
@@ -140,13 +143,29 @@ func recvPartResult(buf io.Writer, part requestPart) error {
        return nil
 }
 
-func readRequestPartResponses(parts []requestPart) ([]byte, error) {
+func readRequestPartResponses(ctx context.Context, parts []requestPart) ([]byte, error) {
+       ctx, cancel := context.WithCancel(ctx)
+       defer cancel()
        var buf bytes.Buffer
-       for _, part := range parts {
-               err := recvPartResult(&buf, part)
-               if err != nil {
-                       return buf.Bytes(), fmt.Errorf("reading %q at %q: %w", part.req.URL, part.req.Header.Get("Range"), err)
+       firstErr := make(chan error, 1)
+       go func() {
+               for _, part := range parts {
+                       err := recvPartResult(ctx, &buf, part)
+                       if err != nil {
+                               // Ensure no further unnecessary response reads occur.
+                               cancel()
+                               select {
+                               case firstErr <- fmt.Errorf("reading %q at %q: %w", part.req.URL, part.req.Header.Get("Range"), err):
+                               default:
+                               }
+                       }
                }
-       }
-       return buf.Bytes(), nil
+               select {
+               case firstErr <- nil:
+               default:
+               }
+       }()
+       // This can't be merged into the return statement, because buf.Bytes is called first!
+       err := <-firstErr
+       return buf.Bytes(), err
 }