]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/client.go
Doc comments
[btrtrc.git] / tracker / udp / client.go
1 package udp
2
3 import (
4         "bytes"
5         "context"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "sync"
11         "time"
12 )
13
14 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
15 // connection specifics.
16 type Client struct {
17         mu           sync.Mutex
18         connId       ConnectionId
19         connIdIssued time.Time
20         Dispatcher   *Dispatcher
21         Writer       io.Writer
22 }
23
24 func (cl *Client) Announce(
25         ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
26 ) (
27         respHdr AnnounceResponseHeader, err error,
28 ) {
29         respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
30         if err != nil {
31                 return
32         }
33         r := bytes.NewBuffer(respBody)
34         err = Read(r, &respHdr)
35         if err != nil {
36                 err = fmt.Errorf("reading response header: %w", err)
37                 return
38         }
39         err = peers.UnmarshalBinary(r.Bytes())
40         if err != nil {
41                 err = fmt.Errorf("reading response peers: %w", err)
42         }
43         return
44 }
45
46 func (cl *Client) Scrape(
47         ctx context.Context, ihs []InfoHash,
48 ) (
49         out ScrapeResponse, err error,
50 ) {
51         // There's no way to pass options in a scrape, since we don't when the request body ends.
52         respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
53         if err != nil {
54                 return
55         }
56         r := bytes.NewBuffer(respBody)
57         for r.Len() != 0 {
58                 var item ScrapeInfohashResult
59                 err = Read(r, &item)
60                 if err != nil {
61                         return
62                 }
63                 out = append(out, item)
64         }
65         if len(out) > len(ihs) {
66                 err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
67                 return
68         }
69         return
70 }
71
72 func (cl *Client) connect(ctx context.Context) (err error) {
73         // We could get fancier here and use RWMutex, and even fire off the connection asynchronously
74         // and provide a grace period while it resolves.
75         cl.mu.Lock()
76         defer cl.mu.Unlock()
77         if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
78                 return nil
79         }
80         respBody, err := cl.request(ctx, ActionConnect, nil)
81         if err != nil {
82                 return err
83         }
84         var connResp ConnectionResponse
85         err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
86         if err != nil {
87                 return
88         }
89         cl.connId = connResp.ConnectionId
90         cl.connIdIssued = time.Now()
91         return
92 }
93
94 func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
95         if action == ActionConnect {
96                 id = ConnectRequestConnectionId
97                 return
98         }
99         err = cl.connect(ctx)
100         if err != nil {
101                 return
102         }
103         id = cl.connId
104         return
105 }
106
107 func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
108         var buf bytes.Buffer
109         for n := 0; ; n++ {
110                 var connId ConnectionId
111                 connId, err = cl.connIdForRequest(ctx, action)
112                 if err != nil {
113                         return
114                 }
115                 buf.Reset()
116                 err = binary.Write(&buf, binary.BigEndian, RequestHeader{
117                         ConnectionId:  connId,
118                         Action:        action,
119                         TransactionId: tId,
120                 })
121                 if err != nil {
122                         panic(err)
123                 }
124                 buf.Write(body)
125                 _, err = cl.Writer.Write(buf.Bytes())
126                 if err != nil {
127                         return
128                 }
129                 select {
130                 case <-ctx.Done():
131                         return ctx.Err()
132                 case <-time.After(timeout(n)):
133                 }
134         }
135 }
136
137 func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
138         respChan := make(chan DispatchedResponse, 1)
139         t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
140                 respChan <- dr
141         })
142         defer t.End()
143         ctx, cancel := context.WithCancel(ctx)
144         defer cancel()
145         writeErr := make(chan error, 1)
146         go func() {
147                 writeErr <- cl.requestWriter(ctx, action, body, t.Id())
148         }()
149         select {
150         case dr := <-respChan:
151                 if dr.Header.Action == action {
152                         respBody = dr.Body
153                 } else if dr.Header.Action == ActionError {
154                         err = errors.New(string(dr.Body))
155                 } else {
156                         err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
157                 }
158         case err = <-writeErr:
159                 err = fmt.Errorf("write error: %w", err)
160         case <-ctx.Done():
161                 err = ctx.Err()
162         }
163         return
164 }