peer_protocol/int.go | 9 ++++++++- peerconn.go | 20 ++++++++++++++++++-- peerconn_test.go | 15 ++++++++++++++- diff --git a/peer_protocol/int.go b/peer_protocol/int.go index 13bd1ca9c2027e6449a9b07d15b1a70bb2b8d791..ebcf60355fd122fd0cc55444dcd2254dc88ec310 100644 --- a/peer_protocol/int.go +++ b/peer_protocol/int.go @@ -3,11 +3,18 @@ import ( "encoding/binary" "io" + "math" "github.com/pkg/errors" ) -type Integer uint32 +type ( + // An alias for the underlying type of Integer. This is needed for fuzzing. + IntegerKind = uint32 + Integer IntegerKind +) + +const IntegerMax = math.MaxUint32 func (i *Integer) UnmarshalBinary(b []byte) error { if len(b) != 4 { diff --git a/peerconn.go b/peerconn.go index 3c515aaf989d130e88670a89215abb14b1364e53..fdc8236d600f726fef653f051c27a8160859410d 100644 --- a/peerconn.go +++ b/peerconn.go @@ -1003,6 +1003,18 @@ } return Some(uploadRateLimiter.Burst()) } +// Returns whether any part of the chunk would lie outside a piece of the given length. +func chunkOverflowsPiece(cs ChunkSpec, pieceLength pp.Integer) bool { + switch { + default: + return false + case cs.Begin+cs.Length > pieceLength: + // Check for integer overflow + case cs.Begin > pp.IntegerMax-cs.Length: + } + return true +} + // startFetch is for testing purposes currently. func (c *PeerConn) onReadRequest(r Request, startFetch bool) error { requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1) @@ -1045,10 +1057,11 @@ // TODO: Tell the peer we don't have the piece, and reject this request. requestsReceivedForMissingPieces.Add(1) return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int()) } + pieceLength := c.t.pieceLength(pieceIndex(r.Index)) // Check this after we know we have the piece, so that the piece length will be known. - if r.Begin+r.Length > c.t.pieceLength(pieceIndex(r.Index)) { + if chunkOverflowsPiece(r.ChunkSpec, pieceLength) { torrent.Add("bad requests received", 1) - return errors.New("bad Request") + return errors.New("chunk overflows piece") } if c.peerRequests == nil { c.peerRequests = make(map[Request]*peerRequestState, localClientReqq) @@ -1255,6 +1268,9 @@ err = c.peerSentBitfield(msg.Bitfield) case pp.Request: r := newRequestFromMessage(&msg) err = c.onReadRequest(r, true) + if err != nil { + err = fmt.Errorf("on reading request %v: %w", r, err) + } case pp.Piece: c.doChunkReadStats(int64(len(msg.Piece))) err = c.receiveChunk(&msg) diff --git a/peerconn_test.go b/peerconn_test.go index 23d32286cd66b82ca86cd3d04b51d6e6f77fca5c..42f8fe273131b29dda29d4d7e666975d2858e023 100644 --- a/peerconn_test.go +++ b/peerconn_test.go @@ -4,11 +4,12 @@ import ( "encoding/binary" "errors" "fmt" - "golang.org/x/time/rate" "io" "net" "sync" "testing" + + "golang.org/x/time/rate" "github.com/frankban/quicktest" qt "github.com/frankban/quicktest" @@ -317,3 +318,15 @@ req.Length = 2 << 20 c.Check(pc.onReadRequest(req, false), qt.IsNil) c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17) } + +func TestChunkOverflowsPiece(t *testing.T) { + c := qt.New(t) + check := func(begin, length, limit pp.Integer, expected bool) { + c.Check(chunkOverflowsPiece(ChunkSpec{begin, length}, limit), qt.Equals, expected) + } + check(2, 3, 1, true) + check(2, pp.IntegerMax, 1, true) + check(2, pp.IntegerMax, 3, true) + check(2, pp.IntegerMax, pp.IntegerMax, true) + check(2, pp.IntegerMax-2, pp.IntegerMax, false) +}