]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Pool webtorrent tracker websockets at the Client level
authorMatt Joiner <anacrolix@gmail.com>
Tue, 21 Apr 2020 08:08:43 +0000 (18:08 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 21 Apr 2020 08:08:43 +0000 (18:08 +1000)
client.go
go.mod
go.sum
torrent.go
webtorrent/tracker_client.go
webtorrent/tracker_protocol.go
webtorrent/transport.go
wstracker.go

index c4e81190e8b0778418f12b77de588eda3709a262..0226edfce431aedbaa46a2487f62b7f981a25a8f 100644 (file)
--- a/client.go
+++ b/client.go
@@ -23,9 +23,12 @@ import (
        "github.com/anacrolix/missinggo/slices"
        "github.com/anacrolix/missinggo/v2/pproffd"
        "github.com/anacrolix/sync"
+       "github.com/anacrolix/torrent/tracker"
+       "github.com/anacrolix/torrent/webtorrent"
        "github.com/davecgh/go-spew/spew"
        "github.com/dustin/go-humanize"
        "github.com/google/btree"
+       "github.com/pion/datachannel"
        "golang.org/x/time/rate"
        "golang.org/x/xerrors"
 
@@ -73,6 +76,8 @@ type Client struct {
 
        acceptLimiter   map[ipStr]int
        dialRateLimiter *rate.Limiter
+
+       websocketTrackers websocketTrackers
 }
 
 type ipStr string
@@ -241,6 +246,32 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                }
        }
 
+       cl.websocketTrackers = websocketTrackers{
+               PeerId: cl.peerID,
+               Logger: cl.logger.WithMap(func(msg log.Msg) log.Msg {
+                       return msg.SetLevel(log.Critical)
+               }),
+               GetAnnounceRequest: func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest {
+                       cl.lock()
+                       defer cl.unlock()
+                       return cl.torrents[infoHash].announceRequest(event)
+               },
+               OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
+                       cl.lock()
+                       defer cl.unlock()
+                       t, ok := cl.torrents[dcc.InfoHash]
+                       if !ok {
+                               cl.logger.WithDefaultLevel(log.Warning).Printf(
+                                       "got webrtc conn for unloaded torrent with infohash %x",
+                                       dcc.InfoHash,
+                               )
+                               dc.Close()
+                               return
+                       }
+                       go t.onWebRtcConn(dc, dcc)
+               },
+       }
+
        return
 }
 
@@ -627,9 +658,17 @@ func (cl *Client) noLongerHalfOpen(t *Torrent, addr string) {
 
 // Performs initiator handshakes and returns a connection. Returns nil *connection if no connection
 // for valid reasons.
-func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr net.Addr,
-       network, connString string,
-) (c *PeerConn, err error) {
+func (cl *Client) handshakesConnection(
+       ctx context.Context,
+       nc net.Conn,
+       t *Torrent,
+       encryptHeader bool,
+       remoteAddr net.Addr,
+       network,
+       connString string,
+) (
+       c *PeerConn, err error,
+) {
        c = cl.newConnection(nc, true, remoteAddr, network, connString)
        c.headerEncrypted = encryptHeader
        ctx, cancel := context.WithTimeout(ctx, cl.config.HandshakesTimeout)
@@ -881,7 +920,11 @@ func (cl *Client) runHandshookConn(c *PeerConn, t *Torrent) error {
        defer t.dropConnection(c)
        go c.writer(time.Minute)
        cl.sendInitialMessages(c, t)
-       return fmt.Errorf("main read loop: %w", c.mainReadLoop())
+       err := c.mainReadLoop()
+       if err != nil {
+               return fmt.Errorf("main read loop: %w", err)
+       }
+       return nil
 }
 
 // See the order given in Transmission's tr_peerMsgsNew.
diff --git a/go.mod b/go.mod
index 922724247e086301a8cad02e0c3f1866ed2a6c4a..300f29278d491f4ce8a51740a41a215cce875971 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -6,7 +6,7 @@ require (
        github.com/anacrolix/dht/v2 v2.6.1-0.20200416071723-3850fa1b802a
        github.com/anacrolix/envpprof v1.1.0
        github.com/anacrolix/go-libutp v1.0.2
-       github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149
+       github.com/anacrolix/log v0.7.0
        github.com/anacrolix/missinggo v1.2.1
        github.com/anacrolix/missinggo/perf v1.0.0
        github.com/anacrolix/missinggo/v2 v2.4.1-0.20200419051441-747d9d7544c6
diff --git a/go.sum b/go.sum
index 3b637b4a881db32a18cb7395d042d4312f8f1d40..c921fcc0beac8a4f47e9c9e49866cb3de68e23c7 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -61,6 +61,8 @@ github.com/anacrolix/log v0.6.0 h1:5y+wtTWoecbrAWWuoBCH7UuGFiD6q2jnQxrLK01RC+Q=
 github.com/anacrolix/log v0.6.0/go.mod h1:lWvLTqzAnCWPJA08T2HCstZi0L1y2Wyvm3FJgwU9jwU=
 github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149 h1:3cEyLU9ObAfTnBHCev8uuWGhbHfol8uTwyMRkLIpZGg=
 github.com/anacrolix/log v0.6.1-0.20200416071330-f58a030e6149/go.mod h1:s5yBP/j046fm9odtUTbHOfDUq/zh1W8OkPpJtnX0oQI=
+github.com/anacrolix/log v0.7.0 h1:koGkC/K0LjIbrhLhwfpsfMuvu8nhvY7J4TmLVc1mAwE=
+github.com/anacrolix/log v0.7.0/go.mod h1:s5yBP/j046fm9odtUTbHOfDUq/zh1W8OkPpJtnX0oQI=
 github.com/anacrolix/missinggo v0.0.0-20180522035225-b4a5853e62ff/go.mod h1:b0p+7cn+rWMIphK1gDH2hrDuwGOcbB6V4VXeSsEfHVk=
 github.com/anacrolix/missinggo v0.0.0-20180725070939-60ef2fbf63df/go.mod h1:kwGiTUTZ0+p4vAz3VbAI5a30t2YbvemcmspjKwrAz5s=
 github.com/anacrolix/missinggo v0.2.1-0.20190310234110-9fbdc9f242a8 h1:E2Xb2SBsVzHJ1tNMW9QcckYEQcyBKz1ee8qVjeVRWys=
index 63952351622d1fb00e5cd94f50fe1f4dc791f97a..11fb8eba458734de8248246e056cd08cddc1efeb 100644 (file)
@@ -1308,7 +1308,9 @@ func (t *Torrent) onWebRtcConn(
        t.cl.lock()
        defer t.cl.unlock()
        err = t.cl.runHandshookConn(pc, t)
-       t.logger.WithDefaultLevel(log.Critical).Printf("error running handshook webrtc conn: %v", err)
+       if err != nil {
+               t.logger.WithDefaultLevel(log.Critical).Printf("error running handshook webrtc conn: %v", err)
+       }
 }
 
 func (t *Torrent) logRunHandshookConn(pc *PeerConn, logAll bool, level log.Level) {
@@ -1322,6 +1324,26 @@ func (t *Torrent) runHandshookConnLoggingErr(pc *PeerConn) {
        t.logRunHandshookConn(pc, false, log.Debug)
 }
 
+func (t *Torrent) startWebsocketAnnouncer(u url.URL) torrentTrackerAnnouncer {
+       wtc, release := t.cl.websocketTrackers.Get(u.String())
+       go func() {
+               <-t.closed.LockedChan(t.cl.locker())
+               release()
+       }()
+       wst := websocketTracker{u, wtc}
+       go func() {
+               err := wtc.Announce(tracker.Started, t.infoHash)
+               if err != nil {
+                       t.logger.WithDefaultLevel(log.Warning).Printf(
+                               "error in initial announce to %q: %v",
+                               u.String(), err,
+                       )
+               }
+       }()
+       return wst
+
+}
+
 func (t *Torrent) startScrapingTracker(_url string) {
        if _url == "" {
                return
@@ -1348,33 +1370,7 @@ func (t *Torrent) startScrapingTracker(_url string) {
        sl := func() torrentTrackerAnnouncer {
                switch u.Scheme {
                case "ws", "wss":
-                       wst := websocketTracker{
-                               *u,
-                               &webtorrent.TrackerClient{
-                                       Url: u.String(),
-                                       GetAnnounceRequest: func(event tracker.AnnounceEvent) tracker.AnnounceRequest {
-                                               t.cl.lock()
-                                               defer t.cl.unlock()
-                                               return t.announceRequest(event)
-                                       },
-                                       PeerId:   t.cl.peerID,
-                                       InfoHash: t.infoHash,
-                                       OnConn:   t.onWebRtcConn,
-                                       Logger: t.logger.WithText(func(m log.Msg) string {
-                                               return fmt.Sprintf("%q: %v", u.String(), m.Text())
-                                       }).WithDefaultLevel(log.Debug),
-                               },
-                       }
-                       go func() {
-                               err := wst.TrackerClient.Run()
-                               if err != nil {
-                                       t.logger.WithDefaultLevel(log.Warning).Printf(
-                                               "error running websocket tracker announcer for %q: %v",
-                                               u.String(), err,
-                                       )
-                               }
-                       }()
-                       return wst
+                       return t.startWebsocketAnnouncer(*u)
                }
                if u.Scheme == "udp4" && (t.cl.config.DisableIPv4Peers || t.cl.config.DisableIPv4) {
                        return nil
index 5c84a72930866ed449f4d0396df411786718926b..14b62ed28a5d6faca5542b627cab64b8df090d3d 100644 (file)
@@ -18,36 +18,35 @@ import (
 // Client represents the webtorrent client
 type TrackerClient struct {
        Url                string
-       GetAnnounceRequest func(tracker.AnnounceEvent) tracker.AnnounceRequest
+       GetAnnounceRequest func(_ tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest
        PeerId             [20]byte
-       InfoHash           [20]byte
        OnConn             onDataChannelOpen
        Logger             log.Logger
 
-       lock           sync.Mutex
+       mu             sync.Mutex
+       cond           sync.Cond
        outboundOffers map[string]outboundOffer // OfferID to outboundOffer
        wsConn         *websocket.Conn
+       closed         bool
 }
 
 func (me *TrackerClient) peerIdBinary() string {
        return binaryToJsonString(me.PeerId[:])
 }
 
-func (me *TrackerClient) infoHashBinary() string {
-       return binaryToJsonString(me.InfoHash[:])
-}
-
 // outboundOffer represents an outstanding offer.
 type outboundOffer struct {
        originalOffer  webrtc.SessionDescription
        peerConnection wrappedPeerConnection
        dataChannel    *webrtc.DataChannel
+       infoHash       [20]byte
 }
 
 type DataChannelContext struct {
        Local, Remote webrtc.SessionDescription
        OfferId       string
        LocalOffered  bool
+       InfoHash      [20]byte
 }
 
 type onDataChannelOpen func(_ datachannel.ReadWriteCloser, dcc DataChannelContext)
@@ -60,26 +59,41 @@ func (tc *TrackerClient) doWebsocket() error {
        }
        defer c.Close()
        tc.Logger.WithDefaultLevel(log.Debug).Printf("dialed tracker %q", tc.Url)
+       tc.mu.Lock()
        tc.wsConn = c
-       go func() {
-               err := tc.announce(tracker.Started)
-               if err != nil {
-                       tc.Logger.WithDefaultLevel(log.Error).Printf("error in initial announce: %v", err)
-               }
-       }()
+       tc.cond.Broadcast()
+       tc.mu.Unlock()
        err = tc.trackerReadLoop(tc.wsConn)
-       tc.lock.Lock()
+       tc.mu.Lock()
        tc.closeUnusedOffers()
-       tc.lock.Unlock()
+       c.Close()
+       tc.mu.Unlock()
        return err
 }
 
 func (tc *TrackerClient) Run() error {
-       for {
+       tc.cond.L = &tc.mu
+       tc.mu.Lock()
+       for !tc.closed {
+               tc.mu.Unlock()
                err := tc.doWebsocket()
                tc.Logger.WithDefaultLevel(log.Warning).Printf("websocket instance ended: %v", err)
                time.Sleep(time.Minute)
+               tc.mu.Lock()
+       }
+       tc.mu.Unlock()
+       return nil
+}
+
+func (tc *TrackerClient) Close() error {
+       tc.mu.Lock()
+       tc.closed = true
+       if tc.wsConn != nil {
+               tc.wsConn.Close()
        }
+       tc.mu.Unlock()
+       tc.cond.Broadcast()
+       return nil
 }
 
 func (tc *TrackerClient) closeUnusedOffers() {
@@ -89,7 +103,7 @@ func (tc *TrackerClient) closeUnusedOffers() {
        tc.outboundOffers = nil
 }
 
-func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error {
+func (tc *TrackerClient) Announce(event tracker.AnnounceEvent, infoHash [20]byte) error {
        metrics.Add("outbound announces", 1)
        var randOfferId [20]byte
        _, err := rand.Read(randOfferId[:])
@@ -103,7 +117,7 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error {
                return fmt.Errorf("creating offer: %w", err)
        }
 
-       request := tc.GetAnnounceRequest(event)
+       request := tc.GetAnnounceRequest(event, infoHash)
 
        req := AnnounceRequest{
                Numwant:    1, // If higher we need to create equal amount of offers.
@@ -112,7 +126,7 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error {
                Left:       request.Left,
                Event:      request.Event.String(),
                Action:     "announce",
-               InfoHash:   tc.infoHashBinary(),
+               InfoHash:   binaryToJsonString(infoHash[:]),
                PeerID:     tc.peerIdBinary(),
                Offers: []Offer{{
                        OfferID: offerIDBinary,
@@ -125,10 +139,9 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error {
                return fmt.Errorf("marshalling request: %w", err)
        }
 
-       tc.lock.Lock()
-       defer tc.lock.Unlock()
-
-       err = tc.wsConn.WriteMessage(websocket.TextMessage, data)
+       tc.mu.Lock()
+       defer tc.mu.Unlock()
+       err = tc.writeMessage(data)
        if err != nil {
                pc.Close()
                return fmt.Errorf("write AnnounceRequest: %w", err)
@@ -140,69 +153,106 @@ func (tc *TrackerClient) announce(event tracker.AnnounceEvent) error {
                peerConnection: pc,
                dataChannel:    dc,
                originalOffer:  offer,
+               infoHash:       infoHash,
        }
        return nil
 }
 
+func (tc *TrackerClient) writeMessage(data []byte) error {
+       for tc.wsConn == nil {
+               if tc.closed {
+                       return fmt.Errorf("%T closed", tc)
+               }
+               tc.cond.Wait()
+       }
+       return tc.wsConn.WriteMessage(websocket.TextMessage, data)
+}
+
 func (tc *TrackerClient) trackerReadLoop(tracker *websocket.Conn) error {
        for {
                _, message, err := tracker.ReadMessage()
                if err != nil {
                        return fmt.Errorf("read message error: %w", err)
                }
-               tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
+               //tc.Logger.WithDefaultLevel(log.Debug).Printf("received message from tracker: %q", message)
 
                var ar AnnounceResponse
                if err := json.Unmarshal(message, &ar); err != nil {
                        tc.Logger.WithDefaultLevel(log.Warning).Printf("error unmarshalling announce response: %v", err)
                        continue
                }
-               if ar.InfoHash != tc.infoHashBinary() {
-                       tc.Logger.Printf("announce response for different hash: expected %q got %q", tc.infoHashBinary(), ar.InfoHash)
-                       continue
-               }
                switch {
                case ar.Offer != nil:
-                       answer, err := getAnswerForOffer(*ar.Offer, tc.OnConn, ar.OfferID)
-                       if err != nil {
-                               return fmt.Errorf("write AnnounceResponse: %w", err)
-                       }
-
-                       req := AnnounceResponse{
-                               Action:   "announce",
-                               InfoHash: tc.infoHashBinary(),
-                               PeerID:   tc.peerIdBinary(),
-                               ToPeerID: ar.PeerID,
-                               Answer:   &answer,
-                               OfferID:  ar.OfferID,
-                       }
-                       data, err := json.Marshal(req)
+                       ih, err := jsonStringToInfoHash(ar.InfoHash)
                        if err != nil {
-                               return fmt.Errorf("failed to marshal request: %w", err)
+                               tc.Logger.WithDefaultLevel(log.Warning).Printf("error decoding info_hash in offer: %v", err)
+                               break
                        }
-
-                       tc.lock.Lock()
-                       err = tracker.WriteMessage(websocket.TextMessage, data)
-                       if err != nil {
-                               return fmt.Errorf("write AnnounceResponse: %w", err)
-                               tc.lock.Unlock()
-                       }
-                       tc.lock.Unlock()
+                       tc.handleOffer(*ar.Offer, ar.OfferID, ih, ar.PeerID)
                case ar.Answer != nil:
                        tc.handleAnswer(ar.OfferID, *ar.Answer)
                }
        }
 }
 
+func (tc *TrackerClient) handleOffer(
+       offer webrtc.SessionDescription,
+       offerId string,
+       infoHash [20]byte,
+       peerId string,
+) error {
+       peerConnection, answer, err := newAnsweringPeerConnection(offer)
+       if err != nil {
+               return fmt.Errorf("write AnnounceResponse: %w", err)
+       }
+       response := AnnounceResponse{
+               Action:   "announce",
+               InfoHash: binaryToJsonString(infoHash[:]),
+               PeerID:   tc.peerIdBinary(),
+               ToPeerID: peerId,
+               Answer:   &answer,
+               OfferID:  offerId,
+       }
+       data, err := json.Marshal(response)
+       if err != nil {
+               peerConnection.Close()
+               return fmt.Errorf("marshalling response: %w", err)
+       }
+       tc.mu.Lock()
+       defer tc.mu.Unlock()
+       if err := tc.writeMessage(data); err != nil {
+               peerConnection.Close()
+               return fmt.Errorf("writing response: %w", err)
+       }
+       timer := time.AfterFunc(30*time.Second, func() {
+               metrics.Add("answering peer connections timed out", 1)
+               peerConnection.Close()
+       })
+       peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
+               setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
+                       timer.Stop()
+                       metrics.Add("answering peer connection conversions", 1)
+                       tc.OnConn(dc, DataChannelContext{
+                               Local:        answer,
+                               Remote:       offer,
+                               OfferId:      offerId,
+                               LocalOffered: false,
+                               InfoHash:     infoHash,
+                       })
+               })
+       })
+       return nil
+}
+
 func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescription) {
-       tc.lock.Lock()
-       defer tc.lock.Unlock()
+       tc.mu.Lock()
+       defer tc.mu.Unlock()
        offer, ok := tc.outboundOffers[offerId]
        if !ok {
                tc.Logger.WithDefaultLevel(log.Warning).Printf("could not find offer for id %q", offerId)
                return
        }
-       tc.Logger.Printf("offer %q got answer %v", offerId, answer)
+       //tc.Logger.WithDefaultLevel(log.Debug).Printf("offer %q got answer %v", offerId, answer)
        metrics.Add("outbound offers answered", 1)
        err := offer.setAnswer(answer, func(dc datachannel.ReadWriteCloser) {
                metrics.Add("outbound offers answered with datachannel", 1)
@@ -211,6 +261,7 @@ func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescr
                        Remote:       answer,
                        OfferId:      offerId,
                        LocalOffered: true,
+                       InfoHash:     offer.infoHash,
                })
        })
        if err != nil {
@@ -218,5 +269,5 @@ func (tc *TrackerClient) handleAnswer(offerId string, answer webrtc.SessionDescr
                return
        }
        delete(tc.outboundOffers, offerId)
-       go tc.announce(tracker.None)
+       go tc.Announce(tracker.None, offer.infoHash)
 }
index 548122ffe2f4061671f3f70fe004030d7ed8c5bf..167044a6315510e833634c0b9fcb1d93f95787f6 100644 (file)
@@ -1,6 +1,9 @@
 package webtorrent
 
 import (
+       "fmt"
+       "math"
+
        "github.com/pion/webrtc/v2"
 )
 
@@ -43,3 +46,21 @@ func binaryToJsonString(b []byte) string {
        }
        return string(seq)
 }
+
+func jsonStringToInfoHash(s string) (ih [20]byte, err error) {
+       defer func() {
+               r := recover()
+               if r == nil {
+                       return
+               }
+               panic(fmt.Sprintf("%q", s))
+       }()
+       for i, c := range []rune(s) {
+               if c < 0 || c > math.MaxUint8 {
+                       err = fmt.Errorf("bad infohash string: %v", s)
+                       return
+               }
+               ih[i] = byte(c)
+       }
+       return
+}
index 37989b842e56c046945563172640ea322a892bc8..c0d03bc7b6ef3c89c1e15e6bc7dbf2f4e402f436 100644 (file)
@@ -5,7 +5,6 @@ import (
        "fmt"
        "io"
        "sync"
-       "time"
 
        "github.com/anacrolix/missinggo/v2/pproffd"
        "github.com/pion/datachannel"
@@ -78,9 +77,7 @@ func newOffer() (
 
 func initAnsweringPeerConnection(
        peerConnection wrappedPeerConnection,
-       offerId string,
        offer webrtc.SessionDescription,
-       onOpen onDataChannelOpen,
 ) (answer webrtc.SessionDescription, err error) {
        err = peerConnection.SetRemoteDescription(offer)
        if err != nil {
@@ -94,35 +91,22 @@ func initAnsweringPeerConnection(
        if err != nil {
                return
        }
-       timer := time.AfterFunc(30*time.Second, func() {
-               metrics.Add("answering peer connections timed out", 1)
-               peerConnection.Close()
-       })
-       peerConnection.OnDataChannel(func(d *webrtc.DataChannel) {
-               setDataChannelOnOpen(d, peerConnection, func(dc datachannel.ReadWriteCloser) {
-                       timer.Stop()
-                       metrics.Add("answering peer connection conversions", 1)
-                       onOpen(dc, DataChannelContext{answer, offer, offerId, false})
-               })
-       })
        return
 }
 
-// getAnswerForOffer creates a transport from a WebRTC offer and and returns a WebRTC answer to be
+// newAnsweringPeerConnection creates a transport from a WebRTC offer and and returns a WebRTC answer to be
 // announced.
-func getAnswerForOffer(
-       offer webrtc.SessionDescription, onOpen onDataChannelOpen, offerId string,
-) (
-       answer webrtc.SessionDescription, err error,
+func newAnsweringPeerConnection(offer webrtc.SessionDescription) (
+       peerConn wrappedPeerConnection, answer webrtc.SessionDescription, err error,
 ) {
-       peerConnection, err := newPeerConnection()
+       peerConn, err = newPeerConnection()
        if err != nil {
-               err = fmt.Errorf("failed to peer connection: %w", err)
+               err = fmt.Errorf("failed to create new connection: %w", err)
                return
        }
-       answer, err = initAnsweringPeerConnection(peerConnection, offerId, offer, onOpen)
+       answer, err = initAnsweringPeerConnection(peerConn, offer)
        if err != nil {
-               peerConnection.Close()
+               peerConn.Close()
        }
        return
 }
index 8cf66e86a1b9297686c4bb6f814b8a875ad2a9b5..95780d774eb5fca1b69d550abf9c74bae4efc3d7 100644 (file)
@@ -3,8 +3,13 @@ package torrent
 import (
        "fmt"
        "net/url"
+       "sync"
 
+       "github.com/anacrolix/log"
+
+       "github.com/anacrolix/torrent/tracker"
        "github.com/anacrolix/torrent/webtorrent"
+       "github.com/pion/datachannel"
 )
 
 type websocketTracker struct {
@@ -19,3 +24,56 @@ func (me websocketTracker) statusLine() string {
 func (me websocketTracker) URL() url.URL {
        return me.url
 }
+
+type refCountedWebtorrentTrackerClient struct {
+       webtorrent.TrackerClient
+       refCount int
+}
+
+type websocketTrackers struct {
+       PeerId             [20]byte
+       Logger             log.Logger
+       GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest
+       OnConn             func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
+       mu                 sync.Mutex
+       clients            map[string]*refCountedWebtorrentTrackerClient
+}
+
+func (me *websocketTrackers) Get(url string) (*webtorrent.TrackerClient, func()) {
+       me.mu.Lock()
+       defer me.mu.Unlock()
+       value, ok := me.clients[url]
+       if !ok {
+               value = &refCountedWebtorrentTrackerClient{
+                       TrackerClient: webtorrent.TrackerClient{
+                               Url:                url,
+                               GetAnnounceRequest: me.GetAnnounceRequest,
+                               PeerId:             me.PeerId,
+                               OnConn:             me.OnConn,
+                               Logger: me.Logger.WithText(func(m log.Msg) string {
+                                       return fmt.Sprintf("tracker client for %q: %v", url, m)
+                               }),
+                       },
+               }
+               go func() {
+                       err := value.TrackerClient.Run()
+                       if err != nil {
+                               me.Logger.Printf("error running tracker client for %q: %v", url, err)
+                       }
+               }()
+               if me.clients == nil {
+                       me.clients = make(map[string]*refCountedWebtorrentTrackerClient)
+               }
+               me.clients[url] = value
+       }
+       value.refCount++
+       return &value.TrackerClient, func() {
+               me.mu.Lock()
+               defer me.mu.Unlock()
+               value.refCount--
+               if value.refCount == 0 {
+                       value.TrackerClient.Close()
+                       delete(me.clients, url)
+               }
+       }
+}