]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker-client.go
64885bf4d1fea79fe35803b6a76a9b14bb55f705
[btrtrc.git] / webtorrent / tracker-client.go
1 package webtorrent
2
3 import (
4         "context"
5         "crypto/rand"
6         "encoding/json"
7         "fmt"
8         "sync"
9         "time"
10
11         "github.com/anacrolix/generics"
12         "github.com/anacrolix/log"
13         "github.com/gorilla/websocket"
14         "github.com/pion/datachannel"
15         "github.com/pion/webrtc/v3"
16         "go.opentelemetry.io/otel/trace"
17
18         "github.com/anacrolix/torrent/tracker"
19 )
20
21 type TrackerClientStats struct {
22         Dials                  int64
23         ConvertedInboundConns  int64
24         ConvertedOutboundConns int64
25 }
26
27 // Client represents the webtorrent client
28 type TrackerClient struct {
29         Url                string
30         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
31         PeerId             [20]byte
32         OnConn             onDataChannelOpen
33         Logger             log.Logger
34         Dialer             *websocket.Dialer
35
36         mu             sync.Mutex
37         cond           sync.Cond
38         outboundOffers map[string]outboundOfferValue // OfferID to outboundOfferValue
39         wsConn         *websocket.Conn
40         closed         bool
41         stats          TrackerClientStats
42         pingTicker     *time.Ticker
43 }
44
45 func (me *TrackerClient) Stats() TrackerClientStats {
46         me.mu.Lock()
47         defer me.mu.Unlock()
48         return me.stats
49 }
50
51 func (me *TrackerClient) peerIdBinary() string {
52         return binaryToJsonString(me.PeerId[:])
53 }
54
55 type outboundOffer struct {
56         offerId string
57         outboundOfferValue
58 }
59
60 // outboundOfferValue represents an outstanding offer.
61 type outboundOfferValue struct {
62         originalOffer  webrtc.SessionDescription
63         peerConnection *wrappedPeerConnection
64         infoHash       [20]byte
65         dataChannel    *webrtc.DataChannel
66 }
67
68 type DataChannelContext struct {
69         OfferId      string
70         LocalOffered bool
71         InfoHash     [20]byte
72         // This is private as some methods might not be appropriate with data channel context.
73         peerConnection *wrappedPeerConnection
74         Span           trace.Span
75         Context        context.Context
76 }
77
78 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
79         return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
80 }
81
82 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
83
84 func (tc *TrackerClient) doWebsocket() error {
85         metrics.Add("websocket dials", 1)
86         tc.mu.Lock()
87         tc.stats.Dials++
88         tc.mu.Unlock()
89         c, _, err := tc.Dialer.Dial(tc.Url, nil)
90         if err != nil {
91                 return fmt.Errorf("dialing tracker: %w", err)
92         }
93         defer c.Close()
94         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
95         tc.mu.Lock()
96         tc.wsConn = c
97         tc.cond.Broadcast()
98         tc.mu.Unlock()
99         tc.announceOffers()
100         closeChan := make(chan struct{})
101         go func() {
102                 for {
103                         select {
104                         case <-tc.pingTicker.C:
105                                 tc.mu.Lock()
106                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
107                                 tc.mu.Unlock()
108                                 if err != nil {
109                                         return
110                                 }
111                         case <-closeChan:
112                                 return
113
114                         }
115                 }
116         }()
117         err = tc.trackerReadLoop(tc.wsConn)
118         close(closeChan)
119         tc.mu.Lock()
120         c.Close()
121         tc.mu.Unlock()
122         return err
123 }
124
125 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
126 // result. We don't let the caller just spawn the runner directly, since then we can race against
127 // .Close to finish initialization.
128 func (tc *TrackerClient) Start(onStop func(error)) {
129         tc.pingTicker = time.NewTicker(60 * time.Second)
130         tc.cond.L = &tc.mu
131         go func() {
132                 onStop(tc.run())
133         }()
134 }
135
136 func (tc *TrackerClient) run() error {
137         tc.mu.Lock()
138         for !tc.closed {
139                 tc.mu.Unlock()
140                 err := tc.doWebsocket()
141                 level := log.Info
142                 tc.mu.Lock()
143                 if tc.closed {
144                         level = log.Debug
145                 }
146                 tc.mu.Unlock()
147                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
148                 time.Sleep(time.Minute)
149                 tc.mu.Lock()
150         }
151         tc.mu.Unlock()
152         return nil
153 }
154
155 func (tc *TrackerClient) Close() error {
156         tc.mu.Lock()
157         tc.closed = true
158         if tc.wsConn != nil {
159                 tc.wsConn.Close()
160         }
161         tc.closeUnusedOffers()
162         tc.pingTicker.Stop()
163         tc.mu.Unlock()
164         tc.cond.Broadcast()
165         return nil
166 }
167
168 func (tc *TrackerClient) announceOffers() {
169         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
170         // is nil. Take ownership of outboundOffers here.
171         tc.mu.Lock()
172         offers := tc.outboundOffers
173         tc.outboundOffers = nil
174         tc.mu.Unlock()
175
176         if offers == nil {
177                 return
178         }
179
180         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
181         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
182         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
183         for _, offer := range offers {
184                 // TODO: Capture the errors? Are we even in a position to do anything with them?
185                 offer.peerConnection.Close()
186                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
187                 go tc.Announce(tracker.Started, offer.infoHash)
188         }
189 }
190
191 func (tc *TrackerClient) closeUnusedOffers() {
192         for _, offer := range tc.outboundOffers {
193                 offer.peerConnection.Close()
194                 offer.dataChannel.Close()
195         }
196         tc.outboundOffers = nil
197 }
198
199 func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
200         tc.mu.Lock()
201         defer tc.mu.Unlock()
202         for key, offer := range tc.outboundOffers {
203                 if offer.infoHash == infoHash {
204                         offer.peerConnection.Close()
205                         delete(tc.outboundOffers, key)
206                 }
207         }
208 }
209
210 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
211         metrics.Add("outbound announces", 1)
212         if event == tracker.Stopped {
213                 return tc.announce(event, infoHash, nil)
214         }
215         var randOfferId [20]byte
216         _, err := rand.Read(randOfferId[:])
217         if err != nil {
218                 return fmt.Errorf("generating offer_id bytes: %w", err)
219         }
220         offerIDBinary := binaryToJsonString(randOfferId[:])
221
222         pc, dc, offer, err := tc.newOffer(tc.Logger, offerIDBinary, infoHash)
223         if err != nil {
224                 return fmt.Errorf("creating offer: %w", err)
225         }
226
227         err = tc.announce(event, infoHash, []outboundOffer{{
228                 offerId: offerIDBinary,
229                 outboundOfferValue: outboundOfferValue{
230                         originalOffer:  offer,
231                         peerConnection: pc,
232                         infoHash:       infoHash,
233                         dataChannel:    dc,
234                 }},
235         })
236         if err != nil {
237                 dc.Close()
238                 pc.Close()
239         }
240         return err
241 }
242
243 func (tc *TrackerClient) announce(event tracker.AnnounceEvent, infoHash [20]byte, offers []outboundOffer) error {
244         request, err := tc.GetAnnounceRequest(event, infoHash)
245         if err != nil {
246                 return fmt.Errorf("getting announce parameters: %w", err)
247         }
248
249         req := AnnounceRequest{
250                 Numwant:    len(offers),
251                 Uploaded:   request.Uploaded,
252                 Downloaded: request.Downloaded,
253                 Left:       request.Left,
254                 Event:      request.Event.String(),
255                 Action:     "announce",
256                 InfoHash:   binaryToJsonString(infoHash[:]),
257                 PeerID:     tc.peerIdBinary(),
258         }
259         for _, offer := range offers {
260                 req.Offers = append(req.Offers, Offer{
261                         OfferID: offer.offerId,
262                         Offer:   offer.originalOffer,
263                 })
264         }
265
266         data, err := json.Marshal(req)
267         if err != nil {
268                 return fmt.Errorf("marshalling request: %w", err)
269         }
270
271         tc.mu.Lock()
272         defer tc.mu.Unlock()
273         err = tc.writeMessage(data)
274         if err != nil {
275                 return fmt.Errorf("write AnnounceRequest: %w", err)
276         }
277         for _, offer := range offers {
278                 generics.MakeMapIfNilAndSet(&tc.outboundOffers, offer.offerId, offer.outboundOfferValue)
279         }
280         return nil
281 }
282
283 func (tc *TrackerClient) writeMessage(data []byte) error {
284         for tc.wsConn == nil {
285                 if tc.closed {
286                         return fmt.Errorf("%T closed", tc)
287                 }
288                 tc.cond.Wait()
289         }
290         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
291 }
292
293 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
294         for {
295                 _, message, err := tracker.ReadMessage()
296                 if err != nil {
297                         return fmt.Errorf("read message error: %w", err)
298                 }
299                 // tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
300
301                 var ar AnnounceResponse
302                 if err := json.Unmarshal(message, &ar); err != nil {
303                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
304                         continue
305                 }
306                 switch {
307                 case ar.Offer != nil:
308                         ih, err := jsonStringToInfoHash(ar.InfoHash)
309                         if err != nil {
310                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
311                                 break
312                         }
313                         err = tc.handleOffer(offerContext{
314                                 SessDesc: *ar.Offer,
315                                 Id:       ar.OfferID,
316                                 InfoHash: ih,
317                         }, ar.PeerID)
318                         if err != nil {
319                                 tc.Logger.Levelf(log.Error, "handling offer for infohash %x: %v", ih, err)
320                         }
321                 case ar.Answer != nil:
322                         tc.handleAnswer(ar.OfferID, *ar.Answer)
323                 default:
324                         tc.Logger.Levelf(log.Warning, "unhandled announce response %q", message)
325                 }
326         }
327 }
328
329 type offerContext struct {
330         SessDesc webrtc.SessionDescription
331         Id       string
332         InfoHash [20]byte
333 }
334
335 func (tc *TrackerClient) handleOffer(
336         offerContext offerContext,
337         peerId string,
338 ) error {
339         peerConnection, answer, err := tc.newAnsweringPeerConnection(offerContext)
340         if err != nil {
341                 return fmt.Errorf("creating answering peer connection: %w", err)
342         }
343         response := AnnounceResponse{
344                 Action:   "announce",
345                 InfoHash: binaryToJsonString(offerContext.InfoHash[:]),
346                 PeerID:   tc.peerIdBinary(),
347                 ToPeerID: peerId,
348                 Answer:   &answer,
349                 OfferID:  offerContext.Id,
350         }
351         data, err := json.Marshal(response)
352         if err != nil {
353                 peerConnection.Close()
354                 return fmt.Errorf("marshalling response: %w", err)
355         }
356         tc.mu.Lock()
357         defer tc.mu.Unlock()
358         if err := tc.writeMessage(data); err != nil {
359                 peerConnection.Close()
360                 return fmt.Errorf("writing response: %w", err)
361         }
362         return nil
363 }
364
365 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
366         tc.mu.Lock()
367         defer tc.mu.Unlock()
368         offer, ok := tc.outboundOffers[offerId]
369         if !ok {
370                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
371                 return
372         }
373         // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
374         metrics.Add("outbound offers answered", 1)
375         err := offer.peerConnection.SetRemoteDescription(answer)
376         if err != nil {
377                 err = fmt.Errorf("using outbound offer answer: %w", err)
378                 offer.peerConnection.span.RecordError(err)
379                 tc.Logger.LevelPrint(log.Error, err)
380                 return
381         }
382         delete(tc.outboundOffers, offerId)
383         go tc.Announce(tracker.None, offer.infoHash)
384 }