return Some(uploadRateLimiter.Burst())
}
+// Returns whether any part of the chunk would lie outside a piece of the given length.
+func chunkOverflowsPiece(cs ChunkSpec, pieceLength pp.Integer) bool {
+ switch {
+ default:
+ return false
+ case cs.Begin+cs.Length > pieceLength:
+ // Check for integer overflow
+ case cs.Begin > pp.IntegerMax-cs.Length:
+ }
+ return true
+}
+
// startFetch is for testing purposes currently.
func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1)
requestsReceivedForMissingPieces.Add(1)
return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int())
}
+ pieceLength := c.t.pieceLength(pieceIndex(r.Index))
// Check this after we know we have the piece, so that the piece length will be known.
- if r.Begin+r.Length > c.t.pieceLength(pieceIndex(r.Index)) {
+ if chunkOverflowsPiece(r.ChunkSpec, pieceLength) {
torrent.Add("bad requests received", 1)
- return errors.New("bad Request")
+ return errors.New("chunk overflows piece")
}
if c.peerRequests == nil {
c.peerRequests = make(map[Request]*peerRequestState, localClientReqq)
case pp.Request:
r := newRequestFromMessage(&msg)
err = c.onReadRequest(r, true)
+ if err != nil {
+ err = fmt.Errorf("on reading request %v: %w", r, err)
+ }
case pp.Piece:
c.doChunkReadStats(int64(len(msg.Piece)))
err = c.receiveChunk(&msg)
"encoding/binary"
"errors"
"fmt"
- "golang.org/x/time/rate"
"io"
"net"
"sync"
"testing"
+ "golang.org/x/time/rate"
+
"github.com/frankban/quicktest"
qt "github.com/frankban/quicktest"
"github.com/stretchr/testify/require"
c.Check(pc.onReadRequest(req, false), qt.IsNil)
c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17)
}
+
+func TestChunkOverflowsPiece(t *testing.T) {
+ c := qt.New(t)
+ check := func(begin, length, limit pp.Integer, expected bool) {
+ c.Check(chunkOverflowsPiece(ChunkSpec{begin, length}, limit), qt.Equals, expected)
+ }
+ check(2, 3, 1, true)
+ check(2, pp.IntegerMax, 1, true)
+ check(2, pp.IntegerMax, 3, true)
+ check(2, pp.IntegerMax, pp.IntegerMax, true)
+ check(2, pp.IntegerMax-2, pp.IntegerMax, false)
+}