]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/client.go
42f0d14e6a99e6a4a226a25adef43d91c4469b14
[btrtrc.git] / tracker / udp / client.go
1 package udp
2
3 import (
4         "bytes"
5         "context"
6         "encoding/binary"
7         "fmt"
8         "io"
9         "net"
10         "sync"
11         "time"
12
13         "github.com/anacrolix/dht/v2/krpc"
14 )
15
16 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
17 // connection specifics.
18 type Client struct {
19         mu           sync.Mutex
20         connId       ConnectionId
21         connIdIssued time.Time
22         Dispatcher   *Dispatcher
23         Writer       io.Writer
24 }
25
26 func (cl *Client) Announce(
27         ctx context.Context, req AnnounceRequest, opts Options,
28         // Decides whether the response body is IPv6 or IPv4, see BEP 15.
29         ipv6 func(net.Addr) bool,
30 ) (
31         respHdr AnnounceResponseHeader,
32         // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
33         peers AnnounceResponsePeers,
34         err error,
35 ) {
36         respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
37         if err != nil {
38                 return
39         }
40         r := bytes.NewBuffer(respBody)
41         err = Read(r, &respHdr)
42         if err != nil {
43                 err = fmt.Errorf("reading response header: %w", err)
44                 return
45         }
46         if ipv6(addr) {
47                 peers = &krpc.CompactIPv6NodeAddrs{}
48         } else {
49                 peers = &krpc.CompactIPv4NodeAddrs{}
50         }
51         err = peers.UnmarshalBinary(r.Bytes())
52         if err != nil {
53                 err = fmt.Errorf("reading response peers: %w", err)
54         }
55         return
56 }
57
58 // There's no way to pass options in a scrape, since we don't when the request body ends.
59 func (cl *Client) Scrape(
60         ctx context.Context, ihs []InfoHash,
61 ) (
62         out ScrapeResponse, err error,
63 ) {
64         respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
65         if err != nil {
66                 return
67         }
68         r := bytes.NewBuffer(respBody)
69         for r.Len() != 0 {
70                 var item ScrapeInfohashResult
71                 err = Read(r, &item)
72                 if err != nil {
73                         return
74                 }
75                 out = append(out, item)
76         }
77         if len(out) > len(ihs) {
78                 err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
79                 return
80         }
81         return
82 }
83
84 func (cl *Client) connect(ctx context.Context) (err error) {
85         // We could get fancier here and use RWMutex, and even fire off the connection asynchronously
86         // and provide a grace period while it resolves.
87         cl.mu.Lock()
88         defer cl.mu.Unlock()
89         if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
90                 return nil
91         }
92         respBody, _, err := cl.request(ctx, ActionConnect, nil)
93         if err != nil {
94                 return err
95         }
96         var connResp ConnectionResponse
97         err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
98         if err != nil {
99                 return
100         }
101         cl.connId = connResp.ConnectionId
102         cl.connIdIssued = time.Now()
103         return
104 }
105
106 func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
107         if action == ActionConnect {
108                 id = ConnectRequestConnectionId
109                 return
110         }
111         err = cl.connect(ctx)
112         if err != nil {
113                 return
114         }
115         id = cl.connId
116         return
117 }
118
119 func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
120         var buf bytes.Buffer
121         for n := 0; ; n++ {
122                 var connId ConnectionId
123                 connId, err = cl.connIdForRequest(ctx, action)
124                 if err != nil {
125                         return
126                 }
127                 buf.Reset()
128                 err = Write(&buf, RequestHeader{
129                         ConnectionId:  connId,
130                         Action:        action,
131                         TransactionId: tId,
132                 })
133                 if err != nil {
134                         panic(err)
135                 }
136                 buf.Write(body)
137                 _, err = cl.Writer.Write(buf.Bytes())
138                 if err != nil {
139                         return
140                 }
141                 select {
142                 case <-ctx.Done():
143                         return ctx.Err()
144                 case <-time.After(timeout(n)):
145                 }
146         }
147 }
148
149 func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
150         respChan := make(chan DispatchedResponse, 1)
151         t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
152                 respChan <- dr
153         })
154         defer t.End()
155         ctx, cancel := context.WithCancel(ctx)
156         defer cancel()
157         writeErr := make(chan error, 1)
158         go func() {
159                 writeErr <- cl.requestWriter(ctx, action, body, t.Id())
160         }()
161         select {
162         case dr := <-respChan:
163                 if dr.Header.Action == action {
164                         respBody = dr.Body
165                         addr = dr.Addr
166                 } else if dr.Header.Action == ActionError {
167                         // I've seen "Connection ID mismatch.^@" in less and other tools, I think they're just
168                         // not handling a trailing \x00 nicely.
169                         err = fmt.Errorf("error response: %#q", dr.Body)
170                 } else {
171                         err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
172                 }
173         case err = <-writeErr:
174                 err = fmt.Errorf("write error: %w", err)
175         case <-ctx.Done():
176                 err = ctx.Err()
177         }
178         return
179 }