]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker_client.go
Tidy up the webtorrent package, remove buffer
[btrtrc.git] / webtorrent / tracker_client.go
1 package webtorrent
2
3 import (
4         "crypto/rand"
5         "encoding/json"
6         "fmt"
7         "sync"
8
9         "github.com/anacrolix/log"
10
11         "github.com/anacrolix/torrent/tracker"
12         "github.com/gorilla/websocket"
13         "github.com/pion/datachannel"
14         "github.com/pion/webrtc/v2"
15 )
16
17 // Client represents the webtorrent client
18 type TrackerClient struct {
19         lock           sync.Mutex
20         peerIDBinary   string
21         infoHashBinary string
22         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
23         tracker        *websocket.Conn
24         onConn         onDataChannelOpen
25         logger         log.Logger
26 }
27
28 // outboundOffer represents an outstanding offer.
29 type outboundOffer struct {
30         originalOffer webrtc.SessionDescription
31         transport     *transport
32 }
33
34 type DataChannelContext struct {
35         Local, Remote webrtc.SessionDescription
36         OfferId       string
37         LocalOffered  bool
38 }
39
40 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
41
42 func NewClient(peerId, infoHash [20]byte, onConn onDataChannelOpen, logger log.Logger) *TrackerClient {
43         return &TrackerClient{
44                 outboundOffers: make(map[string]outboundOffer),
45                 peerIDBinary:   binaryToJsonString(peerId[:]),
46                 infoHashBinary: binaryToJsonString(infoHash[:]),
47                 onConn:         onConn,
48                 logger:         logger,
49         }
50 }
51
52 func (c *TrackerClient) Run(ar tracker.AnnounceRequest, url string) error {
53         t, _, err := websocket.DefaultDialer.Dial(url, nil)
54         if err != nil {
55                 return fmt.Errorf("failed to dial tracker: %w", err)
56         }
57         defer t.Close()
58         c.logger.WithValues(log.Info).Printf("dialed tracker %q", url)
59         c.tracker = t
60
61         go func() {
62                 err := c.announce(ar)
63                 if err != nil {
64                         c.logger.WithValues(log.Error).Printf("error announcing: %v", err)
65                 }
66         }()
67         return c.trackerReadLoop()
68 }
69
70 func (c *TrackerClient) announce(request tracker.AnnounceRequest) error {
71         transport, offer, err := newTransport()
72         if err != nil {
73                 return fmt.Errorf("failed to create transport: %w", err)
74         }
75
76         var randOfferId [20]byte
77         _, err = rand.Read(randOfferId[:])
78         if err != nil {
79                 return fmt.Errorf("failed to generate bytes: %w", err)
80         }
81         offerIDBinary := binaryToJsonString(randOfferId[:])
82
83         c.lock.Lock()
84         c.outboundOffers[offerIDBinary] = outboundOffer{
85                 transport:     transport,
86                 originalOffer: offer,
87         }
88         c.lock.Unlock()
89
90         req := AnnounceRequest{
91                 Numwant:    1, // If higher we need to create equal amount of offers
92                 Uploaded:   0,
93                 Downloaded: 0,
94                 Left:       request.Left,
95                 Event:      "started",
96                 Action:     "announce",
97                 InfoHash:   c.infoHashBinary,
98                 PeerID:     c.peerIDBinary,
99                 Offers: []Offer{{
100                         OfferID: offerIDBinary,
101                         Offer:   offer,
102                 }},
103         }
104
105         data, err := json.Marshal(req)
106         if err != nil {
107                 return fmt.Errorf("failed to marshal request: %w", err)
108         }
109         c.lock.Lock()
110         tracker := c.tracker
111         err = tracker.WriteMessage(websocket.TextMessage, data)
112         if err != nil {
113                 return fmt.Errorf("write AnnounceRequest: %w", err)
114                 c.lock.Unlock()
115         }
116         c.lock.Unlock()
117         return nil
118 }
119
120 func (c *TrackerClient) trackerReadLoop() error {
121
122         c.lock.Lock()
123         tracker := c.tracker
124         c.lock.Unlock()
125         for {
126                 _, message, err := tracker.ReadMessage()
127                 if err != nil {
128                         return fmt.Errorf("read error: %w", err)
129                 }
130                 c.logger.WithValues(log.Debug).Printf("received message from tracker: %q", message)
131
132                 var ar AnnounceResponse
133                 if err := json.Unmarshal(message, &ar); err != nil {
134                         log.Printf("error unmarshaling announce response: %v", err)
135                         continue
136                 }
137                 if ar.InfoHash != c.infoHashBinary {
138                         log.Printf("announce response for different hash: expected %q got %q", c.infoHashBinary, ar.InfoHash)
139                         continue
140                 }
141                 switch {
142                 case ar.Offer != nil:
143                         _, answer, err := newTransportFromOffer(*ar.Offer, c.onConn, ar.OfferID)
144                         if err != nil {
145                                 return fmt.Errorf("write AnnounceResponse: %w", err)
146                         }
147
148                         req := AnnounceResponse{
149                                 Action:   "announce",
150                                 InfoHash: c.infoHashBinary,
151                                 PeerID:   c.peerIDBinary,
152                                 ToPeerID: ar.PeerID,
153                                 Answer:   &answer,
154                                 OfferID:  ar.OfferID,
155                         }
156                         data, err := json.Marshal(req)
157                         if err != nil {
158                                 return fmt.Errorf("failed to marshal request: %w", err)
159                         }
160
161                         c.lock.Lock()
162                         err = tracker.WriteMessage(websocket.TextMessage, data)
163                         if err != nil {
164                                 return fmt.Errorf("write AnnounceResponse: %w", err)
165                                 c.lock.Unlock()
166                         }
167                         c.lock.Unlock()
168                 case ar.Answer != nil:
169                         c.lock.Lock()
170                         offer, ok := c.outboundOffers[ar.OfferID]
171                         c.lock.Unlock()
172                         if !ok {
173                                 c.logger.WithValues(log.Warning).Printf("could not find offer for id %q", ar.OfferID)
174                                 continue
175                         }
176                         c.logger.Printf("offer %q got answer %v", ar.OfferID, *ar.Answer)
177                         err = offer.transport.SetAnswer(*ar.Answer, func(dc datachannel.ReadWriteCloser) {
178                                 c.onConn(dc, DataChannelContext{
179                                         Local:        offer.originalOffer,
180                                         Remote:       *ar.Answer,
181                                         OfferId:      ar.OfferID,
182                                         LocalOffered: true,
183                                 })
184                         })
185                         if err != nil {
186                                 return fmt.Errorf("failed to sent answer: %w", err)
187                         }
188                 }
189         }
190 }