Should fix #759.
return fmt.Errorf("adding connection: %w", err)
}
defer t.dropConnection(c)
- c.startWriter()
+ c.startMessageWriter()
cl.sendInitialMessages(c, t)
c.initUpdateRequestsTimer()
err := c.mainReadLoop()
pp "github.com/anacrolix/torrent/peer_protocol"
)
-func (pc *PeerConn) startWriter() {
+func (pc *PeerConn) initMessageWriter() {
w := &pc.messageWriter
*w = peerConnMsgWriter{
fillWriteBuffer: func() {
},
writeBuffer: new(bytes.Buffer),
}
- go func() {
- defer pc.locker().Unlock()
- defer pc.close()
- defer pc.locker().Lock()
- pc.messageWriter.run(pc.t.cl.config.KeepAliveTimeout)
- }()
+}
+
+func (pc *PeerConn) startMessageWriter() {
+ pc.initMessageWriter()
+ go pc.messageWriterRunner()
+}
+
+func (pc *PeerConn) messageWriterRunner() {
+ defer pc.locker().Unlock()
+ defer pc.close()
+ defer pc.locker().Lock()
+ pc.messageWriter.run(pc.t.cl.config.KeepAliveTimeout)
}
type peerConnMsgWriter struct {
"bytes"
"errors"
"fmt"
+ "golang.org/x/time/rate"
"io"
"math/rand"
"net"
delete(c.peerRequests, r)
}
-func (c *PeerConn) onReadRequest(r Request) error {
+func (c *PeerConn) maximumPeerRequestChunkLength() (_ Option[int]) {
+ uploadRateLimiter := c.t.cl.config.UploadRateLimiter
+ if uploadRateLimiter.Limit() == rate.Inf {
+ return
+ }
+ return Some(uploadRateLimiter.Burst())
+}
+
+// startFetch is for testing purposes currently.
+func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1)
if _, ok := c.peerRequests[r]; ok {
torrent.Add("duplicate requests received", 1)
+ if c.fastEnabled() {
+ return errors.New("received duplicate request with fast enabled")
+ }
return nil
}
if c.choking {
// BEP 6 says we may close here if we choose.
return nil
}
+ if opt := c.maximumPeerRequestChunkLength(); opt.Ok && int(r.Length) > opt.Value {
+ err := fmt.Errorf("peer requested chunk too long (%v)", r.Length)
+ c.logger.Levelf(log.Warning, err.Error())
+ if c.fastEnabled() {
+ c.reject(r)
+ return nil
+ } else {
+ return err
+ }
+ }
if !c.t.havePiece(pieceIndex(r.Index)) {
- // This isn't necessarily them screwing up. We can drop pieces
- // from our storage, and can't communicate this to peers
- // except by reconnecting.
+ // TODO: Tell the peer we don't have the piece, and reject this request.
requestsReceivedForMissingPieces.Add(1)
return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int())
}
}
value := &peerRequestState{}
c.peerRequests[r] = value
- go c.peerRequestDataReader(r, value)
+ if startFetch {
+ // TODO: Limit peer request data read concurrency.
+ go c.peerRequestDataReader(r, value)
+ }
return nil
}
err = c.peerSentBitfield(msg.Bitfield)
case pp.Request:
r := newRequestFromMessage(&msg)
- err = c.onReadRequest(r)
+ err = c.onReadRequest(r, true)
case pp.Piece:
c.doChunkReadStats(int64(len(msg.Piece)))
err = c.receiveChunk(&msg)
"encoding/binary"
"errors"
"fmt"
+ "golang.org/x/time/rate"
"io"
"net"
"sync"
r, w := io.Pipe()
// c.r = r
c.w = w
- c.startWriter()
+ c.startMessageWriter()
c.locker().Lock()
c.t._completedPieces.Add(1)
c.postBitfield( /*[]bool{false, true, false}*/ )
// No difference
c.Assert(pc(1, 2, false, false, false).hasPreferredNetworkOver(pc(1, 2, false, false, false)), qt.IsFalse)
}
+
+func TestReceiveLargeRequest(t *testing.T) {
+ c := qt.New(t)
+ cl := newTestingClient(t)
+ pc := cl.newConnection(nil, false, nil, "test", "")
+ tor := cl.newTorrentForTesting()
+ tor.info = &metainfo.Info{PieceLength: 3 << 20}
+ pc.setTorrent(tor)
+ tor._completedPieces.Add(0)
+ pc.PeerExtensionBytes.SetBit(pp.ExtensionBitFast, true)
+ pc.choking = false
+ pc.initMessageWriter()
+ req := Request{}
+ req.Length = defaultChunkSize
+ c.Assert(pc.fastEnabled(), qt.IsTrue)
+ c.Check(pc.onReadRequest(req, false), qt.IsNil)
+ c.Check(pc.peerRequests, qt.HasLen, 1)
+ req.Length = 2 << 20
+ c.Check(pc.onReadRequest(req, false), qt.IsNil)
+ c.Check(pc.peerRequests, qt.HasLen, 2)
+ pc.peerRequests = nil
+ pc.t.cl.config.UploadRateLimiter = rate.NewLimiter(1, defaultChunkSize)
+ req.Length = defaultChunkSize
+ c.Check(pc.onReadRequest(req, false), qt.IsNil)
+ c.Check(pc.peerRequests, qt.HasLen, 1)
+ req.Length = 2 << 20
+ c.Check(pc.onReadRequest(req, false), qt.IsNil)
+ c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17)
+}