]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peerconn.go
Ability to override fifos/
[btrtrc.git] / peerconn.go
index 41b3f113f297e2c23d797f0b7e533003732a694a..322e7625ba37b59884dbc7051a23b897893dc1e4 100644 (file)
@@ -8,7 +8,6 @@ import (
        "io"
        "math/rand"
        "net"
-       "sort"
        "strconv"
        "strings"
        "sync/atomic"
@@ -21,11 +20,14 @@ import (
        "github.com/anacrolix/missinggo/iter"
        "github.com/anacrolix/missinggo/v2/bitmap"
        "github.com/anacrolix/multiless"
+       "golang.org/x/time/rate"
+
        "github.com/anacrolix/torrent/bencode"
        "github.com/anacrolix/torrent/metainfo"
        "github.com/anacrolix/torrent/mse"
        pp "github.com/anacrolix/torrent/peer_protocol"
        request_strategy "github.com/anacrolix/torrent/request-strategy"
+       "github.com/anacrolix/torrent/typed-roaring"
 )
 
 type PeerSource string
@@ -48,11 +50,10 @@ type PeerRemoteAddr interface {
        String() string
 }
 
-// Since we have to store all the requests in memory, we can't reasonably exceed what would be
-// indexable with the memory space available.
 type (
-       maxRequests  = int
-       requestState = request_strategy.PeerRequestState
+       // Since we have to store all the requests in memory, we can't reasonably exceed what could be
+       // indexed with the memory space available.
+       maxRequests = int
 )
 
 type Peer struct {
@@ -64,10 +65,13 @@ type Peer struct {
        peerImpl
        callbacks *Callbacks
 
-       outgoing     bool
-       Network      string
-       RemoteAddr   PeerRemoteAddr
-       bannableAddr Option[bannableAddr]
+       outgoing   bool
+       Network    string
+       RemoteAddr PeerRemoteAddr
+       // The local address as observed by the remote peer. WebRTC seems to get this right without needing hints from the
+       // config.
+       localPublicAddr peerLocalPublicAddr
+       bannableAddr    Option[bannableAddr]
        // True if the connection is operating over MSE obfuscation.
        headerEncrypted bool
        cryptoMethod    mse.CryptoMethod
@@ -85,7 +89,7 @@ type Peer struct {
 
        // Stuff controlled by the local peer.
        needRequestUpdate    string
-       requestState         requestState
+       requestState         request_strategy.PeerRequestState
        updateRequestsTimer  *time.Timer
        lastRequestUpdate    time.Time
        peakRequests         maxRequests
@@ -120,7 +124,7 @@ type Peer struct {
        peerMinPieces pieceIndex
        // Pieces we've accepted chunks for from the peer.
        peerTouchedPieces map[pieceIndex]struct{}
-       peerAllowedFast   roaring.Bitmap
+       peerAllowedFast   typedRoaring.Bitmap[pieceIndex]
 
        PeerMaxRequests  maxRequests // Maximum pending requests the peer allows.
        PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber
@@ -129,6 +133,12 @@ type Peer struct {
        logger log.Logger
 }
 
+type peerRequests = orderedBitmap[RequestIndex]
+
+func (p *Peer) initRequestState() {
+       p.requestState.Requests = &peerRequests{}
+}
+
 // Maintains the state of a BitTorrent-protocol based connection with a peer.
 type PeerConn struct {
        Peer
@@ -161,8 +171,8 @@ type PeerConn struct {
        peerSentHaveAll bool
 }
 
-func (cn *PeerConn) connStatusString() string {
-       return fmt.Sprintf("%+-55q %s %s", cn.PeerID, cn.PeerExtensionBytes, cn.connString)
+func (cn *PeerConn) peerImplStatusLines() []string {
+       return []string{fmt.Sprintf("%+-55q %s %s", cn.PeerID, cn.PeerExtensionBytes, cn.connString)}
 }
 
 func (cn *Peer) updateExpectingChunks() {
@@ -189,11 +199,11 @@ func (cn *Peer) expectingChunks() bool {
                return true
        }
        haveAllowedFastRequests := false
-       cn.peerAllowedFast.Iterate(func(i uint32) bool {
-               haveAllowedFastRequests = roaringBitmapRangeCardinality(
-                       &cn.requestState.Requests,
-                       cn.t.pieceRequestIndexOffset(pieceIndex(i)),
-                       cn.t.pieceRequestIndexOffset(pieceIndex(i+1)),
+       cn.peerAllowedFast.Iterate(func(i pieceIndex) bool {
+               haveAllowedFastRequests = roaringBitmapRangeCardinality[RequestIndex](
+                       cn.requestState.Requests,
+                       cn.t.pieceRequestIndexOffset(i),
+                       cn.t.pieceRequestIndexOffset(i+1),
                ) == 0
                return !haveAllowedFastRequests
        })
@@ -201,7 +211,7 @@ func (cn *Peer) expectingChunks() bool {
 }
 
 func (cn *Peer) remoteChokingPiece(piece pieceIndex) bool {
-       return cn.peerChoking && !cn.peerAllowedFast.Contains(bitmap.BitIndex(piece))
+       return cn.peerChoking && !cn.peerAllowedFast.Contains(piece)
 }
 
 // Returns true if the connection is over IPv6.
@@ -273,6 +283,10 @@ func (cn *Peer) completedString() string {
        return fmt.Sprintf("%d/%d", have, cn.bestPeerNumPieces())
 }
 
+func (cn *Peer) CompletedString() string {
+       return cn.completedString()
+}
+
 func (cn *PeerConn) onGotInfo(info *metainfo.Info) {
        cn.setNumPieces(info.NumPieces())
 }
@@ -338,6 +352,10 @@ func (cn *Peer) statusFlags() (ret string) {
        return
 }
 
+func (cn *Peer) StatusFlags() string {
+       return cn.statusFlags()
+}
+
 func (cn *Peer) downloadRate() float64 {
        num := cn._stats.BytesReadUsefulData.Int64()
        if num == 0 {
@@ -346,13 +364,43 @@ func (cn *Peer) downloadRate() float64 {
        return float64(num) / cn.totalExpectingTime().Seconds()
 }
 
-func (cn *Peer) numRequestsByPiece() (ret map[pieceIndex]int) {
-       ret = make(map[pieceIndex]int)
-       cn.requestState.Requests.Iterate(func(x uint32) bool {
-               ret[pieceIndex(x/cn.t.chunksPerRegularPiece())]++
+func (cn *Peer) DownloadRate() float64 {
+       cn.locker().RLock()
+       defer cn.locker().RUnlock()
+
+       return cn.downloadRate()
+}
+
+func (cn *Peer) UploadRate() float64 {
+       cn.locker().RLock()
+       defer cn.locker().RUnlock()
+       num := cn._stats.BytesWrittenData.Int64()
+       if num == 0 {
+               return 0
+       }
+       return float64(num) / time.Now().Sub(cn.completedHandshake).Seconds()
+}
+
+
+func (cn *Peer) iterContiguousPieceRequests(f func(piece pieceIndex, count int)) {
+       var last Option[pieceIndex]
+       var count int
+       next := func(item Option[pieceIndex]) {
+               if item == last {
+                       count++
+               } else {
+                       if count != 0 {
+                               f(last.Value, count)
+                       }
+                       last = item
+                       count = 1
+               }
+       }
+       cn.requestState.Requests.Iterate(func(requestIndex request_strategy.RequestIndex) bool {
+               next(Some(cn.t.pieceIndexOfRequestIndex(requestIndex)))
                return true
        })
-       return
+       next(None[pieceIndex]())
 }
 
 func (cn *Peer) writeStatus(w io.Writer, t *Torrent) {
@@ -360,14 +408,14 @@ func (cn *Peer) writeStatus(w io.Writer, t *Torrent) {
        if cn.closed.IsSet() {
                fmt.Fprint(w, "CLOSED: ")
        }
-       fmt.Fprintln(w, cn.connStatusString())
+       fmt.Fprintln(w, strings.Join(cn.peerImplStatusLines(), "\n"))
        prio, err := cn.peerPriority()
        prioStr := fmt.Sprintf("%08x", prio)
        if err != nil {
                prioStr += ": " + err.Error()
        }
-       fmt.Fprintf(w, "    bep40-prio: %v\n", prioStr)
-       fmt.Fprintf(w, "    last msg: %s, connected: %s, last helpful: %s, itime: %s, etime: %s\n",
+       fmt.Fprintf(w, "bep40-prio: %v\n", prioStr)
+       fmt.Fprintf(w, "last msg: %s, connected: %s, last helpful: %s, itime: %s, etime: %s\n",
                eventAgeString(cn.lastMessageReceived),
                eventAgeString(cn.completedHandshake),
                eventAgeString(cn.lastHelpful()),
@@ -375,7 +423,7 @@ func (cn *Peer) writeStatus(w io.Writer, t *Torrent) {
                cn.totalExpectingTime(),
        )
        fmt.Fprintf(w,
-               "    %s completed, %d pieces touched, good chunks: %v/%v:%v reqq: %d+%v/(%d/%d):%d/%d, flags: %s, dr: %.1f KiB/s\n",
+               "%s completed, %d pieces touched, good chunks: %v/%v:%v reqq: %d+%v/(%d/%d):%d/%d, flags: %s, dr: %.1f KiB/s\n",
                cn.completedString(),
                len(cn.peerTouchedPieces),
                &cn._stats.ChunksReadUseful,
@@ -390,21 +438,10 @@ func (cn *Peer) writeStatus(w io.Writer, t *Torrent) {
                cn.statusFlags(),
                cn.downloadRate()/(1<<10),
        )
-       fmt.Fprintf(w, "    requested pieces:")
-       type pieceNumRequestsType struct {
-               piece       pieceIndex
-               numRequests int
-       }
-       var pieceNumRequests []pieceNumRequestsType
-       for piece, count := range cn.numRequestsByPiece() {
-               pieceNumRequests = append(pieceNumRequests, pieceNumRequestsType{piece, count})
-       }
-       sort.Slice(pieceNumRequests, func(i, j int) bool {
-               return pieceNumRequests[i].piece < pieceNumRequests[j].piece
+       fmt.Fprintf(w, "requested pieces:")
+       cn.iterContiguousPieceRequests(func(piece pieceIndex, count int) {
+               fmt.Fprintf(w, " %v(%v)", piece, count)
        })
-       for _, elem := range pieceNumRequests {
-               fmt.Fprintf(w, " %v(%v)", elem.piece, elem.numRequests)
-       }
        fmt.Fprintf(w, "\n")
 }
 
@@ -497,7 +534,7 @@ var (
 
 // The actual value to use as the maximum outbound requests.
 func (cn *Peer) nominalMaxRequests() maxRequests {
-       return maxRequests(maxInt(1, minInt(cn.PeerMaxRequests, cn.peakRequests*2, maxLocalToRemoteRequests)))
+       return maxInt(1, minInt(cn.PeerMaxRequests, cn.peakRequests*2, maxLocalToRemoteRequests))
 }
 
 func (cn *Peer) totalExpectingTime() (ret time.Duration) {
@@ -578,7 +615,11 @@ type messageWriter func(pp.Message) bool
 // This function seems to only used by Peer.request. It's all logic checks, so maybe we can no-op it
 // when we want to go fast.
 func (cn *Peer) shouldRequest(r RequestIndex) error {
-       pi := pieceIndex(r / cn.t.chunksPerRegularPiece())
+       err := cn.t.checkValidReceiveChunk(cn.t.requestIndexToRequest(r))
+       if err != nil {
+               return err
+       }
+       pi := cn.t.pieceIndexOfRequestIndex(r)
        if cn.requestState.Cancelled.Contains(r) {
                return errors.New("request is cancelled and waiting acknowledgement")
        }
@@ -597,7 +638,7 @@ func (cn *Peer) shouldRequest(r RequestIndex) error {
        if cn.t.pieceQueuedForHash(pi) {
                panic("piece is queued for hash")
        }
-       if cn.peerChoking && !cn.peerAllowedFast.Contains(bitmap.BitIndex(pi)) {
+       if cn.peerChoking && !cn.peerAllowedFast.Contains(pi) {
                // This could occur if we made a request with the fast extension, and then got choked and
                // haven't had the request rejected yet.
                if !cn.requestState.Requests.Contains(r) {
@@ -630,8 +671,10 @@ func (cn *Peer) request(r RequestIndex) (more bool, err error) {
                cn.validReceiveChunks = make(map[RequestIndex]int)
        }
        cn.validReceiveChunks[r]++
-       cn.t.pendingRequests[r] = cn
-       cn.t.lastRequested[r] = time.Now()
+       cn.t.requestState[r] = requestState{
+               peer: cn,
+               when: time.Now(),
+       }
        cn.updateExpectingChunks()
        ppReq := cn.t.requestIndexToRequest(r)
        for _, f := range cn.callbacks.SentRequest {
@@ -971,10 +1014,22 @@ func (c *PeerConn) reject(r Request) {
        delete(c.peerRequests, r)
 }
 
-func (c *PeerConn) onReadRequest(r Request) error {
+func (c *PeerConn) maximumPeerRequestChunkLength() (_ Option[int]) {
+       uploadRateLimiter := c.t.cl.config.UploadRateLimiter
+       if uploadRateLimiter.Limit() == rate.Inf {
+               return
+       }
+       return Some(uploadRateLimiter.Burst())
+}
+
+// startFetch is for testing purposes currently.
+func (c *PeerConn) onReadRequest(r Request, startFetch bool) error {
        requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1)
        if _, ok := c.peerRequests[r]; ok {
                torrent.Add("duplicate requests received", 1)
+               if c.fastEnabled() {
+                       return errors.New("received duplicate request with fast enabled")
+               }
                return nil
        }
        if c.choking {
@@ -994,10 +1049,18 @@ func (c *PeerConn) onReadRequest(r Request) error {
                // BEP 6 says we may close here if we choose.
                return nil
        }
+       if opt := c.maximumPeerRequestChunkLength(); opt.Ok && int(r.Length) > opt.Value {
+               err := fmt.Errorf("peer requested chunk too long (%v)", r.Length)
+               c.logger.Levelf(log.Warning, err.Error())
+               if c.fastEnabled() {
+                       c.reject(r)
+                       return nil
+               } else {
+                       return err
+               }
+       }
        if !c.t.havePiece(pieceIndex(r.Index)) {
-               // This isn't necessarily them screwing up. We can drop pieces
-               // from our storage, and can't communicate this to peers
-               // except by reconnecting.
+               // TODO: Tell the peer we don't have the piece, and reject this request.
                requestsReceivedForMissingPieces.Add(1)
                return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int())
        }
@@ -1011,7 +1074,10 @@ func (c *PeerConn) onReadRequest(r Request) error {
        }
        value := &peerRequestState{}
        c.peerRequests[r] = value
-       go c.peerRequestDataReader(r, value)
+       if startFetch {
+               // TODO: Limit peer request data read concurrency.
+               go c.peerRequestDataReader(r, value)
+       }
        return nil
 }
 
@@ -1027,6 +1093,7 @@ func (c *PeerConn) peerRequestDataReader(r Request, prs *peerRequestState) {
                }
                torrent.Add("peer request data read successes", 1)
                prs.data = b
+               // This might be required for the error case too (#752 and #753).
                c.tickleWriter()
        }
 }
@@ -1152,13 +1219,7 @@ func (c *PeerConn) mainReadLoop() (err error) {
                                break
                        }
                        if !c.fastEnabled() {
-                               if !c.deleteAllRequests().IsEmpty() {
-                                       c.t.iterPeers(func(p *Peer) {
-                                               if p.isLowOnRequests() {
-                                                       p.updateRequests("choked by non-fast PeerConn")
-                                               }
-                                       })
-                               }
+                               c.deleteAllRequests("choked by non-fast PeerConn")
                        } else {
                                // We don't decrement pending requests here, let's wait for the peer to either
                                // reject or satisfy the outstanding requests. Additionally, some peers may unchoke
@@ -1178,8 +1239,8 @@ func (c *PeerConn) mainReadLoop() (err error) {
                        }
                        c.peerChoking = false
                        preservedCount := 0
-                       c.requestState.Requests.Iterate(func(x uint32) bool {
-                               if !c.peerAllowedFast.Contains(x / c.t.chunksPerRegularPiece()) {
+                       c.requestState.Requests.Iterate(func(x RequestIndex) bool {
+                               if !c.peerAllowedFast.Contains(c.t.pieceIndexOfRequestIndex(x)) {
                                        preservedCount++
                                }
                                return true
@@ -1212,7 +1273,7 @@ func (c *PeerConn) mainReadLoop() (err error) {
                        err = c.peerSentBitfield(msg.Bitfield)
                case pp.Request:
                        r := newRequestFromMessage(&msg)
-                       err = c.onReadRequest(r)
+                       err = c.onReadRequest(r, true)
                case pp.Piece:
                        c.doChunkReadStats(int64(len(msg.Piece)))
                        err = c.receiveChunk(&msg)
@@ -1387,11 +1448,16 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        chunksReceived.Add("total", 1)
 
        ppReq := newRequestFromMessage(msg)
-       req := c.t.requestIndexFromRequest(ppReq)
        t := c.t
+       err := t.checkValidReceiveChunk(ppReq)
+       if err != nil {
+               err = log.WithLevel(log.Warning, err)
+               return err
+       }
+       req := c.t.requestIndexFromRequest(ppReq)
 
-       if c.bannableAddr.Ok() {
-               t.smartBanCache.RecordBlock(c.bannableAddr.Value(), req, msg.Piece)
+       if c.bannableAddr.Ok {
+               t.smartBanCache.RecordBlock(c.bannableAddr.Value, req, msg.Piece)
        }
 
        if c.peerChoking {
@@ -1404,7 +1470,7 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        }
        c.decExpectedChunkReceive(req)
 
-       if c.peerChoking && c.peerAllowedFast.Contains(bitmap.BitIndex(ppReq.Index)) {
+       if c.peerChoking && c.peerAllowedFast.Contains(pieceIndex(ppReq.Index)) {
                chunksReceived.Add("due to allowed fast", 1)
        }
 
@@ -1463,14 +1529,14 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
        piece.unpendChunkIndex(chunkIndexFromChunkSpec(ppReq.ChunkSpec, t.chunkSize))
 
        // Cancel pending requests for this chunk from *other* peers.
-       if p := t.pendingRequests[req]; p != nil {
+       if p := t.requestingPeer(req); p != nil {
                if p == c {
                        panic("should not be pending request from conn that just received it")
                }
                p.cancel(req)
        }
 
-       err := func() error {
+       err = func() error {
                cl.unlock()
                defer cl.lock()
                concurrentChunkWrites.Add(1)
@@ -1626,8 +1692,7 @@ func (c *Peer) deleteRequest(r RequestIndex) bool {
        if c.t.requestingPeer(r) != c {
                panic("only one peer should have a given request at a time")
        }
-       delete(c.t.pendingRequests, r)
-       delete(c.t.lastRequested, r)
+       delete(c.t.requestState, r)
        // c.t.iterPeers(func(p *Peer) {
        //      if p.isLowOnRequests() {
        //              p.updateRequests("Peer.deleteRequest")
@@ -1636,15 +1701,22 @@ func (c *Peer) deleteRequest(r RequestIndex) bool {
        return true
 }
 
-func (c *Peer) deleteAllRequests() (deleted *roaring.Bitmap) {
-       deleted = c.requestState.Requests.Clone()
-       deleted.Iterate(func(x uint32) bool {
+func (c *Peer) deleteAllRequests(reason string) {
+       if c.requestState.Requests.IsEmpty() {
+               return
+       }
+       c.requestState.Requests.IterateSnapshot(func(x RequestIndex) bool {
                if !c.deleteRequest(x) {
                        panic("request should exist")
                }
                return true
        })
        c.assertNoRequests()
+       c.t.iterPeers(func(p *Peer) {
+               if p.isLowOnRequests() {
+                       p.updateRequests(reason)
+               }
+       })
        return
 }
 
@@ -1654,9 +1726,8 @@ func (c *Peer) assertNoRequests() {
        }
 }
 
-func (c *Peer) cancelAllRequests() (cancelled *roaring.Bitmap) {
-       cancelled = c.requestState.Requests.Clone()
-       cancelled.Iterate(func(x uint32) bool {
+func (c *Peer) cancelAllRequests() {
+       c.requestState.Requests.IterateSnapshot(func(x RequestIndex) bool {
                c.cancel(x)
                return true
        })
@@ -1690,7 +1761,7 @@ func (c *PeerConn) setTorrent(t *Torrent) {
 }
 
 func (c *Peer) peerPriority() (peerPriority, error) {
-       return bep40Priority(c.remoteIpPort(), c.t.cl.publicAddr(c.remoteIp()))
+       return bep40Priority(c.remoteIpPort(), c.localPublicAddr)
 }
 
 func (c *Peer) remoteIp() net.IP {
@@ -1784,6 +1855,10 @@ func (cn *Peer) stats() *ConnStats {
        return &cn._stats
 }
 
+func (cn *Peer) Stats() *ConnStats {
+       return cn.stats()
+}
+
 func (p *Peer) TryAsPeerConn() (*PeerConn, bool) {
        pc, ok := p.peerImpl.(*PeerConn)
        return pc, ok