]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker-client.go
Also close created data channels when cleaning up webrtc conns
[btrtrc.git] / webtorrent / tracker-client.go
1 package webtorrent
2
3 import (
4         "context"
5         "crypto/rand"
6         "encoding/json"
7         "fmt"
8         "github.com/anacrolix/generics"
9         "go.opentelemetry.io/otel/trace"
10         "sync"
11         "time"
12
13         "github.com/anacrolix/log"
14
15         "github.com/anacrolix/torrent/tracker"
16         "github.com/gorilla/websocket"
17         "github.com/pion/datachannel"
18         "github.com/pion/webrtc/v3"
19 )
20
21 type TrackerClientStats struct {
22         Dials                  int64
23         ConvertedInboundConns  int64
24         ConvertedOutboundConns int64
25 }
26
27 // Client represents the webtorrent client
28 type TrackerClient struct {
29         Url                string
30         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
31         PeerId             [20]byte
32         OnConn             onDataChannelOpen
33         Logger             log.Logger
34         Dialer             *websocket.Dialer
35
36         mu             sync.Mutex
37         cond           sync.Cond
38         outboundOffers map[string]outboundOfferValue // OfferID to outboundOfferValue
39         wsConn         *websocket.Conn
40         closed         bool
41         stats          TrackerClientStats
42         pingTicker     *time.Ticker
43 }
44
45 func (me *TrackerClient) Stats() TrackerClientStats {
46         me.mu.Lock()
47         defer me.mu.Unlock()
48         return me.stats
49 }
50
51 func (me *TrackerClient) peerIdBinary() string {
52         return binaryToJsonString(me.PeerId[:])
53 }
54
55 type outboundOffer struct {
56         offerId string
57         outboundOfferValue
58 }
59
60 // outboundOfferValue represents an outstanding offer.
61 type outboundOfferValue struct {
62         originalOffer  webrtc.SessionDescription
63         peerConnection *wrappedPeerConnection
64         infoHash       [20]byte
65         dataChannel    *webrtc.DataChannel
66 }
67
68 type DataChannelContext struct {
69         OfferId      string
70         LocalOffered bool
71         InfoHash     [20]byte
72         // This is private as some methods might not be appropriate with data channel context.
73         peerConnection *wrappedPeerConnection
74         Span           trace.Span
75         Context        context.Context
76 }
77
78 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
79         return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
80 }
81
82 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
83
84 func (tc *TrackerClient) doWebsocket() error {
85         metrics.Add("websocket dials", 1)
86         tc.mu.Lock()
87         tc.stats.Dials++
88         tc.mu.Unlock()
89         c, _, err := tc.Dialer.Dial(tc.Url, nil)
90         if err != nil {
91                 return fmt.Errorf("dialing tracker: %w", err)
92         }
93         defer c.Close()
94         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
95         tc.mu.Lock()
96         tc.wsConn = c
97         tc.cond.Broadcast()
98         tc.mu.Unlock()
99         tc.announceOffers()
100         closeChan := make(chan struct{})
101         go func() {
102                 for {
103                         select {
104                         case <-tc.pingTicker.C:
105                                 tc.mu.Lock()
106                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
107                                 tc.mu.Unlock()
108                                 if err != nil {
109                                         return
110                                 }
111                         case <-closeChan:
112                                 return
113
114                         }
115                 }
116         }()
117         err = tc.trackerReadLoop(tc.wsConn)
118         close(closeChan)
119         tc.mu.Lock()
120         c.Close()
121         tc.mu.Unlock()
122         return err
123 }
124
125 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
126 // result. We don't let the caller just spawn the runner directly, since then we can race against
127 // .Close to finish initialization.
128 func (tc *TrackerClient) Start(onStop func(error)) {
129         tc.pingTicker = time.NewTicker(60 * time.Second)
130         tc.cond.L = &tc.mu
131         go func() {
132                 onStop(tc.run())
133         }()
134 }
135
136 func (tc *TrackerClient) run() error {
137         tc.mu.Lock()
138         for !tc.closed {
139                 tc.mu.Unlock()
140                 err := tc.doWebsocket()
141                 level := log.Info
142                 tc.mu.Lock()
143                 if tc.closed {
144                         level = log.Debug
145                 }
146                 tc.mu.Unlock()
147                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
148                 time.Sleep(time.Minute)
149                 tc.mu.Lock()
150         }
151         tc.mu.Unlock()
152         return nil
153 }
154
155 func (tc *TrackerClient) Close() error {
156         tc.mu.Lock()
157         tc.closed = true
158         if tc.wsConn != nil {
159                 tc.wsConn.Close()
160         }
161         tc.closeUnusedOffers()
162         tc.pingTicker.Stop()
163         tc.mu.Unlock()
164         tc.cond.Broadcast()
165         return nil
166 }
167
168 func (tc *TrackerClient) announceOffers() {
169         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
170         // is nil. Take ownership of outboundOffers here.
171         tc.mu.Lock()
172         offers := tc.outboundOffers
173         tc.outboundOffers = nil
174         tc.mu.Unlock()
175
176         if offers == nil {
177                 return
178         }
179
180         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
181         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
182         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
183         for _, offer := range offers {
184                 // TODO: Capture the errors? Are we even in a position to do anything with them?
185                 offer.peerConnection.Close()
186                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
187                 go tc.Announce(tracker.Started, offer.infoHash)
188         }
189 }
190
191 func (tc *TrackerClient) closeUnusedOffers() {
192         for _, offer := range tc.outboundOffers {
193                 offer.peerConnection.Close()
194                 offer.dataChannel.Close()
195         }
196         tc.outboundOffers = nil
197 }
198
199 func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
200         tc.mu.Lock()
201         defer tc.mu.Unlock()
202         for key, offer := range tc.outboundOffers {
203                 if offer.infoHash == infoHash {
204                         offer.peerConnection.Close()
205                         delete(tc.outboundOffers, key)
206                 }
207         }
208 }
209
210 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
211         metrics.Add("outbound announces", 1)
212         if event == tracker.Stopped {
213                 return tc.announce(event, infoHash, nil)
214         }
215         var randOfferId [20]byte
216         _, err := rand.Read(randOfferId[:])
217         if err != nil {
218                 return fmt.Errorf("generating offer_id bytes: %w", err)
219         }
220         offerIDBinary := binaryToJsonString(randOfferId[:])
221
222         pc, dc, offer, err := tc.newOffer(tc.Logger, offerIDBinary, infoHash)
223         if err != nil {
224                 return fmt.Errorf("creating offer: %w", err)
225         }
226
227         err = tc.announce(event, infoHash, []outboundOffer{{
228                 offerId: offerIDBinary,
229                 outboundOfferValue: outboundOfferValue{
230                         originalOffer:  offer,
231                         peerConnection: pc,
232                         infoHash:       infoHash,
233                         dataChannel:    dc,
234                 }},
235         })
236         if err != nil {
237                 pc.Close()
238         }
239         return err
240 }
241
242 func (tc *TrackerClient) announce(event tracker.AnnounceEvent, infoHash [20]byte, offers []outboundOffer) error {
243         request, err := tc.GetAnnounceRequest(event, infoHash)
244         if err != nil {
245                 return fmt.Errorf("getting announce parameters: %w", err)
246         }
247
248         req := AnnounceRequest{
249                 Numwant:    len(offers),
250                 Uploaded:   request.Uploaded,
251                 Downloaded: request.Downloaded,
252                 Left:       request.Left,
253                 Event:      request.Event.String(),
254                 Action:     "announce",
255                 InfoHash:   binaryToJsonString(infoHash[:]),
256                 PeerID:     tc.peerIdBinary(),
257         }
258         for _, offer := range offers {
259                 req.Offers = append(req.Offers, Offer{
260                         OfferID: offer.offerId,
261                         Offer:   offer.originalOffer,
262                 })
263         }
264
265         data, err := json.Marshal(req)
266         if err != nil {
267                 return fmt.Errorf("marshalling request: %w", err)
268         }
269
270         tc.mu.Lock()
271         defer tc.mu.Unlock()
272         err = tc.writeMessage(data)
273         if err != nil {
274                 return fmt.Errorf("write AnnounceRequest: %w", err)
275         }
276         for _, offer := range offers {
277                 generics.MakeMapIfNilAndSet(&tc.outboundOffers, offer.offerId, offer.outboundOfferValue)
278         }
279         return nil
280 }
281
282 func (tc *TrackerClient) writeMessage(data []byte) error {
283         for tc.wsConn == nil {
284                 if tc.closed {
285                         return fmt.Errorf("%T closed", tc)
286                 }
287                 tc.cond.Wait()
288         }
289         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
290 }
291
292 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
293         for {
294                 _, message, err := tracker.ReadMessage()
295                 if err != nil {
296                         return fmt.Errorf("read message error: %w", err)
297                 }
298                 // tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
299
300                 var ar AnnounceResponse
301                 if err := json.Unmarshal(message, &ar); err != nil {
302                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
303                         continue
304                 }
305                 switch {
306                 case ar.Offer != nil:
307                         ih, err := jsonStringToInfoHash(ar.InfoHash)
308                         if err != nil {
309                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
310                                 break
311                         }
312                         err = tc.handleOffer(offerContext{
313                                 SessDesc: *ar.Offer,
314                                 Id:       ar.OfferID,
315                                 InfoHash: ih,
316                         }, ar.PeerID)
317                         if err != nil {
318                                 tc.Logger.Levelf(log.Error, "handling offer for infohash %x: %v", ih, err)
319                         }
320                 case ar.Answer != nil:
321                         tc.handleAnswer(ar.OfferID, *ar.Answer)
322                 default:
323                         tc.Logger.Levelf(log.Warning, "unhandled announce response %q", message)
324                 }
325         }
326 }
327
328 type offerContext struct {
329         SessDesc webrtc.SessionDescription
330         Id       string
331         InfoHash [20]byte
332 }
333
334 func (tc *TrackerClient) handleOffer(
335         offerContext offerContext,
336         peerId string,
337 ) error {
338         peerConnection, answer, err := tc.newAnsweringPeerConnection(offerContext)
339         if err != nil {
340                 return fmt.Errorf("creating answering peer connection: %w", err)
341         }
342         response := AnnounceResponse{
343                 Action:   "announce",
344                 InfoHash: binaryToJsonString(offerContext.InfoHash[:]),
345                 PeerID:   tc.peerIdBinary(),
346                 ToPeerID: peerId,
347                 Answer:   &answer,
348                 OfferID:  offerContext.Id,
349         }
350         data, err := json.Marshal(response)
351         if err != nil {
352                 peerConnection.Close()
353                 return fmt.Errorf("marshalling response: %w", err)
354         }
355         tc.mu.Lock()
356         defer tc.mu.Unlock()
357         if err := tc.writeMessage(data); err != nil {
358                 peerConnection.Close()
359                 return fmt.Errorf("writing response: %w", err)
360         }
361         return nil
362 }
363
364 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
365         tc.mu.Lock()
366         defer tc.mu.Unlock()
367         offer, ok := tc.outboundOffers[offerId]
368         if !ok {
369                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
370                 return
371         }
372         // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
373         metrics.Add("outbound offers answered", 1)
374         err := offer.peerConnection.SetRemoteDescription(answer)
375         if err != nil {
376                 err = fmt.Errorf("using outbound offer answer: %w", err)
377                 offer.peerConnection.span.RecordError(err)
378                 tc.Logger.LevelPrint(log.Error, err)
379                 return
380         }
381         delete(tc.outboundOffers, offerId)
382         go tc.Announce(tracker.None, offer.infoHash)
383 }