]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/client.go
Don't dial in UDP tracking
[btrtrc.git] / tracker / udp / client.go
1 package udp
2
3 import (
4         "bytes"
5         "context"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "net"
11         "sync"
12         "time"
13
14         "github.com/anacrolix/dht/v2/krpc"
15 )
16
17 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
18 // connection specifics.
19 type Client struct {
20         mu           sync.Mutex
21         connId       ConnectionId
22         connIdIssued time.Time
23         Dispatcher   *Dispatcher
24         Writer       io.Writer
25 }
26
27 func (cl *Client) Announce(
28         ctx context.Context, req AnnounceRequest, opts Options,
29         // Decides whether the response body is IPv6 or IPv4, see BEP 15.
30         ipv6 func(net.Addr) bool,
31 ) (
32         respHdr AnnounceResponseHeader,
33         // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
34         peers AnnounceResponsePeers,
35         err error,
36 ) {
37         respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
38         if err != nil {
39                 return
40         }
41         r := bytes.NewBuffer(respBody)
42         err = Read(r, &respHdr)
43         if err != nil {
44                 err = fmt.Errorf("reading response header: %w", err)
45                 return
46         }
47         if ipv6(addr) {
48                 peers = &krpc.CompactIPv6NodeAddrs{}
49         } else {
50                 peers = &krpc.CompactIPv4NodeAddrs{}
51         }
52         err = peers.UnmarshalBinary(r.Bytes())
53         if err != nil {
54                 err = fmt.Errorf("reading response peers: %w", err)
55         }
56         return
57 }
58
59 // There's no way to pass options in a scrape, since we don't when the request body ends.
60 func (cl *Client) Scrape(
61         ctx context.Context, ihs []InfoHash,
62 ) (
63         out ScrapeResponse, err error,
64 ) {
65         respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
66         if err != nil {
67                 return
68         }
69         r := bytes.NewBuffer(respBody)
70         for r.Len() != 0 {
71                 var item ScrapeInfohashResult
72                 err = Read(r, &item)
73                 if err != nil {
74                         return
75                 }
76                 out = append(out, item)
77         }
78         if len(out) > len(ihs) {
79                 err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
80                 return
81         }
82         return
83 }
84
85 func (cl *Client) connect(ctx context.Context) (err error) {
86         // We could get fancier here and use RWMutex, and even fire off the connection asynchronously
87         // and provide a grace period while it resolves.
88         cl.mu.Lock()
89         defer cl.mu.Unlock()
90         if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
91                 return nil
92         }
93         respBody, _, err := cl.request(ctx, ActionConnect, nil)
94         if err != nil {
95                 return err
96         }
97         var connResp ConnectionResponse
98         err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
99         if err != nil {
100                 return
101         }
102         cl.connId = connResp.ConnectionId
103         cl.connIdIssued = time.Now()
104         return
105 }
106
107 func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
108         if action == ActionConnect {
109                 id = ConnectRequestConnectionId
110                 return
111         }
112         err = cl.connect(ctx)
113         if err != nil {
114                 return
115         }
116         id = cl.connId
117         return
118 }
119
120 func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
121         var buf bytes.Buffer
122         for n := 0; ; n++ {
123                 var connId ConnectionId
124                 connId, err = cl.connIdForRequest(ctx, action)
125                 if err != nil {
126                         return
127                 }
128                 buf.Reset()
129                 err = binary.Write(&buf, binary.BigEndian, RequestHeader{
130                         ConnectionId:  connId,
131                         Action:        action,
132                         TransactionId: tId,
133                 })
134                 if err != nil {
135                         panic(err)
136                 }
137                 buf.Write(body)
138                 _, err = cl.Writer.Write(buf.Bytes())
139                 if err != nil {
140                         return
141                 }
142                 select {
143                 case <-ctx.Done():
144                         return ctx.Err()
145                 case <-time.After(timeout(n)):
146                 }
147         }
148 }
149
150 func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
151         respChan := make(chan DispatchedResponse, 1)
152         t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
153                 respChan <- dr
154         })
155         defer t.End()
156         ctx, cancel := context.WithCancel(ctx)
157         defer cancel()
158         writeErr := make(chan error, 1)
159         go func() {
160                 writeErr <- cl.requestWriter(ctx, action, body, t.Id())
161         }()
162         select {
163         case dr := <-respChan:
164                 if dr.Header.Action == action {
165                         respBody = dr.Body
166                         addr = dr.Addr
167                 } else if dr.Header.Action == ActionError {
168                         err = errors.New(string(dr.Body))
169                 } else {
170                         err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
171                 }
172         case err = <-writeErr:
173                 err = fmt.Errorf("write error: %w", err)
174         case <-ctx.Done():
175                 err = ctx.Err()
176         }
177         return
178 }