]> Sergey Matveev's repositories - btrtrc.git/blobdiff - webtorrent/tracker-client.go
Add WebRTC ICE servers config (#824)
[btrtrc.git] / webtorrent / tracker-client.go
index 9f02c7e785f38e7767e86e73f3304eb61a322ef8..bc9dab312e2d64acf843f46d423b424e1514ebad 100644 (file)
@@ -1,18 +1,22 @@
 package webtorrent
 
 import (
+       "context"
        "crypto/rand"
        "encoding/json"
        "fmt"
+       "net/http"
        "sync"
        "time"
 
+       g "github.com/anacrolix/generics"
        "github.com/anacrolix/log"
-
-       "github.com/anacrolix/torrent/tracker"
        "github.com/gorilla/websocket"
        "github.com/pion/datachannel"
        "github.com/pion/webrtc/v3"
+       "go.opentelemetry.io/otel/trace"
+
+       "github.com/anacrolix/torrent/tracker"
 )
 
 type TrackerClientStats struct {
@@ -32,11 +36,14 @@ type TrackerClient struct {
 
        mu             sync.Mutex
        cond           sync.Cond
-       outboundOffers map[string]outboundOffer // OfferID to outboundOffer
+       outboundOffers map[string]outboundOfferValue // OfferID to outboundOfferValue
        wsConn         *websocket.Conn
        closed         bool
        stats          TrackerClientStats
        pingTicker     *time.Ticker
+
+       WebsocketTrackerHttpHeader func() http.Header
+       ICEServers                 []string
 }
 
 func (me *TrackerClient) Stats() TrackerClientStats {
@@ -49,19 +56,31 @@ func (me *TrackerClient) peerIdBinary() string {
        return binaryToJsonString(me.PeerId[:])
 }
 
-// outboundOffer represents an outstanding offer.
 type outboundOffer struct {
+       offerId string
+       outboundOfferValue
+}
+
+// outboundOfferValue represents an outstanding offer.
+type outboundOfferValue struct {
        originalOffer  webrtc.SessionDescription
        peerConnection *wrappedPeerConnection
-       dataChannel    *webrtc.DataChannel
        infoHash       [20]byte
+       dataChannel    *webrtc.DataChannel
 }
 
 type DataChannelContext struct {
-       Local, Remote webrtc.SessionDescription
-       OfferId       string
-       LocalOffered  bool
-       InfoHash      [20]byte
+       OfferId      string
+       LocalOffered bool
+       InfoHash     [20]byte
+       // This is private as some methods might not be appropriate with data channel context.
+       peerConnection *wrappedPeerConnection
+       Span           trace.Span
+       Context        context.Context
+}
+
+func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
+       return me.peerConnection.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
 }
 
 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
@@ -71,7 +90,13 @@ func (tc *TrackerClient) doWebsocket() error {
        tc.mu.Lock()
        tc.stats.Dials++
        tc.mu.Unlock()
-       c, _, err := tc.Dialer.Dial(tc.Url, nil)
+
+       var header http.Header
+       if tc.WebsocketTrackerHttpHeader != nil {
+               header = tc.WebsocketTrackerHttpHeader()
+       }
+
+       c, _, err := tc.Dialer.Dial(tc.Url, header)
        if err != nil {
                return fmt.Errorf("dialing tracker: %w", err)
        }
@@ -176,12 +201,27 @@ func (tc *TrackerClient) announceOffers() {
 func (tc *TrackerClient) closeUnusedOffers() {
        for _, offer := range tc.outboundOffers {
                offer.peerConnection.Close()
+               offer.dataChannel.Close()
        }
        tc.outboundOffers = nil
 }
 
+func (tc *TrackerClient) CloseOffersForInfohash(infoHash [20]byte) {
+       tc.mu.Lock()
+       defer tc.mu.Unlock()
+       for key, offer := range tc.outboundOffers {
+               if offer.infoHash == infoHash {
+                       offer.peerConnection.Close()
+                       delete(tc.outboundOffers, key)
+               }
+       }
+}
+
 func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
        metrics.Add("outbound announces", 1)
+       if event == tracker.Stopped {
+               return tc.announce(event, infoHash, nil)
+       }
        var randOfferId [20]byte
        _, err := rand.Read(randOfferId[:])
        if err != nil {
@@ -189,19 +229,37 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte
        }
        offerIDBinary := binaryToJsonString(randOfferId[:])
 
-       pc, dc, offer, err := newOffer()
+       pc, dc, offer, err := tc.newOffer(tc.Logger, offerIDBinary, infoHash)
        if err != nil {
                return fmt.Errorf("creating offer: %w", err)
        }
 
-       request, err := tc.GetAnnounceRequest(event, infoHash)
+       err = tc.announce(event, infoHash, []outboundOffer{
+               {
+                       offerId: offerIDBinary,
+                       outboundOfferValue: outboundOfferValue{
+                               originalOffer:  offer,
+                               peerConnection: pc,
+                               infoHash:       infoHash,
+                               dataChannel:    dc,
+                       },
+               },
+       })
        if err != nil {
+               dc.Close()
                pc.Close()
+       }
+       return err
+}
+
+func (tc *TrackerClient) announce(event tracker.AnnounceEvent, infoHash [20]byte, offers []outboundOffer) error {
+       request, err := tc.GetAnnounceRequest(event, infoHash)
+       if err != nil {
                return fmt.Errorf("getting announce parameters: %w", err)
        }
 
        req := AnnounceRequest{
-               Numwant:    1, // If higher we need to create equal amount of offers.
+               Numwant:    len(offers),
                Uploaded:   request.Uploaded,
                Downloaded: request.Downloaded,
                Left:       request.Left,
@@ -209,15 +267,16 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte
                Action:     "announce",
                InfoHash:   binaryToJsonString(infoHash[:]),
                PeerID:     tc.peerIdBinary(),
-               Offers: []Offer{{
-                       OfferID: offerIDBinary,
-                       Offer:   offer,
-               }},
+       }
+       for _, offer := range offers {
+               req.Offers = append(req.Offers, Offer{
+                       OfferID: offer.offerId,
+                       Offer:   offer.originalOffer,
+               })
        }
 
        data, err := json.Marshal(req)
        if err != nil {
-               pc.Close()
                return fmt.Errorf("marshalling request: %w", err)
        }
 
@@ -225,17 +284,10 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte
        defer tc.mu.Unlock()
        err = tc.writeMessage(data)
        if err != nil {
-               pc.Close()
                return fmt.Errorf("write AnnounceRequest: %w", err)
        }
-       if tc.outboundOffers == nil {
-               tc.outboundOffers = make(map[string]outboundOffer)
-       }
-       tc.outboundOffers[offerIDBinary] = outboundOffer{
-               peerConnection: pc,
-               dataChannel:    dc,
-               originalOffer:  offer,
-               infoHash:       infoHash,
+       for _, offer := range offers {
+               g.MakeMapIfNilAndSet(&tc.outboundOffers, offer.offerId, offer.outboundOfferValue)
        }
        return nil
 }
@@ -270,30 +322,43 @@ func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
                                tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
                                break
                        }
-                       tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
+                       err = tc.handleOffer(offerContext{
+                               SessDesc: *ar.Offer,
+                               Id:       ar.OfferID,
+                               InfoHash: ih,
+                       }, ar.PeerID)
+                       if err != nil {
+                               tc.Logger.Levelf(log.Error, "handling offer for infohash %x: %v", ih, err)
+                       }
                case ar.Answer != nil:
                        tc.handleAnswer(ar.OfferID, *ar.Answer)
+               default:
+                       tc.Logger.Levelf(log.Warning, "unhandled announce response %q", message)
                }
        }
 }
 
+type offerContext struct {
+       SessDesc webrtc.SessionDescription
+       Id       string
+       InfoHash [20]byte
+}
+
 func (tc *TrackerClient) handleOffer(
-       offer webrtc.SessionDescription,
-       offerId string,
-       infoHash [20]byte,
+       offerContext offerContext,
        peerId string,
 ) error {
-       peerConnection, answer, err := newAnsweringPeerConnection(offer)
+       peerConnection, answer, err := tc.newAnsweringPeerConnection(offerContext)
        if err != nil {
-               return fmt.Errorf("write AnnounceResponse: %w", err)
+               return fmt.Errorf("creating answering peer connection: %w", err)
        }
        response := AnnounceResponse{
                Action:   "announce",
-               InfoHash: binaryToJsonString(infoHash[:]),
+               InfoHash: binaryToJsonString(offerContext.InfoHash[:]),
                PeerID:   tc.peerIdBinary(),
                ToPeerID: peerId,
                Answer:   &answer,
-               OfferID:  offerId,
+               OfferID:  offerContext.Id,
        }
        data, err := json.Marshal(response)
        if err != nil {
@@ -306,26 +371,6 @@ func (tc *TrackerClient) handleOffer(
                peerConnection.Close()
                return fmt.Errorf("writing response: %w", err)
        }
-       timer := time.AfterFunc(30*time.Second, func() {
-               metrics.Add("answering peer connections timed out", 1)
-               peerConnection.Close()
-       })
-       peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
-               setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
-                       timer.Stop()
-                       metrics.Add("answering peer connection conversions", 1)
-                       tc.mu.Lock()
-                       tc.stats.ConvertedInboundConns++
-                       tc.mu.Unlock()
-                       tc.OnConn(dc, DataChannelContext{
-                               Local:        answer,
-                               Remote:       offer,
-                               OfferId:      offerId,
-                               LocalOffered: false,
-                               InfoHash:     infoHash,
-                       })
-               })
-       })
        return nil
 }
 
@@ -339,21 +384,11 @@ func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescr
        }
        // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
        metrics.Add("outbound offers answered", 1)
-       err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
-               metrics.Add("outbound offers answered with datachannel", 1)
-               tc.mu.Lock()
-               tc.stats.ConvertedOutboundConns++
-               tc.mu.Unlock()
-               tc.OnConn(dc, DataChannelContext{
-                       Local:        offer.originalOffer,
-                       Remote:       answer,
-                       OfferId:      offerId,
-                       LocalOffered: true,
-                       InfoHash:     offer.infoHash,
-               })
-       })
+       err := offer.peerConnection.SetRemoteDescription(answer)
        if err != nil {
-               tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err)
+               err = fmt.Errorf("using outbound offer answer: %w", err)
+               offer.peerConnection.span.RecordError(err)
+               tc.Logger.LevelPrint(log.Error, err)
                return
        }
        delete(tc.outboundOffers, offerId)