client-piece-request-order.go | 2 +- client-tracker-announcer.go | 3 ++- client.go | 128 ++++++++++++++++++++++++----------------------------- mse/cmd/mse/main.go | 15 ++++++++++----- mse/mse.go | 21 ++++++++++++--------- mse/mse_test.go | 24 ++++++++++++++++-------- request-strategy-impls.go | 12 +++++++++--- torrent.go | 10 ++++------ torrents-by-short-infohash.go | 90 +++++++++++++++++++++++++++++++++++++++++++++++++++++ webseed-requesting.go | 3 ++- diff --git a/client-piece-request-order.go b/client-piece-request-order.go index 0984b40996f9051f161e892c8ec7863a1116bbaa..48749cbe5a7bb47706c95458f833e927f39d252a 100644 --- a/client-piece-request-order.go +++ b/client-piece-request-order.go @@ -31,7 +31,7 @@ func (c clientPieceRequestOrderSharedStorageTorrentKey) getRequestStrategyInput(cl *Client) requestStrategy.Input { return requestStrategyInputMultiTorrent{ requestStrategyInputCommon: cl.getRequestStrategyInputCommon(), - torrents: cl.torrentsByShortHash, + torrents: &cl.torrentsByShortHash, capFunc: c.inner, } } diff --git a/client-tracker-announcer.go b/client-tracker-announcer.go index 249359ba3b4b1edb4aa1968f422b8f217550330a..d4f300969251069a77e758d7e8e73700d7be9a87 100644 --- a/client-tracker-announcer.go +++ b/client-tracker-announcer.go @@ -413,7 +413,8 @@ } // Returns nil if the torrent was dropped. func (me *regularTrackerAnnounceDispatcher) torrentFromShortInfohash(short shortInfohash) *Torrent { - return me.torrentClient.torrentsByShortHash[short] + t, _ := me.torrentClient.torrentsByShortHash.Get(short) + return t } const maxConcurrentAnnouncesPerTracker = 2 diff --git a/client.go b/client.go index 59d6e61cc90cd8286a2d6a8f4b5d79185fb388a8..e4558032a7e9395ef039f1725549d0c95dff8c9f 100644 --- a/client.go +++ b/client.go @@ -93,7 +93,7 @@ torrents map[*Torrent]struct{} // All Torrents by their short infohashes (v1 if valid, and truncated v2 if valid). Unless the // info has been obtained, there's no knowing if an infohash belongs to v1 or v2. TODO: Make // this a weak pointer. - torrentsByShortHash map[InfoHash]*Torrent + torrentsByShortHash syncMapTorrentsByShortHash // Piece request orderings grouped by storage. Value is value type because all fields are // references. @@ -274,9 +274,8 @@ cl.initLogger() cl.regularTrackerAnnounceDispatcher.init(cl) cfg.setRateLimiterBursts() g.MakeMap(&cl.dopplegangerAddrs) - g.MakeMap(&cl.torrentsByShortHash) + cl.torrentsByShortHash.Init() g.MakeMap(&cl.torrents) - cl.torrentsByShortHash = make(map[metainfo.Hash]*Torrent) cl.event.L = cl.locker() cl.ipBlockList = cfg.IPBlocklist cl.httpClient = &http.Client{ @@ -331,7 +330,7 @@ tracker.AnnounceRequest, error, ) { cl.lock() defer cl.unlock() - t, ok := cl.torrentsByShortHash[infoHash] + t, ok := cl.torrentsByShortHash.Get(infoHash) if !ok { return tracker.AnnounceRequest{}, errors.New("torrent not tracked by client") } @@ -345,7 +344,7 @@ callbacks: &cl.config.Callbacks, OnConn: func(dc webtorrent.DataChannelConn, dcc webtorrent.DataChannelContext) { cl.lock() defer cl.unlock() - t, ok := cl.torrentsByShortHash[dcc.InfoHash] + t, ok := cl.torrentsByShortHash.Get(dcc.InfoHash) if !ok { cl.logger.WithDefaultLevel(log.Warning).Printf( "got webrtc conn for unloaded torrent with infohash %x", @@ -543,7 +542,6 @@ cl.dropTorrent(t, &closeGroup) } // Can we not modify cl.torrents as we delete from it? panicif.NotZero(len(cl.torrents)) - panicif.NotZero(len(cl.torrentsByShortHash)) cl.clearPortMappings() for i := range cl.onClose { cl.onClose[len(cl.onClose)-1-i]() @@ -696,10 +694,7 @@ } // Returns a handle to the given torrent, if it's present in the client. func (cl *Client) Torrent(ih metainfo.Hash) (t *Torrent, ok bool) { - cl.rLock() - defer cl.rUnlock() - t, ok = cl.torrentsByShortHash[ih] - return + return cl.torrentsByShortHash.Get(ih) } type DialResult struct { @@ -1022,32 +1017,17 @@ } // Calls f with any secret keys. Note that it takes the Client lock, and so must be used from code // that won't also try to take the lock. This saves us copying all the infohashes everytime. -func (cl *Client) forSkeys(f func([]byte) bool) { - cl.rLock() - defer cl.rUnlock() - if false { // Emulate the bug from #114 - var firstIh InfoHash - for ih := range cl.torrentsByShortHash { - firstIh = ih - break - } - for range cl.torrentsByShortHash { - if !f(firstIh[:]) { - break - } - } - return - } - for ih := range cl.torrentsByShortHash { - if !f(ih[:]) { - break +func (cl *Client) forSkeys(yield func([20]byte) bool) { + for ih := range cl.torrentsByShortHash.IterKeys { + if !yield(ih) { + return } } } func (cl *Client) handshakeReceiverSecretKeys() mse.SecretKeyIter { - if ret := cl.config.Callbacks.ReceiveEncryptedHandshakeSkeys; ret != nil { - return ret + if cb := cl.config.Callbacks.ReceiveEncryptedHandshakeSkeys; cb != nil { + return cb } return cl.forSkeys } @@ -1086,13 +1066,18 @@ if err != nil { return nil, fmt.Errorf("during bt handshake: %w", err) } - cl.lock() - t = cl.torrentsByShortHash[ih] - if t != nil && t.infoHashV2.Ok && *t.infoHashV2.Value.ToShort() == ih { - torrent.Add("v2 handshakes received", 1) - c.v2 = true + // Hooray for atomics. + t, _ = cl.torrentsByShortHash.Get(ih) + if t != nil { + cl.rLock() + isV2 := t.infoHashV2.Ok && *t.infoHashV2.Value.ToShort() == ih + cl.rUnlock() + if isV2 { + torrent.Add("v2 handshakes received", 1) + // PeerConn isn't owned by the Client yet. + c.v2 = true + } } - cl.unlock() return } @@ -1475,53 +1460,54 @@ func (cl *Client) AddTorrentInfoHash(infoHash metainfo.Hash) (t *Torrent, new bool) { return cl.AddTorrentInfoHashWithStorage(infoHash, nil) } -// Deprecated. Adds a torrent by InfoHash with a custom Storage implementation. -// If the torrent already exists then this Storage is ignored and the -// existing torrent returned with `new` set to `false` +// Deprecated. Adds a torrent by InfoHash with a custom Storage implementation. If the torrent +// already exists then this Storage is ignored and the existing torrent returned with `new` set to +// `false` func (cl *Client) AddTorrentInfoHashWithStorage( infoHash metainfo.Hash, specStorage storage.ClientImpl, ) (t *Torrent, new bool) { - cl.lock() - defer cl.unlock() - t, ok := cl.torrentsByShortHash[infoHash] - if ok { - return - } - new = true + return cl.AddTorrentOpt(AddTorrentOpts{ + InfoHash: infoHash, + Storage: specStorage, + }) +} - t = cl.newTorrent(infoHash, specStorage) - cl.eachDhtServer(func(s DhtServer) { - if cl.config.PeriodicallyAnnounceTorrentsToDht { - go t.dhtAnnouncer(s) +func (cl *Client) addTorrentReturningExisting(opts AddTorrentOpts) (t *Torrent, ok bool) { + t, ok = cl.torrentsByShortHash.Get(opts.InfoHash) + if !ok { + if opts.InfoHashV2.Ok { + t, ok = cl.torrentsByShortHash.Get(*opts.InfoHashV2.Value.ToShort()) } - }) - cl.torrentsByShortHash[infoHash] = t - cl.torrents[t] = struct{}{} - cl.clearAcceptLimits() - t.updateWantPeersEvent() - // Tickle Client.waitAccept, new torrent may want conns. - cl.event.Broadcast() + } return } // Adds a torrent by InfoHash with a custom Storage implementation. If the torrent already exists // then this Storage is ignored and the existing torrent returned with `new` set to `false`. func (cl *Client) AddTorrentOpt(opts AddTorrentOpts) (t *Torrent, new bool) { - infoHash := opts.InfoHash - panicif.Zero(infoHash) + panicif.Zero(opts.InfoHash) + + t, ok := cl.addTorrentReturningExisting(opts) + if ok && !t.closed.IsSet() { + return + } + cl.lock() defer cl.unlock() - t, ok := cl.torrentsByShortHash[infoHash] + + t, ok = cl.addTorrentReturningExisting(opts) if ok { - return - } - if opts.InfoHashV2.Ok { - t, ok = cl.torrentsByShortHash[*opts.InfoHashV2.Value.ToShort()] - if ok { + if !t.closed.IsSet() { return } + // Do we have to nuke this? Can't we just clobber it? + t.eachShortInfohash(func(short [20]byte) { + cl.torrentsByShortHash.Delete(short) + }) } + + infoHash := opts.InfoHash new = true t = cl.newTorrentOpt(opts) @@ -1530,7 +1516,7 @@ if cl.config.PeriodicallyAnnounceTorrentsToDht { go t.dhtAnnouncer(s) } }) - cl.torrentsByShortHash[infoHash] = t + panicif.False(cl.torrentsByShortHash.Set(infoHash, t)) t.setInfoBytesLocked(opts.InfoBytes) cl.clearAcceptLimits() t.updateWantPeersEvent() @@ -1787,12 +1773,12 @@ } } func (cl *Client) onDHTAnnouncePeer(ih metainfo.Hash, ip net.IP, port int, portOk bool) { - cl.lock() - defer cl.unlock() - t := cl.torrentsByShortHash[ih] - if t == nil { + t, ok := cl.torrentsByShortHash.Get(ih) + if !ok { return } + cl.lock() + defer cl.unlock() t.addPeers([]PeerInfo{{ Addr: ipPortAddr{ip, port}, Source: PeerSourceDhtAnnouncePeer, diff --git a/mse/cmd/mse/main.go b/mse/cmd/mse/main.go index c96af097175bc9062c49dad442e5e3064c6d0395..2d1885f8ea79380cc5250b31410a8a55906e13f0 100644 --- a/mse/cmd/mse/main.go +++ b/mse/cmd/mse/main.go @@ -63,11 +63,16 @@ if err != nil { return fmt.Errorf("accepting: %w", err) } defer cn.Close() - rw, _, err := mse.ReceiveHandshake(context.TODO(), cn, func(f func([]byte) bool) { - for _, sk := range args.Listen.SecretKeys { - f([]byte(sk)) - } - }, mse.DefaultCryptoSelector) + rw, _, err := mse.ReceiveHandshake( + context.TODO(), + cn, + func(f func(mse.SecretKey) bool) { + for _, sk := range args.Listen.SecretKeys { + f(mse.SecretKey([]byte(sk))) + } + }, + mse.DefaultCryptoSelector, + ) if err != nil { log.Fatalf("error receiving: %v", err) } diff --git a/mse/mse.go b/mse/mse.go index f566c63f79f34acc05ceebef0211caf73d292ee3..e351a764a0b907ecec8fc14f94e6dc8311a23940 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -13,6 +13,7 @@ "errors" "expvar" "fmt" "io" + "iter" "math" "math/big" "strconv" @@ -437,19 +438,18 @@ expectedHash := hash(req3, h.s[:]) eachHash := sha1.New() var sum, xored [sha1.Size]byte err = ErrNoSecretKeyMatch - h.skeys(func(skey []byte) bool { + for skey := range h.skeys { eachHash.Reset() eachHash.Write(req2) - eachHash.Write(skey) + eachHash.Write(skey[:]) eachHash.Sum(sum[:0]) xorInPlace(xored[:], sum[:], expectedHash) if bytes.Equal(xored[:], b[:]) { - h.skey = skey + h.skey = skey[:] err = nil - return false + break } - return true - }) + } if err != nil { return } @@ -598,9 +598,12 @@ ret.SecretKey = h.skey return } -// A function that given a function, calls it with secret keys until it -// returns false or exhausted. -type SecretKeyIter func(callback func(skey []byte) (more bool)) +type ( + // For performance reasons prefer a static-sized array rather than []byte. + SecretKey = [20]byte + // A function that given a function, calls it with secret keys until it returns false or exhausted. + SecretKeyIter = iter.Seq[SecretKey] +) func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod { // We prefer plaintext for performance reasons. diff --git a/mse/mse_test.go b/mse/mse_test.go index d97323f1c83b328093302ed6c5e99880b15ac776..60e13f9a8876a1a34da8ab380360077d1766b365 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -7,6 +7,7 @@ "crypto/rand" "crypto/rc4" "io" "net" + "slices" "sync" "testing" @@ -15,14 +16,20 @@ "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func padBytesToHash(b []byte) (ret []byte) { + var sk SecretKey + copy(sk[:], b) + return sk[:] +} + func sliceIter(skeys [][]byte) SecretKeyIter { - return func(callback func([]byte) bool) { - for _, sk := range skeys { - if !callback(sk) { - break - } - } + arraySlice := make([][20]byte, 0, len(skeys)) + for _, key := range skeys { + var sk SecretKey + copy(sk[:], key) + arraySlice = append(arraySlice, sk) } + return slices.Values(arraySlice) } func TestReadUntil(t *testing.T) { @@ -62,9 +69,10 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) { a, b := net.Pipe() wg := sync.WaitGroup{} wg.Add(2) + senderSkey := padBytesToHash([]byte("yep")) go func() { defer wg.Done() - a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides) + a, cm, err := InitiateHandshake(a, senderSkey, ia, cryptoProvides) require.NoError(t, err) assert.Equal(t, cryptoSelect(cryptoProvides), cm) go a.Write([]byte(aData)) @@ -85,7 +93,7 @@ sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect, ) require.NoError(t, res.error) - assert.EqualValues(t, "yep", res.SecretKey) + assert.EqualValues(t, senderSkey, res.SecretKey) b := res.ReadWriter assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod) go b.Write([]byte(bData)) diff --git a/request-strategy-impls.go b/request-strategy-impls.go index e67bb9f0e3c0d12c7999074adc172d11846ae879..d375c4b121631b2deb3ad2bc1eff3192c3f53ac2 100644 --- a/request-strategy-impls.go +++ b/request-strategy-impls.go @@ -1,7 +1,7 @@ package torrent import ( - g "github.com/anacrolix/generics" + "github.com/anacrolix/missinggo/v2/panicif" requestStrategy "github.com/anacrolix/torrent/internal/request-strategy" "github.com/anacrolix/torrent/metainfo" @@ -14,16 +14,22 @@ } func (r requestStrategyInputCommon) MaxUnverifiedBytes() int64 { return r.maxUnverifiedBytes +} + +type torrentFromHashGetter interface { + Get(shortInfohash) (*Torrent, bool) } type requestStrategyInputMultiTorrent struct { requestStrategyInputCommon - torrents map[metainfo.Hash]*Torrent + torrents torrentFromHashGetter capFunc storage.TorrentCapacity } func (r requestStrategyInputMultiTorrent) Torrent(ih metainfo.Hash) requestStrategy.Torrent { - return requestStrategyTorrent{g.MapMustGet(r.torrents, ih)} + t, ok := r.torrents.Get(ih) + panicif.False(ok) + return requestStrategyTorrent{t} } func (r requestStrategyInputMultiTorrent) Capacity() (int64, bool) { diff --git a/torrent.go b/torrent.go index 993234add6515d50179519563a43b92a05654fef..b5c2f9f642e8d01220e33f865b19fd6872f756f2 100644 --- a/torrent.go +++ b/torrent.go @@ -621,7 +621,7 @@ if t.infoHash.Ok && !t.infoHashV2.Ok { if v1Hash == t.infoHash.Value { if info.HasV2() { t.infoHashV2.Set(v2Hash) - cl.torrentsByShortHash[*v2Hash.ToShort()] = t + cl.torrentsByShortHash.Set(*v2Hash.ToShort(), t) } } else if *v2Hash.ToShort() == t.infoHash.Value { if !info.HasV2() { @@ -630,7 +630,7 @@ } t.infoHashV2.Set(v2Hash) t.infoHash.SetNone() if info.HasV1() { - cl.torrentsByShortHash[v1Hash] = t + cl.torrentsByShortHash.Set(v1Hash, t) t.infoHash.Set(v1Hash) } } @@ -647,7 +647,7 @@ return errors.New("incorrect v2 infohash") } if info.HasV1() { t.infoHash.Set(v1Hash) - cl.torrentsByShortHash[v1Hash] = t + cl.torrentsByShortHash.Set(v1Hash, t) } } else { panic("no expected infohashes") @@ -1102,9 +1102,7 @@ func (t *Torrent) close(wg *sync.WaitGroup) { // Should only be called from the Client. panicif.False(t.closed.Set()) - t.eachShortInfohash(func(short [20]byte) { - delete(t.cl.torrentsByShortHash, short) - }) + // We now keep a weak pointer in torrentsByShortHash for asynchronous cleanup like announcing Stopped. t.deferUpdateRegularTrackerAnnouncing() t.closedCtxCancel(errTorrentClosed) t.getInfoCtxCancel(errTorrentClosed) diff --git a/torrents-by-short-infohash.go b/torrents-by-short-infohash.go new file mode 100644 index 0000000000000000000000000000000000000000..f3ca331d318522779eff69380022d17a8a46c825 --- /dev/null +++ b/torrents-by-short-infohash.go @@ -0,0 +1,90 @@ +package torrent + +import ( + "iter" + "sync" + "unique" + "weak" +) + +type torrentsByShortHash interface { + Get(key shortInfohash) (*Torrent, bool) + IterKeys() iter.Seq[shortInfohash] + Init() +} + +type syncMapTorrentsByShortHash struct { + inner sync.Map +} + +type ( + syncMapTorrentsByShortHashKey = unique.Handle[shortInfohash] + syncMapTorrentsByShortHashValue = weak.Pointer[Torrent] +) + +// sync.Map is zero initialized, but we want this function in case we switch implementations. +func (me *syncMapTorrentsByShortHash) Init() {} + +func (me *syncMapTorrentsByShortHash) IterKeys(yield func(shortInfohash) bool) { + // Do we have to check the values for weak pointers? Probably should to keep the map clean + // to speed up iteration. I wonder if it will introduce overhead to forSkeys. + for key := range me.iter() { + if !yield(key) { + return + } + } +} + +func (me *syncMapTorrentsByShortHash) iter() iter.Seq2[shortInfohash, *Torrent] { + return func(yield func(shortInfohash, *Torrent) bool) { + me.inner.Range(func(key, value any) bool { + uk := key.(syncMapTorrentsByShortHashKey) + v, ok := me.derefValueOrDelete(uk, value.(syncMapTorrentsByShortHashValue)) + if !ok { + // Current value was lost, move on. + return true + } + return yield(uk.Value(), v) + }) + } +} + +func (me *syncMapTorrentsByShortHash) IsEmpty() bool { + for range me.iter() { + return false + } + return true +} + +func (me *syncMapTorrentsByShortHash) derefValueOrDelete( + key syncMapTorrentsByShortHashKey, + wp syncMapTorrentsByShortHashValue, +) (*Torrent, bool) { + t := wp.Value() + if t != nil { + return t, true + } + me.inner.CompareAndDelete(key, wp) + return nil, false +} + +func (me *syncMapTorrentsByShortHash) Get(ih shortInfohash) (t *Torrent, ok bool) { + key := unique.Make(ih) + v, ok := me.inner.Load(key) + if !ok { + return + } + wp := v.(syncMapTorrentsByShortHashValue) + return me.derefValueOrDelete(key, wp) +} + +// Returns true if the key was newly inserted. +func (me *syncMapTorrentsByShortHash) Set(ih shortInfohash, t *Torrent) bool { + _, loaded := me.inner.Swap(unique.Make(ih), weak.Make(t)) + return !loaded +} + +// Returns true if the key was newly inserted. +func (me *syncMapTorrentsByShortHash) Delete(ih shortInfohash) { + me.inner.Delete(unique.Make(ih)) +} diff --git a/webseed-requesting.go b/webseed-requesting.go index 5f0705aeb9e10041b55672c2db5a60862035e68c..ba509be5ff51fcbf7be51a6d4aefc4f952c84d3c 100644 --- a/webseed-requesting.go +++ b/webseed-requesting.go @@ -380,7 +380,8 @@ if !requestStrategy.GetRequestablePieces( input, value.pieces, func(ih metainfo.Hash, pieceIndex int, orderState requestStrategy.PieceRequestOrderState) bool { - t := cl.torrentsByShortHash[ih] + t, ok := cl.torrentsByShortHash.Get(ih) + panicif.False(ok) if len(t.webSeeds) == 0 { return true }