]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/client.go
Don't reconnect before sending requests with current conn ID
[btrtrc.git] / tracker / udp / client.go
1 package udp
2
3 import (
4         "bytes"
5         "context"
6         "encoding/binary"
7         "fmt"
8         "io"
9         "net"
10         "sync"
11         "time"
12
13         "github.com/anacrolix/dht/v2/krpc"
14 )
15
16 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
17 // connection specifics.
18 type Client struct {
19         mu           sync.Mutex
20         connId       ConnectionId
21         connIdIssued time.Time
22
23         shouldReconnectOverride func() bool
24
25         Dispatcher *Dispatcher
26         Writer     io.Writer
27 }
28
29 func (cl *Client) Announce(
30         ctx context.Context, req AnnounceRequest, opts Options,
31         // Decides whether the response body is IPv6 or IPv4, see BEP 15.
32         ipv6 func(net.Addr) bool,
33 ) (
34         respHdr AnnounceResponseHeader,
35         // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
36         peers AnnounceResponsePeers,
37         err error,
38 ) {
39         respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
40         if err != nil {
41                 return
42         }
43         r := bytes.NewBuffer(respBody)
44         err = Read(r, &respHdr)
45         if err != nil {
46                 err = fmt.Errorf("reading response header: %w", err)
47                 return
48         }
49         if ipv6(addr) {
50                 peers = &krpc.CompactIPv6NodeAddrs{}
51         } else {
52                 peers = &krpc.CompactIPv4NodeAddrs{}
53         }
54         err = peers.UnmarshalBinary(r.Bytes())
55         if err != nil {
56                 err = fmt.Errorf("reading response peers: %w", err)
57         }
58         return
59 }
60
61 // There's no way to pass options in a scrape, since we don't when the request body ends.
62 func (cl *Client) Scrape(
63         ctx context.Context, ihs []InfoHash,
64 ) (
65         out ScrapeResponse, err error,
66 ) {
67         respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
68         if err != nil {
69                 return
70         }
71         r := bytes.NewBuffer(respBody)
72         for r.Len() != 0 {
73                 var item ScrapeInfohashResult
74                 err = Read(r, &item)
75                 if err != nil {
76                         return
77                 }
78                 out = append(out, item)
79         }
80         if len(out) > len(ihs) {
81                 err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
82                 return
83         }
84         return
85 }
86
87 func (cl *Client) shouldReconnectDefault() bool {
88         return cl.connIdIssued.IsZero() || time.Since(cl.connIdIssued) >= time.Minute
89 }
90
91 func (cl *Client) shouldReconnect() bool {
92         if cl.shouldReconnectOverride != nil {
93                 return cl.shouldReconnectOverride()
94         }
95         return cl.shouldReconnectDefault()
96 }
97
98 func (cl *Client) connect(ctx context.Context) (err error) {
99         if !cl.shouldReconnect() {
100                 return nil
101         }
102         return cl.doConnectRoundTrip(ctx)
103 }
104
105 // This just does the connect request and updates local state if it succeeds.
106 func (cl *Client) doConnectRoundTrip(ctx context.Context) (err error) {
107         respBody, _, err := cl.request(ctx, ActionConnect, nil)
108         if err != nil {
109                 return err
110         }
111         var connResp ConnectionResponse
112         err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
113         if err != nil {
114                 return
115         }
116         cl.connId = connResp.ConnectionId
117         cl.connIdIssued = time.Now()
118         //log.Printf("conn id set to %x", cl.connId)
119         return
120 }
121
122 func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
123         if action == ActionConnect {
124                 id = ConnectRequestConnectionId
125                 return
126         }
127         err = cl.connect(ctx)
128         if err != nil {
129                 return
130         }
131         id = cl.connId
132         return
133 }
134
135 func (cl *Client) writeRequest(
136         ctx context.Context, action Action, body []byte, tId TransactionId, buf *bytes.Buffer,
137 ) (
138         err error,
139 ) {
140         var connId ConnectionId
141         if action == ActionConnect {
142                 connId = ConnectRequestConnectionId
143         } else {
144                 // We lock here while establishing a connection ID, and then ensuring that the request is
145                 // written before allowing the connection ID to change again. This is to ensure the server
146                 // doesn't assign us another ID before we've sent this request. Note that this doesn't allow
147                 // for us to return if the context is cancelled while we wait to obtain a new ID.
148                 cl.mu.Lock()
149                 defer cl.mu.Unlock()
150                 connId, err = cl.connIdForRequest(ctx, action)
151                 if err != nil {
152                         return
153                 }
154         }
155         buf.Reset()
156         err = Write(buf, RequestHeader{
157                 ConnectionId:  connId,
158                 Action:        action,
159                 TransactionId: tId,
160         })
161         if err != nil {
162                 panic(err)
163         }
164         buf.Write(body)
165         _, err = cl.Writer.Write(buf.Bytes())
166         //log.Printf("sent request with conn id %x", connId)
167         return
168 }
169
170 func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
171         var buf bytes.Buffer
172         for n := 0; ; n++ {
173                 err = cl.writeRequest(ctx, action, body, tId, &buf)
174                 if err != nil {
175                         return
176                 }
177                 select {
178                 case <-ctx.Done():
179                         return ctx.Err()
180                 case <-time.After(timeout(n)):
181                 }
182         }
183 }
184
185 func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
186         respChan := make(chan DispatchedResponse, 1)
187         t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
188                 respChan <- dr
189         })
190         defer t.End()
191         ctx, cancel := context.WithCancel(ctx)
192         defer cancel()
193         writeErr := make(chan error, 1)
194         go func() {
195                 writeErr <- cl.requestWriter(ctx, action, body, t.Id())
196         }()
197         select {
198         case dr := <-respChan:
199                 if dr.Header.Action == action {
200                         respBody = dr.Body
201                         addr = dr.Addr
202                 } else if dr.Header.Action == ActionError {
203                         // I've seen "Connection ID mismatch.^@" in less and other tools, I think they're just
204                         // not handling a trailing \x00 nicely.
205                         err = fmt.Errorf("error response: %#q", dr.Body)
206                 } else {
207                         err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
208                 }
209         case err = <-writeErr:
210                 err = fmt.Errorf("write error: %w", err)
211         case <-ctx.Done():
212                 err = ctx.Err()
213         }
214         return
215 }