]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker_client.go
fixed code quality issues using DeepSource
[btrtrc.git] / webtorrent / tracker_client.go
1 package webtorrent
2
3 import (
4         "crypto/rand"
5         "encoding/json"
6         "fmt"
7         "sync"
8         "time"
9
10         "github.com/anacrolix/log"
11
12         "github.com/anacrolix/torrent/tracker"
13         "github.com/gorilla/websocket"
14         "github.com/pion/datachannel"
15         "github.com/pion/webrtc/v2"
16 )
17
18 type TrackerClientStats struct {
19         Dials                  int64
20         ConvertedInboundConns  int64
21         ConvertedOutboundConns int64
22 }
23
24 // Client represents the webtorrent client
25 type TrackerClient struct {
26         Url                string
27         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
28         PeerId             [20]byte
29         OnConn             onDataChannelOpen
30         Logger             log.Logger
31
32         mu             sync.Mutex
33         cond           sync.Cond
34         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
35         wsConn         *websocket.Conn
36         closed         bool
37         stats          TrackerClientStats
38         pingTicker     *time.Ticker
39 }
40
41 func (me *TrackerClient) Stats() TrackerClientStats {
42         me.mu.Lock()
43         defer me.mu.Unlock()
44         return me.stats
45 }
46
47 func (me *TrackerClient) peerIdBinary() string {
48         return binaryToJsonString(me.PeerId[:])
49 }
50
51 // outboundOffer represents an outstanding offer.
52 type outboundOffer struct {
53         originalOffer  webrtc.SessionDescription
54         peerConnection *wrappedPeerConnection
55         dataChannel    *webrtc.DataChannel
56         infoHash       [20]byte
57 }
58
59 type DataChannelContext struct {
60         Local, Remote webrtc.SessionDescription
61         OfferId       string
62         LocalOffered  bool
63         InfoHash      [20]byte
64 }
65
66 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
67
68 func (tc *TrackerClient) doWebsocket() error {
69         metrics.Add("websocket dials", 1)
70         tc.stats.Dials++
71         c, _, err := websocket.DefaultDialer.Dial(tc.Url, nil)
72         if err != nil {
73                 return fmt.Errorf("dialing tracker: %w", err)
74         }
75         defer c.Close()
76         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
77         tc.mu.Lock()
78         tc.wsConn = c
79         tc.cond.Broadcast()
80         tc.mu.Unlock()
81         tc.announceOffers()
82         closeChan := make(chan struct{})
83         go func() {
84                 for {
85                         select {
86                         case <-tc.pingTicker.C:
87                                 tc.mu.Lock()
88                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
89                                 tc.mu.Unlock()
90                                 if err != nil {
91                                         return
92                                 }
93                         case <-closeChan:
94                                 return
95
96                         }
97                 }
98         }()
99         err = tc.trackerReadLoop(tc.wsConn)
100         close(closeChan)
101         tc.mu.Lock()
102         c.Close()
103         tc.mu.Unlock()
104         return err
105 }
106
107 func (tc *TrackerClient) Run() error {
108         tc.pingTicker = time.NewTicker(60 * time.Second)
109         tc.cond.L = &tc.mu
110         tc.mu.Lock()
111         for !tc.closed {
112                 tc.mu.Unlock()
113                 err := tc.doWebsocket()
114                 level := log.Info
115                 tc.mu.Lock()
116                 if tc.closed {
117                         level = log.Debug
118                 }
119                 tc.mu.Unlock()
120                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
121                 time.Sleep(time.Minute)
122                 tc.mu.Lock()
123         }
124         tc.mu.Unlock()
125         return nil
126 }
127
128 func (tc *TrackerClient) Close() error {
129         tc.mu.Lock()
130         tc.closed = true
131         if tc.wsConn != nil {
132                 tc.wsConn.Close()
133         }
134         tc.closeUnusedOffers()
135         tc.pingTicker.Stop()
136         tc.mu.Unlock()
137         tc.cond.Broadcast()
138         return nil
139 }
140
141 func (tc *TrackerClient) announceOffers() {
142
143         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
144         // is nil. Take ownership of outboundOffers here.
145         tc.mu.Lock()
146         offers := tc.outboundOffers
147         tc.outboundOffers = nil
148         tc.mu.Unlock()
149
150         if offers == nil {
151                 return
152         }
153
154         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
155         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
156         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
157         for _, offer := range offers {
158                 // TODO: Capture the errors? Are we even in a position to do anything with them?
159                 offer.peerConnection.Close()
160                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
161                 go tc.Announce(tracker.Started, offer.infoHash)
162         }
163 }
164
165 func (tc *TrackerClient) closeUnusedOffers() {
166         for _, offer := range tc.outboundOffers {
167                 offer.peerConnection.Close()
168         }
169         tc.outboundOffers = nil
170 }
171
172 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
173         metrics.Add("outbound announces", 1)
174         var randOfferId [20]byte
175         _, err := rand.Read(randOfferId[:])
176         if err != nil {
177                 return fmt.Errorf("generating offer_id bytes: %w", err)
178         }
179         offerIDBinary := binaryToJsonString(randOfferId[:])
180
181         pc, dc, offer, err := newOffer()
182         if err != nil {
183                 return fmt.Errorf("creating offer: %w", err)
184         }
185
186         request, err := tc.GetAnnounceRequest(event, infoHash)
187         if err != nil {
188                 pc.Close()
189                 return fmt.Errorf("getting announce parameters: %w", err)
190         }
191
192         req := AnnounceRequest{
193                 Numwant:    1, // If higher we need to create equal amount of offers.
194                 Uploaded:   request.Uploaded,
195                 Downloaded: request.Downloaded,
196                 Left:       request.Left,
197                 Event:      request.Event.String(),
198                 Action:     "announce",
199                 InfoHash:   binaryToJsonString(infoHash[:]),
200                 PeerID:     tc.peerIdBinary(),
201                 Offers: []Offer{{
202                         OfferID: offerIDBinary,
203                         Offer:   offer,
204                 }},
205         }
206
207         data, err := json.Marshal(req)
208         if err != nil {
209                 pc.Close()
210                 return fmt.Errorf("marshalling request: %w", err)
211         }
212
213         tc.mu.Lock()
214         defer tc.mu.Unlock()
215         err = tc.writeMessage(data)
216         if err != nil {
217                 pc.Close()
218                 return fmt.Errorf("write AnnounceRequest: %w", err)
219         }
220         if tc.outboundOffers == nil {
221                 tc.outboundOffers = make(map[string]outboundOffer)
222         }
223         tc.outboundOffers[offerIDBinary] = outboundOffer{
224                 peerConnection: pc,
225                 dataChannel:    dc,
226                 originalOffer:  offer,
227                 infoHash:       infoHash,
228         }
229         return nil
230 }
231
232 func (tc *TrackerClient) writeMessage(data []byte) error {
233         for tc.wsConn == nil {
234                 if tc.closed {
235                         return fmt.Errorf("%T closed", tc)
236                 }
237                 tc.cond.Wait()
238         }
239         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
240 }
241
242 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
243         for {
244                 _, message, err := tracker.ReadMessage()
245                 if err != nil {
246                         return fmt.Errorf("read message error: %w", err)
247                 }
248                 //tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
249
250                 var ar AnnounceResponse
251                 if err := json.Unmarshal(message, &ar); err != nil {
252                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
253                         continue
254                 }
255                 switch {
256                 case ar.Offer != nil:
257                         ih, err := jsonStringToInfoHash(ar.InfoHash)
258                         if err != nil {
259                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
260                                 break
261                         }
262                         tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
263                 case ar.Answer != nil:
264                         tc.handleAnswer(ar.OfferID, *ar.Answer)
265                 }
266         }
267 }
268
269 func (tc *TrackerClient) handleOffer(
270         offer webrtc.SessionDescription,
271         offerId string,
272         infoHash [20]byte,
273         peerId string,
274 ) error {
275         peerConnection, answer, err := newAnsweringPeerConnection(offer)
276         if err != nil {
277                 return fmt.Errorf("write AnnounceResponse: %w", err)
278         }
279         response := AnnounceResponse{
280                 Action:   "announce",
281                 InfoHash: binaryToJsonString(infoHash[:]),
282                 PeerID:   tc.peerIdBinary(),
283                 ToPeerID: peerId,
284                 Answer:   &answer,
285                 OfferID:  offerId,
286         }
287         data, err := json.Marshal(response)
288         if err != nil {
289                 peerConnection.Close()
290                 return fmt.Errorf("marshalling response: %w", err)
291         }
292         tc.mu.Lock()
293         defer tc.mu.Unlock()
294         if err := tc.writeMessage(data); err != nil {
295                 peerConnection.Close()
296                 return fmt.Errorf("writing response: %w", err)
297         }
298         timer := time.AfterFunc(30*time.Second, func() {
299                 metrics.Add("answering peer connections timed out", 1)
300                 peerConnection.Close()
301         })
302         peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
303                 setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
304                         timer.Stop()
305                         metrics.Add("answering peer connection conversions", 1)
306                         tc.mu.Lock()
307                         tc.stats.ConvertedInboundConns++
308                         tc.mu.Unlock()
309                         tc.OnConn(dc, DataChannelContext{
310                                 Local:        answer,
311                                 Remote:       offer,
312                                 OfferId:      offerId,
313                                 LocalOffered: false,
314                                 InfoHash:     infoHash,
315                         })
316                 })
317         })
318         return nil
319 }
320
321 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
322         tc.mu.Lock()
323         defer tc.mu.Unlock()
324         offer, ok := tc.outboundOffers[offerId]
325         if !ok {
326                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %q", offerId)
327                 return
328         }
329         //tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
330         metrics.Add("outbound offers answered", 1)
331         err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
332                 metrics.Add("outbound offers answered with datachannel", 1)
333                 tc.mu.Lock()
334                 tc.stats.ConvertedOutboundConns++
335                 tc.mu.Unlock()
336                 tc.OnConn(dc, DataChannelContext{
337                         Local:        offer.originalOffer,
338                         Remote:       answer,
339                         OfferId:      offerId,
340                         LocalOffered: true,
341                         InfoHash:     offer.infoHash,
342                 })
343         })
344         if err != nil {
345                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err)
346                 return
347         }
348         delete(tc.outboundOffers, offerId)
349         go tc.Announce(tracker.None, offer.infoHash)
350 }