]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peerconn.go
Check that incoming peer request chunk lengths don't exceed the upload rate limiter...
[btrtrc.git] / peerconn.go
index a344934ebae9cf6119e7788423c16e350bd6cc8d..bd6b376b508d7b9606e6b4ad5e61e354a472e7c9 100644 (file)
@@ -5,10 +5,10 @@ import (
        "bytes"
        "errors"
        "fmt"
+       "golang.org/x/time/rate"
        "io"
        "math/rand"
        "net"
-       "sort"
        "strconv"
        "strings"
        "sync/atomic"
@@ -49,11 +49,10 @@ type PeerRemoteAddr interface {
        String() string
 }
 
-// Since we have to store all the requests in memory, we can't reasonably exceed what would be
-// indexable with the memory space available.
 type (
-       maxRequests  = int
-       requestState = request_strategy.PeerRequestState
+       // Since we have to store all the requests in memory, we can't reasonably exceed what could be
+       // indexed with the memory space available.
+       maxRequests = int
 )
 
 type Peer struct {
@@ -86,7 +85,7 @@ type Peer struct {
 
        // Stuff controlled by the local peer.
        needRequestUpdate    string
-       requestState         requestState
+       requestState         request_strategy.PeerRequestState
        updateRequestsTimer  *time.Timer
        lastRequestUpdate    time.Time
        peakRequests         maxRequests
@@ -130,13 +129,7 @@ type Peer struct {
        logger log.Logger
 }
 
-type peerRequests struct {
-       typedRoaring.Bitmap[RequestIndex]
-}
-
-func (p *peerRequests) IterateSnapshot(f func(request_strategy.RequestIndex) bool) {
-       p.Clone().Iterate(f)
-}
+type peerRequests = orderedBitmap[RequestIndex]
 
 func (p *Peer) initRequestState() {
        p.requestState.Requests = &peerRequests{}
@@ -359,13 +352,32 @@ func (cn *Peer) downloadRate() float64 {
        return float64(num) / cn.totalExpectingTime().Seconds()
 }
 
-func (cn *Peer) numRequestsByPiece() (ret map[pieceIndex]int) {
-       ret = make(map[pieceIndex]int)
-       cn.requestState.Requests.Iterate(func(x RequestIndex) bool {
-               ret[cn.t.pieceIndexOfRequestIndex(x)]++
+func (cn *Peer) DownloadRate() float64 {
+       cn.locker().Lock()
+       defer cn.locker().Unlock()
+
+       return cn.downloadRate()
+}
+
+func (cn *Peer) iterContiguousPieceRequests(f func(piece pieceIndex, count int)) {
+       var last Option[pieceIndex]
+       var count int
+       next := func(item Option[pieceIndex]) {
+               if item == last {
+                       count++
+               } else {
+                       if count != 0 {
+                               f(last.Value, count)
+                       }
+                       last = item
+                       count = 1
+               }
+       }
+       cn.requestState.Requests.Iterate(func(requestIndex request_strategy.RequestIndex) bool {
+               next(Some(cn.t.pieceIndexOfRequestIndex(requestIndex)))
                return true
        })
-       return
+       next(None[pieceIndex]())
 }
 
 func (cn *Peer) writeStatus(w io.Writer, t *Torrent) {
@@ -404,20 +416,9 @@ func (cn *Peer) writeStatus(w io.Writer, t *Torrent) {
                cn.downloadRate()/(1<<10),
        )
        fmt.Fprintf(w, "    requested pieces:")
-       type pieceNumRequestsType struct {
-               piece       pieceIndex
-               numRequests int
-       }
-       var pieceNumRequests []pieceNumRequestsType
-       for piece, count := range cn.numRequestsByPiece() {
-               pieceNumRequests = append(pieceNumRequests, pieceNumRequestsType{piece, count})
-       }
-       sort.Slice(pieceNumRequests, func(i, j int) bool {
-               return pieceNumRequests[i].piece < pieceNumRequests[j].piece
+       cn.iterContiguousPieceRequests(func(piece pieceIndex, count int) {
+               fmt.Fprintf(w, " %v(%v)", piece, count)
        })
-       for _, elem := range pieceNumRequests {
-               fmt.Fprintf(w, " %v(%v)", elem.piece, elem.numRequests)
-       }
        fmt.Fprintf(w, "\n")
 }
 
@@ -510,7 +511,7 @@ var (
 
 // The actual value to use as the maximum outbound requests.
 func (cn *Peer) nominalMaxRequests() maxRequests {
-       return maxRequests(maxInt(1, minInt(cn.PeerMaxRequests, cn.peakRequests*2, maxLocalToRemoteRequests)))
+       return maxInt(1, minInt(cn.PeerMaxRequests, cn.peakRequests*2, maxLocalToRemoteRequests))
 }
 
 func (cn *Peer) totalExpectingTime() (ret time.Duration) {
@@ -643,8 +644,10 @@ func (cn *Peer) request(r RequestIndex) (more bool, err error) {
                cn.validReceiveChunks = make(map[RequestIndex]int)
        }
        cn.validReceiveChunks[r]++
-       cn.t.pendingRequests[r] = cn
-       cn.t.lastRequested[r] = time.Now()
+       cn.t.requestState[r] = requestState{
+               peer: cn,
+               when: time.Now(),
+       }
        cn.updateExpectingChunks()
        ppReq := cn.t.requestIndexToRequest(r)
        for _, f := range cn.callbacks.SentRequest {
@@ -984,10 +987,22 @@ func (c *PeerConn) reject(r Request) {
        delete(c.peerRequests, r)
 }
 
-func (c *PeerConn) onReadRequest(r Request) error {
+func (c *PeerConn) maximumPeerRequestChunkLength() (_ Option[int]) {
+       uploadRateLimiter := c.t.cl.config.UploadRateLimiter
+       if uploadRateLimiter.Limit() == rate.Inf {
+               return
+       }
+       return Some(uploadRateLimiter.Burst())
+}
+
+// startFetch is for testing purposes currently.
+func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
        requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1)
        if _, ok := c.peerRequests[r]; ok {
                torrent.Add("duplicate requests received", 1)
+               if c.fastEnabled() {
+                       return errors.New("received duplicate request with fast enabled")
+               }
                return nil
        }
        if c.choking {
@@ -1007,10 +1022,18 @@ func (c *PeerConn) onReadRequest(r Request) error {
                // BEP 6 says we may close here if we choose.
                return nil
        }
+       if opt := c.maximumPeerRequestChunkLength(); opt.Ok && int(r.Length) > opt.Value {
+               err := fmt.Errorf("peer requested chunk too long (%v)", r.Length)
+               c.logger.Levelf(log.Warning, err.Error())
+               if c.fastEnabled() {
+                       c.reject(r)
+                       return nil
+               } else {
+                       return err
+               }
+       }
        if !c.t.havePiece(pieceIndex(r.Index)) {
-               // This isn't necessarily them screwing up. We can drop pieces
-               // from our storage, and can't communicate this to peers
-               // except by reconnecting.
+               // TODO: Tell the peer we don't have the piece, and reject this request.
                requestsReceivedForMissingPieces.Add(1)
                return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int())
        }
@@ -1024,7 +1047,10 @@ func (c *PeerConn) onReadRequest(r Request) error {
        }
        value := &peerRequestState{}
        c.peerRequests[r] = value
-       go c.peerRequestDataReader(r, value)
+       if startFetch {
+               // TODO: Limit peer request data read concurrency.
+               go c.peerRequestDataReader(r, value)
+       }
        return nil
 }
 
@@ -1040,6 +1066,7 @@ func (c *PeerConn) peerRequestDataReader(r Request, prs *peerRequestState) {
                }
                torrent.Add("peer request data read successes", 1)
                prs.data = b
+               // This might be required for the error case too (#752 and #753).
                c.tickleWriter()
        }
 }
@@ -1219,7 +1246,7 @@ func (c *PeerConn) mainReadLoop() (err error) {
                        err = c.peerSentBitfield(msg.Bitfield)
                case pp.Request:
                        r := newRequestFromMessage(&msg)
-                       err = c.onReadRequest(r)
+                       err = c.onReadRequest(r, true)
                case pp.Piece:
                        c.doChunkReadStats(int64(len(msg.Piece)))
                        err = c.receiveChunk(&msg)
@@ -1397,8 +1424,8 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        req := c.t.requestIndexFromRequest(ppReq)
        t := c.t
 
-       if c.bannableAddr.Ok() {
-               t.smartBanCache.RecordBlock(c.bannableAddr.Value(), req, msg.Piece)
+       if c.bannableAddr.Ok {
+               t.smartBanCache.RecordBlock(c.bannableAddr.Value, req, msg.Piece)
        }
 
        if c.peerChoking {
@@ -1470,7 +1497,7 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        piece.unpendChunkIndex(chunkIndexFromChunkSpec(ppReq.ChunkSpec, t.chunkSize))
 
        // Cancel pending requests for this chunk from *other* peers.
-       if p := t.pendingRequests[req]; p != nil {
+       if p := t.requestingPeer(req); p != nil {
                if p == c {
                        panic("should not be pending request from conn that just received it")
                }
@@ -1633,8 +1660,7 @@ func (c *Peer) deleteRequest(r RequestIndex) bool {
        if c.t.requestingPeer(r) != c {
                panic("only one peer should have a given request at a time")
        }
-       delete(c.t.pendingRequests, r)
-       delete(c.t.lastRequested, r)
+       delete(c.t.requestState, r)
        // c.t.iterPeers(func(p *Peer) {
        //      if p.isLowOnRequests() {
        //              p.updateRequests("Peer.deleteRequest")