]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker_client.go
Ping websocket to ensure connect remains online.
[btrtrc.git] / webtorrent / tracker_client.go
1 package webtorrent
2
3 import (
4         "crypto/rand"
5         "encoding/json"
6         "fmt"
7         "sync"
8         "time"
9
10         "github.com/anacrolix/log"
11
12         "github.com/anacrolix/torrent/tracker"
13         "github.com/gorilla/websocket"
14         "github.com/pion/datachannel"
15         "github.com/pion/webrtc/v2"
16 )
17
18 type TrackerClientStats struct {
19         Dials                  int64
20         ConvertedInboundConns  int64
21         ConvertedOutboundConns int64
22 }
23
24 // Client represents the webtorrent client
25 type TrackerClient struct {
26         Url                string
27         GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
28         PeerId             [20]byte
29         OnConn             onDataChannelOpen
30         Logger             log.Logger
31
32         mu             sync.Mutex
33         cond           sync.Cond
34         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
35         wsConn         *websocket.Conn
36         closed         bool
37         stats          TrackerClientStats
38         pingTicker     *time.Ticker
39 }
40
41 func (me *TrackerClient) Stats() TrackerClientStats {
42         me.mu.Lock()
43         defer me.mu.Unlock()
44         return me.stats
45 }
46
47 func (me *TrackerClient) peerIdBinary() string {
48         return binaryToJsonString(me.PeerId[:])
49 }
50
51 // outboundOffer represents an outstanding offer.
52 type outboundOffer struct {
53         originalOffer  webrtc.SessionDescription
54         peerConnection *wrappedPeerConnection
55         dataChannel    *webrtc.DataChannel
56         infoHash       [20]byte
57 }
58
59 type DataChannelContext struct {
60         Local, Remote webrtc.SessionDescription
61         OfferId       string
62         LocalOffered  bool
63         InfoHash      [20]byte
64 }
65
66 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
67
68 func (tc *TrackerClient) doWebsocket() error {
69         metrics.Add("websocket dials", 1)
70         tc.stats.Dials++
71         c, _, err := websocket.DefaultDialer.Dial(tc.Url, nil)
72         if err != nil {
73                 return fmt.Errorf("dialing tracker: %w", err)
74         }
75         defer c.Close()
76         tc.Logger.WithDefaultLevel(log.Info).Printf("connected")
77         tc.mu.Lock()
78         tc.wsConn = c
79         tc.cond.Broadcast()
80         tc.mu.Unlock()
81         tc.announceOffers()
82         closeChan := make(chan struct{})
83         go func() {
84                 for {
85                         select {
86                         case <-tc.pingTicker.C:
87                                 err := c.WriteMessage(websocket.PingMessage, []byte{})
88                                 if err != nil {
89                                         return
90                                 }
91                         case <-closeChan:
92                                 return
93                         default:
94                         }
95                 }
96         }()
97         err = tc.trackerReadLoop(tc.wsConn)
98         close(closeChan)
99         tc.mu.Lock()
100         c.Close()
101         tc.mu.Unlock()
102         return err
103 }
104
105 func (tc *TrackerClient) Run() error {
106         tc.pingTicker = time.NewTicker(60 * time.Second)
107         tc.cond.L = &tc.mu
108         tc.mu.Lock()
109         for !tc.closed {
110                 tc.mu.Unlock()
111                 err := tc.doWebsocket()
112                 level := log.Info
113                 tc.mu.Lock()
114                 if tc.closed {
115                         level = log.Debug
116                 }
117                 tc.mu.Unlock()
118                 tc.Logger.WithDefaultLevel(level).Printf("websocket instance ended: %v", err)
119                 time.Sleep(time.Minute)
120                 tc.mu.Lock()
121         }
122         tc.mu.Unlock()
123         return nil
124 }
125
126 func (tc *TrackerClient) Close() error {
127         tc.mu.Lock()
128         tc.closed = true
129         if tc.wsConn != nil {
130                 tc.wsConn.Close()
131         }
132         tc.closeUnusedOffers()
133         tc.pingTicker.Stop()
134         tc.mu.Unlock()
135         tc.cond.Broadcast()
136         return nil
137 }
138
139 func (tc *TrackerClient) announceOffers() {
140
141         // tc.Announce grabs a lock on tc.outboundOffers. It also handles the case where outboundOffers
142         // is nil. Take ownership of outboundOffers here.
143         tc.mu.Lock()
144         offers := tc.outboundOffers
145         tc.outboundOffers = nil
146         tc.mu.Unlock()
147
148         if offers == nil {
149                 return
150         }
151
152         // Iterate over our locally-owned offers, close any existing "invalid" ones from before the
153         // socket reconnected, reannounce the infohash, adding it back into the tc.outboundOffers.
154         tc.Logger.WithDefaultLevel(log.Info).Printf("reannouncing %d infohashes after restart", len(offers))
155         for _, offer := range offers {
156                 // TODO: Capture the errors? Are we even in a position to do anything with them?
157                 offer.peerConnection.Close()
158                 // Use goroutine here to allow read loop to start and ensure the buffer drains.
159                 go tc.Announce(tracker.Started, offer.infoHash)
160         }
161 }
162
163 func (tc *TrackerClient) closeUnusedOffers() {
164         for _, offer := range tc.outboundOffers {
165                 offer.peerConnection.Close()
166         }
167         tc.outboundOffers = nil
168 }
169
170 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
171         metrics.Add("outbound announces", 1)
172         var randOfferId [20]byte
173         _, err := rand.Read(randOfferId[:])
174         if err != nil {
175                 return fmt.Errorf("generating offer_id bytes: %w", err)
176         }
177         offerIDBinary := binaryToJsonString(randOfferId[:])
178
179         pc, dc, offer, err := newOffer()
180         if err != nil {
181                 return fmt.Errorf("creating offer: %w", err)
182         }
183
184         request, err := tc.GetAnnounceRequest(event, infoHash)
185         if err != nil {
186                 pc.Close()
187                 return fmt.Errorf("getting announce parameters: %w", err)
188         }
189
190         req := AnnounceRequest{
191                 Numwant:    1, // If higher we need to create equal amount of offers.
192                 Uploaded:   request.Uploaded,
193                 Downloaded: request.Downloaded,
194                 Left:       request.Left,
195                 Event:      request.Event.String(),
196                 Action:     "announce",
197                 InfoHash:   binaryToJsonString(infoHash[:]),
198                 PeerID:     tc.peerIdBinary(),
199                 Offers: []Offer{{
200                         OfferID: offerIDBinary,
201                         Offer:   offer,
202                 }},
203         }
204
205         data, err := json.Marshal(req)
206         if err != nil {
207                 pc.Close()
208                 return fmt.Errorf("marshalling request: %w", err)
209         }
210
211         tc.mu.Lock()
212         defer tc.mu.Unlock()
213         err = tc.writeMessage(data)
214         if err != nil {
215                 pc.Close()
216                 return fmt.Errorf("write AnnounceRequest: %w", err)
217         }
218         if tc.outboundOffers == nil {
219                 tc.outboundOffers = make(map[string]outboundOffer)
220         }
221         tc.outboundOffers[offerIDBinary] = outboundOffer{
222                 peerConnection: pc,
223                 dataChannel:    dc,
224                 originalOffer:  offer,
225                 infoHash:       infoHash,
226         }
227         return nil
228 }
229
230 func (tc *TrackerClient) writeMessage(data []byte) error {
231         for tc.wsConn == nil {
232                 if tc.closed {
233                         return fmt.Errorf("%T closed", tc)
234                 }
235                 tc.cond.Wait()
236         }
237         return tc.wsConn.WriteMessage(websocket.TextMessage, data)
238 }
239
240 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
241         for {
242                 _, message, err := tracker.ReadMessage()
243                 if err != nil {
244                         return fmt.Errorf("read message error: %w", err)
245                 }
246                 //tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
247
248                 var ar AnnounceResponse
249                 if err := json.Unmarshal(message, &ar); err != nil {
250                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
251                         continue
252                 }
253                 switch {
254                 case ar.Offer != nil:
255                         ih, err := jsonStringToInfoHash(ar.InfoHash)
256                         if err != nil {
257                                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
258                                 break
259                         }
260                         tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
261                 case ar.Answer != nil:
262                         tc.handleAnswer(ar.OfferID, *ar.Answer)
263                 }
264         }
265 }
266
267 func (tc *TrackerClient) handleOffer(
268         offer webrtc.SessionDescription,
269         offerId string,
270         infoHash [20]byte,
271         peerId string,
272 ) error {
273         peerConnection, answer, err := newAnsweringPeerConnection(offer)
274         if err != nil {
275                 return fmt.Errorf("write AnnounceResponse: %w", err)
276         }
277         response := AnnounceResponse{
278                 Action:   "announce",
279                 InfoHash: binaryToJsonString(infoHash[:]),
280                 PeerID:   tc.peerIdBinary(),
281                 ToPeerID: peerId,
282                 Answer:   &answer,
283                 OfferID:  offerId,
284         }
285         data, err := json.Marshal(response)
286         if err != nil {
287                 peerConnection.Close()
288                 return fmt.Errorf("marshalling response: %w", err)
289         }
290         tc.mu.Lock()
291         defer tc.mu.Unlock()
292         if err := tc.writeMessage(data); err != nil {
293                 peerConnection.Close()
294                 return fmt.Errorf("writing response: %w", err)
295         }
296         timer := time.AfterFunc(30*time.Second, func() {
297                 metrics.Add("answering peer connections timed out", 1)
298                 peerConnection.Close()
299         })
300         peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
301                 setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
302                         timer.Stop()
303                         metrics.Add("answering peer connection conversions", 1)
304                         tc.mu.Lock()
305                         tc.stats.ConvertedInboundConns++
306                         tc.mu.Unlock()
307                         tc.OnConn(dc, DataChannelContext{
308                                 Local:        answer,
309                                 Remote:       offer,
310                                 OfferId:      offerId,
311                                 LocalOffered: false,
312                                 InfoHash:     infoHash,
313                         })
314                 })
315         })
316         return nil
317 }
318
319 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
320         tc.mu.Lock()
321         defer tc.mu.Unlock()
322         offer, ok := tc.outboundOffers[offerId]
323         if !ok {
324                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %q", offerId)
325                 return
326         }
327         //tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
328         metrics.Add("outbound offers answered", 1)
329         err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
330                 metrics.Add("outbound offers answered with datachannel", 1)
331                 tc.mu.Lock()
332                 tc.stats.ConvertedOutboundConns++
333                 tc.mu.Unlock()
334                 tc.OnConn(dc, DataChannelContext{
335                         Local:        offer.originalOffer,
336                         Remote:       answer,
337                         OfferId:      offerId,
338                         LocalOffered: true,
339                         InfoHash:     offer.infoHash,
340                 })
341         })
342         if err != nil {
343                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err)
344                 return
345         }
346         delete(tc.outboundOffers, offerId)
347         go tc.Announce(tracker.None, offer.infoHash)
348 }