From 0ab6d108be82c505b3ceac5e55bf25ab5d77dd6a Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Tue, 21 Apr 2020 18:08:43 +1000 Subject: [PATCH] Pool webtorrent tracker websockets at the Client level --- client.go | 51 ++++++++++- go.mod | 2 +- go.sum | 2 + torrent.go | 52 +++++------ webtorrent/tracker_client.go | 161 ++++++++++++++++++++++----------- webtorrent/tracker_protocol.go | 21 +++++ webtorrent/transport.go | 30 ++---- wstracker.go | 58 ++++++++++++ 8 files changed, 266 insertions(+), 111 deletions(-) diff --git a/client.go b/client.go index c4e81190..0226edfc 100644 --- a/client.go +++ b/client.go @@ -23,9 +23,12 @@ import ( "github.com/anacrolix/missinggo/slices" "github.com/anacrolix/missinggo/v2/pproffd" "github.com/anacrolix/sync" + "github.com/anacrolix/torrent/tracker" + "github.com/anacrolix/torrent/webtorrent" "github.com/davecgh/go-spew/spew" "github.com/dustin/go-humanize" "github.com/google/btree" + "github.com/pion/datachannel" "golang.org/x/time/rate" "golang.org/x/xerrors" @@ -73,6 +76,8 @@ type Client struct { acceptLimiter map[ipStr]int dialRateLimiter *rate.Limiter + + websocketTrackers websocketTrackers } type ipStr string @@ -241,6 +246,32 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) { } } + cl.websocketTrackers = websocketTrackers{ + PeerId: cl.peerID, + Logger: cl.logger.WithMap(func(msg log.Msg) log.Msg { + return msg.SetLevel(log.Critical) + }), + GetAnnounceRequest: func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest { + cl.lock() + defer cl.unlock() + return cl.torrents[infoHash].announceRequest(event) + }, + OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) { + cl.lock() + defer cl.unlock() + t, ok := cl.torrents[dcc.InfoHash] + if !ok { + cl.logger.WithDefaultLevel(log.Warning).Printf( + "got webrtc conn for unloaded torrent with infohash %x", + dcc.InfoHash, + ) + dc.Close() + return + } + go t.onWebRtcConn(dc, dcc) + }, + } + return } @@ -627,9 +658,17 @@ func (cl *Client) noLongerHalfOpen(t *Torrent, addr string) { // Performs initiator handshakes and returns a connection. Returns nil *connection if no connection // for valid reasons. -func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr net.Addr, - network, connString string, -) (c *PeerConn, err error) { +func (cl *Client) handshakesConnection( + ctx context.Context, + nc net.Conn, + t *Torrent, + encryptHeader bool, + remoteAddr net.Addr, + network, + connString string, +) ( + c *PeerConn, err error, +) { c = cl.newConnection(nc, true, remoteAddr, network, connString) c.headerEncrypted = encryptHeader ctx, cancel := context.WithTimeout(ctx, cl.config.HandshakesTimeout) @@ -881,7 +920,11 @@ func (cl *Client) runHandshookConn(c *PeerConn, t *Torrent) error { defer t.dropConnection(c) go c.writer(time.Minute) cl.sendInitialMessages(c, t) - return fmt.Errorf("main read loop: %w", c.mainReadLoop()) + err := c.mainReadLoop() + if err != nil { + return fmt.Errorf("main read loop: %w", err) + } + return nil } // See the order given in Transmission's tr_peerMsgsNew. diff --git a/go.mod b/go.mod index 92272424..300f2927 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/anacrolix/dht/v2 v2.6.1-0.20200416071723-3850fa1b802a github.com/anacrolix/envpprof v1.1.0 github.com/anacrolix/go-libutp v1.0.2 - github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149 + github.com/anacrolix/log v0.7.0 github.com/anacrolix/missinggo v1.2.1 github.com/anacrolix/missinggo/perf v1.0.0 github.com/anacrolix/missinggo/v2 v2.4.1-0.20200419051441-747d9d7544c6 diff --git a/go.sum b/go.sum index 3b637b4a..c921fcc0 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/anacrolix/log v0.6.0 h1:5y+wtTWoecbrAWWuoBCH7UuGFiD6q2jnQxrLK01RC+Q= github.com/anacrolix/log v0.6.0/go.mod h1:lWvLTqzAnCWPJA08T2HCstZi0L1y2Wyvm3FJgwU9jwU= github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149 h1:3cEyLU9ObAfTnBHCev8uuWGhbHfol8uTwyMRkLIpZGg= github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149/go.mod h1:s5yBP/j046fm9odtUTbHOfDUq/zh1W8OkPpJtnX0oQI= +github.com/anacrolix/log v0.7.0 h1:koGkC/K0LjIbrhLhwfpsfMuvu8nhvY7J4TmLVc1mAwE= +github.com/anacrolix/log v0.7.0/go.mod h1:s5yBP/j046fm9odtUTbHOfDUq/zh1W8OkPpJtnX0oQI= github.com/anacrolix/missinggo v0.0.0-20180522035225-b4a5853e62ff/go.mod h1:b0p+7cn+rWMIphK1gDH2hrDuwGOcbB6V4VXeSsEfHVk= github.com/anacrolix/missinggo v0.0.0-20180725070939-60ef2fbf63df/go.mod h1:kwGiTUTZ0+p4vAz3VbAI5a30t2YbvemcmspjKwrAz5s= github.com/anacrolix/missinggo v0.2.1-0.20190310234110-9fbdc9f242a8 h1:E2Xb2SBsVzHJ1tNMW9QcckYEQcyBKz1ee8qVjeVRWys= diff --git a/torrent.go b/torrent.go index 63952351..11fb8eba 100644 --- a/torrent.go +++ b/torrent.go @@ -1308,7 +1308,9 @@ func (t *Torrent) onWebRtcConn( t.cl.lock() defer t.cl.unlock() err = t.cl.runHandshookConn(pc, t) - t.logger.WithDefaultLevel(log.Critical).Printf("error running handshook webrtc conn: %v", err) + if err != nil { + t.logger.WithDefaultLevel(log.Critical).Printf("error running handshook webrtc conn: %v", err) + } } func (t *Torrent) logRunHandshookConn(pc *PeerConn, logAll bool, level log.Level) { @@ -1322,6 +1324,26 @@ func (t *Torrent) runHandshookConnLoggingErr(pc *PeerConn) { t.logRunHandshookConn(pc, false, log.Debug) } +func (t *Torrent) startWebsocketAnnouncer(u url.URL) torrentTrackerAnnouncer { + wtc, release := t.cl.websocketTrackers.Get(u.String()) + go func() { + <-t.closed.LockedChan(t.cl.locker()) + release() + }() + wst := websocketTracker{u, wtc} + go func() { + err := wtc.Announce(tracker.Started, t.infoHash) + if err != nil { + t.logger.WithDefaultLevel(log.Warning).Printf( + "error in initial announce to %q: %v", + u.String(), err, + ) + } + }() + return wst + +} + func (t *Torrent) startScrapingTracker(_url string) { if _url == "" { return @@ -1348,33 +1370,7 @@ func (t *Torrent) startScrapingTracker(_url string) { sl := func() torrentTrackerAnnouncer { switch u.Scheme { case "ws", "wss": - wst := websocketTracker{ - *u, - &webtorrent.TrackerClient{ - Url: u.String(), - GetAnnounceRequest: func(event tracker.AnnounceEvent) tracker.AnnounceRequest { - t.cl.lock() - defer t.cl.unlock() - return t.announceRequest(event) - }, - PeerId: t.cl.peerID, - InfoHash: t.infoHash, - OnConn: t.onWebRtcConn, - Logger: t.logger.WithText(func(m log.Msg) string { - return fmt.Sprintf("%q: %v", u.String(), m.Text()) - }).WithDefaultLevel(log.Debug), - }, - } - go func() { - err := wst.TrackerClient.Run() - if err != nil { - t.logger.WithDefaultLevel(log.Warning).Printf( - "error running websocket tracker announcer for %q: %v", - u.String(), err, - ) - } - }() - return wst + return t.startWebsocketAnnouncer(*u) } if u.Scheme == "udp4" && (t.cl.config.DisableIPv4Peers || t.cl.config.DisableIPv4) { return nil diff --git a/webtorrent/tracker_client.go b/webtorrent/tracker_client.go index 5c84a729..14b62ed2 100644 --- a/webtorrent/tracker_client.go +++ b/webtorrent/tracker_client.go @@ -18,36 +18,35 @@ import ( // Client represents the webtorrent client type TrackerClient struct { Url string - GetAnnounceRequest func(tracker.AnnounceEvent) tracker.AnnounceRequest + GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest PeerId [20]byte - InfoHash [20]byte OnConn onDataChannelOpen Logger log.Logger - lock sync.Mutex + mu sync.Mutex + cond sync.Cond outboundOffers map[string]outboundOffer // OfferID to outboundOffer wsConn *websocket.Conn + closed bool } func (me *TrackerClient) peerIdBinary() string { return binaryToJsonString(me.PeerId[:]) } -func (me *TrackerClient) infoHashBinary() string { - return binaryToJsonString(me.InfoHash[:]) -} - // outboundOffer represents an outstanding offer. type outboundOffer struct { originalOffer webrtc.SessionDescription peerConnection wrappedPeerConnection dataChannel *webrtc.DataChannel + infoHash [20]byte } type DataChannelContext struct { Local, Remote webrtc.SessionDescription OfferId string LocalOffered bool + InfoHash [20]byte } type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext) @@ -60,26 +59,41 @@ func (tc *TrackerClient) doWebsocket() error { } defer c.Close() tc.Logger.WithDefaultLevel(log.Debug).Printf("dialed tracker %q", tc.Url) + tc.mu.Lock() tc.wsConn = c - go func() { - err := tc.announce(tracker.Started) - if err != nil { - tc.Logger.WithDefaultLevel(log.Error).Printf("error in initial announce: %v", err) - } - }() + tc.cond.Broadcast() + tc.mu.Unlock() err = tc.trackerReadLoop(tc.wsConn) - tc.lock.Lock() + tc.mu.Lock() tc.closeUnusedOffers() - tc.lock.Unlock() + c.Close() + tc.mu.Unlock() return err } func (tc *TrackerClient) Run() error { - for { + tc.cond.L = &tc.mu + tc.mu.Lock() + for !tc.closed { + tc.mu.Unlock() err := tc.doWebsocket() tc.Logger.WithDefaultLevel(log.Warning).Printf("websocket instance ended: %v", err) time.Sleep(time.Minute) + tc.mu.Lock() + } + tc.mu.Unlock() + return nil +} + +func (tc *TrackerClient) Close() error { + tc.mu.Lock() + tc.closed = true + if tc.wsConn != nil { + tc.wsConn.Close() } + tc.mu.Unlock() + tc.cond.Broadcast() + return nil } func (tc *TrackerClient) closeUnusedOffers() { @@ -89,7 +103,7 @@ func (tc *TrackerClient) closeUnusedOffers() { tc.outboundOffers = nil } -func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error { +func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error { metrics.Add("outbound announces", 1) var randOfferId [20]byte _, err := rand.Read(randOfferId[:]) @@ -103,7 +117,7 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error { return fmt.Errorf("creating offer: %w", err) } - request := tc.GetAnnounceRequest(event) + request := tc.GetAnnounceRequest(event, infoHash) req := AnnounceRequest{ Numwant: 1, // If higher we need to create equal amount of offers. @@ -112,7 +126,7 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error { Left: request.Left, Event: request.Event.String(), Action: "announce", - InfoHash: tc.infoHashBinary(), + InfoHash: binaryToJsonString(infoHash[:]), PeerID: tc.peerIdBinary(), Offers: []Offer{{ OfferID: offerIDBinary, @@ -125,10 +139,9 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error { return fmt.Errorf("marshalling request: %w", err) } - tc.lock.Lock() - defer tc.lock.Unlock() - - err = tc.wsConn.WriteMessage(websocket.TextMessage, data) + tc.mu.Lock() + defer tc.mu.Unlock() + err = tc.writeMessage(data) if err != nil { pc.Close() return fmt.Errorf("write AnnounceRequest: %w", err) @@ -140,69 +153,106 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error { peerConnection: pc, dataChannel: dc, originalOffer: offer, + infoHash: infoHash, } return nil } +func (tc *TrackerClient) writeMessage(data []byte) error { + for tc.wsConn == nil { + if tc.closed { + return fmt.Errorf("%T closed", tc) + } + tc.cond.Wait() + } + return tc.wsConn.WriteMessage(websocket.TextMessage, data) +} + func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error { for { _, message, err := tracker.ReadMessage() if err != nil { return fmt.Errorf("read message error: %w", err) } - tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message) + //tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message) var ar AnnounceResponse if err := json.Unmarshal(message, &ar); err != nil { tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err) continue } - if ar.InfoHash != tc.infoHashBinary() { - tc.Logger.Printf("announce response for different hash: expected %q got %q", tc.infoHashBinary(), ar.InfoHash) - continue - } switch { case ar.Offer != nil: - answer, err := getAnswerForOffer(*ar.Offer, tc.OnConn, ar.OfferID) - if err != nil { - return fmt.Errorf("write AnnounceResponse: %w", err) - } - - req := AnnounceResponse{ - Action: "announce", - InfoHash: tc.infoHashBinary(), - PeerID: tc.peerIdBinary(), - ToPeerID: ar.PeerID, - Answer: &answer, - OfferID: ar.OfferID, - } - data, err := json.Marshal(req) + ih, err := jsonStringToInfoHash(ar.InfoHash) if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) + tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err) + break } - - tc.lock.Lock() - err = tracker.WriteMessage(websocket.TextMessage, data) - if err != nil { - return fmt.Errorf("write AnnounceResponse: %w", err) - tc.lock.Unlock() - } - tc.lock.Unlock() + tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID) case ar.Answer != nil: tc.handleAnswer(ar.OfferID, *ar.Answer) } } } +func (tc *TrackerClient) handleOffer( + offer webrtc.SessionDescription, + offerId string, + infoHash [20]byte, + peerId string, +) error { + peerConnection, answer, err := newAnsweringPeerConnection(offer) + if err != nil { + return fmt.Errorf("write AnnounceResponse: %w", err) + } + response := AnnounceResponse{ + Action: "announce", + InfoHash: binaryToJsonString(infoHash[:]), + PeerID: tc.peerIdBinary(), + ToPeerID: peerId, + Answer: &answer, + OfferID: offerId, + } + data, err := json.Marshal(response) + if err != nil { + peerConnection.Close() + return fmt.Errorf("marshalling response: %w", err) + } + tc.mu.Lock() + defer tc.mu.Unlock() + if err := tc.writeMessage(data); err != nil { + peerConnection.Close() + return fmt.Errorf("writing response: %w", err) + } + timer := time.AfterFunc(30*time.Second, func() { + metrics.Add("answering peer connections timed out", 1) + peerConnection.Close() + }) + peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { + setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) { + timer.Stop() + metrics.Add("answering peer connection conversions", 1) + tc.OnConn(dc, DataChannelContext{ + Local: answer, + Remote: offer, + OfferId: offerId, + LocalOffered: false, + InfoHash: infoHash, + }) + }) + }) + return nil +} + func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) { - tc.lock.Lock() - defer tc.lock.Unlock() + tc.mu.Lock() + defer tc.mu.Unlock() offer, ok := tc.outboundOffers[offerId] if !ok { tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %q", offerId) return } - tc.Logger.Printf("offer %q got answer %v", offerId, answer) + //tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer) metrics.Add("outbound offers answered", 1) err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) { metrics.Add("outbound offers answered with datachannel", 1) @@ -211,6 +261,7 @@ func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescr Remote: answer, OfferId: offerId, LocalOffered: true, + InfoHash: offer.infoHash, }) }) if err != nil { @@ -218,5 +269,5 @@ func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescr return } delete(tc.outboundOffers, offerId) - go tc.announce(tracker.None) + go tc.Announce(tracker.None, offer.infoHash) } diff --git a/webtorrent/tracker_protocol.go b/webtorrent/tracker_protocol.go index 548122ff..167044a6 100644 --- a/webtorrent/tracker_protocol.go +++ b/webtorrent/tracker_protocol.go @@ -1,6 +1,9 @@ package webtorrent import ( + "fmt" + "math" + "github.com/pion/webrtc/v2" ) @@ -43,3 +46,21 @@ func binaryToJsonString(b []byte) string { } return string(seq) } + +func jsonStringToInfoHash(s string) (ih [20]byte, err error) { + defer func() { + r := recover() + if r == nil { + return + } + panic(fmt.Sprintf("%q", s)) + }() + for i, c := range []rune(s) { + if c < 0 || c > math.MaxUint8 { + err = fmt.Errorf("bad infohash string: %v", s) + return + } + ih[i] = byte(c) + } + return +} diff --git a/webtorrent/transport.go b/webtorrent/transport.go index 37989b84..c0d03bc7 100644 --- a/webtorrent/transport.go +++ b/webtorrent/transport.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "sync" - "time" "github.com/anacrolix/missinggo/v2/pproffd" "github.com/pion/datachannel" @@ -78,9 +77,7 @@ func newOffer() ( func initAnsweringPeerConnection( peerConnection wrappedPeerConnection, - offerId string, offer webrtc.SessionDescription, - onOpen onDataChannelOpen, ) (answer webrtc.SessionDescription, err error) { err = peerConnection.SetRemoteDescription(offer) if err != nil { @@ -94,35 +91,22 @@ func initAnsweringPeerConnection( if err != nil { return } - timer := time.AfterFunc(30*time.Second, func() { - metrics.Add("answering peer connections timed out", 1) - peerConnection.Close() - }) - peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { - setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) { - timer.Stop() - metrics.Add("answering peer connection conversions", 1) - onOpen(dc, DataChannelContext{answer, offer, offerId, false}) - }) - }) return } -// getAnswerForOffer creates a transport from a WebRTC offer and and returns a WebRTC answer to be +// newAnsweringPeerConnection creates a transport from a WebRTC offer and and returns a WebRTC answer to be // announced. -func getAnswerForOffer( - offer webrtc.SessionDescription, onOpen onDataChannelOpen, offerId string, -) ( - answer webrtc.SessionDescription, err error, +func newAnsweringPeerConnection(offer webrtc.SessionDescription) ( + peerConn wrappedPeerConnection, answer webrtc.SessionDescription, err error, ) { - peerConnection, err := newPeerConnection() + peerConn, err = newPeerConnection() if err != nil { - err = fmt.Errorf("failed to peer connection: %w", err) + err = fmt.Errorf("failed to create new connection: %w", err) return } - answer, err = initAnsweringPeerConnection(peerConnection, offerId, offer, onOpen) + answer, err = initAnsweringPeerConnection(peerConn, offer) if err != nil { - peerConnection.Close() + peerConn.Close() } return } diff --git a/wstracker.go b/wstracker.go index 8cf66e86..95780d77 100644 --- a/wstracker.go +++ b/wstracker.go @@ -3,8 +3,13 @@ package torrent import ( "fmt" "net/url" + "sync" + "github.com/anacrolix/log" + + "github.com/anacrolix/torrent/tracker" "github.com/anacrolix/torrent/webtorrent" + "github.com/pion/datachannel" ) type websocketTracker struct { @@ -19,3 +24,56 @@ func (me websocketTracker) statusLine() string { func (me websocketTracker) URL() url.URL { return me.url } + +type refCountedWebtorrentTrackerClient struct { + webtorrent.TrackerClient + refCount int +} + +type websocketTrackers struct { + PeerId [20]byte + Logger log.Logger + GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest + OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext) + mu sync.Mutex + clients map[string]*refCountedWebtorrentTrackerClient +} + +func (me *websocketTrackers) Get(url string) (*webtorrent.TrackerClient, func()) { + me.mu.Lock() + defer me.mu.Unlock() + value, ok := me.clients[url] + if !ok { + value = &refCountedWebtorrentTrackerClient{ + TrackerClient: webtorrent.TrackerClient{ + Url: url, + GetAnnounceRequest: me.GetAnnounceRequest, + PeerId: me.PeerId, + OnConn: me.OnConn, + Logger: me.Logger.WithText(func(m log.Msg) string { + return fmt.Sprintf("tracker client for %q: %v", url, m) + }), + }, + } + go func() { + err := value.TrackerClient.Run() + if err != nil { + me.Logger.Printf("error running tracker client for %q: %v", url, err) + } + }() + if me.clients == nil { + me.clients = make(map[string]*refCountedWebtorrentTrackerClient) + } + me.clients[url] = value + } + value.refCount++ + return &value.TrackerClient, func() { + me.mu.Lock() + defer me.mu.Unlock() + value.refCount-- + if value.refCount == 0 { + value.TrackerClient.Close() + delete(me.clients, url) + } + } +} -- 2.48.1