]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker-client.go
95d87ff45af7d3855441bc4e53efe47aba55d785
[btrtrc.git] / webtorrent / tracker-client.go
1 package webtorrent
2
3 import (
4         "context"
5         "crypto/rand"
6         "encoding/json"
7         "fmt"
8         "go.opentelemetry.io/otel/codes"
9         "go.opentelemetry.io/otel/trace"
10         "sync"
11         "time"
12
13         "github.com/anacrolix/log"
14
15         "github.com/anacrolix/torrent/tracker"
16         "github.com/gorilla/websocket"
17         "github.com/pion/datachannel"
18         "github.com/pion/webrtc/v3"
19 )
20
21 type TrackerClientStats struct {
22         Dials                  int64
23         ConvertedInboundConns  int64
24         ConvertedOutboundConns int64
25 }
26
27 // Client represents the webtorrent client
28 type TrackerClient struct {
29         Url                string
30         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
31         PeerId             [20]byte
32         OnConn             onDataChannelOpen
33         Logger             log.Logger
34         Dialer             *websocket.Dialer
35
36         mu             sync.Mutex
37         cond           sync.Cond
38         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
39         wsConn         *websocket.Conn
40         closed         bool
41         stats          TrackerClientStats
42         pingTicker     *time.Ticker
43 }
44
45 func (me *TrackerClient) Stats() TrackerClientStats {
46         me.mu.Lock()
47         defer me.mu.Unlock()
48         return me.stats
49 }
50
51 func (me *TrackerClient) peerIdBinary() string {
52         return binaryToJsonString(me.PeerId[:])
53 }
54
55 // outboundOffer represents an outstanding offer.
56 type outboundOffer struct {
57         originalOffer  webrtc.SessionDescription
58         peerConnection *wrappedPeerConnection
59         infoHash       [20]byte
60 }
61
62 type DataChannelContext struct {
63         // Can these be obtained by just calling the relevant methods on peerConnection?
64         Local, Remote webrtc.SessionDescription
65         OfferId       string
66         LocalOffered  bool
67         InfoHash      [20]byte
68         // This is private as some methods might not be appropriate with data channel context.
69         peerConnection *wrappedPeerConnection
70         span           trace.Span
71         ctx            context.Context
72 }
73
74 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
75         return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
76 }
77
78 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
79
80 func (tc *TrackerClient) doWebsocket() error {
81         metrics.Add("websocket dials", 1)
82         tc.mu.Lock()
83         tc.stats.Dials++
84         tc.mu.Unlock()
85         c, _, err := tc.Dialer.Dial(tc.Url, nil)
86         if err != nil {
87                 return fmt.Errorf("dialing tracker: %w", err)
88         }
89         defer c.Close()
90         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
91         tc.mu.Lock()
92         tc.wsConn = c
93         tc.cond.Broadcast()
94         tc.mu.Unlock()
95         tc.announceOffers()
96         closeChan := make(chan struct{})
97         go func() {
98                 for {
99                         select {
100                         case <-tc.pingTicker.C:
101                                 tc.mu.Lock()
102                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
103                                 tc.mu.Unlock()
104                                 if err != nil {
105                                         return
106                                 }
107                         case <-closeChan:
108                                 return
109
110                         }
111                 }
112         }()
113         err = tc.trackerReadLoop(tc.wsConn)
114         close(closeChan)
115         tc.mu.Lock()
116         c.Close()
117         tc.mu.Unlock()
118         return err
119 }
120
121 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
122 // result. We don't let the caller just spawn the runner directly, since then we can race against
123 // .Close to finish initialization.
124 func (tc *TrackerClient) Start(onStop func(error)) {
125         tc.pingTicker = time.NewTicker(60 * time.Second)
126         tc.cond.L = &tc.mu
127         go func() {
128                 onStop(tc.run())
129         }()
130 }
131
132 func (tc *TrackerClient) run() error {
133         tc.mu.Lock()
134         for !tc.closed {
135                 tc.mu.Unlock()
136                 err := tc.doWebsocket()
137                 level := log.Info
138                 tc.mu.Lock()
139                 if tc.closed {
140                         level = log.Debug
141                 }
142                 tc.mu.Unlock()
143                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
144                 time.Sleep(time.Minute)
145                 tc.mu.Lock()
146         }
147         tc.mu.Unlock()
148         return nil
149 }
150
151 func (tc *TrackerClient) Close() error {
152         tc.mu.Lock()
153         tc.closed = true
154         if tc.wsConn != nil {
155                 tc.wsConn.Close()
156         }
157         tc.closeUnusedOffers()
158         tc.pingTicker.Stop()
159         tc.mu.Unlock()
160         tc.cond.Broadcast()
161         return nil
162 }
163
164 func (tc *TrackerClient) announceOffers() {
165         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
166         // is nil. Take ownership of outboundOffers here.
167         tc.mu.Lock()
168         offers := tc.outboundOffers
169         tc.outboundOffers = nil
170         tc.mu.Unlock()
171
172         if offers == nil {
173                 return
174         }
175
176         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
177         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
178         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
179         for _, offer := range offers {
180                 // TODO: Capture the errors? Are we even in a position to do anything with them?
181                 offer.peerConnection.Close()
182                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
183                 go tc.Announce(tracker.Started, offer.infoHash)
184         }
185 }
186
187 func (tc *TrackerClient) closeUnusedOffers() {
188         for _, offer := range tc.outboundOffers {
189                 offer.peerConnection.Close()
190         }
191         tc.outboundOffers = nil
192 }
193
194 func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
195         tc.mu.Lock()
196         defer tc.mu.Unlock()
197         for key, offer := range tc.outboundOffers {
198                 if offer.infoHash == infoHash {
199                         offer.peerConnection.Close()
200                         delete(tc.outboundOffers, key)
201                 }
202         }
203 }
204
205 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
206         metrics.Add("outbound announces", 1)
207         var randOfferId [20]byte
208         _, err := rand.Read(randOfferId[:])
209         if err != nil {
210                 return fmt.Errorf("generating offer_id bytes: %w", err)
211         }
212         offerIDBinary := binaryToJsonString(randOfferId[:])
213
214         pc, offer, err := newOffer(tc.Logger)
215         if err != nil {
216                 return fmt.Errorf("creating offer: %w", err)
217         }
218
219         request, err := tc.GetAnnounceRequest(event, infoHash)
220         if err != nil {
221                 pc.Close()
222                 return fmt.Errorf("getting announce parameters: %w", err)
223         }
224
225         req := AnnounceRequest{
226                 Numwant:    1, // If higher we need to create equal amount of offers.
227                 Uploaded:   request.Uploaded,
228                 Downloaded: request.Downloaded,
229                 Left:       request.Left,
230                 Event:      request.Event.String(),
231                 Action:     "announce",
232                 InfoHash:   binaryToJsonString(infoHash[:]),
233                 PeerID:     tc.peerIdBinary(),
234                 Offers: []Offer{{
235                         OfferID: offerIDBinary,
236                         Offer:   offer,
237                 }},
238         }
239
240         data, err := json.Marshal(req)
241         if err != nil {
242                 pc.Close()
243                 return fmt.Errorf("marshalling request: %w", err)
244         }
245
246         tc.mu.Lock()
247         defer tc.mu.Unlock()
248         err = tc.writeMessage(data)
249         if err != nil {
250                 pc.Close()
251                 return fmt.Errorf("write AnnounceRequest: %w", err)
252         }
253         if tc.outboundOffers == nil {
254                 tc.outboundOffers = make(map[string]outboundOffer)
255         }
256         tc.outboundOffers[offerIDBinary] = outboundOffer{
257                 peerConnection: pc,
258                 originalOffer:  offer,
259                 infoHash:       infoHash,
260         }
261         return nil
262 }
263
264 func (tc *TrackerClient) writeMessage(data []byte) error {
265         for tc.wsConn == nil {
266                 if tc.closed {
267                         return fmt.Errorf("%T closed", tc)
268                 }
269                 tc.cond.Wait()
270         }
271         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
272 }
273
274 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
275         for {
276                 _, message, err := tracker.ReadMessage()
277                 if err != nil {
278                         return fmt.Errorf("read message error: %w", err)
279                 }
280                 // tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
281
282                 var ar AnnounceResponse
283                 if err := json.Unmarshal(message, &ar); err != nil {
284                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
285                         continue
286                 }
287                 switch {
288                 case ar.Offer != nil:
289                         ih, err := jsonStringToInfoHash(ar.InfoHash)
290                         if err != nil {
291                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
292                                 break
293                         }
294                         tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
295                 case ar.Answer != nil:
296                         tc.handleAnswer(ar.OfferID, *ar.Answer)
297                 }
298         }
299 }
300
301 func (tc *TrackerClient) handleOffer(
302         offer webrtc.SessionDescription,
303         offerId string,
304         infoHash [20]byte,
305         peerId string,
306 ) error {
307         peerConnection, answer, err := newAnsweringPeerConnection(tc.Logger, offer)
308         if err != nil {
309                 return fmt.Errorf("write AnnounceResponse: %w", err)
310         }
311         response := AnnounceResponse{
312                 Action:   "announce",
313                 InfoHash: binaryToJsonString(infoHash[:]),
314                 PeerID:   tc.peerIdBinary(),
315                 ToPeerID: peerId,
316                 Answer:   &answer,
317                 OfferID:  offerId,
318         }
319         data, err := json.Marshal(response)
320         if err != nil {
321                 peerConnection.Close()
322                 return fmt.Errorf("marshalling response: %w", err)
323         }
324         tc.mu.Lock()
325         defer tc.mu.Unlock()
326         if err := tc.writeMessage(data); err != nil {
327                 peerConnection.Close()
328                 return fmt.Errorf("writing response: %w", err)
329         }
330         timer := time.AfterFunc(30*time.Second, func() {
331                 peerConnection.span.SetStatus(codes.Error, "answer timeout")
332                 metrics.Add("answering peer connections timed out", 1)
333                 peerConnection.Close()
334         })
335         peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
336                 ctx, span := dataChannelStarted(peerConnection.ctx, d)
337                 setDataChannelOnOpen(ctx, d, peerConnection, func(dc datachannel.ReadWriteCloser) {
338                         timer.Stop()
339                         metrics.Add("answering peer connection conversions", 1)
340                         tc.mu.Lock()
341                         tc.stats.ConvertedInboundConns++
342                         tc.mu.Unlock()
343                         tc.OnConn(dc, DataChannelContext{
344                                 Local:          answer,
345                                 Remote:         offer,
346                                 OfferId:        offerId,
347                                 LocalOffered:   false,
348                                 InfoHash:       infoHash,
349                                 peerConnection: peerConnection,
350                                 ctx:            ctx,
351                                 span:           span,
352                         })
353                 })
354         })
355         return nil
356 }
357
358 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
359         tc.mu.Lock()
360         defer tc.mu.Unlock()
361         offer, ok := tc.outboundOffers[offerId]
362         if !ok {
363                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
364                 return
365         }
366         // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
367         metrics.Add("outbound offers answered", 1)
368         // Why do we create the data channel before setting the remote description? Are we trying to avoid the peer
369         // initiating?
370         dataChannel, err := offer.peerConnection.CreateDataChannel("webrtc-datachannel", nil)
371         if err != nil {
372                 err = fmt.Errorf("creating data channel: %w", err)
373                 tc.Logger.LevelPrint(log.Error, err)
374                 offer.peerConnection.span.RecordError(err)
375                 offer.peerConnection.Close()
376                 goto deleteOffer
377         }
378         {
379                 ctx, span := dataChannelStarted(offer.peerConnection.ctx, dataChannel)
380                 setDataChannelOnOpen(ctx, dataChannel, offer.peerConnection, func(dc datachannel.ReadWriteCloser) {
381                         metrics.Add("outbound offers answered with datachannel", 1)
382                         tc.mu.Lock()
383                         tc.stats.ConvertedOutboundConns++
384                         tc.mu.Unlock()
385                         tc.OnConn(dc, DataChannelContext{
386                                 Local:          offer.originalOffer,
387                                 Remote:         answer,
388                                 OfferId:        offerId,
389                                 LocalOffered:   true,
390                                 InfoHash:       offer.infoHash,
391                                 peerConnection: offer.peerConnection,
392                                 ctx:            ctx,
393                                 span:           span,
394                         })
395                 })
396         }
397         err = offer.peerConnection.SetRemoteDescription(answer)
398         if err != nil {
399                 err = fmt.Errorf("using outbound offer answer: %w", err)
400                 offer.peerConnection.span.RecordError(err)
401                 dataChannel.Close()
402                 tc.Logger.WithDefaultLevel(log.Error).Print(err)
403                 return
404         }
405 deleteOffer:
406         delete(tc.outboundOffers, offerId)
407         go tc.Announce(tracker.None, offer.infoHash)
408 }