"bytes"
"context"
"encoding/binary"
- "errors"
"fmt"
"io"
+ "net"
+ "sync"
"time"
+
+ "github.com/anacrolix/dht/v2/krpc"
)
+// Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
+// connection specifics.
type Client struct {
+ mu sync.Mutex
connId ConnectionId
connIdIssued time.Time
Dispatcher *Dispatcher
}
func (cl *Client) Announce(
- ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
+ ctx context.Context, req AnnounceRequest, opts Options,
+ // Decides whether the response body is IPv6 or IPv4, see BEP 15.
+ ipv6 func(net.Addr) bool,
) (
- respHdr AnnounceResponseHeader, err error,
+ respHdr AnnounceResponseHeader,
+ // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
+ peers AnnounceResponsePeers,
+ err error,
) {
- body, err := marshal(req)
- if err != nil {
- return
- }
- respBody, err := cl.request(ctx, ActionAnnounce, append(body, opts.Encode()...))
+ respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
if err != nil {
return
}
err = fmt.Errorf("reading response header: %w", err)
return
}
+ if ipv6(addr) {
+ peers = &krpc.CompactIPv6NodeAddrs{}
+ } else {
+ peers = &krpc.CompactIPv4NodeAddrs{}
+ }
err = peers.UnmarshalBinary(r.Bytes())
if err != nil {
err = fmt.Errorf("reading response peers: %w", err)
return
}
+// There's no way to pass options in a scrape, since we don't when the request body ends.
+func (cl *Client) Scrape(
+ ctx context.Context, ihs []InfoHash,
+) (
+ out ScrapeResponse, err error,
+) {
+ respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
+ if err != nil {
+ return
+ }
+ r := bytes.NewBuffer(respBody)
+ for r.Len() != 0 {
+ var item ScrapeInfohashResult
+ err = Read(r, &item)
+ if err != nil {
+ return
+ }
+ out = append(out, item)
+ }
+ if len(out) > len(ihs) {
+ err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
+ return
+ }
+ return
+}
+
func (cl *Client) connect(ctx context.Context) (err error) {
- if time.Since(cl.connIdIssued) < time.Minute {
+ // We could get fancier here and use RWMutex, and even fire off the connection asynchronously
+ // and provide a grace period while it resolves.
+ cl.mu.Lock()
+ defer cl.mu.Unlock()
+ if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
return nil
}
- respBody, err := cl.request(ctx, ActionConnect, nil)
+ respBody, _, err := cl.request(ctx, ActionConnect, nil)
if err != nil {
return err
}
}
}
-func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
+func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
respChan := make(chan DispatchedResponse, 1)
t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
respChan <- dr
case dr := <-respChan:
if dr.Header.Action == action {
respBody = dr.Body
+ addr = dr.Addr
} else if dr.Header.Action == ActionError {
- err = errors.New(string(dr.Body))
+ // I've seen "Connection ID mismatch.^@" in less and other tools, I think they're just
+ // not handling a trailing \x00 nicely.
+ err = fmt.Errorf("error response: %#q", dr.Body)
} else {
err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
}