client.go | 5 +++-- torrent.go | 23 ++++++++++++++++++++++- webrtc.go | 43 +++++++++++++++++++++++++++++++++++++++++++ webtorrent/client.go | 34 ++++++++++++---------------------- diff --git a/client.go b/client.go index 8a94e021756b5ff5c8a75a6af77043b169071cc8..43a4c46fd5431897f53637e3d8296ed466b20dd1 100644 --- a/client.go +++ b/client.go @@ -621,8 +621,8 @@ delete(t.halfOpen, addr) t.openNewConns() } -// Performs initiator handshakes and returns a connection. Returns nil -// *connection if no connection for valid reasons. +// Performs initiator handshakes and returns a connection. Returns nil *connection if no connection +// for valid reasons. func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr net.Addr, network string) (c *PeerConn, err error) { c = cl.newConnection(nc, true, remoteAddr, network) c.headerEncrypted = encryptHeader @@ -850,6 +850,7 @@ defer cl.unlock() cl.runHandshookConn(c, t) } +// Client lock must be held before entering this. func (cl *Client) runHandshookConn(c *PeerConn, t *Torrent) { c.setTorrent(t) if c.PeerID == cl.peerID { diff --git a/torrent.go b/torrent.go index ca8b06f55817b49bdbf9d48447a65a3acfc48835..de508e61be721cfd364e9f15294282cfb285edc7 100644 --- a/torrent.go +++ b/torrent.go @@ -2,6 +2,7 @@ package torrent import ( "container/heap" + "context" "crypto/sha1" "errors" "fmt" @@ -14,6 +15,7 @@ "time" "unsafe" "github.com/davecgh/go-spew/spew" + "github.com/pion/datachannel" "github.com/anacrolix/dht/v2" "github.com/anacrolix/log" @@ -1262,6 +1264,25 @@ } return true } +func (t *Torrent) onWebRtcConn( + c datachannel.ReadWriteCloser, + initiatedLocally bool, // Whether we offered first, or they did. +) { + defer c.Close() + pc, err := t.cl.handshakesConnection(context.Background(), webrtcNetConn{c}, t, false, nil, "webrtc") + if err != nil { + t.logger.Printf("error in handshaking webrtc connection: %v", err) + } + if initiatedLocally { + pc.Discovery = PeerSourceTracker + } else { + pc.Discovery = PeerSourceIncoming + } + t.cl.lock() + defer t.cl.unlock() + t.cl.runHandshookConn(pc, t) +} + func (t *Torrent) startScrapingTracker(_url string) { if _url == "" { return @@ -1288,7 +1309,7 @@ } sl := func() torrentTrackerAnnouncer { switch u.Scheme { case "ws", "wss": - wst := websocketTracker{*u, webtorrent.NewClient(t.cl.peerID, t.infoHash)} + wst := websocketTracker{*u, webtorrent.NewClient(t.cl.peerID, t.infoHash, t.onWebRtcConn)} go func() { err := wst.Client.Run(t.announceRequest(tracker.Started)) if err != nil { diff --git a/webrtc.go b/webrtc.go new file mode 100644 index 0000000000000000000000000000000000000000..d805b54bd3fd5864d9090b2d15ec756e2081f4a2 --- /dev/null +++ b/webrtc.go @@ -0,0 +1,43 @@ +package torrent + +import ( + "net" + "time" + + "github.com/pion/datachannel" +) + +type webrtcNetConn struct { + datachannel.ReadWriteCloser +} + +type webrtcNetAddr struct { +} + +func (webrtcNetAddr) Network() string { + return "webrtc" +} + +func (webrtcNetAddr) String() string { + return "" +} + +func (w webrtcNetConn) LocalAddr() net.Addr { + return webrtcNetAddr{} +} + +func (w webrtcNetConn) RemoteAddr() net.Addr { + return webrtcNetAddr{} +} + +func (w webrtcNetConn) SetDeadline(t time.Time) error { + return nil +} + +func (w webrtcNetConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (w webrtcNetConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/webtorrent/client.go b/webtorrent/client.go index 9b49e15df865bdbecfff8552a8df8868c036aaf7..d4dd604954caf6fc690cd8e0163196bb1c814ef8 100644 --- a/webtorrent/client.go +++ b/webtorrent/client.go @@ -3,7 +3,6 @@ import ( "encoding/json" "fmt" - "io" "sync" "github.com/anacrolix/log" @@ -26,6 +25,7 @@ peerIDBinary string infoHashBinary string offeredPeers map[string]Peer // OfferID to Peer tracker *websocket.Conn + onConn func(_ datachannel.ReadWriteCloser, initiatedLocally bool) } // Peer represents a remote peer @@ -42,11 +42,14 @@ } return string(seq) } -func NewClient(peerId, infoHash [20]byte) *Client { +type onDataChannelOpen func(_ datachannel.ReadWriteCloser, initiatedLocally bool) + +func NewClient(peerId, infoHash [20]byte, onConn onDataChannelOpen) *Client { return &Client{ offeredPeers: make(map[string]Peer), peerIDBinary: binaryToJsonString(peerId[:]), infoHashBinary: binaryToJsonString(infoHash[:]), + onConn: onConn, } } @@ -134,7 +137,9 @@ continue } switch { case ar.Offer != nil: - t, answer, err := NewTransportFromOffer(*ar.Offer, c.handleDataChannel) + t, answer, err := NewTransportFromOffer(*ar.Offer, func(dc datachannel.ReadWriteCloser) { + c.onConn(dc, false) + }) if err != nil { return fmt.Errorf("write AnnounceResponse: %w", err) } @@ -170,29 +175,14 @@ if !ok { log.Printf("could not find peer for offer %q", ar.OfferID) continue } - log.Printf("offer %q got answer %q", ar.OfferID, ar.Answer) - err = peer.transport.SetAnswer(*ar.Answer, c.handleDataChannel) + log.Printf("offer %q got answer %v", ar.OfferID, *ar.Answer) + err = peer.transport.SetAnswer(*ar.Answer, func(dc datachannel.ReadWriteCloser) { + c.onConn(dc, true) + }) if err != nil { return fmt.Errorf("failed to sent answer: %v", err) } } - } -} - -func (c *Client) handleDataChannel(dc datachannel.ReadWriteCloser) { - go c.dcReadLoop(dc) - //go c.dcWriteLoop(dc) -} - -func (c *Client) dcReadLoop(d io.Reader) { - for { - buffer := make([]byte, 1024) - n, err := d.Read(buffer) - if err != nil { - log.Printf("Datachannel closed; Exit the readloop: %v", err) - } - - fmt.Printf("Message from DataChannel: %s\n", string(buffer[:n])) } }