]> Sergey Matveev's repositories - btrtrc.git/blobdiff - tracker/udp/client.go
Quote UDP tracker response error bodies with %#q
[btrtrc.git] / tracker / udp / client.go
index d66348e16928d4bd48a2499cf44340093461154c..0095a912aaac378995dd359f4bfa57d6d53e5380 100644 (file)
@@ -4,13 +4,19 @@ import (
        "bytes"
        "context"
        "encoding/binary"
-       "errors"
        "fmt"
        "io"
+       "net"
+       "sync"
        "time"
+
+       "github.com/anacrolix/dht/v2/krpc"
 )
 
+// Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
+// connection specifics.
 type Client struct {
+       mu           sync.Mutex
        connId       ConnectionId
        connIdIssued time.Time
        Dispatcher   *Dispatcher
@@ -18,15 +24,16 @@ type Client struct {
 }
 
 func (cl *Client) Announce(
-       ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
+       ctx context.Context, req AnnounceRequest, opts Options,
+       // Decides whether the response body is IPv6 or IPv4, see BEP 15.
+       ipv6 func(net.Addr) bool,
 ) (
-       respHdr AnnounceResponseHeader, err error,
+       respHdr AnnounceResponseHeader,
+       // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
+       peers AnnounceResponsePeers,
+       err error,
 ) {
-       body, err := marshal(req)
-       if err != nil {
-               return
-       }
-       respBody, err := cl.request(ctx, ActionAnnounce, append(body, opts.Encode()...))
+       respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
        if err != nil {
                return
        }
@@ -36,6 +43,11 @@ func (cl *Client) Announce(
                err = fmt.Errorf("reading response header: %w", err)
                return
        }
+       if ipv6(addr) {
+               peers = &krpc.CompactIPv6NodeAddrs{}
+       } else {
+               peers = &krpc.CompactIPv4NodeAddrs{}
+       }
        err = peers.UnmarshalBinary(r.Bytes())
        if err != nil {
                err = fmt.Errorf("reading response peers: %w", err)
@@ -43,11 +55,41 @@ func (cl *Client) Announce(
        return
 }
 
+// There's no way to pass options in a scrape, since we don't when the request body ends.
+func (cl *Client) Scrape(
+       ctx context.Context, ihs []InfoHash,
+) (
+       out ScrapeResponse, err error,
+) {
+       respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
+       if err != nil {
+               return
+       }
+       r := bytes.NewBuffer(respBody)
+       for r.Len() != 0 {
+               var item ScrapeInfohashResult
+               err = Read(r, &item)
+               if err != nil {
+                       return
+               }
+               out = append(out, item)
+       }
+       if len(out) > len(ihs) {
+               err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
+               return
+       }
+       return
+}
+
 func (cl *Client) connect(ctx context.Context) (err error) {
-       if time.Since(cl.connIdIssued) < time.Minute {
+       // We could get fancier here and use RWMutex, and even fire off the connection asynchronously
+       // and provide a grace period while it resolves.
+       cl.mu.Lock()
+       defer cl.mu.Unlock()
+       if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
                return nil
        }
-       respBody, err := cl.request(ctx, ActionConnect, nil)
+       respBody, _, err := cl.request(ctx, ActionConnect, nil)
        if err != nil {
                return err
        }
@@ -104,7 +146,7 @@ func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte,
        }
 }
 
-func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
+func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
        respChan := make(chan DispatchedResponse, 1)
        t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
                respChan <- dr
@@ -120,8 +162,11 @@ func (cl *Client) request(ctx context.Context, action Action, body []byte) (resp
        case dr := <-respChan:
                if dr.Header.Action == action {
                        respBody = dr.Body
+                       addr = dr.Addr
                } else if dr.Header.Action == ActionError {
-                       err = errors.New(string(dr.Body))
+                       // I've seen "Connection ID mismatch.^@" in less and other tools, I think they're just
+                       // not handling a trailing \x00 nicely.
+                       err = fmt.Errorf("error response: %#q", dr.Body)
                } else {
                        err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
                }