From: Matt Joiner <anacrolix@gmail.com>
Date: Wed, 12 Jan 2022 03:23:30 +0000 (+1100)
Subject: Do smart banning on existing badPeerIPs
X-Git-Tag: v1.42.0~8^2^2~15
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=4a8611b23ef65d0309f96ed5c552d62e3e35a76d;p=btrtrc.git

Do smart banning on existing badPeerIPs
---

diff --git a/client.go b/client.go
index 2a208fc9..773159e3 100644
--- a/client.go
+++ b/client.go
@@ -13,6 +13,7 @@ import (
 	"math"
 	"net"
 	"net/http"
+	"net/netip"
 	"sort"
 	"strconv"
 	"strings"
@@ -28,6 +29,7 @@ import (
 	"github.com/anacrolix/missinggo/v2/bitmap"
 	"github.com/anacrolix/missinggo/v2/pproffd"
 	"github.com/anacrolix/sync"
+	"github.com/anacrolix/torrent/generics"
 	"github.com/anacrolix/torrent/option"
 	request_strategy "github.com/anacrolix/torrent/request-strategy"
 	"github.com/davecgh/go-spew/spew"
@@ -75,8 +77,7 @@ type Client struct {
 	// include ourselves if we end up trying to connect to our own address
 	// through legitimate channels.
 	dopplegangerAddrs map[string]struct{}
-	badPeerIPs        map[string]struct{}
-	bannedPrefixes    map[string]struct{}
+	badPeerIPs        map[netip.Addr]struct{}
 	torrents          map[InfoHash]*Torrent
 	pieceRequestOrder map[interface{}]*request_strategy.PieceRequestOrder
 
@@ -103,7 +104,7 @@ func (cl *Client) badPeerIPsLocked() (ips []string) {
 	ips = make([]string, len(cl.badPeerIPs))
 	i := 0
 	for k := range cl.badPeerIPs {
-		ips[i] = k
+		ips[i] = k.String()
 		i += 1
 	}
 	return
@@ -201,7 +202,7 @@ func (cl *Client) announceKey() int32 {
 // Initializes a bare minimum Client. *Client and *ClientConfig must not be nil.
 func (cl *Client) init(cfg *ClientConfig) {
 	cl.config = cfg
-	cl.dopplegangerAddrs = make(map[string]struct{})
+	generics.MakeMap(&cl.dopplegangerAddrs)
 	cl.torrents = make(map[metainfo.Hash]*Torrent)
 	cl.dialRateLimiter = rate.NewLimiter(10, 10)
 	cl.activeAnnounceLimiter.SlotsPerKey = 2
@@ -213,7 +214,6 @@ func (cl *Client) init(cfg *ClientConfig) {
 			MaxConnsPerHost: 10,
 		},
 	}
-	cl.bannedPrefixes = make(map[banPrefix]struct{})
 }
 
 func NewClient(cfg *ClientConfig) (cl *Client, err error) {
@@ -1130,12 +1130,6 @@ func (cl *Client) badPeerAddr(addr PeerRemoteAddr) bool {
 	if ipa, ok := tryIpPortFromNetAddr(addr); ok {
 		return cl.badPeerIPPort(ipa.IP, ipa.Port)
 	}
-	addrStr := addr.String()
-	for prefix := range cl.bannedPrefixes {
-		if strings.HasPrefix(addrStr, prefix) {
-			return true
-		}
-	}
 	return false
 }
 
@@ -1149,7 +1143,11 @@ func (cl *Client) badPeerIPPort(ip net.IP, port int) bool {
 	if _, ok := cl.ipBlockRange(ip); ok {
 		return true
 	}
-	if _, ok := cl.badPeerIPs[ip.String()]; ok {
+	ipAddr, ok := netip.AddrFromSlice(ip)
+	if !ok {
+		panic(ip)
+	}
+	if _, ok := cl.badPeerIPs[ipAddr]; ok {
 		return true
 	}
 	return false
@@ -1495,11 +1493,7 @@ func (cl *Client) AddDhtNodes(nodes []string) {
 }
 
 func (cl *Client) banPeerIP(ip net.IP) {
-	cl.logger.Printf("banning ip %v", ip)
-	if cl.badPeerIPs == nil {
-		cl.badPeerIPs = make(map[string]struct{})
-	}
-	cl.badPeerIPs[ip.String()] = struct{}{}
+	generics.MakeMapIfNilAndSet(&cl.badPeerIPs, netip.MustParseAddr(ip.String()), struct{}{})
 }
 
 func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr PeerRemoteAddr, network, connString string) (c *PeerConn) {
@@ -1520,8 +1514,12 @@ func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr PeerRemot
 		connString: connString,
 		conn:       nc,
 	}
+	// TODO: Need to be much more explicit about this, including allowing non-IP bannable addresses.
 	if remoteAddr != nil {
-		c.banPrefix = option.Some(remoteAddr.String())
+		netipAddrPort, err := netip.ParseAddrPort(remoteAddr.String())
+		if err == nil {
+			c.bannableAddr = option.Some(netipAddrPort.Addr())
+		}
 	}
 	c.peerImpl = c
 	c.logger = cl.logger.WithDefaultLevel(log.Warning).WithContextValue(c)
@@ -1691,7 +1689,3 @@ func (cl *Client) String() string {
 func (cl *Client) ConnStats() ConnStats {
 	return cl.stats.Copy()
 }
-
-func (cl *Client) banPrefix(prefix banPrefix) {
-	cl.bannedPrefixes[prefix] = struct{}{}
-}
diff --git a/peerconn.go b/peerconn.go
index 566553f9..2dd2ced4 100644
--- a/peerconn.go
+++ b/peerconn.go
@@ -65,10 +65,10 @@ type Peer struct {
 	peerImpl
 	callbacks *Callbacks
 
-	outgoing   bool
-	Network    string
-	RemoteAddr PeerRemoteAddr
-	banPrefix  option.T[string]
+	outgoing     bool
+	Network      string
+	RemoteAddr   PeerRemoteAddr
+	bannableAddr option.T[bannableAddr]
 	// True if the connection is operating over MSE obfuscation.
 	headerEncrypted bool
 	cryptoMethod    mse.CryptoMethod
@@ -1390,8 +1390,8 @@ func (c *Peer) receiveChunk(msg *pp.Message) error {
 	req := c.t.requestIndexFromRequest(ppReq)
 	t := c.t
 
-	if c.banPrefix.Ok() {
-		t.smartBanCache.RecordBlock(c.banPrefix.Value(), req, msg.Piece)
+	if c.bannableAddr.Ok() {
+		t.smartBanCache.RecordBlock(c.bannableAddr.Value(), req, msg.Piece)
 	}
 
 	if c.peerChoking {
diff --git a/smartban.go b/smartban.go
index 9f43104f..74f645e2 100644
--- a/smartban.go
+++ b/smartban.go
@@ -3,19 +3,21 @@ package torrent
 import (
 	"bytes"
 	"crypto/sha1"
+	"net/netip"
 
+	"github.com/anacrolix/torrent/generics"
 	"github.com/anacrolix/torrent/smartban"
 )
 
-type banPrefix = string
+type bannableAddr = netip.Addr
 
-type smartBanCache = smartban.Cache[banPrefix, RequestIndex, [sha1.Size]byte]
+type smartBanCache = smartban.Cache[bannableAddr, RequestIndex, [sha1.Size]byte]
 
 type blockCheckingWriter struct {
 	cache        *smartBanCache
 	requestIndex RequestIndex
 	// Peers that didn't match blocks written now.
-	badPeers    map[banPrefix]struct{}
+	badPeers    map[bannableAddr]struct{}
 	blockBuffer bytes.Buffer
 	chunkSize   int
 }
@@ -23,7 +25,7 @@ type blockCheckingWriter struct {
 func (me *blockCheckingWriter) checkBlock() {
 	b := me.blockBuffer.Next(me.chunkSize)
 	for _, peer := range me.cache.CheckBlock(me.requestIndex, b) {
-		me.badPeers[peer] = struct{}{}
+		generics.MakeMapIfNilAndSet(&me.badPeers, peer, struct{}{})
 	}
 	me.requestIndex++
 }
diff --git a/torrent.go b/torrent.go
index 02e7c399..a70bbd7d 100644
--- a/torrent.go
+++ b/torrent.go
@@ -8,6 +8,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net/netip"
 	"net/url"
 	"sort"
 	"strings"
@@ -27,6 +28,7 @@ import (
 	"github.com/anacrolix/missinggo/v2/bitmap"
 	"github.com/anacrolix/multiless"
 	"github.com/anacrolix/sync"
+	"github.com/anacrolix/torrent/option"
 	request_strategy "github.com/anacrolix/torrent/request-strategy"
 	"github.com/davecgh/go-spew/spew"
 	"github.com/pion/datachannel"
@@ -952,7 +954,7 @@ func (t *Torrent) smartBanBlockCheckingWriter(piece pieceIndex) *blockCheckingWr
 func (t *Torrent) hashPiece(piece pieceIndex) (
 	ret metainfo.Hash,
 	// These are peers that sent us blocks that differ from what we hash here.
-	differingPeers map[banPrefix]struct{},
+	differingPeers map[bannableAddr]struct{},
 	err error,
 ) {
 	p := t.piece(piece)
@@ -2035,8 +2037,11 @@ func (t *Torrent) pieceHashed(piece pieceIndex, passed bool, hashIoErr error) {
 
 			if len(bannableTouchers) >= 1 {
 				c := bannableTouchers[0]
-				t.cl.banPeerIP(c.remoteIp())
-				c.drop()
+				log.Printf("would have banned %v for touching piece %v after failed piece check", c.remoteIp(), piece)
+				if false {
+					t.cl.banPeerIP(c.remoteIp())
+					c.drop()
+				}
 			}
 		}
 		t.onIncompletePiece(piece)
@@ -2124,15 +2129,38 @@ func (t *Torrent) getPieceToHash() (ret pieceIndex, ok bool) {
 	return
 }
 
+func (t *Torrent) dropBannedPeers() {
+	t.iterPeers(func(p *Peer) {
+		remoteIp := p.remoteIp()
+		if remoteIp == nil {
+			if p.bannableAddr.Ok() {
+				log.Printf("can't get remote ip for peer %v", p)
+			}
+			return
+		}
+		netipAddr := netip.MustParseAddr(remoteIp.String())
+		if option.Some(netipAddr) != p.bannableAddr {
+			log.Printf(
+				"peer remote ip does not match its bannable addr [peer=%v, remote ip=%v, bannable addr=%v]",
+				p, remoteIp, p.bannableAddr)
+		}
+		if _, ok := t.cl.badPeerIPs[netipAddr]; ok {
+			p.drop()
+			log.Printf("dropped %v for banned remote IP %v", p, netipAddr)
+		}
+	})
+}
+
 func (t *Torrent) pieceHasher(index pieceIndex) {
 	p := t.piece(index)
 	sum, failedPeers, copyErr := t.hashPiece(index)
 	correct := sum == *p.hash
 	if correct {
 		for peer := range failedPeers {
-			log.Printf("would smart ban %q for %v here", peer, p)
-			t.cl.banPrefix(peer)
+			t.cl.banPeerIP(peer.AsSlice())
+			log.Printf("smart banned %v for piece %v", peer, index)
 		}
+		t.dropBannedPeers()
 	}
 	switch copyErr {
 	case nil, io.EOF: