]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/client.go
Check that udp conn ID age is non-zero
[btrtrc.git] / tracker / udp / client.go
1 package udp
2
3 import (
4         "bytes"
5         "context"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "time"
11 )
12
13 type Client struct {
14         connId       ConnectionId
15         connIdIssued time.Time
16         Dispatcher   *Dispatcher
17         Writer       io.Writer
18 }
19
20 func (cl *Client) Announce(
21         ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
22 ) (
23         respHdr AnnounceResponseHeader, err error,
24 ) {
25         respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
26         if err != nil {
27                 return
28         }
29         r := bytes.NewBuffer(respBody)
30         err = Read(r, &respHdr)
31         if err != nil {
32                 err = fmt.Errorf("reading response header: %w", err)
33                 return
34         }
35         err = peers.UnmarshalBinary(r.Bytes())
36         if err != nil {
37                 err = fmt.Errorf("reading response peers: %w", err)
38         }
39         return
40 }
41
42 func (cl *Client) Scrape(
43         ctx context.Context, ihs []InfoHash,
44 ) (
45         out ScrapeResponse, err error,
46 ) {
47         // There's no way to pass options in a scrape, since we don't when the request body ends.
48         respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
49         if err != nil {
50                 return
51         }
52         r := bytes.NewBuffer(respBody)
53         for r.Len() != 0 {
54                 var item ScrapeInfohashResult
55                 err = Read(r, &item)
56                 if err != nil {
57                         return
58                 }
59                 out = append(out, item)
60         }
61         if len(out) > len(ihs) {
62                 err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
63                 return
64         }
65         return
66 }
67
68 func (cl *Client) connect(ctx context.Context) (err error) {
69         if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
70                 return nil
71         }
72         respBody, err := cl.request(ctx, ActionConnect, nil)
73         if err != nil {
74                 return err
75         }
76         var connResp ConnectionResponse
77         err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
78         if err != nil {
79                 return
80         }
81         cl.connId = connResp.ConnectionId
82         cl.connIdIssued = time.Now()
83         return
84 }
85
86 func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
87         if action == ActionConnect {
88                 id = ConnectRequestConnectionId
89                 return
90         }
91         err = cl.connect(ctx)
92         if err != nil {
93                 return
94         }
95         id = cl.connId
96         return
97 }
98
99 func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
100         var buf bytes.Buffer
101         for n := 0; ; n++ {
102                 var connId ConnectionId
103                 connId, err = cl.connIdForRequest(ctx, action)
104                 if err != nil {
105                         return
106                 }
107                 buf.Reset()
108                 err = binary.Write(&buf, binary.BigEndian, RequestHeader{
109                         ConnectionId:  connId,
110                         Action:        action,
111                         TransactionId: tId,
112                 })
113                 if err != nil {
114                         panic(err)
115                 }
116                 buf.Write(body)
117                 _, err = cl.Writer.Write(buf.Bytes())
118                 if err != nil {
119                         return
120                 }
121                 select {
122                 case <-ctx.Done():
123                         return ctx.Err()
124                 case <-time.After(timeout(n)):
125                 }
126         }
127 }
128
129 func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
130         respChan := make(chan DispatchedResponse, 1)
131         t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
132                 respChan <- dr
133         })
134         defer t.End()
135         ctx, cancel := context.WithCancel(ctx)
136         defer cancel()
137         writeErr := make(chan error, 1)
138         go func() {
139                 writeErr <- cl.requestWriter(ctx, action, body, t.Id())
140         }()
141         select {
142         case dr := <-respChan:
143                 if dr.Header.Action == action {
144                         respBody = dr.Body
145                 } else if dr.Header.Action == ActionError {
146                         err = errors.New(string(dr.Body))
147                 } else {
148                         err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
149                 }
150         case err = <-writeErr:
151                 err = fmt.Errorf("write error: %w", err)
152         case <-ctx.Done():
153                 err = ctx.Err()
154         }
155         return
156 }