import (
"iter"
- g "github.com/anacrolix/generics"
+ "golang.org/x/exp/constraints"
)
-// Returns Some of the last item in a iter.Seq, or None if the sequence is empty.
-func seqLast[V any](seq iter.Seq[V]) (last g.Option[V]) {
- for item := range seq {
- last.Set(item)
+// Returns an iterator that yields integers from start (inclusive) to end (exclusive).
+func iterRange[T constraints.Integer](start, end T) iter.Seq[T] {
+ return func(yield func(T) bool) {
+ for i := start; i < end; i++ {
+ if !yield(i) {
+ return
+ }
+ }
}
- return
}
package smartban
import (
+ "iter"
"sync"
g "github.com/anacrolix/generics"
type Cache[Peer, BlockKey, Hash comparable] struct {
Hash func([]byte) Hash
+ // Wonder if we should make this an atomic.
lock sync.RWMutex
blocks map[BlockKey][]peerAndHash[Peer, Hash]
}
return
}
-func (me *Cache[Peer, BlockKey, Hash]) ForgetBlock(key BlockKey) {
+func (me *Cache[Peer, BlockKey, Hash]) ForgetBlockSeq(seq iter.Seq[BlockKey]) {
me.lock.Lock()
defer me.lock.Unlock()
- delete(me.blocks, key)
+ if len(me.blocks) == 0 {
+ return
+ }
+ for key := range seq {
+ delete(me.blocks, key)
+ }
+}
+
+// Returns whether any block in the sequence has at least once peer recorded.
+func (me *Cache[Peer, BlockKey, Hash]) HasPeerForBlocks(seq iter.Seq[BlockKey]) bool {
+ me.lock.RLock()
+ defer me.lock.RUnlock()
+ if len(me.blocks) == 0 {
+ return false
+ }
+ for key := range seq {
+ if len(me.blocks[key]) != 0 {
+ return true
+ }
+ }
+ return false
}
func (me *Cache[Peer, BlockKey, Hash]) HasBlocks() bool {
return pp.Integer(t.info.PieceLength)
}
-func (t *Torrent) smartBanBlockCheckingWriter(piece pieceIndex) *blockCheckingWriter {
- return &blockCheckingWriter{
+func (t *Torrent) getBlockCheckingWriterForPiece(piece pieceIndex) blockCheckingWriter {
+ return blockCheckingWriter{
cache: &t.smartBanCache,
requestIndex: t.pieceRequestIndexBegin(piece),
chunkSize: t.chunkSize.Int(),
}
}
+func (t *Torrent) hasSmartbanDataForPiece(piece pieceIndex) bool {
+ return t.smartBanCache.HasPeerForBlocks(iterRange(t.pieceRequestIndexBegin(piece), t.pieceRequestIndexBegin(piece+1)))
+}
+
func (t *Torrent) countBytesHashed(n int64) {
t.counters.BytesHashed.Add(n)
t.cl.counters.BytesHashed.Add(n)
differingPeers map[bannableAddr]struct{},
err error,
) {
+ var w io.Writer = h
+ if t.hasSmartbanDataForPiece(piece) {
+ smartBanWriter := t.getBlockCheckingWriterForPiece(piece)
+ w = io.MultiWriter(h, &smartBanWriter)
+ defer func() {
+ if err != nil {
+ // Skip smart banning since we can't blame them for storage issues. A short write would
+ // ban peers for all recorded blocks that weren't just written.
+ return
+ }
+ // Flush now, even though we may not have finished writing to the piece hash, since
+ // further data is padding only and should not have come from peers.
+ smartBanWriter.Flush()
+ differingPeers = smartBanWriter.badPeers
+ }()
+ }
p := t.piece(piece)
storagePiece := p.Storage()
-
- smartBanWriter := t.smartBanBlockCheckingWriter(piece)
- multiWriter := io.MultiWriter(h, smartBanWriter)
- {
- var written int64
- written, err = storagePiece.WriteTo(multiWriter)
- if err == nil && written != int64(p.length()) {
- err = fmt.Errorf("wrote %v bytes from storage, piece has length %v", written, p.length())
- // Skip smart banning since we can't blame them for storage issues. A short write would
- // ban peers for all recorded blocks that weren't just written.
- return
- }
- t.countBytesHashed(written)
+ var written int64
+ written, err = storagePiece.WriteTo(w)
+ if err == nil && written != int64(p.length()) {
+ err = fmt.Errorf("wrote %v bytes from storage, piece has length %v", written, p.length())
}
- // Flush before writing padding, since we would not have recorded the padding blocks.
- smartBanWriter.Flush()
- differingPeers = smartBanWriter.badPeers
+ t.countBytesHashed(written)
return
}
t.logger.WithDefaultLevel(log.Debug).Printf("smart banned %v for piece %v", peer, index)
}
t.dropBannedPeers()
- for ri := t.pieceRequestIndexBegin(index); ri < t.pieceRequestIndexBegin(index+1); ri++ {
- t.smartBanCache.ForgetBlock(ri)
- }
+ t.smartBanCache.ForgetBlockSeq(iterRange(t.pieceRequestIndexBegin(index), t.pieceRequestIndexBegin(index+1)))
}
p.hashing = false
t.pieceHashed(index, correct, copyErr)