]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker-client.go
Add WebRTC ICE servers config (#824)
[btrtrc.git] / webtorrent / tracker-client.go
1 package webtorrent
2
3 import (
4         "context"
5         "crypto/rand"
6         "encoding/json"
7         "fmt"
8         "net/http"
9         "sync"
10         "time"
11
12         g "github.com/anacrolix/generics"
13         "github.com/anacrolix/log"
14         "github.com/gorilla/websocket"
15         "github.com/pion/datachannel"
16         "github.com/pion/webrtc/v3"
17         "go.opentelemetry.io/otel/trace"
18
19         "github.com/anacrolix/torrent/tracker"
20 )
21
22 type TrackerClientStats struct {
23         Dials                  int64
24         ConvertedInboundConns  int64
25         ConvertedOutboundConns int64
26 }
27
28 // Client represents the webtorrent client
29 type TrackerClient struct {
30         Url                string
31         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
32         PeerId             [20]byte
33         OnConn             onDataChannelOpen
34         Logger             log.Logger
35         Dialer             *websocket.Dialer
36
37         mu             sync.Mutex
38         cond           sync.Cond
39         outboundOffers map[string]outboundOfferValue // OfferID to outboundOfferValue
40         wsConn         *websocket.Conn
41         closed         bool
42         stats          TrackerClientStats
43         pingTicker     *time.Ticker
44
45         WebsocketTrackerHttpHeader func() http.Header
46         ICEServers                 []string
47 }
48
49 func (me *TrackerClient) Stats() TrackerClientStats {
50         me.mu.Lock()
51         defer me.mu.Unlock()
52         return me.stats
53 }
54
55 func (me *TrackerClient) peerIdBinary() string {
56         return binaryToJsonString(me.PeerId[:])
57 }
58
59 type outboundOffer struct {
60         offerId string
61         outboundOfferValue
62 }
63
64 // outboundOfferValue represents an outstanding offer.
65 type outboundOfferValue struct {
66         originalOffer  webrtc.SessionDescription
67         peerConnection *wrappedPeerConnection
68         infoHash       [20]byte
69         dataChannel    *webrtc.DataChannel
70 }
71
72 type DataChannelContext struct {
73         OfferId      string
74         LocalOffered bool
75         InfoHash     [20]byte
76         // This is private as some methods might not be appropriate with data channel context.
77         peerConnection *wrappedPeerConnection
78         Span           trace.Span
79         Context        context.Context
80 }
81
82 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
83         return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
84 }
85
86 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
87
88 func (tc *TrackerClient) doWebsocket() error {
89         metrics.Add("websocket dials", 1)
90         tc.mu.Lock()
91         tc.stats.Dials++
92         tc.mu.Unlock()
93
94         var header http.Header
95         if tc.WebsocketTrackerHttpHeader != nil {
96                 header = tc.WebsocketTrackerHttpHeader()
97         }
98
99         c, _, err := tc.Dialer.Dial(tc.Url, header)
100         if err != nil {
101                 return fmt.Errorf("dialing tracker: %w", err)
102         }
103         defer c.Close()
104         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
105         tc.mu.Lock()
106         tc.wsConn = c
107         tc.cond.Broadcast()
108         tc.mu.Unlock()
109         tc.announceOffers()
110         closeChan := make(chan struct{})
111         go func() {
112                 for {
113                         select {
114                         case <-tc.pingTicker.C:
115                                 tc.mu.Lock()
116                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
117                                 tc.mu.Unlock()
118                                 if err != nil {
119                                         return
120                                 }
121                         case <-closeChan:
122                                 return
123
124                         }
125                 }
126         }()
127         err = tc.trackerReadLoop(tc.wsConn)
128         close(closeChan)
129         tc.mu.Lock()
130         c.Close()
131         tc.mu.Unlock()
132         return err
133 }
134
135 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
136 // result. We don't let the caller just spawn the runner directly, since then we can race against
137 // .Close to finish initialization.
138 func (tc *TrackerClient) Start(onStop func(error)) {
139         tc.pingTicker = time.NewTicker(60 * time.Second)
140         tc.cond.L = &tc.mu
141         go func() {
142                 onStop(tc.run())
143         }()
144 }
145
146 func (tc *TrackerClient) run() error {
147         tc.mu.Lock()
148         for !tc.closed {
149                 tc.mu.Unlock()
150                 err := tc.doWebsocket()
151                 level := log.Info
152                 tc.mu.Lock()
153                 if tc.closed {
154                         level = log.Debug
155                 }
156                 tc.mu.Unlock()
157                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
158                 time.Sleep(time.Minute)
159                 tc.mu.Lock()
160         }
161         tc.mu.Unlock()
162         return nil
163 }
164
165 func (tc *TrackerClient) Close() error {
166         tc.mu.Lock()
167         tc.closed = true
168         if tc.wsConn != nil {
169                 tc.wsConn.Close()
170         }
171         tc.closeUnusedOffers()
172         tc.pingTicker.Stop()
173         tc.mu.Unlock()
174         tc.cond.Broadcast()
175         return nil
176 }
177
178 func (tc *TrackerClient) announceOffers() {
179         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
180         // is nil. Take ownership of outboundOffers here.
181         tc.mu.Lock()
182         offers := tc.outboundOffers
183         tc.outboundOffers = nil
184         tc.mu.Unlock()
185
186         if offers == nil {
187                 return
188         }
189
190         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
191         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
192         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
193         for _, offer := range offers {
194                 // TODO: Capture the errors? Are we even in a position to do anything with them?
195                 offer.peerConnection.Close()
196                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
197                 go tc.Announce(tracker.Started, offer.infoHash)
198         }
199 }
200
201 func (tc *TrackerClient) closeUnusedOffers() {
202         for _, offer := range tc.outboundOffers {
203                 offer.peerConnection.Close()
204                 offer.dataChannel.Close()
205         }
206         tc.outboundOffers = nil
207 }
208
209 func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
210         tc.mu.Lock()
211         defer tc.mu.Unlock()
212         for key, offer := range tc.outboundOffers {
213                 if offer.infoHash == infoHash {
214                         offer.peerConnection.Close()
215                         delete(tc.outboundOffers, key)
216                 }
217         }
218 }
219
220 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
221         metrics.Add("outbound announces", 1)
222         if event == tracker.Stopped {
223                 return tc.announce(event, infoHash, nil)
224         }
225         var randOfferId [20]byte
226         _, err := rand.Read(randOfferId[:])
227         if err != nil {
228                 return fmt.Errorf("generating offer_id bytes: %w", err)
229         }
230         offerIDBinary := binaryToJsonString(randOfferId[:])
231
232         pc, dc, offer, err := tc.newOffer(tc.Logger, offerIDBinary, infoHash)
233         if err != nil {
234                 return fmt.Errorf("creating offer: %w", err)
235         }
236
237         err = tc.announce(event, infoHash, []outboundOffer{
238                 {
239                         offerId: offerIDBinary,
240                         outboundOfferValue: outboundOfferValue{
241                                 originalOffer:  offer,
242                                 peerConnection: pc,
243                                 infoHash:       infoHash,
244                                 dataChannel:    dc,
245                         },
246                 },
247         })
248         if err != nil {
249                 dc.Close()
250                 pc.Close()
251         }
252         return err
253 }
254
255 func (tc *TrackerClient) announce(event tracker.AnnounceEvent, infoHash [20]byte, offers []outboundOffer) error {
256         request, err := tc.GetAnnounceRequest(event, infoHash)
257         if err != nil {
258                 return fmt.Errorf("getting announce parameters: %w", err)
259         }
260
261         req := AnnounceRequest{
262                 Numwant:    len(offers),
263                 Uploaded:   request.Uploaded,
264                 Downloaded: request.Downloaded,
265                 Left:       request.Left,
266                 Event:      request.Event.String(),
267                 Action:     "announce",
268                 InfoHash:   binaryToJsonString(infoHash[:]),
269                 PeerID:     tc.peerIdBinary(),
270         }
271         for _, offer := range offers {
272                 req.Offers = append(req.Offers, Offer{
273                         OfferID: offer.offerId,
274                         Offer:   offer.originalOffer,
275                 })
276         }
277
278         data, err := json.Marshal(req)
279         if err != nil {
280                 return fmt.Errorf("marshalling request: %w", err)
281         }
282
283         tc.mu.Lock()
284         defer tc.mu.Unlock()
285         err = tc.writeMessage(data)
286         if err != nil {
287                 return fmt.Errorf("write AnnounceRequest: %w", err)
288         }
289         for _, offer := range offers {
290                 g.MakeMapIfNilAndSet(&tc.outboundOffers, offer.offerId, offer.outboundOfferValue)
291         }
292         return nil
293 }
294
295 func (tc *TrackerClient) writeMessage(data []byte) error {
296         for tc.wsConn == nil {
297                 if tc.closed {
298                         return fmt.Errorf("%T closed", tc)
299                 }
300                 tc.cond.Wait()
301         }
302         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
303 }
304
305 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
306         for {
307                 _, message, err := tracker.ReadMessage()
308                 if err != nil {
309                         return fmt.Errorf("read message error: %w", err)
310                 }
311                 // tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
312
313                 var ar AnnounceResponse
314                 if err := json.Unmarshal(message, &ar); err != nil {
315                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
316                         continue
317                 }
318                 switch {
319                 case ar.Offer != nil:
320                         ih, err := jsonStringToInfoHash(ar.InfoHash)
321                         if err != nil {
322                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
323                                 break
324                         }
325                         err = tc.handleOffer(offerContext{
326                                 SessDesc: *ar.Offer,
327                                 Id:       ar.OfferID,
328                                 InfoHash: ih,
329                         }, ar.PeerID)
330                         if err != nil {
331                                 tc.Logger.Levelf(log.Error, "handling offer for infohash %x: %v", ih, err)
332                         }
333                 case ar.Answer != nil:
334                         tc.handleAnswer(ar.OfferID, *ar.Answer)
335                 default:
336                         tc.Logger.Levelf(log.Warning, "unhandled announce response %q", message)
337                 }
338         }
339 }
340
341 type offerContext struct {
342         SessDesc webrtc.SessionDescription
343         Id       string
344         InfoHash [20]byte
345 }
346
347 func (tc *TrackerClient) handleOffer(
348         offerContext offerContext,
349         peerId string,
350 ) error {
351         peerConnection, answer, err := tc.newAnsweringPeerConnection(offerContext)
352         if err != nil {
353                 return fmt.Errorf("creating answering peer connection: %w", err)
354         }
355         response := AnnounceResponse{
356                 Action:   "announce",
357                 InfoHash: binaryToJsonString(offerContext.InfoHash[:]),
358                 PeerID:   tc.peerIdBinary(),
359                 ToPeerID: peerId,
360                 Answer:   &answer,
361                 OfferID:  offerContext.Id,
362         }
363         data, err := json.Marshal(response)
364         if err != nil {
365                 peerConnection.Close()
366                 return fmt.Errorf("marshalling response: %w", err)
367         }
368         tc.mu.Lock()
369         defer tc.mu.Unlock()
370         if err := tc.writeMessage(data); err != nil {
371                 peerConnection.Close()
372                 return fmt.Errorf("writing response: %w", err)
373         }
374         return nil
375 }
376
377 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
378         tc.mu.Lock()
379         defer tc.mu.Unlock()
380         offer, ok := tc.outboundOffers[offerId]
381         if !ok {
382                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
383                 return
384         }
385         // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
386         metrics.Add("outbound offers answered", 1)
387         err := offer.peerConnection.SetRemoteDescription(answer)
388         if err != nil {
389                 err = fmt.Errorf("using outbound offer answer: %w", err)
390                 offer.peerConnection.span.RecordError(err)
391                 tc.Logger.LevelPrint(log.Error, err)
392                 return
393         }
394         delete(tc.outboundOffers, offerId)
395         go tc.Announce(tracker.None, offer.infoHash)
396 }