]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker-client.go
Remove relevant webtorrent offers when closing Torrent
[btrtrc.git] / webtorrent / tracker-client.go
1 package webtorrent
2
3 import (
4         "crypto/rand"
5         "encoding/json"
6         "fmt"
7         "sync"
8         "time"
9
10         "github.com/anacrolix/log"
11
12         "github.com/anacrolix/torrent/tracker"
13         "github.com/gorilla/websocket"
14         "github.com/pion/datachannel"
15         "github.com/pion/webrtc/v3"
16 )
17
18 type TrackerClientStats struct {
19         Dials                  int64
20         ConvertedInboundConns  int64
21         ConvertedOutboundConns int64
22 }
23
24 // Client represents the webtorrent client
25 type TrackerClient struct {
26         Url                string
27         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
28         PeerId             [20]byte
29         OnConn             onDataChannelOpen
30         Logger             log.Logger
31         Dialer             *websocket.Dialer
32
33         mu             sync.Mutex
34         cond           sync.Cond
35         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
36         wsConn         *websocket.Conn
37         closed         bool
38         stats          TrackerClientStats
39         pingTicker     *time.Ticker
40 }
41
42 func (me *TrackerClient) Stats() TrackerClientStats {
43         me.mu.Lock()
44         defer me.mu.Unlock()
45         return me.stats
46 }
47
48 func (me *TrackerClient) peerIdBinary() string {
49         return binaryToJsonString(me.PeerId[:])
50 }
51
52 // outboundOffer represents an outstanding offer.
53 type outboundOffer struct {
54         originalOffer  webrtc.SessionDescription
55         peerConnection *wrappedPeerConnection
56         dataChannel    *webrtc.DataChannel
57         infoHash       [20]byte
58 }
59
60 type DataChannelContext struct {
61         // Can these be obtained by just calling the relevant methods on peerConnection?
62         Local, Remote webrtc.SessionDescription
63         OfferId       string
64         LocalOffered  bool
65         InfoHash      [20]byte
66         // This is private as some methods might not be appropriate with data channel context.
67         peerConnection *wrappedPeerConnection
68 }
69
70 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
71         return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
72 }
73
74 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
75
76 func (tc *TrackerClient) doWebsocket() error {
77         metrics.Add("websocket dials", 1)
78         tc.mu.Lock()
79         tc.stats.Dials++
80         tc.mu.Unlock()
81         c, _, err := tc.Dialer.Dial(tc.Url, nil)
82         if err != nil {
83                 return fmt.Errorf("dialing tracker: %w", err)
84         }
85         defer c.Close()
86         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
87         tc.mu.Lock()
88         tc.wsConn = c
89         tc.cond.Broadcast()
90         tc.mu.Unlock()
91         tc.announceOffers()
92         closeChan := make(chan struct{})
93         go func() {
94                 for {
95                         select {
96                         case <-tc.pingTicker.C:
97                                 tc.mu.Lock()
98                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
99                                 tc.mu.Unlock()
100                                 if err != nil {
101                                         return
102                                 }
103                         case <-closeChan:
104                                 return
105
106                         }
107                 }
108         }()
109         err = tc.trackerReadLoop(tc.wsConn)
110         close(closeChan)
111         tc.mu.Lock()
112         c.Close()
113         tc.mu.Unlock()
114         return err
115 }
116
117 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
118 // result. We don't let the caller just spawn the runner directly, since then we can race against
119 // .Close to finish initialization.
120 func (tc *TrackerClient) Start(onStop func(error)) {
121         tc.pingTicker = time.NewTicker(60 * time.Second)
122         tc.cond.L = &tc.mu
123         go func() {
124                 onStop(tc.run())
125         }()
126 }
127
128 func (tc *TrackerClient) run() error {
129         tc.mu.Lock()
130         for !tc.closed {
131                 tc.mu.Unlock()
132                 err := tc.doWebsocket()
133                 level := log.Info
134                 tc.mu.Lock()
135                 if tc.closed {
136                         level = log.Debug
137                 }
138                 tc.mu.Unlock()
139                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
140                 time.Sleep(time.Minute)
141                 tc.mu.Lock()
142         }
143         tc.mu.Unlock()
144         return nil
145 }
146
147 func (tc *TrackerClient) Close() error {
148         tc.mu.Lock()
149         tc.closed = true
150         if tc.wsConn != nil {
151                 tc.wsConn.Close()
152         }
153         tc.closeUnusedOffers()
154         tc.pingTicker.Stop()
155         tc.mu.Unlock()
156         tc.cond.Broadcast()
157         return nil
158 }
159
160 func (tc *TrackerClient) announceOffers() {
161         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
162         // is nil. Take ownership of outboundOffers here.
163         tc.mu.Lock()
164         offers := tc.outboundOffers
165         tc.outboundOffers = nil
166         tc.mu.Unlock()
167
168         if offers == nil {
169                 return
170         }
171
172         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
173         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
174         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
175         for _, offer := range offers {
176                 // TODO: Capture the errors? Are we even in a position to do anything with them?
177                 offer.peerConnection.Close()
178                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
179                 go tc.Announce(tracker.Started, offer.infoHash)
180         }
181 }
182
183 func (tc *TrackerClient) closeUnusedOffers() {
184         for _, offer := range tc.outboundOffers {
185                 offer.peerConnection.Close()
186         }
187         tc.outboundOffers = nil
188 }
189
190 func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
191         tc.mu.Lock()
192         defer tc.mu.Unlock()
193         for key, offer := range tc.outboundOffers {
194                 if offer.infoHash == infoHash {
195                         offer.peerConnection.Close()
196                         delete(tc.outboundOffers, key)
197                 }
198         }
199 }
200
201 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
202         metrics.Add("outbound announces", 1)
203         var randOfferId [20]byte
204         _, err := rand.Read(randOfferId[:])
205         if err != nil {
206                 return fmt.Errorf("generating offer_id bytes: %w", err)
207         }
208         offerIDBinary := binaryToJsonString(randOfferId[:])
209
210         pc, dc, offer, err := newOffer()
211         if err != nil {
212                 return fmt.Errorf("creating offer: %w", err)
213         }
214
215         request, err := tc.GetAnnounceRequest(event, infoHash)
216         if err != nil {
217                 pc.Close()
218                 return fmt.Errorf("getting announce parameters: %w", err)
219         }
220
221         req := AnnounceRequest{
222                 Numwant:    1, // If higher we need to create equal amount of offers.
223                 Uploaded:   request.Uploaded,
224                 Downloaded: request.Downloaded,
225                 Left:       request.Left,
226                 Event:      request.Event.String(),
227                 Action:     "announce",
228                 InfoHash:   binaryToJsonString(infoHash[:]),
229                 PeerID:     tc.peerIdBinary(),
230                 Offers: []Offer{{
231                         OfferID: offerIDBinary,
232                         Offer:   offer,
233                 }},
234         }
235
236         data, err := json.Marshal(req)
237         if err != nil {
238                 pc.Close()
239                 return fmt.Errorf("marshalling request: %w", err)
240         }
241
242         tc.mu.Lock()
243         defer tc.mu.Unlock()
244         err = tc.writeMessage(data)
245         if err != nil {
246                 pc.Close()
247                 return fmt.Errorf("write AnnounceRequest: %w", err)
248         }
249         if tc.outboundOffers == nil {
250                 tc.outboundOffers = make(map[string]outboundOffer)
251         }
252         tc.outboundOffers[offerIDBinary] = outboundOffer{
253                 peerConnection: pc,
254                 dataChannel:    dc,
255                 originalOffer:  offer,
256                 infoHash:       infoHash,
257         }
258         return nil
259 }
260
261 func (tc *TrackerClient) writeMessage(data []byte) error {
262         for tc.wsConn == nil {
263                 if tc.closed {
264                         return fmt.Errorf("%T closed", tc)
265                 }
266                 tc.cond.Wait()
267         }
268         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
269 }
270
271 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
272         for {
273                 _, message, err := tracker.ReadMessage()
274                 if err != nil {
275                         return fmt.Errorf("read message error: %w", err)
276                 }
277                 // tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
278
279                 var ar AnnounceResponse
280                 if err := json.Unmarshal(message, &ar); err != nil {
281                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
282                         continue
283                 }
284                 switch {
285                 case ar.Offer != nil:
286                         ih, err := jsonStringToInfoHash(ar.InfoHash)
287                         if err != nil {
288                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
289                                 break
290                         }
291                         tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
292                 case ar.Answer != nil:
293                         tc.handleAnswer(ar.OfferID, *ar.Answer)
294                 }
295         }
296 }
297
298 func (tc *TrackerClient) handleOffer(
299         offer webrtc.SessionDescription,
300         offerId string,
301         infoHash [20]byte,
302         peerId string,
303 ) error {
304         peerConnection, answer, err := newAnsweringPeerConnection(offer)
305         if err != nil {
306                 return fmt.Errorf("write AnnounceResponse: %w", err)
307         }
308         response := AnnounceResponse{
309                 Action:   "announce",
310                 InfoHash: binaryToJsonString(infoHash[:]),
311                 PeerID:   tc.peerIdBinary(),
312                 ToPeerID: peerId,
313                 Answer:   &answer,
314                 OfferID:  offerId,
315         }
316         data, err := json.Marshal(response)
317         if err != nil {
318                 peerConnection.Close()
319                 return fmt.Errorf("marshalling response: %w", err)
320         }
321         tc.mu.Lock()
322         defer tc.mu.Unlock()
323         if err := tc.writeMessage(data); err != nil {
324                 peerConnection.Close()
325                 return fmt.Errorf("writing response: %w", err)
326         }
327         timer := time.AfterFunc(30*time.Second, func() {
328                 metrics.Add("answering peer connections timed out", 1)
329                 peerConnection.Close()
330         })
331         peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
332                 setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
333                         timer.Stop()
334                         metrics.Add("answering peer connection conversions", 1)
335                         tc.mu.Lock()
336                         tc.stats.ConvertedInboundConns++
337                         tc.mu.Unlock()
338                         tc.OnConn(dc, DataChannelContext{
339                                 Local:          answer,
340                                 Remote:         offer,
341                                 OfferId:        offerId,
342                                 LocalOffered:   false,
343                                 InfoHash:       infoHash,
344                                 peerConnection: peerConnection,
345                         })
346                 })
347         })
348         return nil
349 }
350
351 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
352         tc.mu.Lock()
353         defer tc.mu.Unlock()
354         offer, ok := tc.outboundOffers[offerId]
355         if !ok {
356                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
357                 return
358         }
359         // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
360         metrics.Add("outbound offers answered", 1)
361         err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
362                 metrics.Add("outbound offers answered with datachannel", 1)
363                 tc.mu.Lock()
364                 tc.stats.ConvertedOutboundConns++
365                 tc.mu.Unlock()
366                 tc.OnConn(dc, DataChannelContext{
367                         Local:          offer.originalOffer,
368                         Remote:         answer,
369                         OfferId:        offerId,
370                         LocalOffered:   true,
371                         InfoHash:       offer.infoHash,
372                         peerConnection: offer.peerConnection,
373                 })
374         })
375         if err != nil {
376                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err)
377                 return
378         }
379         delete(tc.outboundOffers, offerId)
380         go tc.Announce(tracker.None, offer.infoHash)
381 }