]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add webseed response content length checks
authorMatt Joiner <anacrolix@gmail.com>
Wed, 21 May 2025 02:42:26 +0000 (12:42 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 21 May 2025 02:42:26 +0000 (12:42 +1000)
webseed/client.go

index 36f2f569cc860c24c6cb5fb95042eef266fa110a..ca59b6115358d7ed649418de4b955c08f66af76a 100644 (file)
@@ -6,10 +6,12 @@ import (
        "fmt"
        "io"
        "log"
+       "log/slog"
        "net/http"
        "strings"
 
        "github.com/RoaringBitmap/roaring"
+
        "github.com/anacrolix/missinggo/v2/panicif"
 
        "github.com/anacrolix/torrent/metainfo"
@@ -38,6 +40,7 @@ func (r Request) Cancel() {
 }
 
 type Client struct {
+       Logger     *slog.Logger
        HttpClient *http.Client
        Url        string
        // Max concurrent requests to a WebSeed for a given torrent.
@@ -104,7 +107,7 @@ func (ws *Client) StartNewRequest(r RequestSpec) Request {
                Body:   body,
        }
        go func() {
-               err := readRequestPartResponses(ctx, w, requestParts)
+               err := ws.readRequestPartResponses(ctx, w, requestParts)
                panicif.Err(w.CloseWithError(err))
        }()
        return req
@@ -119,8 +122,26 @@ func (me ErrBadResponse) Error() string {
        return me.Msg
 }
 
+// Warn about bad content-lengths.
+func (me *Client) checkContentLength(resp *http.Response, part requestPart, expectedLen int64) {
+       if resp.ContentLength == -1 {
+               return
+       }
+       switch resp.Header.Get("Content-Encoding") {
+       case "identity", "":
+       default:
+               return
+       }
+       if resp.ContentLength != expectedLen {
+               me.Logger.Warn("unexpected identity response Content-Length value",
+                       "actual", resp.ContentLength,
+                       "expected", expectedLen,
+                       "url", part.req.URL)
+       }
+}
+
 // Reads the part in full. All expected bytes must be returned or there will an error returned.
-func recvPartResult(ctx context.Context, w io.Writer, part requestPart, resp *http.Response) error {
+func (me *Client) recvPartResult(ctx context.Context, w io.Writer, part requestPart, resp *http.Response) error {
        defer resp.Body.Close()
        var body io.Reader = resp.Body
        if part.responseBodyWrapper != nil {
@@ -133,6 +154,8 @@ func recvPartResult(ctx context.Context, w io.Writer, part requestPart, resp *ht
        }
        switch resp.StatusCode {
        case http.StatusPartialContent:
+               // The response should be just as long as we requested.
+               me.checkContentLength(resp, part, part.e.Length)
                copied, err := io.Copy(w, body)
                if err != nil {
                        return err
@@ -142,6 +165,8 @@ func recvPartResult(ctx context.Context, w io.Writer, part requestPart, resp *ht
                }
                return nil
        case http.StatusOK:
+               // The response is from the beginning.
+               me.checkContentLength(resp, part, part.e.End())
                // This number is based on
                // https://archive.org/download/BloodyPitOfHorror/BloodyPitOfHorror.asr.srt. It seems that
                // archive.org might be using a webserver implementation that refuses to do partial
@@ -168,6 +193,7 @@ func recvPartResult(ctx context.Context, w io.Writer, part requestPart, resp *ht
                _, err = io.CopyN(w, body, part.e.Length)
                return err
        case http.StatusServiceUnavailable:
+               // TODO: Include all of Erigon's cases here?
                return ErrTooFast
        default:
                // TODO: Could we have a slog.Valuer or something to allow callers to unpack reasonable values?
@@ -180,12 +206,12 @@ func recvPartResult(ctx context.Context, w io.Writer, part requestPart, resp *ht
 
 var ErrTooFast = errors.New("making requests too fast")
 
-func readRequestPartResponses(ctx context.Context, w io.Writer, parts []requestPart) (err error) {
+func (me *Client) readRequestPartResponses(ctx context.Context, w io.Writer, parts []requestPart) (err error) {
        for _, part := range parts {
                var resp *http.Response
                resp, err = part.do()
                if err == nil {
-                       err = recvPartResult(ctx, w, part, resp)
+                       err = me.recvPartResult(ctx, w, part, resp)
                }
                if err != nil {
                        err = fmt.Errorf("reading %q at %q: %w", part.req.URL, part.req.Header.Get("Range"), err)