"encoding/binary"
"errors"
"fmt"
- "golang.org/x/sync/errgroup"
+ g "github.com/anacrolix/generics"
"io"
"net"
"net/netip"
}
type torrentStorage struct {
- writeSem sync.Mutex
+ allChunksWritten sync.WaitGroup
}
func (me *torrentStorage) Close() error { return nil }
if len(b) != defaultChunkSize {
panic(len(b))
}
- me.writeSem.Unlock()
+ me.allChunksWritten.Done()
return len(b), nil
}
})
c.Assert(cn.bannableAddr.Ok, qt.IsTrue)
cn.setTorrent(t)
- msg := pp.Message{
- Type: pp.Piece,
- Piece: make([]byte, defaultChunkSize),
+ requestIndexBegin := t.pieceRequestIndexOffset(0)
+ requestIndexEnd := t.pieceRequestIndexOffset(1)
+ eachRequestIndex := func(f func(ri RequestIndex)) {
+ for ri := requestIndexBegin; ri < requestIndexEnd; ri++ {
+ f(ri)
+ }
}
- var errGroup errgroup.Group
- errGroup.Go(func() error {
+ const chunkSize = defaultChunkSize
+ numRequests := requestIndexEnd - requestIndexBegin
+ msgBufs := make([][]byte, 0, numRequests)
+ eachRequestIndex(func(ri RequestIndex) {
+ msgBufs = append(msgBufs, pp.Message{
+ Type: pp.Piece,
+ Piece: make([]byte, chunkSize),
+ Begin: pp.Integer(chunkSize) * pp.Integer(ri),
+ }.MustMarshalBinary())
+ })
+ // errgroup can't handle this pattern...
+ allErrors := make(chan error, 2)
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
cl.lock()
err := cn.mainReadLoop()
if errors.Is(err, io.EOF) {
err = nil
}
- return err
- })
- wb := msg.MustMarshalBinary()
- b.SetBytes(int64(len(msg.Piece)))
- errGroup.Go(func() error {
- ts.writeSem.Lock()
+ allErrors <- err
+ }()
+ b.SetBytes(chunkSize * int64(numRequests))
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
for i := 0; i < b.N; i += 1 {
cl.lock()
// The chunk must be written to storage everytime, to ensure the
// writeSem is unlocked.
t.pendAllChunkSpecs(0)
- cn.validReceiveChunks = map[RequestIndex]int{
- t.requestIndexFromRequest(newRequestFromMessage(&msg)): 1,
- }
+ g.MakeMapIfNil(&cn.validReceiveChunks)
+ eachRequestIndex(func(ri RequestIndex) {
+ cn.validReceiveChunks[ri] = 1
+ })
cl.unlock()
- n, err := w.Write(wb)
- require.NoError(b, err)
- require.EqualValues(b, len(wb), n)
+ ts.allChunksWritten.Add(int(numRequests))
+ for _, wb := range msgBufs {
+ n, err := w.Write(wb)
+ require.NoError(b, err)
+ require.EqualValues(b, len(wb), n)
+ }
// This is unlocked by a successful write to storage. So this unblocks when that is
// done.
- ts.writeSem.Lock()
+ ts.allChunksWritten.Wait()
}
if err := w.Close(); err != nil {
panic(err)
}
- return nil
- })
- err := errGroup.Wait()
+ }()
+ go func() {
+ wg.Wait()
+ close(allErrors)
+ }()
+ var err error
+ for err = range allErrors {
+ if err != nil {
+ break
+ }
+ }
c.Assert(err, qt.IsNil)
- c.Assert(cn._stats.ChunksReadUseful.Int64(), quicktest.Equals, int64(b.N))
+ c.Assert(cn._stats.ChunksReadUseful.Int64(), quicktest.Equals, int64(b.N)*int64(numRequests))
c.Assert(t.smartBanCache.HasBlocks(), qt.IsTrue)
}