]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peerconn_test.go
Write whole piece in BenchmarkConnectionMainReadLoop
[btrtrc.git] / peerconn_test.go
index 8a640544fc8ca3788f7112d8101c33d4a618a872..36327800164388946d1dbcf990754b6d095c41ee 100644 (file)
@@ -4,7 +4,7 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
-       "golang.org/x/sync/errgroup"
+       g "github.com/anacrolix/generics"
        "io"
        "net"
        "net/netip"
@@ -57,7 +57,7 @@ func TestSendBitfieldThenHave(t *testing.T) {
 }
 
 type torrentStorage struct {
-       writeSem sync.Mutex
+       allChunksWritten sync.WaitGroup
 }
 
 func (me *torrentStorage) Close() error { return nil }
@@ -86,7 +86,7 @@ func (me *torrentStorage) WriteAt(b []byte, _ int64) (int, error) {
        if len(b) != defaultChunkSize {
                panic(len(b))
        }
-       me.writeSem.Unlock()
+       me.allChunksWritten.Done()
        return len(b), nil
 }
 
@@ -119,47 +119,76 @@ func BenchmarkConnectionMainReadLoop(b *testing.B) {
        })
        c.Assert(cn.bannableAddr.Ok, qt.IsTrue)
        cn.setTorrent(t)
-       msg := pp.Message{
-               Type:  pp.Piece,
-               Piece: make([]byte, defaultChunkSize),
+       requestIndexBegin := t.pieceRequestIndexOffset(0)
+       requestIndexEnd := t.pieceRequestIndexOffset(1)
+       eachRequestIndex := func(f func(ri RequestIndex)) {
+               for ri := requestIndexBegin; ri < requestIndexEnd; ri++ {
+                       f(ri)
+               }
        }
-       var errGroup errgroup.Group
-       errGroup.Go(func() error {
+       const chunkSize = defaultChunkSize
+       numRequests := requestIndexEnd - requestIndexBegin
+       msgBufs := make([][]byte, 0, numRequests)
+       eachRequestIndex(func(ri RequestIndex) {
+               msgBufs = append(msgBufs, pp.Message{
+                       Type:  pp.Piece,
+                       Piece: make([]byte, chunkSize),
+                       Begin: pp.Integer(chunkSize) * pp.Integer(ri),
+               }.MustMarshalBinary())
+       })
+       // errgroup can't handle this pattern...
+       allErrors := make(chan error, 2)
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
                cl.lock()
                err := cn.mainReadLoop()
                if errors.Is(err, io.EOF) {
                        err = nil
                }
-               return err
-       })
-       wb := msg.MustMarshalBinary()
-       b.SetBytes(int64(len(msg.Piece)))
-       errGroup.Go(func() error {
-               ts.writeSem.Lock()
+               allErrors <- err
+       }()
+       b.SetBytes(chunkSize * int64(numRequests))
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
                for i := 0; i < b.N; i += 1 {
                        cl.lock()
                        // The chunk must be written to storage everytime, to ensure the
                        // writeSem is unlocked.
                        t.pendAllChunkSpecs(0)
-                       cn.validReceiveChunks = map[RequestIndex]int{
-                               t.requestIndexFromRequest(newRequestFromMessage(&msg)): 1,
-                       }
+                       g.MakeMapIfNil(&cn.validReceiveChunks)
+                       eachRequestIndex(func(ri RequestIndex) {
+                               cn.validReceiveChunks[ri] = 1
+                       })
                        cl.unlock()
-                       n, err := w.Write(wb)
-                       require.NoError(b, err)
-                       require.EqualValues(b, len(wb), n)
+                       ts.allChunksWritten.Add(int(numRequests))
+                       for _, wb := range msgBufs {
+                               n, err := w.Write(wb)
+                               require.NoError(b, err)
+                               require.EqualValues(b, len(wb), n)
+                       }
                        // This is unlocked by a successful write to storage. So this unblocks when that is
                        // done.
-                       ts.writeSem.Lock()
+                       ts.allChunksWritten.Wait()
                }
                if err := w.Close(); err != nil {
                        panic(err)
                }
-               return nil
-       })
-       err := errGroup.Wait()
+       }()
+       go func() {
+               wg.Wait()
+               close(allErrors)
+       }()
+       var err error
+       for err = range allErrors {
+               if err != nil {
+                       break
+               }
+       }
        c.Assert(err, qt.IsNil)
-       c.Assert(cn._stats.ChunksReadUseful.Int64(), quicktest.Equals, int64(b.N))
+       c.Assert(cn._stats.ChunksReadUseful.Int64(), quicktest.Equals, int64(b.N)*int64(numRequests))
        c.Assert(t.smartBanCache.HasBlocks(), qt.IsTrue)
 }