]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Merge branch 'issue-905': Smart ban hash performance improvements
authorMatt Joiner <anacrolix@gmail.com>
Tue, 20 Feb 2024 11:07:59 +0000 (22:07 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 22 Feb 2024 03:27:57 +0000 (14:27 +1100)
Fixes #905.

client.go
go.mod
go.sum
peer.go
peerconn_test.go
smartban.go
smartban/smartban.go
smartban_test.go [new file with mode: 0644]
torrent.go

index 62c0d2b4c52c5678990405fa0b779ad84ed041b3..7aab0402a70fc5333e44e759ac96b306ae07c696 100644 (file)
--- a/client.go
+++ b/client.go
@@ -4,12 +4,12 @@ import (
        "bufio"
        "context"
        "crypto/rand"
-       "crypto/sha1"
        "encoding/binary"
        "encoding/hex"
        "errors"
        "expvar"
        "fmt"
+       "github.com/cespare/xxhash"
        "io"
        "math"
        "net"
@@ -1301,7 +1301,14 @@ func (cl *Client) newTorrentOpt(opts AddTorrentOpts) (t *Torrent) {
                webSeeds:     make(map[string]*Peer),
                gotMetainfoC: make(chan struct{}),
        }
-       t.smartBanCache.Hash = sha1.Sum
+       var salt [8]byte
+       rand.Read(salt[:])
+       t.smartBanCache.Hash = func(b []byte) uint64 {
+               h := xxhash.New()
+               h.Write(salt[:])
+               h.Write(b)
+               return h.Sum64()
+       }
        t.smartBanCache.Init()
        t.networkingEnabled.Set()
        t.logger = cl.logger.WithDefaultLevel(log.Debug)
diff --git a/go.mod b/go.mod
index 68120007ebcff1b0ef08b38faef530fbdd64fd30..1b0f5816fee11a28622fffb7e94d20305efc2bf7 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -25,6 +25,7 @@ require (
        github.com/anacrolix/utp v0.1.0
        github.com/bahlo/generic-list-go v0.2.0
        github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8
+       github.com/cespare/xxhash v1.1.0
        github.com/davecgh/go-spew v1.1.1
        github.com/dustin/go-humanize v1.0.0
        github.com/edsrzf/mmap-go v1.1.0
diff --git a/go.sum b/go.sum
index ecf851997883f090e7d8e17ce9bfe8d4672e235a..c8ce30f99d6495b690a4055859456b85ea883c53 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -142,6 +142,7 @@ github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8/go.mod h1:spo1JLcs67
 github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4=
 github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
 github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
diff --git a/peer.go b/peer.go
index 23bf5468f08bd1ede676c55a69ea1db75346a7ab..d3ea15161ab8261d5c0df48a6eddd4a05f9a2df0 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -6,6 +6,7 @@ import (
        "io"
        "net"
        "strings"
+       "sync"
        "time"
 
        "github.com/RoaringBitmap/roaring"
@@ -600,9 +601,13 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        }
        req := c.t.requestIndexFromRequest(ppReq)
 
-       if c.bannableAddr.Ok {
-               t.smartBanCache.RecordBlock(c.bannableAddr.Value, req, msg.Piece)
-       }
+       recordBlockForSmartBan := sync.OnceFunc(func() {
+               c.recordBlockForSmartBan(req, msg.Piece)
+       })
+       // This needs to occur before we return, but we try to do it when the client is unlocked. It
+       // can't be done before checking if chunks are valid because they won't be deallocated by piece
+       // hashing if they're out of bounds.
+       defer recordBlockForSmartBan()
 
        if c.peerChoking {
                chunksReceived.Add("while choked", 1)
@@ -683,6 +688,8 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        err = func() error {
                cl.unlock()
                defer cl.lock()
+               // Opportunistically do this here while we aren't holding the client lock.
+               recordBlockForSmartBan()
                concurrentChunkWrites.Add(1)
                defer concurrentChunkWrites.Add(-1)
                // Write the chunk out. Note that the upper bound on chunk writing concurrency will be the
@@ -875,3 +882,9 @@ func (p *Peer) decPeakRequests() {
        // }
        p.peakRequests--
 }
+
+func (p *Peer) recordBlockForSmartBan(req RequestIndex, blockData []byte) {
+       if p.bannableAddr.Ok {
+               p.t.smartBanCache.RecordBlock(p.bannableAddr.Value, req, blockData)
+       }
+}
index 3fdcbff1c7663b76368d736cf5658c341e39ee94..36327800164388946d1dbcf990754b6d095c41ee 100644 (file)
@@ -4,8 +4,10 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
+       g "github.com/anacrolix/generics"
        "io"
        "net"
+       "net/netip"
        "sync"
        "testing"
 
@@ -55,7 +57,7 @@ func TestSendBitfieldThenHave(t *testing.T) {
 }
 
 type torrentStorage struct {
-       writeSem sync.Mutex
+       allChunksWritten sync.WaitGroup
 }
 
 func (me *torrentStorage) Close() error { return nil }
@@ -84,7 +86,7 @@ func (me *torrentStorage) WriteAt(b []byte, _ int64) (int, error) {
        if len(b) != defaultChunkSize {
                panic(len(b))
        }
-       me.writeSem.Unlock()
+       me.allChunksWritten.Done()
        return len(b), nil
 }
 
@@ -107,53 +109,87 @@ func BenchmarkConnectionMainReadLoop(b *testing.B) {
        t.onSetInfo()
        t._pendingPieces.Add(0)
        r, w := net.Pipe()
+       c.Logf("pipe reader remote addr: %v", r.RemoteAddr())
        cn := cl.newConnection(r, newConnectionOpts{
-               outgoing:   true,
-               remoteAddr: r.RemoteAddr(),
+               outgoing: true,
+               // TODO: This is a hack to give the pipe a bannable remote address.
+               remoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 1234),
                network:    r.RemoteAddr().Network(),
                connString: regularNetConnPeerConnConnString(r),
        })
+       c.Assert(cn.bannableAddr.Ok, qt.IsTrue)
        cn.setTorrent(t)
-       mrlErrChan := make(chan error)
-       msg := pp.Message{
-               Type:  pp.Piece,
-               Piece: make([]byte, defaultChunkSize),
+       requestIndexBegin := t.pieceRequestIndexOffset(0)
+       requestIndexEnd := t.pieceRequestIndexOffset(1)
+       eachRequestIndex := func(f func(ri RequestIndex)) {
+               for ri := requestIndexBegin; ri < requestIndexEnd; ri++ {
+                       f(ri)
+               }
        }
+       const chunkSize = defaultChunkSize
+       numRequests := requestIndexEnd - requestIndexBegin
+       msgBufs := make([][]byte, 0, numRequests)
+       eachRequestIndex(func(ri RequestIndex) {
+               msgBufs = append(msgBufs, pp.Message{
+                       Type:  pp.Piece,
+                       Piece: make([]byte, chunkSize),
+                       Begin: pp.Integer(chunkSize) * pp.Integer(ri),
+               }.MustMarshalBinary())
+       })
+       // errgroup can't handle this pattern...
+       allErrors := make(chan error, 2)
+       var wg sync.WaitGroup
+       wg.Add(1)
        go func() {
+               defer wg.Done()
                cl.lock()
                err := cn.mainReadLoop()
-               if err != nil {
-                       mrlErrChan <- err
+               if errors.Is(err, io.EOF) {
+                       err = nil
                }
-               close(mrlErrChan)
+               allErrors <- err
        }()
-       wb := msg.MustMarshalBinary()
-       b.SetBytes(int64(len(msg.Piece)))
+       b.SetBytes(chunkSize * int64(numRequests))
+       wg.Add(1)
        go func() {
-               ts.writeSem.Lock()
+               defer wg.Done()
                for i := 0; i < b.N; i += 1 {
                        cl.lock()
                        // The chunk must be written to storage everytime, to ensure the
                        // writeSem is unlocked.
                        t.pendAllChunkSpecs(0)
-                       cn.validReceiveChunks = map[RequestIndex]int{
-                               t.requestIndexFromRequest(newRequestFromMessage(&msg)): 1,
-                       }
+                       g.MakeMapIfNil(&cn.validReceiveChunks)
+                       eachRequestIndex(func(ri RequestIndex) {
+                               cn.validReceiveChunks[ri] = 1
+                       })
                        cl.unlock()
-                       n, err := w.Write(wb)
-                       require.NoError(b, err)
-                       require.EqualValues(b, len(wb), n)
-                       ts.writeSem.Lock()
+                       ts.allChunksWritten.Add(int(numRequests))
+                       for _, wb := range msgBufs {
+                               n, err := w.Write(wb)
+                               require.NoError(b, err)
+                               require.EqualValues(b, len(wb), n)
+                       }
+                       // This is unlocked by a successful write to storage. So this unblocks when that is
+                       // done.
+                       ts.allChunksWritten.Wait()
                }
                if err := w.Close(); err != nil {
                        panic(err)
                }
        }()
-       mrlErr := <-mrlErrChan
-       if mrlErr != nil && !errors.Is(mrlErr, io.EOF) {
-               c.Fatal(mrlErr)
+       go func() {
+               wg.Wait()
+               close(allErrors)
+       }()
+       var err error
+       for err = range allErrors {
+               if err != nil {
+                       break
+               }
        }
-       c.Assert(cn._stats.ChunksReadUseful.Int64(), quicktest.Equals, int64(b.N))
+       c.Assert(err, qt.IsNil)
+       c.Assert(cn._stats.ChunksReadUseful.Int64(), quicktest.Equals, int64(b.N)*int64(numRequests))
+       c.Assert(t.smartBanCache.HasBlocks(), qt.IsTrue)
 }
 
 func TestConnPexPeerFlags(t *testing.T) {
index 034a702d950372400c8dddb002fa1121304073e5..5515ded51fa3ef94f145a57439ef9e877dee6719 100644 (file)
@@ -2,7 +2,6 @@ package torrent
 
 import (
        "bytes"
-       "crypto/sha1"
        "net/netip"
 
        g "github.com/anacrolix/generics"
@@ -12,7 +11,7 @@ import (
 
 type bannableAddr = netip.Addr
 
-type smartBanCache = smartban.Cache[bannableAddr, RequestIndex, [sha1.Size]byte]
+type smartBanCache = smartban.Cache[bannableAddr, RequestIndex, uint64]
 
 type blockCheckingWriter struct {
        cache        *smartBanCache
index 96e9b759a5d17aa1eba9f95530e65862237f2570..ba568c98d041b4eb38408f69f4c77cc398bbe889 100644 (file)
@@ -1,6 +1,7 @@
 package smartban
 
 import (
+       g "github.com/anacrolix/generics"
        "sync"
 )
 
@@ -8,7 +9,7 @@ type Cache[Peer, BlockKey, Hash comparable] struct {
        Hash func([]byte) Hash
 
        lock   sync.RWMutex
-       blocks map[BlockKey]map[Peer]Hash
+       blocks map[BlockKey][]peerAndHash[Peer, Hash]
 }
 
 type Block[Key any] struct {
@@ -16,8 +17,13 @@ type Block[Key any] struct {
        Data []byte
 }
 
+type peerAndHash[Peer, Hash any] struct {
+       Peer Peer
+       Hash Hash
+}
+
 func (me *Cache[Peer, BlockKey, Hash]) Init() {
-       me.blocks = make(map[BlockKey]map[Peer]Hash)
+       g.MakeMap(&me.blocks)
 }
 
 func (me *Cache[Peer, BlockKey, Hash]) RecordBlock(peer Peer, key BlockKey, data []byte) {
@@ -25,20 +31,17 @@ func (me *Cache[Peer, BlockKey, Hash]) RecordBlock(peer Peer, key BlockKey, data
        me.lock.Lock()
        defer me.lock.Unlock()
        peers := me.blocks[key]
-       if peers == nil {
-               peers = make(map[Peer]Hash)
-               me.blocks[key] = peers
-       }
-       peers[peer] = hash
+       peers = append(peers, peerAndHash[Peer, Hash]{peer, hash})
+       me.blocks[key] = peers
 }
 
 func (me *Cache[Peer, BlockKey, Hash]) CheckBlock(key BlockKey, data []byte) (bad []Peer) {
        correct := me.Hash(data)
        me.lock.RLock()
        defer me.lock.RUnlock()
-       for peer, hash := range me.blocks[key] {
-               if hash != correct {
-                       bad = append(bad, peer)
+       for _, item := range me.blocks[key] {
+               if item.Hash != correct {
+                       bad = append(bad, item.Peer)
                }
        }
        return
@@ -49,3 +52,9 @@ func (me *Cache[Peer, BlockKey, Hash]) ForgetBlock(key BlockKey) {
        defer me.lock.Unlock()
        delete(me.blocks, key)
 }
+
+func (me *Cache[Peer, BlockKey, Hash]) HasBlocks() bool {
+       me.lock.RLock()
+       defer me.lock.RUnlock()
+       return len(me.blocks) != 0
+}
diff --git a/smartban_test.go b/smartban_test.go
new file mode 100644 (file)
index 0000000..2947f52
--- /dev/null
@@ -0,0 +1,39 @@
+package torrent
+
+import (
+       "crypto/sha1"
+       "github.com/anacrolix/missinggo/v2/iter"
+       "github.com/anacrolix/torrent/smartban"
+       "github.com/cespare/xxhash"
+       "net/netip"
+       "testing"
+)
+
+func benchmarkSmartBanRecordBlock[Sum comparable](b *testing.B, hash func([]byte) Sum) {
+       var cache smartban.Cache[bannableAddr, RequestIndex, Sum]
+       cache.Hash = hash
+       cache.Init()
+       var data [defaultChunkSize]byte
+       var addr netip.Addr
+       b.SetBytes(int64(len(data)))
+       for i := range iter.N(b.N) {
+               cache.RecordBlock(addr, RequestIndex(i), data[:])
+       }
+}
+
+func BenchmarkSmartBanRecordBlock(b *testing.B) {
+       b.Run("xxHash", func(b *testing.B) {
+               var salt [8]byte
+               benchmarkSmartBanRecordBlock(b, func(block []byte) uint64 {
+                       h := xxhash.New()
+                       // xxHash is not cryptographic, and so we're salting it so attackers can't know a priori
+                       // where block data collisions are.
+                       h.Write(salt[:])
+                       h.Write(block)
+                       return h.Sum64()
+               })
+       })
+       b.Run("Sha1", func(b *testing.B) {
+               benchmarkSmartBanRecordBlock(b, sha1.Sum)
+       })
+}
index 2643dd6a69237a394196166f1f9dca22b305b56d..63bfe7442e900c80acfd3a6451eb4824b66ee92b 100644 (file)
@@ -24,7 +24,6 @@ import (
        . "github.com/anacrolix/generics"
        g "github.com/anacrolix/generics"
        "github.com/anacrolix/log"
-       "github.com/anacrolix/missinggo/perf"
        "github.com/anacrolix/missinggo/slices"
        "github.com/anacrolix/missinggo/v2"
        "github.com/anacrolix/missinggo/v2/bitmap"
@@ -949,7 +948,7 @@ func (t *Torrent) offsetRequest(off int64) (req Request, ok bool) {
 }
 
 func (t *Torrent) writeChunk(piece int, begin int64, data []byte) (err error) {
-       defer perf.ScopeTimerErr(&err)()
+       //defer perf.ScopeTimerErr(&err)()
        n, err := t.pieces[piece].Storage().WriteAt(data, begin)
        if err == nil && n != len(data) {
                err = io.ErrShortWrite