]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker-client.go
Relax TestTcpSimultaneousOpen
[btrtrc.git] / webtorrent / tracker-client.go
1 package webtorrent
2
3 import (
4         "context"
5         "crypto/rand"
6         "encoding/json"
7         "fmt"
8         "net/http"
9         "sync"
10         "time"
11
12         g "github.com/anacrolix/generics"
13         "github.com/anacrolix/log"
14         "github.com/gorilla/websocket"
15         "github.com/pion/datachannel"
16         "github.com/pion/webrtc/v3"
17         "go.opentelemetry.io/otel/trace"
18
19         "github.com/anacrolix/torrent/tracker"
20 )
21
22 type TrackerClientStats struct {
23         Dials                  int64
24         ConvertedInboundConns  int64
25         ConvertedOutboundConns int64
26 }
27
28 // Client represents the webtorrent client
29 type TrackerClient struct {
30         Url                string
31         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
32         PeerId             [20]byte
33         OnConn             onDataChannelOpen
34         Logger             log.Logger
35         Dialer             *websocket.Dialer
36
37         mu             sync.Mutex
38         cond           sync.Cond
39         outboundOffers map[string]outboundOfferValue // OfferID to outboundOfferValue
40         wsConn         *websocket.Conn
41         closed         bool
42         stats          TrackerClientStats
43         pingTicker     *time.Ticker
44
45         WebsocketTrackerHttpHeader func() http.Header
46 }
47
48 func (me *TrackerClient) Stats() TrackerClientStats {
49         me.mu.Lock()
50         defer me.mu.Unlock()
51         return me.stats
52 }
53
54 func (me *TrackerClient) peerIdBinary() string {
55         return binaryToJsonString(me.PeerId[:])
56 }
57
58 type outboundOffer struct {
59         offerId string
60         outboundOfferValue
61 }
62
63 // outboundOfferValue represents an outstanding offer.
64 type outboundOfferValue struct {
65         originalOffer  webrtc.SessionDescription
66         peerConnection *wrappedPeerConnection
67         infoHash       [20]byte
68         dataChannel    *webrtc.DataChannel
69 }
70
71 type DataChannelContext struct {
72         OfferId      string
73         LocalOffered bool
74         InfoHash     [20]byte
75         // This is private as some methods might not be appropriate with data channel context.
76         peerConnection *wrappedPeerConnection
77         Span           trace.Span
78         Context        context.Context
79 }
80
81 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
82         return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
83 }
84
85 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
86
87 func (tc *TrackerClient) doWebsocket() error {
88         metrics.Add("websocket dials", 1)
89         tc.mu.Lock()
90         tc.stats.Dials++
91         tc.mu.Unlock()
92
93         var header http.Header
94         if tc.WebsocketTrackerHttpHeader != nil {
95                 header = tc.WebsocketTrackerHttpHeader()
96         }
97
98         c, _, err := tc.Dialer.Dial(tc.Url, header)
99         if err != nil {
100                 return fmt.Errorf("dialing tracker: %w", err)
101         }
102         defer c.Close()
103         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
104         tc.mu.Lock()
105         tc.wsConn = c
106         tc.cond.Broadcast()
107         tc.mu.Unlock()
108         tc.announceOffers()
109         closeChan := make(chan struct{})
110         go func() {
111                 for {
112                         select {
113                         case <-tc.pingTicker.C:
114                                 tc.mu.Lock()
115                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
116                                 tc.mu.Unlock()
117                                 if err != nil {
118                                         return
119                                 }
120                         case <-closeChan:
121                                 return
122
123                         }
124                 }
125         }()
126         err = tc.trackerReadLoop(tc.wsConn)
127         close(closeChan)
128         tc.mu.Lock()
129         c.Close()
130         tc.mu.Unlock()
131         return err
132 }
133
134 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
135 // result. We don't let the caller just spawn the runner directly, since then we can race against
136 // .Close to finish initialization.
137 func (tc *TrackerClient) Start(onStop func(error)) {
138         tc.pingTicker = time.NewTicker(60 * time.Second)
139         tc.cond.L = &tc.mu
140         go func() {
141                 onStop(tc.run())
142         }()
143 }
144
145 func (tc *TrackerClient) run() error {
146         tc.mu.Lock()
147         for !tc.closed {
148                 tc.mu.Unlock()
149                 err := tc.doWebsocket()
150                 level := log.Info
151                 tc.mu.Lock()
152                 if tc.closed {
153                         level = log.Debug
154                 }
155                 tc.mu.Unlock()
156                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
157                 time.Sleep(time.Minute)
158                 tc.mu.Lock()
159         }
160         tc.mu.Unlock()
161         return nil
162 }
163
164 func (tc *TrackerClient) Close() error {
165         tc.mu.Lock()
166         tc.closed = true
167         if tc.wsConn != nil {
168                 tc.wsConn.Close()
169         }
170         tc.closeUnusedOffers()
171         tc.pingTicker.Stop()
172         tc.mu.Unlock()
173         tc.cond.Broadcast()
174         return nil
175 }
176
177 func (tc *TrackerClient) announceOffers() {
178         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
179         // is nil. Take ownership of outboundOffers here.
180         tc.mu.Lock()
181         offers := tc.outboundOffers
182         tc.outboundOffers = nil
183         tc.mu.Unlock()
184
185         if offers == nil {
186                 return
187         }
188
189         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
190         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
191         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
192         for _, offer := range offers {
193                 // TODO: Capture the errors? Are we even in a position to do anything with them?
194                 offer.peerConnection.Close()
195                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
196                 go tc.Announce(tracker.Started, offer.infoHash)
197         }
198 }
199
200 func (tc *TrackerClient) closeUnusedOffers() {
201         for _, offer := range tc.outboundOffers {
202                 offer.peerConnection.Close()
203                 offer.dataChannel.Close()
204         }
205         tc.outboundOffers = nil
206 }
207
208 func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
209         tc.mu.Lock()
210         defer tc.mu.Unlock()
211         for key, offer := range tc.outboundOffers {
212                 if offer.infoHash == infoHash {
213                         offer.peerConnection.Close()
214                         delete(tc.outboundOffers, key)
215                 }
216         }
217 }
218
219 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
220         metrics.Add("outbound announces", 1)
221         if event == tracker.Stopped {
222                 return tc.announce(event, infoHash, nil)
223         }
224         var randOfferId [20]byte
225         _, err := rand.Read(randOfferId[:])
226         if err != nil {
227                 return fmt.Errorf("generating offer_id bytes: %w", err)
228         }
229         offerIDBinary := binaryToJsonString(randOfferId[:])
230
231         pc, dc, offer, err := tc.newOffer(tc.Logger, offerIDBinary, infoHash)
232         if err != nil {
233                 return fmt.Errorf("creating offer: %w", err)
234         }
235
236         err = tc.announce(event, infoHash, []outboundOffer{
237                 {
238                         offerId: offerIDBinary,
239                         outboundOfferValue: outboundOfferValue{
240                                 originalOffer:  offer,
241                                 peerConnection: pc,
242                                 infoHash:       infoHash,
243                                 dataChannel:    dc,
244                         },
245                 },
246         })
247         if err != nil {
248                 dc.Close()
249                 pc.Close()
250         }
251         return err
252 }
253
254 func (tc *TrackerClient) announce(event tracker.AnnounceEvent, infoHash [20]byte, offers []outboundOffer) error {
255         request, err := tc.GetAnnounceRequest(event, infoHash)
256         if err != nil {
257                 return fmt.Errorf("getting announce parameters: %w", err)
258         }
259
260         req := AnnounceRequest{
261                 Numwant:    len(offers),
262                 Uploaded:   request.Uploaded,
263                 Downloaded: request.Downloaded,
264                 Left:       request.Left,
265                 Event:      request.Event.String(),
266                 Action:     "announce",
267                 InfoHash:   binaryToJsonString(infoHash[:]),
268                 PeerID:     tc.peerIdBinary(),
269         }
270         for _, offer := range offers {
271                 req.Offers = append(req.Offers, Offer{
272                         OfferID: offer.offerId,
273                         Offer:   offer.originalOffer,
274                 })
275         }
276
277         data, err := json.Marshal(req)
278         if err != nil {
279                 return fmt.Errorf("marshalling request: %w", err)
280         }
281
282         tc.mu.Lock()
283         defer tc.mu.Unlock()
284         err = tc.writeMessage(data)
285         if err != nil {
286                 return fmt.Errorf("write AnnounceRequest: %w", err)
287         }
288         for _, offer := range offers {
289                 g.MakeMapIfNilAndSet(&tc.outboundOffers, offer.offerId, offer.outboundOfferValue)
290         }
291         return nil
292 }
293
294 func (tc *TrackerClient) writeMessage(data []byte) error {
295         for tc.wsConn == nil {
296                 if tc.closed {
297                         return fmt.Errorf("%T closed", tc)
298                 }
299                 tc.cond.Wait()
300         }
301         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
302 }
303
304 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
305         for {
306                 _, message, err := tracker.ReadMessage()
307                 if err != nil {
308                         return fmt.Errorf("read message error: %w", err)
309                 }
310                 // tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
311
312                 var ar AnnounceResponse
313                 if err := json.Unmarshal(message, &ar); err != nil {
314                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
315                         continue
316                 }
317                 switch {
318                 case ar.Offer != nil:
319                         ih, err := jsonStringToInfoHash(ar.InfoHash)
320                         if err != nil {
321                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
322                                 break
323                         }
324                         err = tc.handleOffer(offerContext{
325                                 SessDesc: *ar.Offer,
326                                 Id:       ar.OfferID,
327                                 InfoHash: ih,
328                         }, ar.PeerID)
329                         if err != nil {
330                                 tc.Logger.Levelf(log.Error, "handling offer for infohash %x: %v", ih, err)
331                         }
332                 case ar.Answer != nil:
333                         tc.handleAnswer(ar.OfferID, *ar.Answer)
334                 default:
335                         tc.Logger.Levelf(log.Warning, "unhandled announce response %q", message)
336                 }
337         }
338 }
339
340 type offerContext struct {
341         SessDesc webrtc.SessionDescription
342         Id       string
343         InfoHash [20]byte
344 }
345
346 func (tc *TrackerClient) handleOffer(
347         offerContext offerContext,
348         peerId string,
349 ) error {
350         peerConnection, answer, err := tc.newAnsweringPeerConnection(offerContext)
351         if err != nil {
352                 return fmt.Errorf("creating answering peer connection: %w", err)
353         }
354         response := AnnounceResponse{
355                 Action:   "announce",
356                 InfoHash: binaryToJsonString(offerContext.InfoHash[:]),
357                 PeerID:   tc.peerIdBinary(),
358                 ToPeerID: peerId,
359                 Answer:   &answer,
360                 OfferID:  offerContext.Id,
361         }
362         data, err := json.Marshal(response)
363         if err != nil {
364                 peerConnection.Close()
365                 return fmt.Errorf("marshalling response: %w", err)
366         }
367         tc.mu.Lock()
368         defer tc.mu.Unlock()
369         if err := tc.writeMessage(data); err != nil {
370                 peerConnection.Close()
371                 return fmt.Errorf("writing response: %w", err)
372         }
373         return nil
374 }
375
376 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
377         tc.mu.Lock()
378         defer tc.mu.Unlock()
379         offer, ok := tc.outboundOffers[offerId]
380         if !ok {
381                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
382                 return
383         }
384         // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
385         metrics.Add("outbound offers answered", 1)
386         err := offer.peerConnection.SetRemoteDescription(answer)
387         if err != nil {
388                 err = fmt.Errorf("using outbound offer answer: %w", err)
389                 offer.peerConnection.span.RecordError(err)
390                 tc.Logger.LevelPrint(log.Error, err)
391                 return
392         }
393         delete(tc.outboundOffers, offerId)
394         go tc.Announce(tracker.None, offer.infoHash)
395 }