client.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++---- go.mod | 2 +- go.sum | 2 ++ torrent.go | 52 ++++++++++++++++++++++++---------------------------- webtorrent/tracker_client.go | 161 +++++++++++++++++++++++++++++++++++------------------ webtorrent/tracker_protocol.go | 21 +++++++++++++++++++++ webtorrent/transport.go | 30 +++++++----------------------- wstracker.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/client.go b/client.go index c4e81190e8b0778418f12b77de588eda3709a262..0226edfce431aedbaa46a2487f62b7f981a25a8f 100644 --- a/client.go +++ b/client.go @@ -23,9 +23,12 @@ "github.com/anacrolix/missinggo/pubsub" "github.com/anacrolix/missinggo/slices" "github.com/anacrolix/missinggo/v2/pproffd" "github.com/anacrolix/sync" + "github.com/anacrolix/torrent/tracker" + "github.com/anacrolix/torrent/webtorrent" "github.com/davecgh/go-spew/spew" "github.com/dustin/go-humanize" "github.com/google/btree" + "github.com/pion/datachannel" "golang.org/x/time/rate" "golang.org/x/xerrors" @@ -73,6 +76,8 @@ torrents map[InfoHash]*Torrent acceptLimiter map[ipStr]int dialRateLimiter *rate.Limiter + + websocketTrackers websocketTrackers } type ipStr string @@ -241,6 +246,32 @@ } } } + cl.websocketTrackers = websocketTrackers{ + PeerId: cl.peerID, + Logger: cl.logger.WithMap(func(msg log.Msg) log.Msg { + return msg.SetLevel(log.Critical) + }), + GetAnnounceRequest: func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest { + cl.lock() + defer cl.unlock() + return cl.torrents[infoHash].announceRequest(event) + }, + OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) { + cl.lock() + defer cl.unlock() + t, ok := cl.torrents[dcc.InfoHash] + if !ok { + cl.logger.WithDefaultLevel(log.Warning).Printf( + "got webrtc conn for unloaded torrent with infohash %x", + dcc.InfoHash, + ) + dc.Close() + return + } + go t.onWebRtcConn(dc, dcc) + }, + } + return } @@ -627,9 +658,17 @@ } // Performs initiator handshakes and returns a connection. Returns nil *connection if no connection // for valid reasons. -func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr net.Addr, - network, connString string, -) (c *PeerConn, err error) { +func (cl *Client) handshakesConnection( + ctx context.Context, + nc net.Conn, + t *Torrent, + encryptHeader bool, + remoteAddr net.Addr, + network, + connString string, +) ( + c *PeerConn, err error, +) { c = cl.newConnection(nc, true, remoteAddr, network, connString) c.headerEncrypted = encryptHeader ctx, cancel := context.WithTimeout(ctx, cl.config.HandshakesTimeout) @@ -881,7 +920,11 @@ } defer t.dropConnection(c) go c.writer(time.Minute) cl.sendInitialMessages(c, t) - return fmt.Errorf("main read loop: %w", c.mainReadLoop()) + err := c.mainReadLoop() + if err != nil { + return fmt.Errorf("main read loop: %w", err) + } + return nil } // See the order given in Transmission's tr_peerMsgsNew. diff --git a/go.mod b/go.mod index 922724247e086301a8cad02e0c3f1866ed2a6c4a..300f29278d491f4ce8a51740a41a215cce875971 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ github.com/alexflint/go-arg v1.2.0 github.com/anacrolix/dht/v2 v2.6.1-0.20200416071723-3850fa1b802a github.com/anacrolix/envpprof v1.1.0 github.com/anacrolix/go-libutp v1.0.2 - github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149 + github.com/anacrolix/log v0.7.0 github.com/anacrolix/missinggo v1.2.1 github.com/anacrolix/missinggo/perf v1.0.0 github.com/anacrolix/missinggo/v2 v2.4.1-0.20200419051441-747d9d7544c6 diff --git a/go.sum b/go.sum index 3b637b4a881db32a18cb7395d042d4312f8f1d40..c921fcc0beac8a4f47e9c9e49866cb3de68e23c7 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/anacrolix/log v0.6.0 h1:5y+wtTWoecbrAWWuoBCH7UuGFiD6q2jnQxrLK01RC+Q= github.com/anacrolix/log v0.6.0/go.mod h1:lWvLTqzAnCWPJA08T2HCstZi0L1y2Wyvm3FJgwU9jwU= github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149 h1:3cEyLU9ObAfTnBHCev8uuWGhbHfol8uTwyMRkLIpZGg= github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149/go.mod h1:s5yBP/j046fm9odtUTbHOfDUq/zh1W8OkPpJtnX0oQI= +github.com/anacrolix/log v0.7.0 h1:koGkC/K0LjIbrhLhwfpsfMuvu8nhvY7J4TmLVc1mAwE= +github.com/anacrolix/log v0.7.0/go.mod h1:s5yBP/j046fm9odtUTbHOfDUq/zh1W8OkPpJtnX0oQI= github.com/anacrolix/missinggo v0.0.0-20180522035225-b4a5853e62ff/go.mod h1:b0p+7cn+rWMIphK1gDH2hrDuwGOcbB6V4VXeSsEfHVk= github.com/anacrolix/missinggo v0.0.0-20180725070939-60ef2fbf63df/go.mod h1:kwGiTUTZ0+p4vAz3VbAI5a30t2YbvemcmspjKwrAz5s= github.com/anacrolix/missinggo v0.2.1-0.20190310234110-9fbdc9f242a8 h1:E2Xb2SBsVzHJ1tNMW9QcckYEQcyBKz1ee8qVjeVRWys= diff --git a/torrent.go b/torrent.go index 63952351622d1fb00e5cd94f50fe1f4dc791f97a..11fb8eba458734de8248246e056cd08cddc1efeb 100644 --- a/torrent.go +++ b/torrent.go @@ -1308,7 +1308,9 @@ } t.cl.lock() defer t.cl.unlock() err = t.cl.runHandshookConn(pc, t) - t.logger.WithDefaultLevel(log.Critical).Printf("error running handshook webrtc conn: %v", err) + if err != nil { + t.logger.WithDefaultLevel(log.Critical).Printf("error running handshook webrtc conn: %v", err) + } } func (t *Torrent) logRunHandshookConn(pc *PeerConn, logAll bool, level log.Level) { @@ -1322,6 +1324,26 @@ func (t *Torrent) runHandshookConnLoggingErr(pc *PeerConn) { t.logRunHandshookConn(pc, false, log.Debug) } +func (t *Torrent) startWebsocketAnnouncer(u url.URL) torrentTrackerAnnouncer { + wtc, release := t.cl.websocketTrackers.Get(u.String()) + go func() { + <-t.closed.LockedChan(t.cl.locker()) + release() + }() + wst := websocketTracker{u, wtc} + go func() { + err := wtc.Announce(tracker.Started, t.infoHash) + if err != nil { + t.logger.WithDefaultLevel(log.Warning).Printf( + "error in initial announce to %q: %v", + u.String(), err, + ) + } + }() + return wst + +} + func (t *Torrent) startScrapingTracker(_url string) { if _url == "" { return @@ -1348,33 +1370,7 @@ } sl := func() torrentTrackerAnnouncer { switch u.Scheme { case "ws", "wss": - wst := websocketTracker{ - *u, - &webtorrent.TrackerClient{ - Url: u.String(), - GetAnnounceRequest: func(event tracker.AnnounceEvent) tracker.AnnounceRequest { - t.cl.lock() - defer t.cl.unlock() - return t.announceRequest(event) - }, - PeerId: t.cl.peerID, - InfoHash: t.infoHash, - OnConn: t.onWebRtcConn, - Logger: t.logger.WithText(func(m log.Msg) string { - return fmt.Sprintf("%q: %v", u.String(), m.Text()) - }).WithDefaultLevel(log.Debug), - }, - } - go func() { - err := wst.TrackerClient.Run() - if err != nil { - t.logger.WithDefaultLevel(log.Warning).Printf( - "error running websocket tracker announcer for %q: %v", - u.String(), err, - ) - } - }() - return wst + return t.startWebsocketAnnouncer(*u) } if u.Scheme == "udp4" && (t.cl.config.DisableIPv4Peers || t.cl.config.DisableIPv4) { return nil diff --git a/webtorrent/tracker_client.go b/webtorrent/tracker_client.go index 5c84a72930866ed449f4d0396df411786718926b..14b62ed28a5d6faca5542b627cab64b8df090d3d 100644 --- a/webtorrent/tracker_client.go +++ b/webtorrent/tracker_client.go @@ -18,23 +18,20 @@ // Client represents the webtorrent client type TrackerClient struct { Url string - GetAnnounceRequest func(tracker.AnnounceEvent) tracker.AnnounceRequest + GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest PeerId [20]byte - InfoHash [20]byte OnConn onDataChannelOpen Logger log.Logger - lock sync.Mutex + mu sync.Mutex + cond sync.Cond outboundOffers map[string]outboundOffer // OfferID to outboundOffer wsConn *websocket.Conn + closed bool } func (me *TrackerClient) peerIdBinary() string { return binaryToJsonString(me.PeerId[:]) -} - -func (me *TrackerClient) infoHashBinary() string { - return binaryToJsonString(me.InfoHash[:]) } // outboundOffer represents an outstanding offer. @@ -42,12 +39,14 @@ type outboundOffer struct { originalOffer webrtc.SessionDescription peerConnection wrappedPeerConnection dataChannel *webrtc.DataChannel + infoHash [20]byte } type DataChannelContext struct { Local, Remote webrtc.SessionDescription OfferId string LocalOffered bool + InfoHash [20]byte } type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext) @@ -60,26 +59,41 @@ return fmt.Errorf("dialing tracker: %w", err) } defer c.Close() tc.Logger.WithDefaultLevel(log.Debug).Printf("dialed tracker %q", tc.Url) + tc.mu.Lock() tc.wsConn = c - go func() { - err := tc.announce(tracker.Started) - if err != nil { - tc.Logger.WithDefaultLevel(log.Error).Printf("error in initial announce: %v", err) - } - }() + tc.cond.Broadcast() + tc.mu.Unlock() err = tc.trackerReadLoop(tc.wsConn) - tc.lock.Lock() + tc.mu.Lock() tc.closeUnusedOffers() - tc.lock.Unlock() + c.Close() + tc.mu.Unlock() return err } func (tc *TrackerClient) Run() error { - for { + tc.cond.L = &tc.mu + tc.mu.Lock() + for !tc.closed { + tc.mu.Unlock() err := tc.doWebsocket() tc.Logger.WithDefaultLevel(log.Warning).Printf("websocket instance ended: %v", err) time.Sleep(time.Minute) + tc.mu.Lock() } + tc.mu.Unlock() + return nil +} + +func (tc *TrackerClient) Close() error { + tc.mu.Lock() + tc.closed = true + if tc.wsConn != nil { + tc.wsConn.Close() + } + tc.mu.Unlock() + tc.cond.Broadcast() + return nil } func (tc *TrackerClient) closeUnusedOffers() { @@ -89,7 +103,7 @@ } tc.outboundOffers = nil } -func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error { +func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error { metrics.Add("outbound announces", 1) var randOfferId [20]byte _, err := rand.Read(randOfferId[:]) @@ -103,7 +117,7 @@ if err != nil { return fmt.Errorf("creating offer: %w", err) } - request := tc.GetAnnounceRequest(event) + request := tc.GetAnnounceRequest(event, infoHash) req := AnnounceRequest{ Numwant: 1, // If higher we need to create equal amount of offers. @@ -112,7 +126,7 @@ Downloaded: request.Downloaded, Left: request.Left, Event: request.Event.String(), Action: "announce", - InfoHash: tc.infoHashBinary(), + InfoHash: binaryToJsonString(infoHash[:]), PeerID: tc.peerIdBinary(), Offers: []Offer{{ OfferID: offerIDBinary, @@ -125,10 +139,9 @@ if err != nil { return fmt.Errorf("marshalling request: %w", err) } - tc.lock.Lock() - defer tc.lock.Unlock() - - err = tc.wsConn.WriteMessage(websocket.TextMessage, data) + tc.mu.Lock() + defer tc.mu.Unlock() + err = tc.writeMessage(data) if err != nil { pc.Close() return fmt.Errorf("write AnnounceRequest: %w", err) @@ -140,69 +153,106 @@ tc.outboundOffers[offerIDBinary] = outboundOffer{ peerConnection: pc, dataChannel: dc, originalOffer: offer, + infoHash: infoHash, } return nil } +func (tc *TrackerClient) writeMessage(data []byte) error { + for tc.wsConn == nil { + if tc.closed { + return fmt.Errorf("%T closed", tc) + } + tc.cond.Wait() + } + return tc.wsConn.WriteMessage(websocket.TextMessage, data) +} + func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error { for { _, message, err := tracker.ReadMessage() if err != nil { return fmt.Errorf("read message error: %w", err) } - tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message) + //tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message) var ar AnnounceResponse if err := json.Unmarshal(message, &ar); err != nil { tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err) - continue - } - if ar.InfoHash != tc.infoHashBinary() { - tc.Logger.Printf("announce response for different hash: expected %q got %q", tc.infoHashBinary(), ar.InfoHash) continue } switch { case ar.Offer != nil: - answer, err := getAnswerForOffer(*ar.Offer, tc.OnConn, ar.OfferID) + ih, err := jsonStringToInfoHash(ar.InfoHash) if err != nil { - return fmt.Errorf("write AnnounceResponse: %w", err) + tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err) + break } - - req := AnnounceResponse{ - Action: "announce", - InfoHash: tc.infoHashBinary(), - PeerID: tc.peerIdBinary(), - ToPeerID: ar.PeerID, - Answer: &answer, - OfferID: ar.OfferID, - } - data, err := json.Marshal(req) - if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) - } - - tc.lock.Lock() - err = tracker.WriteMessage(websocket.TextMessage, data) - if err != nil { - return fmt.Errorf("write AnnounceResponse: %w", err) - tc.lock.Unlock() - } - tc.lock.Unlock() + tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID) case ar.Answer != nil: tc.handleAnswer(ar.OfferID, *ar.Answer) } } } +func (tc *TrackerClient) handleOffer( + offer webrtc.SessionDescription, + offerId string, + infoHash [20]byte, + peerId string, +) error { + peerConnection, answer, err := newAnsweringPeerConnection(offer) + if err != nil { + return fmt.Errorf("write AnnounceResponse: %w", err) + } + response := AnnounceResponse{ + Action: "announce", + InfoHash: binaryToJsonString(infoHash[:]), + PeerID: tc.peerIdBinary(), + ToPeerID: peerId, + Answer: &answer, + OfferID: offerId, + } + data, err := json.Marshal(response) + if err != nil { + peerConnection.Close() + return fmt.Errorf("marshalling response: %w", err) + } + tc.mu.Lock() + defer tc.mu.Unlock() + if err := tc.writeMessage(data); err != nil { + peerConnection.Close() + return fmt.Errorf("writing response: %w", err) + } + timer := time.AfterFunc(30*time.Second, func() { + metrics.Add("answering peer connections timed out", 1) + peerConnection.Close() + }) + peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { + setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) { + timer.Stop() + metrics.Add("answering peer connection conversions", 1) + tc.OnConn(dc, DataChannelContext{ + Local: answer, + Remote: offer, + OfferId: offerId, + LocalOffered: false, + InfoHash: infoHash, + }) + }) + }) + return nil +} + func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) { - tc.lock.Lock() - defer tc.lock.Unlock() + tc.mu.Lock() + defer tc.mu.Unlock() offer, ok := tc.outboundOffers[offerId] if !ok { tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %q", offerId) return } - tc.Logger.Printf("offer %q got answer %v", offerId, answer) + //tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer) metrics.Add("outbound offers answered", 1) err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) { metrics.Add("outbound offers answered with datachannel", 1) @@ -211,6 +261,7 @@ Local: offer.originalOffer, Remote: answer, OfferId: offerId, LocalOffered: true, + InfoHash: offer.infoHash, }) }) if err != nil { @@ -218,5 +269,5 @@ tc.Logger.WithDefaultLevel(log.Warning).Printf("error using outbound offer answer: %v", err) return } delete(tc.outboundOffers, offerId) - go tc.announce(tracker.None) + go tc.Announce(tracker.None, offer.infoHash) } diff --git a/webtorrent/tracker_protocol.go b/webtorrent/tracker_protocol.go index 548122ffe2f4061671f3f70fe004030d7ed8c5bf..167044a6315510e833634c0b9fcb1d93f95787f6 100644 --- a/webtorrent/tracker_protocol.go +++ b/webtorrent/tracker_protocol.go @@ -1,6 +1,9 @@ package webtorrent import ( + "fmt" + "math" + "github.com/pion/webrtc/v2" ) @@ -43,3 +46,21 @@ seq = append(seq, rune(v)) } return string(seq) } + +func jsonStringToInfoHash(s string) (ih [20]byte, err error) { + defer func() { + r := recover() + if r == nil { + return + } + panic(fmt.Sprintf("%q", s)) + }() + for i, c := range []rune(s) { + if c < 0 || c > math.MaxUint8 { + err = fmt.Errorf("bad infohash string: %v", s) + return + } + ih[i] = byte(c) + } + return +} diff --git a/webtorrent/transport.go b/webtorrent/transport.go index 37989b842e56c046945563172640ea322a892bc8..c0d03bc7b6ef3c89c1e15e6bc7dbf2f4e402f436 100644 --- a/webtorrent/transport.go +++ b/webtorrent/transport.go @@ -5,7 +5,6 @@ "expvar" "fmt" "io" "sync" - "time" "github.com/anacrolix/missinggo/v2/pproffd" "github.com/pion/datachannel" @@ -78,9 +77,7 @@ } func initAnsweringPeerConnection( peerConnection wrappedPeerConnection, - offerId string, offer webrtc.SessionDescription, - onOpen onDataChannelOpen, ) (answer webrtc.SessionDescription, err error) { err = peerConnection.SetRemoteDescription(offer) if err != nil { @@ -94,35 +91,22 @@ err = peerConnection.SetLocalDescription(answer) if err != nil { return } - timer := time.AfterFunc(30*time.Second, func() { - metrics.Add("answering peer connections timed out", 1) - peerConnection.Close() - }) - peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { - setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) { - timer.Stop() - metrics.Add("answering peer connection conversions", 1) - onOpen(dc, DataChannelContext{answer, offer, offerId, false}) - }) - }) return } -// getAnswerForOffer creates a transport from a WebRTC offer and and returns a WebRTC answer to be +// newAnsweringPeerConnection creates a transport from a WebRTC offer and and returns a WebRTC answer to be // announced. -func getAnswerForOffer( - offer webrtc.SessionDescription, onOpen onDataChannelOpen, offerId string, -) ( - answer webrtc.SessionDescription, err error, +func newAnsweringPeerConnection(offer webrtc.SessionDescription) ( + peerConn wrappedPeerConnection, answer webrtc.SessionDescription, err error, ) { - peerConnection, err := newPeerConnection() + peerConn, err = newPeerConnection() if err != nil { - err = fmt.Errorf("failed to peer connection: %w", err) + err = fmt.Errorf("failed to create new connection: %w", err) return } - answer, err = initAnsweringPeerConnection(peerConnection, offerId, offer, onOpen) + answer, err = initAnsweringPeerConnection(peerConn, offer) if err != nil { - peerConnection.Close() + peerConn.Close() } return } diff --git a/wstracker.go b/wstracker.go index 8cf66e86a1b9297686c4bb6f814b8a875ad2a9b5..95780d774eb5fca1b69d550abf9c74bae4efc3d7 100644 --- a/wstracker.go +++ b/wstracker.go @@ -3,8 +3,13 @@ import ( "fmt" "net/url" + "sync" + + "github.com/anacrolix/log" + "github.com/anacrolix/torrent/tracker" "github.com/anacrolix/torrent/webtorrent" + "github.com/pion/datachannel" ) type websocketTracker struct { @@ -19,3 +24,56 @@ func (me websocketTracker) URL() url.URL { return me.url } + +type refCountedWebtorrentTrackerClient struct { + webtorrent.TrackerClient + refCount int +} + +type websocketTrackers struct { + PeerId [20]byte + Logger log.Logger + GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest + OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext) + mu sync.Mutex + clients map[string]*refCountedWebtorrentTrackerClient +} + +func (me *websocketTrackers) Get(url string) (*webtorrent.TrackerClient, func()) { + me.mu.Lock() + defer me.mu.Unlock() + value, ok := me.clients[url] + if !ok { + value = &refCountedWebtorrentTrackerClient{ + TrackerClient: webtorrent.TrackerClient{ + Url: url, + GetAnnounceRequest: me.GetAnnounceRequest, + PeerId: me.PeerId, + OnConn: me.OnConn, + Logger: me.Logger.WithText(func(m log.Msg) string { + return fmt.Sprintf("tracker client for %q: %v", url, m) + }), + }, + } + go func() { + err := value.TrackerClient.Run() + if err != nil { + me.Logger.Printf("error running tracker client for %q: %v", url, err) + } + }() + if me.clients == nil { + me.clients = make(map[string]*refCountedWebtorrentTrackerClient) + } + me.clients[url] = value + } + value.refCount++ + return &value.TrackerClient, func() { + me.mu.Lock() + defer me.mu.Unlock() + value.refCount-- + if value.refCount == 0 { + value.TrackerClient.Close() + delete(me.clients, url) + } + } +}