]> Sergey Matveev's repositories - btrtrc.git/commitdiff
webtorrent: Create data channel earlier per webrtc examples
authorMatt Joiner <anacrolix@gmail.com>
Mon, 11 Jul 2022 01:39:54 +0000 (11:39 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 12 Jul 2022 06:15:50 +0000 (16:15 +1000)
otel.go [new file with mode: 0644]
webtorrent/otel.go
webtorrent/tracker-client.go
webtorrent/transport.go
webtorrent/transport_test.go [new file with mode: 0644]

diff --git a/otel.go b/otel.go
new file mode 100644 (file)
index 0000000..5dddd6a
--- /dev/null
+++ b/otel.go
@@ -0,0 +1,3 @@
+package torrent
+
+const tracerName = "anacrolix.torrent"
index d939c3f83401d7ac6c36cf039d2d54d8079fdc47..2c09964219914df4c2d0493c93b2089799b67854 100644 (file)
@@ -1,22 +1,6 @@
 package webtorrent
 
-import (
-       "context"
-       "github.com/pion/webrtc/v3"
-       "go.opentelemetry.io/otel"
-       "go.opentelemetry.io/otel/trace"
-)
-
 const (
        tracerName        = "anacrolix.torrent.webtorrent"
        webrtcConnTypeKey = "webtorrent.webrtc.conn.type"
 )
-
-func dataChannelStarted(peerConnectionCtx context.Context, dc *webrtc.DataChannel) (dataChannelCtx context.Context, span trace.Span) {
-       trace.SpanFromContext(peerConnectionCtx).AddEvent("starting data channel")
-       dataChannelCtx, span = otel.Tracer(tracerName).Start(peerConnectionCtx, "DataChannel")
-       dc.OnClose(func() {
-               span.End()
-       })
-       return
-}
index 95d87ff45af7d3855441bc4e53efe47aba55d785..d65dcab4d40f6359dcca145046a4f533ec0d98fb 100644 (file)
@@ -5,7 +5,6 @@ import (
        "crypto/rand"
        "encoding/json"
        "fmt"
-       "go.opentelemetry.io/otel/codes"
        "go.opentelemetry.io/otel/trace"
        "sync"
        "time"
@@ -57,18 +56,17 @@ type outboundOffer struct {
        originalOffer  webrtc.SessionDescription
        peerConnection *wrappedPeerConnection
        infoHash       [20]byte
+       dataChannel    *webrtc.DataChannel
 }
 
 type DataChannelContext struct {
-       // Can these be obtained by just calling the relevant methods on peerConnection?
-       Local, Remote webrtc.SessionDescription
-       OfferId       string
-       LocalOffered  bool
-       InfoHash      [20]byte
+       OfferId      string
+       LocalOffered bool
+       InfoHash     [20]byte
        // This is private as some methods might not be appropriate with data channel context.
        peerConnection *wrappedPeerConnection
-       span           trace.Span
-       ctx            context.Context
+       Span           trace.Span
+       Context        context.Context
 }
 
 func (me *DataChannelContext) GetSelectedIceCandidatePair() (*webrtc.ICECandidatePair, error) {
@@ -211,7 +209,7 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte
        }
        offerIDBinary := binaryToJsonString(randOfferId[:])
 
-       pc, offer, err := newOffer(tc.Logger)
+       pc, dc, offer, err := tc.newOffer(tc.Logger, offerIDBinary, infoHash)
        if err != nil {
                return fmt.Errorf("creating offer: %w", err)
        }
@@ -257,6 +255,7 @@ func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte
                peerConnection: pc,
                originalOffer:  offer,
                infoHash:       infoHash,
+               dataChannel:    dc,
        }
        return nil
 }
@@ -291,30 +290,43 @@ func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
                                tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
                                break
                        }
-                       tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
+                       err = tc.handleOffer(offerContext{
+                               SessDesc: *ar.Offer,
+                               Id:       ar.OfferID,
+                               InfoHash: ih,
+                       }, ar.PeerID)
+                       if err != nil {
+                               tc.Logger.Levelf(log.Error, "handling offer for infohash %x: %v", ih, err)
+                       }
                case ar.Answer != nil:
                        tc.handleAnswer(ar.OfferID, *ar.Answer)
+               default:
+                       tc.Logger.Levelf(log.Warning, "unhandled announce response %q", message)
                }
        }
 }
 
+type offerContext struct {
+       SessDesc webrtc.SessionDescription
+       Id       string
+       InfoHash [20]byte
+}
+
 func (tc *TrackerClient) handleOffer(
-       offer webrtc.SessionDescription,
-       offerId string,
-       infoHash [20]byte,
+       offerContext offerContext,
        peerId string,
 ) error {
-       peerConnection, answer, err := newAnsweringPeerConnection(tc.Logger, offer)
+       peerConnection, answer, err := tc.newAnsweringPeerConnection(offerContext)
        if err != nil {
-               return fmt.Errorf("write AnnounceResponse: %w", err)
+               return fmt.Errorf("creating answering peer connection: %w", err)
        }
        response := AnnounceResponse{
                Action:   "announce",
-               InfoHash: binaryToJsonString(infoHash[:]),
+               InfoHash: binaryToJsonString(offerContext.InfoHash[:]),
                PeerID:   tc.peerIdBinary(),
                ToPeerID: peerId,
                Answer:   &answer,
-               OfferID:  offerId,
+               OfferID:  offerContext.Id,
        }
        data, err := json.Marshal(response)
        if err != nil {
@@ -327,31 +339,6 @@ func (tc *TrackerClient) handleOffer(
                peerConnection.Close()
                return fmt.Errorf("writing response: %w", err)
        }
-       timer := time.AfterFunc(30*time.Second, func() {
-               peerConnection.span.SetStatus(codes.Error, "answer timeout")
-               metrics.Add("answering peer connections timed out", 1)
-               peerConnection.Close()
-       })
-       peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
-               ctx, span := dataChannelStarted(peerConnection.ctx, d)
-               setDataChannelOnOpen(ctx, d, peerConnection, func(dc datachannel.ReadWriteCloser) {
-                       timer.Stop()
-                       metrics.Add("answering peer connection conversions", 1)
-                       tc.mu.Lock()
-                       tc.stats.ConvertedInboundConns++
-                       tc.mu.Unlock()
-                       tc.OnConn(dc, DataChannelContext{
-                               Local:          answer,
-                               Remote:         offer,
-                               OfferId:        offerId,
-                               LocalOffered:   false,
-                               InfoHash:       infoHash,
-                               peerConnection: peerConnection,
-                               ctx:            ctx,
-                               span:           span,
-                       })
-               })
-       })
        return nil
 }
 
@@ -365,44 +352,13 @@ func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescr
        }
        // tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
        metrics.Add("outbound offers answered", 1)
-       // Why do we create the data channel before setting the remote description? Are we trying to avoid the peer
-       // initiating?
-       dataChannel, err := offer.peerConnection.CreateDataChannel("webrtc-datachannel", nil)
-       if err != nil {
-               err = fmt.Errorf("creating data channel: %w", err)
-               tc.Logger.LevelPrint(log.Error, err)
-               offer.peerConnection.span.RecordError(err)
-               offer.peerConnection.Close()
-               goto deleteOffer
-       }
-       {
-               ctx, span := dataChannelStarted(offer.peerConnection.ctx, dataChannel)
-               setDataChannelOnOpen(ctx, dataChannel, offer.peerConnection, func(dc datachannel.ReadWriteCloser) {
-                       metrics.Add("outbound offers answered with datachannel", 1)
-                       tc.mu.Lock()
-                       tc.stats.ConvertedOutboundConns++
-                       tc.mu.Unlock()
-                       tc.OnConn(dc, DataChannelContext{
-                               Local:          offer.originalOffer,
-                               Remote:         answer,
-                               OfferId:        offerId,
-                               LocalOffered:   true,
-                               InfoHash:       offer.infoHash,
-                               peerConnection: offer.peerConnection,
-                               ctx:            ctx,
-                               span:           span,
-                       })
-               })
-       }
-       err = offer.peerConnection.SetRemoteDescription(answer)
+       err := offer.peerConnection.SetRemoteDescription(answer)
        if err != nil {
                err = fmt.Errorf("using outbound offer answer: %w", err)
                offer.peerConnection.span.RecordError(err)
-               dataChannel.Close()
-               tc.Logger.WithDefaultLevel(log.Error).Print(err)
+               tc.Logger.LevelPrint(log.Error, err)
                return
        }
-deleteOffer:
        delete(tc.outboundOffers, offerId)
        go tc.Announce(tracker.None, offer.infoHash)
 }
index 561a6b9fb5c98783dc5bd29bc629c64b49227660..da56e40d1a2ef1d9ec34d8b0ca780c489d139b3b 100644 (file)
@@ -4,17 +4,21 @@ import (
        "context"
        "expvar"
        "fmt"
+       "github.com/anacrolix/log"
+       "github.com/anacrolix/missinggo/v2/pproffd"
+       "github.com/pion/datachannel"
+       "github.com/pion/webrtc/v3"
+       "go.opentelemetry.io/otel"
        "go.opentelemetry.io/otel/attribute"
        "go.opentelemetry.io/otel/codes"
        "go.opentelemetry.io/otel/trace"
        "io"
        "sync"
+       "time"
+)
 
-       "github.com/anacrolix/log"
-       "github.com/anacrolix/missinggo/v2/pproffd"
-       "github.com/pion/datachannel"
-       "github.com/pion/webrtc/v3"
-       "go.opentelemetry.io/otel"
+const (
+       dataChannelLabel = "webrtc-datachannel"
 )
 
 var (
@@ -82,11 +86,15 @@ func setAndGatherLocalDescription(peerConnection *wrappedPeerConnection, sdp web
        return *peerConnection.LocalDescription(), nil
 }
 
-// newOffer creates a transport and returns a WebRTC offer to be announced
-func newOffer(
+// newOffer creates a transport and returns a WebRTC offer to be announced. See
+// https://github.com/pion/webrtc/blob/master/examples/data-channels/jsfiddle/main.go for what this is modelled on.
+func (tc *TrackerClient) newOffer(
        logger log.Logger,
+       offerId string,
+       infoHash [20]byte,
 ) (
        peerConnection *wrappedPeerConnection,
+       dataChannel *webrtc.DataChannel,
        offer webrtc.SessionDescription,
        err error,
 ) {
@@ -97,6 +105,26 @@ func newOffer(
 
        peerConnection.span.SetAttributes(attribute.String(webrtcConnTypeKey, "offer"))
 
+       dataChannel, err = peerConnection.CreateDataChannel(dataChannelLabel, nil)
+       if err != nil {
+               err = fmt.Errorf("creating data channel: %w", err)
+               peerConnection.Close()
+       }
+       initDataChannel(dataChannel, peerConnection, func(dc datachannel.ReadWriteCloser, dcCtx context.Context, dcSpan trace.Span) {
+               metrics.Add("outbound offers answered with datachannel", 1)
+               tc.mu.Lock()
+               tc.stats.ConvertedOutboundConns++
+               tc.mu.Unlock()
+               tc.OnConn(dc, DataChannelContext{
+                       OfferId:        offerId,
+                       LocalOffered:   true,
+                       InfoHash:       infoHash,
+                       peerConnection: peerConnection,
+                       Context:        dcCtx,
+                       Span:           dcSpan,
+               })
+       })
+
        offer, err = peerConnection.CreateOffer(nil)
        if err != nil {
                peerConnection.Close()
@@ -110,38 +138,62 @@ func newOffer(
        return
 }
 
-func initAnsweringPeerConnection(
-       peerConnection *wrappedPeerConnection,
-       offer webrtc.SessionDescription,
+type onDetachedDataChannelFunc func(detached datachannel.ReadWriteCloser, ctx context.Context, span trace.Span)
+
+func (tc *TrackerClient) initAnsweringPeerConnection(
+       peerConn *wrappedPeerConnection,
+       offerContext offerContext,
 ) (answer webrtc.SessionDescription, err error) {
-       peerConnection.span.SetAttributes(attribute.String(webrtcConnTypeKey, "answer"))
+       peerConn.span.SetAttributes(attribute.String(webrtcConnTypeKey, "answer"))
 
-       err = peerConnection.SetRemoteDescription(offer)
+       timer := time.AfterFunc(30*time.Second, func() {
+               peerConn.span.SetStatus(codes.Error, "answer timeout")
+               metrics.Add("answering peer connections timed out", 1)
+               peerConn.Close()
+       })
+       peerConn.OnDataChannel(func(d *webrtc.DataChannel) {
+               initDataChannel(d, peerConn, func(detached datachannel.ReadWriteCloser, ctx context.Context, span trace.Span) {
+                       timer.Stop()
+                       metrics.Add("answering peer connection conversions", 1)
+                       tc.mu.Lock()
+                       tc.stats.ConvertedInboundConns++
+                       tc.mu.Unlock()
+                       tc.OnConn(detached, DataChannelContext{
+                               OfferId:        offerContext.Id,
+                               LocalOffered:   false,
+                               InfoHash:       offerContext.InfoHash,
+                               peerConnection: peerConn,
+                               Context:        ctx,
+                               Span:           span,
+                       })
+               })
+       })
+
+       err = peerConn.SetRemoteDescription(offerContext.SessDesc)
        if err != nil {
                return
        }
-       answer, err = peerConnection.CreateAnswer(nil)
+       answer, err = peerConn.CreateAnswer(nil)
        if err != nil {
                return
        }
 
-       answer, err = setAndGatherLocalDescription(peerConnection, answer)
+       answer, err = setAndGatherLocalDescription(peerConn, answer)
        return
 }
 
 // newAnsweringPeerConnection creates a transport from a WebRTC offer and returns a WebRTC answer to be announced.
-func newAnsweringPeerConnection(
-       logger log.Logger,
-       offer webrtc.SessionDescription,
+func (tc *TrackerClient) newAnsweringPeerConnection(
+       offerContext offerContext,
 ) (
        peerConn *wrappedPeerConnection, answer webrtc.SessionDescription, err error,
 ) {
-       peerConn, err = newPeerConnection(logger)
+       peerConn, err = newPeerConnection(tc.Logger)
        if err != nil {
                err = fmt.Errorf("failed to create new connection: %w", err)
                return
        }
-       answer, err = initAnsweringPeerConnection(peerConn, offer)
+       answer, err = tc.initAnsweringPeerConnection(peerConn, offerContext)
        if err != nil {
                peerConn.span.RecordError(err)
                peerConn.Close()
@@ -162,22 +214,25 @@ func (me ioCloserFunc) Close() error {
        return me()
 }
 
-func setDataChannelOnOpen(
-       ctx context.Context,
+func initDataChannel(
        dc *webrtc.DataChannel,
        pc *wrappedPeerConnection,
-       onOpen func(closer datachannel.ReadWriteCloser),
+       onOpen onDetachedDataChannelFunc,
 ) {
+       var span trace.Span
+       dc.OnClose(func() {
+               span.End()
+       })
        dc.OnOpen(func() {
-               dataChannelSpan := trace.SpanFromContext(ctx)
-               dataChannelSpan.AddEvent("opened")
+               pc.span.AddEvent("data channel opened")
+               var ctx context.Context
+               ctx, span = otel.Tracer(tracerName).Start(pc.ctx, "DataChannel")
                raw, err := dc.Detach()
                if err != nil {
                        // This shouldn't happen if the API is configured correctly, and we call from OnOpen.
                        panic(err)
                }
-               //dc.OnClose()
-               onOpen(hookDataChannelCloser(raw, pc, dataChannelSpan))
+               onOpen(hookDataChannelCloser(raw, pc, span), ctx, span)
        })
 }
 
diff --git a/webtorrent/transport_test.go b/webtorrent/transport_test.go
new file mode 100644 (file)
index 0000000..b993487
--- /dev/null
@@ -0,0 +1,34 @@
+package webtorrent
+
+import (
+       "github.com/anacrolix/log"
+       qt "github.com/frankban/quicktest"
+       "github.com/pion/webrtc/v3"
+       "testing"
+)
+
+func TestClosingPeerConnectionDoesNotCloseUnopenedDataChannel(t *testing.T) {
+       c := qt.New(t)
+       var tc TrackerClient
+       pc, dc, _, err := tc.newOffer(log.Default, "", [20]byte{})
+       c.Assert(err, qt.IsNil)
+       defer pc.Close()
+       defer dc.Close()
+       peerConnClosed := make(chan struct{})
+       pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
+               if state == webrtc.PeerConnectionStateClosed {
+                       close(peerConnClosed)
+               }
+       })
+       dc.OnClose(func() {
+               // This should not be called because the DataChannel is never opened.
+               t.Fatal("DataChannel.OnClose handler called")
+       })
+       t.Logf("data channel ready state before close: %v", dc.ReadyState())
+       dc.OnError(func(err error) {
+               t.Logf("data channel error: %v", err)
+       })
+       pc.Close()
+       c.Check(dc.ReadyState(), qt.Equals, webrtc.DataChannelStateClosed)
+       <-peerConnClosed
+}