From: Matt Joiner <anacrolix@gmail.com>
Date: Sun, 23 Apr 2023 01:44:56 +0000 (+1000)
Subject: WIP support for ut_holepunch
X-Git-Tag: v1.51.0~38
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=e86e624415db323121e8b4c0824622cdc03bc874;p=btrtrc.git

WIP support for ut_holepunch
---

diff --git a/client.go b/client.go
index 1ba37871..75960923 100644
--- a/client.go
+++ b/client.go
@@ -19,6 +19,8 @@ import (
 	"strings"
 	"time"
 
+	utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch"
+
 	"github.com/anacrolix/chansync"
 	"github.com/anacrolix/chansync/events"
 	"github.com/anacrolix/dht/v2"
@@ -1045,7 +1047,8 @@ func (cl *Client) sendInitialMessages(conn *PeerConn, torrent *Torrent) {
 			ExtendedPayload: func() []byte {
 				msg := pp.ExtendedHandshakeMessage{
 					M: map[pp.ExtensionName]pp.ExtensionNumber{
-						pp.ExtensionNameMetadata: metadataExtendedId,
+						pp.ExtensionNameMetadata:  metadataExtendedId,
+						utHolepunch.ExtensionName: utHolepunchExtendedId,
 					},
 					V:            cl.config.ExtendedHandshakeClientVersion,
 					Reqq:         localClientReqq,
diff --git a/global.go b/global.go
index 988d434a..dfb6bd4c 100644
--- a/global.go
+++ b/global.go
@@ -24,6 +24,7 @@ const (
 const (
 	metadataExtendedId = iota + 1 // 0 is reserved for deleting keys
 	pexExtendedId
+	utHolepunchExtendedId
 )
 
 func defaultPeerExtensionBytes() PeerExtensionBits {
diff --git a/peer.go b/peer.go
index e88485b2..4e2a24f6 100644
--- a/peer.go
+++ b/peer.go
@@ -21,7 +21,7 @@ import (
 	"github.com/anacrolix/torrent/mse"
 	pp "github.com/anacrolix/torrent/peer_protocol"
 	request_strategy "github.com/anacrolix/torrent/request-strategy"
-	"github.com/anacrolix/torrent/typed-roaring"
+	typedRoaring "github.com/anacrolix/torrent/typed-roaring"
 )
 
 type (
@@ -117,6 +117,7 @@ type (
 )
 
 const (
+	PeerSourceUtHolepunch     = "C"
 	PeerSourceTracker         = "Tr"
 	PeerSourceIncoming        = "I"
 	PeerSourceDhtGetPeers     = "Hg" // Peers we found by searching a DHT.
diff --git a/peer_protocol/ut-holepunch/ut-holepunch.go b/peer_protocol/ut-holepunch/ut-holepunch.go
new file mode 100644
index 00000000..f3ff0c19
--- /dev/null
+++ b/peer_protocol/ut-holepunch/ut-holepunch.go
@@ -0,0 +1,92 @@
+package utHolepunch
+
+import (
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"net/netip"
+)
+
+const ExtensionName = "ut_holepunch"
+
+type (
+	Msg struct {
+		MsgType  MsgType
+		AddrPort netip.AddrPort
+		ErrCode  ErrCode
+	}
+	MsgType  byte
+	AddrType byte
+	ErrCode  uint32
+)
+
+const (
+	Rendezvous MsgType = iota
+	Connect
+	Error
+)
+
+const (
+	Ipv4 AddrType = iota
+	Ipv6 AddrType = iota
+)
+
+const (
+	NoSuchPeer ErrCode = iota + 1
+	NotConnected
+	NoSupport
+	NoSelf
+)
+
+func (m *Msg) UnmarshalBinary(b []byte) error {
+	if len(b) < 12 {
+		return fmt.Errorf("buffer too small to be valid")
+	}
+	m.MsgType = MsgType(b[0])
+	b = b[1:]
+	addrType := AddrType(b[0])
+	b = b[1:]
+	var addr netip.Addr
+	switch addrType {
+	case Ipv4:
+		addr = netip.AddrFrom4([4]byte(b[:4]))
+		b = b[4:]
+	case Ipv6:
+		if len(b) < 22 {
+			return fmt.Errorf("not enough bytes")
+		}
+		addr = netip.AddrFrom16([16]byte(b[:16]))
+		b = b[16:]
+	default:
+		return fmt.Errorf("unhandled addr type value %v", addrType)
+	}
+	port := binary.BigEndian.Uint16(b[:])
+	b = b[2:]
+	m.AddrPort = netip.AddrPortFrom(addr, port)
+	m.ErrCode = ErrCode(binary.BigEndian.Uint32(b[:]))
+	b = b[4:]
+	if len(b) != 0 {
+		return fmt.Errorf("%v trailing unused bytes", len(b))
+	}
+	return nil
+}
+
+func (m *Msg) MarshalBinary() (_ []byte, err error) {
+	var buf bytes.Buffer
+	buf.Grow(24)
+	buf.WriteByte(byte(m.MsgType))
+	addr := m.AddrPort.Addr()
+	switch {
+	case addr.Is4():
+		buf.WriteByte(byte(Ipv4))
+	case addr.Is6():
+		buf.WriteByte(byte(Ipv6))
+	default:
+		err = fmt.Errorf("unhandled addr type: %v", addr)
+		return
+	}
+	buf.Write(addr.AsSlice())
+	binary.Write(&buf, binary.BigEndian, m.AddrPort.Port())
+	binary.Write(&buf, binary.BigEndian, m.ErrCode)
+	return buf.Bytes(), nil
+}
diff --git a/peerconn.go b/peerconn.go
index 1163ad77..c45d87ff 100644
--- a/peerconn.go
+++ b/peerconn.go
@@ -9,10 +9,13 @@ import (
 	"io"
 	"math/rand"
 	"net"
+	"net/netip"
 	"strconv"
 	"strings"
 	"time"
 
+	utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch"
+
 	"github.com/RoaringBitmap/roaring"
 	. "github.com/anacrolix/generics"
 	"github.com/anacrolix/log"
@@ -59,6 +62,8 @@ type PeerConn struct {
 	peerSentHaveAll bool
 
 	peerRequestDataAllocLimiter alloclim.Limiter
+
+	outstandingHolepunchingRendezvous map[netip.AddrPort]struct{}
 }
 
 func (cn *PeerConn) peerImplStatusLines() []string {
@@ -879,6 +884,13 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
 			err = fmt.Errorf("receiving pex message: %w", err)
 		}
 		return
+	case utHolepunchExtendedId:
+		var msg utHolepunch.Msg
+		err = msg.UnmarshalBinary(payload)
+		if err != nil {
+			err = fmt.Errorf("unmarshalling ut_holepunch message: %w", err)
+			return
+		}
 	default:
 		return fmt.Errorf("unexpected extended message ID: %v", id)
 	}
@@ -1050,3 +1062,9 @@ func (cn *PeerConn) PeerPieces() *roaring.Bitmap {
 func (pc *PeerConn) remoteIsTransmission() bool {
 	return bytes.HasPrefix(pc.PeerID[:], []byte("-TR")) && pc.PeerID[7] == '-'
 }
+
+func (pc *PeerConn) remoteAddrPort() Option[netip.AddrPort] {
+	return Some(pc.conn.RemoteAddr().(interface {
+		AddrPort() netip.AddrPort
+	}).AddrPort())
+}
diff --git a/torrent.go b/torrent.go
index b0820e5b..285f5cc8 100644
--- a/torrent.go
+++ b/torrent.go
@@ -17,6 +17,8 @@ import (
 	"time"
 	"unsafe"
 
+	utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch"
+
 	"github.com/RoaringBitmap/roaring"
 	"github.com/anacrolix/chansync"
 	"github.com/anacrolix/chansync/events"
@@ -2357,7 +2359,8 @@ func (t *Torrent) VerifyData() {
 	}
 }
 
-// Start the process of connecting to the given peer for the given torrent if appropriate.
+// Start the process of connecting to the given peer for the given torrent if appropriate. I'm not
+// sure all the PeerInfo fields are being used.
 func (t *Torrent) initiateConn(peer PeerInfo) {
 	if peer.Id == t.cl.peerID {
 		return
@@ -2664,3 +2667,72 @@ func (t *Torrent) checkValidReceiveChunk(r Request) error {
 	// catch most of the overflow manipulation stuff by checking index and begin above.
 	return nil
 }
+
+func (t *Torrent) peerConnsWithRemoteAddrPort(addrPort netip.AddrPort) (ret []*PeerConn) {
+	for pc := range t.conns {
+		addr := pc.remoteAddrPort()
+		if !(addr.Ok && addr.Value == addrPort) {
+			continue
+		}
+		ret = append(ret, pc)
+	}
+	return
+}
+
+func makeUtHolepunchMsgForPeerConn(
+	recipient *PeerConn,
+	msgType utHolepunch.MsgType,
+	addrPort netip.AddrPort,
+	errCode utHolepunch.ErrCode,
+) pp.Message {
+	utHolepunchMsg := utHolepunch.Msg{
+		MsgType:  msgType,
+		AddrPort: addrPort,
+		ErrCode:  errCode,
+	}
+	extendedPayload, err := utHolepunchMsg.MarshalBinary()
+	if err != nil {
+		panic(err)
+	}
+	return pp.Message{
+		Type:            pp.Extended,
+		ExtendedID:      MapMustGet(recipient.PeerExtensionIDs, utHolepunch.ExtensionName),
+		ExtendedPayload: extendedPayload,
+	}
+}
+
+func (t *Torrent) handleReceivedUtHolepunchMsg(msg utHolepunch.Msg, sender *PeerConn) error {
+	switch msg.MsgType {
+	case utHolepunch.Rendezvous:
+		sendMsg := func(
+			pc *PeerConn,
+			msgType utHolepunch.MsgType,
+			addrPort netip.AddrPort,
+			errCode utHolepunch.ErrCode,
+		) {
+			pc.write(makeUtHolepunchMsgForPeerConn(pc, msgType, addrPort, errCode))
+		}
+		targets := t.peerConnsWithRemoteAddrPort(msg.AddrPort)
+		if len(targets) == 0 {
+			sendMsg(sender, utHolepunch.Error, msg.AddrPort, utHolepunch.NotConnected)
+			break
+		}
+		for _, pc := range targets {
+			if !pc.supportsExtension(utHolepunch.ExtensionName) {
+				sendMsg(sender, utHolepunch.Error, msg.AddrPort, utHolepunch.NoSupport)
+				continue
+			}
+			sendMsg(sender, utHolepunch.Connect, msg.AddrPort, 0)
+			sendMsg(pc, utHolepunch.Connect, sender.remoteAddrPort().Unwrap(), 0)
+		}
+	case utHolepunch.Connect:
+		t.initiateConn(PeerInfo{
+			Addr:   msg.AddrPort,
+			Source: PeerSourceUtHolepunch,
+		})
+	case utHolepunch.Error:
+
+	default:
+		return fmt.Errorf("unhandled msg type %v", msg.MsgType)
+	}
+}