]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker_client.go
d75b9f9bbd5973335acc7ea1cbfec84530028669
[btrtrc.git] / webtorrent / tracker_client.go
1 package webtorrent
2
3 import (
4         "crypto/rand"
5         "encoding/json"
6         "fmt"
7         "sync"
8         "time"
9
10         "github.com/anacrolix/log"
11
12         "github.com/anacrolix/torrent/tracker"
13         "github.com/gorilla/websocket"
14         "github.com/pion/datachannel"
15         "github.com/pion/webrtc/v3"
16 )
17
18 type TrackerClientStats struct {
19         Dials                  int64
20         ConvertedInboundConns  int64
21         ConvertedOutboundConns int64
22 }
23
24 // Client represents the webtorrent client
25 type TrackerClient struct {
26         Url                string
27         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
28         PeerId             [20]byte
29         OnConn             onDataChannelOpen
30         Logger             log.Logger
31
32         mu             sync.Mutex
33         cond           sync.Cond
34         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
35         wsConn         *websocket.Conn
36         closed         bool
37         stats          TrackerClientStats
38         pingTicker     *time.Ticker
39 }
40
41 func (me *TrackerClient) Stats() TrackerClientStats {
42         me.mu.Lock()
43         defer me.mu.Unlock()
44         return me.stats
45 }
46
47 func (me *TrackerClient) peerIdBinary() string {
48         return binaryToJsonString(me.PeerId[:])
49 }
50
51 // outboundOffer represents an outstanding offer.
52 type outboundOffer struct {
53         originalOffer  webrtc.SessionDescription
54         peerConnection *wrappedPeerConnection
55         dataChannel    *webrtc.DataChannel
56         infoHash       [20]byte
57 }
58
59 type DataChannelContext struct {
60         Local, Remote webrtc.SessionDescription
61         OfferId       string
62         LocalOffered  bool
63         InfoHash      [20]byte
64 }
65
66 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
67
68 func (tc *TrackerClient) doWebsocket() error {
69         metrics.Add("websocket dials", 1)
70         tc.mu.Lock()
71         tc.stats.Dials++
72         tc.mu.Unlock()
73         c, _, err := websocket.DefaultDialer.Dial(tc.Url, nil)
74         if err != nil {
75                 return fmt.Errorf("dialing tracker: %w", err)
76         }
77         defer c.Close()
78         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
79         tc.mu.Lock()
80         tc.wsConn = c
81         tc.cond.Broadcast()
82         tc.mu.Unlock()
83         tc.announceOffers()
84         closeChan := make(chan struct{})
85         go func() {
86                 for {
87                         select {
88                         case <-tc.pingTicker.C:
89                                 tc.mu.Lock()
90                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
91                                 tc.mu.Unlock()
92                                 if err != nil {
93                                         return
94                                 }
95                         case <-closeChan:
96                                 return
97
98                         }
99                 }
100         }()
101         err = tc.trackerReadLoop(tc.wsConn)
102         close(closeChan)
103         tc.mu.Lock()
104         c.Close()
105         tc.mu.Unlock()
106         return err
107 }
108
109 // Finishes initialization and spawns the run routine, calling onStop when it completes with the
110 // result. We don't let the caller just spawn the runner directly, since then we can race against
111 // .Close to finish initialization.
112 func (tc *TrackerClient) Start(onStop func(error)) {
113         tc.pingTicker = time.NewTicker(60 * time.Second)
114         tc.cond.L = &tc.mu
115         go func() {
116                 onStop(tc.run())
117         }()
118 }
119
120 func (tc *TrackerClient) run() error {
121         tc.mu.Lock()
122         for !tc.closed {
123                 tc.mu.Unlock()
124                 err := tc.doWebsocket()
125                 level := log.Info
126                 tc.mu.Lock()
127                 if tc.closed {
128                         level = log.Debug
129                 }
130                 tc.mu.Unlock()
131                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
132                 time.Sleep(time.Minute)
133                 tc.mu.Lock()
134         }
135         tc.mu.Unlock()
136         return nil
137 }
138
139 func (tc *TrackerClient) Close() error {
140         tc.mu.Lock()
141         tc.closed = true
142         if tc.wsConn != nil {
143                 tc.wsConn.Close()
144         }
145         tc.closeUnusedOffers()
146         tc.pingTicker.Stop()
147         tc.mu.Unlock()
148         tc.cond.Broadcast()
149         return nil
150 }
151
152 func (tc *TrackerClient) announceOffers() {
153
154         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
155         // is nil. Take ownership of outboundOffers here.
156         tc.mu.Lock()
157         offers := tc.outboundOffers
158         tc.outboundOffers = nil
159         tc.mu.Unlock()
160
161         if offers == nil {
162                 return
163         }
164
165         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
166         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
167         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
168         for _, offer := range offers {
169                 // TODO: Capture the errors? Are we even in a position to do anything with them?
170                 offer.peerConnection.Close()
171                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
172                 go tc.Announce(tracker.Started, offer.infoHash)
173         }
174 }
175
176 func (tc *TrackerClient) closeUnusedOffers() {
177         for _, offer := range tc.outboundOffers {
178                 offer.peerConnection.Close()
179         }
180         tc.outboundOffers = nil
181 }
182
183 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
184         metrics.Add("outbound announces", 1)
185         var randOfferId [20]byte
186         _, err := rand.Read(randOfferId[:])
187         if err != nil {
188                 return fmt.Errorf("generating offer_id bytes: %w", err)
189         }
190         offerIDBinary := binaryToJsonString(randOfferId[:])
191
192         pc, dc, offer, err := newOffer()
193         if err != nil {
194                 return fmt.Errorf("creating offer: %w", err)
195         }
196
197         request, err := tc.GetAnnounceRequest(event, infoHash)
198         if err != nil {
199                 pc.Close()
200                 return fmt.Errorf("getting announce parameters: %w", err)
201         }
202
203         req := AnnounceRequest{
204                 Numwant:    1, // If higher we need to create equal amount of offers.
205                 Uploaded:   request.Uploaded,
206                 Downloaded: request.Downloaded,
207                 Left:       request.Left,
208                 Event:      request.Event.String(),
209                 Action:     "announce",
210                 InfoHash:   binaryToJsonString(infoHash[:]),
211                 PeerID:     tc.peerIdBinary(),
212                 Offers: []Offer{{
213                         OfferID: offerIDBinary,
214                         Offer:   offer,
215                 }},
216         }
217
218         data, err := json.Marshal(req)
219         if err != nil {
220                 pc.Close()
221                 return fmt.Errorf("marshalling request: %w", err)
222         }
223
224         tc.mu.Lock()
225         defer tc.mu.Unlock()
226         err = tc.writeMessage(data)
227         if err != nil {
228                 pc.Close()
229                 return fmt.Errorf("write AnnounceRequest: %w", err)
230         }
231         if tc.outboundOffers == nil {
232                 tc.outboundOffers = make(map[string]outboundOffer)
233         }
234         tc.outboundOffers[offerIDBinary] = outboundOffer{
235                 peerConnection: pc,
236                 dataChannel:    dc,
237                 originalOffer:  offer,
238                 infoHash:       infoHash,
239         }
240         return nil
241 }
242
243 func (tc *TrackerClient) writeMessage(data []byte) error {
244         for tc.wsConn == nil {
245                 if tc.closed {
246                         return fmt.Errorf("%T closed", tc)
247                 }
248                 tc.cond.Wait()
249         }
250         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
251 }
252
253 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
254         for {
255                 _, message, err := tracker.ReadMessage()
256                 if err != nil {
257                         return fmt.Errorf("read message error: %w", err)
258                 }
259                 //tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
260
261                 var ar AnnounceResponse
262                 if err := json.Unmarshal(message, &ar); err != nil {
263                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
264                         continue
265                 }
266                 switch {
267                 case ar.Offer != nil:
268                         ih, err := jsonStringToInfoHash(ar.InfoHash)
269                         if err != nil {
270                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
271                                 break
272                         }
273                         tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
274                 case ar.Answer != nil:
275                         tc.handleAnswer(ar.OfferID, *ar.Answer)
276                 }
277         }
278 }
279
280 func (tc *TrackerClient) handleOffer(
281         offer webrtc.SessionDescription,
282         offerId string,
283         infoHash [20]byte,
284         peerId string,
285 ) error {
286         peerConnection, answer, err := newAnsweringPeerConnection(offer)
287         if err != nil {
288                 return fmt.Errorf("write AnnounceResponse: %w", err)
289         }
290         response := AnnounceResponse{
291                 Action:   "announce",
292                 InfoHash: binaryToJsonString(infoHash[:]),
293                 PeerID:   tc.peerIdBinary(),
294                 ToPeerID: peerId,
295                 Answer:   &answer,
296                 OfferID:  offerId,
297         }
298         data, err := json.Marshal(response)
299         if err != nil {
300                 peerConnection.Close()
301                 return fmt.Errorf("marshalling response: %w", err)
302         }
303         tc.mu.Lock()
304         defer tc.mu.Unlock()
305         if err := tc.writeMessage(data); err != nil {
306                 peerConnection.Close()
307                 return fmt.Errorf("writing response: %w", err)
308         }
309         timer := time.AfterFunc(30*time.Second, func() {
310                 metrics.Add("answering peer connections timed out", 1)
311                 peerConnection.Close()
312         })
313         peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
314                 setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
315                         timer.Stop()
316                         metrics.Add("answering peer connection conversions", 1)
317                         tc.mu.Lock()
318                         tc.stats.ConvertedInboundConns++
319                         tc.mu.Unlock()
320                         tc.OnConn(dc, DataChannelContext{
321                                 Local:        answer,
322                                 Remote:       offer,
323                                 OfferId:      offerId,
324                                 LocalOffered: false,
325                                 InfoHash:     infoHash,
326                         })
327                 })
328         })
329         return nil
330 }
331
332 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
333         tc.mu.Lock()
334         defer tc.mu.Unlock()
335         offer, ok := tc.outboundOffers[offerId]
336         if !ok {
337                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %+q", offerId)
338                 return
339         }
340         //tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
341         metrics.Add("outbound offers answered", 1)
342         err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
343                 metrics.Add("outbound offers answered with datachannel", 1)
344                 tc.mu.Lock()
345                 tc.stats.ConvertedOutboundConns++
346                 tc.mu.Unlock()
347                 tc.OnConn(dc, DataChannelContext{
348                         Local:        offer.originalOffer,
349                         Remote:       answer,
350                         OfferId:      offerId,
351                         LocalOffered: true,
352                         InfoHash:     offer.infoHash,
353                 })
354         })
355         if err != nil {
356                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err)
357                 return
358         }
359         delete(tc.outboundOffers, offerId)
360         go tc.Announce(tracker.None, offer.infoHash)
361 }