]> Sergey Matveev's repositories - btrtrc.git/blob - webtorrent/tracker_client.go
Remove unused outbound offer answered field
[btrtrc.git] / webtorrent / tracker_client.go
1 package webtorrent
2
3 import (
4         "crypto/rand"
5         "encoding/json"
6         "fmt"
7         "sync"
8         "time"
9
10         "github.com/anacrolix/log"
11
12         "github.com/anacrolix/torrent/tracker"
13         "github.com/gorilla/websocket"
14         "github.com/pion/datachannel"
15         "github.com/pion/webrtc/v2"
16 )
17
18 // Client represents the webtorrent client
19 type TrackerClient struct {
20         Url                string
21         GetAnnounceRequest func(tracker.AnnounceEvent) tracker.AnnounceRequest
22         PeerId             [20]byte
23         InfoHash           [20]byte
24         OnConn             onDataChannelOpen
25         Logger             log.Logger
26
27         lock           sync.Mutex
28         outboundOffers map[string]outboundOffer // OfferID to outboundOffer
29         wsConn         *websocket.Conn
30 }
31
32 func (me *TrackerClient) peerIdBinary() string {
33         return binaryToJsonString(me.PeerId[:])
34 }
35
36 func (me *TrackerClient) infoHashBinary() string {
37         return binaryToJsonString(me.InfoHash[:])
38 }
39
40 // outboundOffer represents an outstanding offer.
41 type outboundOffer struct {
42         originalOffer  webrtc.SessionDescription
43         peerConnection wrappedPeerConnection
44         dataChannel    *webrtc.DataChannel
45 }
46
47 type DataChannelContext struct {
48         Local, Remote webrtc.SessionDescription
49         OfferId       string
50         LocalOffered  bool
51 }
52
53 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
54
55 func (tc *TrackerClient) doWebsocket() error {
56         c, _, err := websocket.DefaultDialer.Dial(tc.Url, nil)
57         if err != nil {
58                 return fmt.Errorf("dialing tracker: %w", err)
59         }
60         defer c.Close()
61         tc.Logger.WithDefaultLevel(log.Debug).Printf("dialed tracker %q", tc.Url)
62         tc.wsConn = c
63         go func() {
64                 err := tc.announce(tracker.Started)
65                 if err != nil {
66                         tc.Logger.WithDefaultLevel(log.Error).Printf("error in initial announce: %v", err)
67                 }
68         }()
69         err = tc.trackerReadLoop(tc.wsConn)
70         tc.lock.Lock()
71         tc.closeUnusedOffers()
72         tc.lock.Unlock()
73         return err
74 }
75
76 func (tc *TrackerClient) Run() error {
77         for {
78                 err := tc.doWebsocket()
79                 tc.Logger.Printf("websocket instance ended: %v", err)
80                 time.Sleep(time.Minute)
81         }
82 }
83
84 func (tc *TrackerClient) closeUnusedOffers() {
85         for _, offer := range tc.outboundOffers {
86                 offer.peerConnection.Close()
87         }
88         tc.outboundOffers = nil
89 }
90
91 func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error {
92         var randOfferId [20]byte
93         _, err := rand.Read(randOfferId[:])
94         if err != nil {
95                 return fmt.Errorf("generating offer_id bytes: %w", err)
96         }
97         offerIDBinary := binaryToJsonString(randOfferId[:])
98
99         pc, dc, offer, err := newOffer()
100         if err != nil {
101                 return fmt.Errorf("creating offer: %w", err)
102         }
103
104         request := tc.GetAnnounceRequest(event)
105
106         req := AnnounceRequest{
107                 Numwant:    1, // If higher we need to create equal amount of offers.
108                 Uploaded:   request.Uploaded,
109                 Downloaded: request.Downloaded,
110                 Left:       request.Left,
111                 Event:      request.Event.String(),
112                 Action:     "announce",
113                 InfoHash:   tc.infoHashBinary(),
114                 PeerID:     tc.peerIdBinary(),
115                 Offers: []Offer{{
116                         OfferID: offerIDBinary,
117                         Offer:   offer,
118                 }},
119         }
120
121         data, err := json.Marshal(req)
122         if err != nil {
123                 return fmt.Errorf("marshalling request: %w", err)
124         }
125
126         tc.lock.Lock()
127         defer tc.lock.Unlock()
128
129         err = tc.wsConn.WriteMessage(websocket.TextMessage, data)
130         if err != nil {
131                 pc.Close()
132                 return fmt.Errorf("write AnnounceRequest: %w", err)
133         }
134         if tc.outboundOffers == nil {
135                 tc.outboundOffers = make(map[string]outboundOffer)
136         }
137         tc.outboundOffers[offerIDBinary] = outboundOffer{
138                 peerConnection: pc,
139                 dataChannel:    dc,
140                 originalOffer:  offer,
141         }
142         return nil
143 }
144
145 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
146         for {
147                 _, message, err := tracker.ReadMessage()
148                 if err != nil {
149                         return fmt.Errorf("read message error: %w", err)
150                 }
151                 tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
152
153                 var ar AnnounceResponse
154                 if err := json.Unmarshal(message, &ar); err != nil {
155                         tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
156                         continue
157                 }
158                 if ar.InfoHash != tc.infoHashBinary() {
159                         tc.Logger.Printf("announce response for different hash: expected %q got %q", tc.infoHashBinary(), ar.InfoHash)
160                         continue
161                 }
162                 switch {
163                 case ar.Offer != nil:
164                         answer, err := getAnswerForOffer(*ar.Offer, tc.OnConn, ar.OfferID)
165                         if err != nil {
166                                 return fmt.Errorf("write AnnounceResponse: %w", err)
167                         }
168
169                         req := AnnounceResponse{
170                                 Action:   "announce",
171                                 InfoHash: tc.infoHashBinary(),
172                                 PeerID:   tc.peerIdBinary(),
173                                 ToPeerID: ar.PeerID,
174                                 Answer:   &answer,
175                                 OfferID:  ar.OfferID,
176                         }
177                         data, err := json.Marshal(req)
178                         if err != nil {
179                                 return fmt.Errorf("failed to marshal request: %w", err)
180                         }
181
182                         tc.lock.Lock()
183                         err = tracker.WriteMessage(websocket.TextMessage, data)
184                         if err != nil {
185                                 return fmt.Errorf("write AnnounceResponse: %w", err)
186                                 tc.lock.Unlock()
187                         }
188                         tc.lock.Unlock()
189                 case ar.Answer != nil:
190                         tc.handleAnswer(ar.OfferID, *ar.Answer)
191                 }
192         }
193 }
194
195 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
196         tc.lock.Lock()
197         defer tc.lock.Unlock()
198         offer, ok := tc.outboundOffers[offerId]
199         if !ok {
200                 tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %q", offerId)
201                 return
202         }
203         tc.Logger.Printf("offer %q got answer %v", offerId, answer)
204         err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
205                 tc.OnConn(dc, DataChannelContext{
206                         Local:        offer.originalOffer,
207                         Remote:       answer,
208                         OfferId:      offerId,
209                         LocalOffered: true,
210                 })
211         })
212         if err != nil {
213                 tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err)
214                 return
215         }
216         delete(tc.outboundOffers, offerId)
217         go tc.announce(tracker.None)
218 }