]> Sergey Matveev's repositories - btrtrc.git/blobdiff - client.go
Handle v2 Torrents added by short infohash only
[btrtrc.git] / client.go
index 4b8e9df1ff58b28b5048952d33856672638523ea..64b053fce260741095b2da51cc4bde2a9cb05d8d 100644 (file)
--- a/client.go
+++ b/client.go
@@ -78,7 +78,12 @@ type Client struct {
        // through legitimate channels.
        dopplegangerAddrs map[string]struct{}
        badPeerIPs        map[netip.Addr]struct{}
-       torrents          map[InfoHash]*Torrent
+       // All Torrents once.
+       torrents map[*Torrent]struct{}
+       // All Torrents by their short infohashes (v1 if valid, and truncated v2 if valid). Unless the
+       // info has been obtained, there's no knowing if an infohash belongs to v1 or v2.
+       torrentsByShortHash map[InfoHash]*Torrent
+
        pieceRequestOrder map[interface{}]*request_strategy.PieceRequestOrder
 
        acceptLimiter map[ipStr]int
@@ -200,7 +205,9 @@ func (cl *Client) announceKey() int32 {
 func (cl *Client) init(cfg *ClientConfig) {
        cl.config = cfg
        g.MakeMap(&cl.dopplegangerAddrs)
-       cl.torrents = make(map[metainfo.Hash]*Torrent)
+       g.MakeMap(&cl.torrentsByShortHash)
+       g.MakeMap(&cl.torrents)
+       cl.torrentsByShortHash = make(map[metainfo.Hash]*Torrent)
        cl.activeAnnounceLimiter.SlotsPerKey = 2
        cl.event.L = cl.locker()
        cl.ipBlockList = cfg.IPBlocklist
@@ -313,7 +320,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                ) {
                        cl.lock()
                        defer cl.unlock()
-                       t, ok := cl.torrents[infoHash]
+                       t, ok := cl.torrentsByShortHash[infoHash]
                        if !ok {
                                return tracker.AnnounceRequest{}, errors.New("torrent not tracked by client")
                        }
@@ -326,7 +333,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
                        cl.lock()
                        defer cl.unlock()
-                       t, ok := cl.torrents[dcc.InfoHash]
+                       t, ok := cl.torrentsByShortHash[dcc.InfoHash]
                        if !ok {
                                cl.logger.WithDefaultLevel(log.Warning).Printf(
                                        "got webrtc conn for unloaded torrent with infohash %x",
@@ -352,7 +359,7 @@ func (cl *Client) AddDialer(d Dialer) {
        cl.lock()
        defer cl.unlock()
        cl.dialers = append(cl.dialers, d)
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                t.openNewConns()
        }
 }
@@ -448,7 +455,7 @@ func (cl *Client) eachDhtServer(f func(DhtServer)) {
 func (cl *Client) Close() (errs []error) {
        var closeGroup sync.WaitGroup // For concurrent cleanup to complete before returning
        cl.lock()
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                err := t.close(&closeGroup)
                if err != nil {
                        errs = append(errs, err)
@@ -480,7 +487,7 @@ func (cl *Client) wantConns() bool {
        if cl.config.AlwaysWantConns {
                return true
        }
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                if t.wantIncomingConns() {
                        return true
                }
@@ -609,14 +616,10 @@ func (cl *Client) incomingConnection(nc net.Conn) {
 func (cl *Client) Torrent(ih metainfo.Hash) (t *Torrent, ok bool) {
        cl.rLock()
        defer cl.rUnlock()
-       t, ok = cl.torrents[ih]
+       t, ok = cl.torrentsByShortHash[ih]
        return
 }
 
-func (cl *Client) torrent(ih metainfo.Hash) *Torrent {
-       return cl.torrents[ih]
-}
-
 type DialResult struct {
        Conn   net.Conn
        Dialer Dialer
@@ -686,13 +689,13 @@ func (cl *Client) noLongerHalfOpen(t *Torrent, addr string, attemptKey outgoingC
        if cl.numHalfOpen < 0 {
                panic("should not be possible")
        }
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                t.openNewConns()
        }
 }
 
 func (cl *Client) countHalfOpenFromTorrents() (count int) {
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                count += t.numHalfOpenAttempts()
        }
        return
@@ -946,18 +949,18 @@ func (cl *Client) forSkeys(f func([]byte) bool) {
        defer cl.rUnlock()
        if false { // Emulate the bug from #114
                var firstIh InfoHash
-               for ih := range cl.torrents {
+               for ih := range cl.torrentsByShortHash {
                        firstIh = ih
                        break
                }
-               for range cl.torrents {
+               for range cl.torrentsByShortHash {
                        if !f(firstIh[:]) {
                                break
                        }
                }
                return
        }
-       for ih := range cl.torrents {
+       for ih := range cl.torrentsByShortHash {
                if !f(ih[:]) {
                        break
                }
@@ -975,7 +978,12 @@ func (cl *Client) handshakeReceiverSecretKeys() mse.SecretKeyIter {
 func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
        defer perf.ScopeTimerErr(&err)()
        var rw io.ReadWriter
-       rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.handshakeReceiverSecretKeys(), cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
+       rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(
+               c.rw(),
+               cl.handshakeReceiverSecretKeys(),
+               cl.config.HeaderObfuscationPolicy,
+               cl.config.CryptoSelector,
+       )
        c.setRW(rw)
        if err == nil || err == mse.ErrNoSecretKeyMatch {
                if c.headerEncrypted {
@@ -1001,7 +1009,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
                return nil, fmt.Errorf("during bt handshake: %w", err)
        }
        cl.lock()
-       t = cl.torrents[ih]
+       t = cl.torrentsByShortHash[ih]
        cl.unlock()
        return
 }
@@ -1368,7 +1376,7 @@ func (cl *Client) AddTorrentInfoHashWithStorage(
 ) (t *Torrent, new bool) {
        cl.lock()
        defer cl.unlock()
-       t, ok := cl.torrents[infoHash]
+       t, ok := cl.torrentsByShortHash[infoHash]
        if ok {
                return
        }
@@ -1380,7 +1388,8 @@ func (cl *Client) AddTorrentInfoHashWithStorage(
                        go t.dhtAnnouncer(s)
                }
        })
-       cl.torrents[infoHash] = t
+       cl.torrentsByShortHash[infoHash] = t
+       cl.torrents[t] = struct{}{}
        cl.clearAcceptLimits()
        t.updateWantPeersEvent()
        // Tickle Client.waitAccept, new torrent may want conns.
@@ -1394,10 +1403,16 @@ func (cl *Client) AddTorrentOpt(opts AddTorrentOpts) (t *Torrent, new bool) {
        infoHash := opts.InfoHash
        cl.lock()
        defer cl.unlock()
-       t, ok := cl.torrents[infoHash]
+       t, ok := cl.torrentsByShortHash[infoHash]
        if ok {
                return
        }
+       if opts.InfoHashV2.Ok {
+               t, ok = cl.torrentsByShortHash[*opts.InfoHashV2.Value.ToShort()]
+               if ok {
+                       return
+               }
+       }
        new = true
 
        t = cl.newTorrentOpt(opts)
@@ -1406,7 +1421,8 @@ func (cl *Client) AddTorrentOpt(opts AddTorrentOpts) (t *Torrent, new bool) {
                        go t.dhtAnnouncer(s)
                }
        })
-       cl.torrents[infoHash] = t
+       cl.torrentsByShortHash[infoHash] = t
+       cl.torrents[t] = struct{}{}
        t.setInfoBytesLocked(opts.InfoBytes)
        cl.clearAcceptLimits()
        t.updateWantPeersEvent()
@@ -1484,19 +1500,17 @@ func (t *Torrent) MergeSpec(spec *TorrentSpec) error {
        return t.AddPieceLayers(spec.PieceLayers)
 }
 
-func (cl *Client) dropTorrent(infoHash metainfo.Hash, wg *sync.WaitGroup) (err error) {
-       t, ok := cl.torrents[infoHash]
-       if !ok {
-               err = fmt.Errorf("no such torrent")
-               return
-       }
+func (cl *Client) dropTorrent(t *Torrent, wg *sync.WaitGroup) (err error) {
+       t.eachShortInfohash(func(short [20]byte) {
+               delete(cl.torrentsByShortHash, short)
+       })
        err = t.close(wg)
-       delete(cl.torrents, infoHash)
+       delete(cl.torrents, t)
        return
 }
 
 func (cl *Client) allTorrentsCompleted() bool {
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                if !t.haveInfo() {
                        return false
                }
@@ -1529,7 +1543,7 @@ func (cl *Client) Torrents() []*Torrent {
 }
 
 func (cl *Client) torrentsAsSlice() (ret []*Torrent) {
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                ret = append(ret, t)
        }
        return
@@ -1593,7 +1607,7 @@ func (cl *Client) banPeerIP(ip net.IP) {
                panic(ip)
        }
        g.MakeMapIfNilAndSet(&cl.badPeerIPs, ipAddr, struct{}{})
-       for _, t := range cl.torrents {
+       for t := range cl.torrents {
                t.iterPeers(func(p *Peer) {
                        if p.remoteIp().Equal(ip) {
                                t.logger.Levelf(log.Warning, "dropping peer %v with banned ip %v", p, ip)
@@ -1662,7 +1676,7 @@ func (cl *Client) newConnection(nc net.Conn, opts newConnectionOpts) (c *PeerCon
 func (cl *Client) onDHTAnnouncePeer(ih metainfo.Hash, ip net.IP, port int, portOk bool) {
        cl.lock()
        defer cl.unlock()
-       t := cl.torrent(ih)
+       t := cl.torrentsByShortHash[ih]
        if t == nil {
                return
        }