]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add low level support for BEP 10 user protocols
authorMatt Joiner <anacrolix@gmail.com>
Thu, 22 Feb 2024 03:36:47 +0000 (14:36 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 22 Feb 2024 03:36:47 +0000 (14:36 +1100)
callbacks.go
client.go
global.go
ltep.go [new file with mode: 0644]
ltep_test.go [new file with mode: 0644]
peer.go
peer_protocol/extended.go
peerconn.go
pexconn_test.go

index f9ba131b136e8859a801837c275f4c247ae60904..0c66bc50f78c7ca1e59f9721bc4a2edbd7b37dab 100644 (file)
@@ -11,10 +11,14 @@ import (
 type Callbacks struct {
        // Called after a peer connection completes the BitTorrent handshake. The Client lock is not
        // held.
-       CompletedHandshake    func(*PeerConn, InfoHash)
-       ReadMessage           func(*PeerConn, *pp.Message)
+       CompletedHandshake func(*PeerConn, InfoHash)
+       ReadMessage        func(*PeerConn, *pp.Message)
+       // This can be folded into the general case below.
        ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage)
        PeerConnClosed        func(*PeerConn)
+       // BEP 10 message. Not sure if I should call this Ltep universally. Each handler here is called
+       // in order.
+       PeerConnReadExtensionMessage []func(PeerConnReadExtensionMessageEvent)
 
        // Provides secret keys to be tried against incoming encrypted connections.
        ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter
@@ -25,6 +29,11 @@ type Callbacks struct {
        SentRequest        []func(PeerRequestEvent)
        PeerClosed         []func(*Peer)
        NewPeer            []func(*Peer)
+       // Called when a PeerConn has been added to a Torrent. It's finished all BitTorrent protocol
+       // handshakes, and is about to start sending and receiving BitTorrent messages. The extended
+       // handshake has not yet occurred. This is a good time to alter the supported extension
+       // protocols.
+       PeerConnAdded []func(*PeerConn)
 }
 
 type ReceivedUsefulDataEvent = PeerMessageEvent
@@ -38,3 +47,10 @@ type PeerRequestEvent struct {
        Peer *Peer
        Request
 }
+
+type PeerConnReadExtensionMessageEvent struct {
+       PeerConn *PeerConn
+       // You can look up what protocol this corresponds to using the PeerConn.LocalLtepProtocolMap.
+       ExtensionNumber pp.ExtensionNumber
+       Payload         []byte
+}
index 7e3fa0f13425162595c5b3aeb5bd3b553123cde9..e2f1de31aa032061c14c1b6b23cc5ea287842c89 100644 (file)
--- a/client.go
+++ b/client.go
@@ -43,7 +43,6 @@ import (
        "github.com/anacrolix/torrent/metainfo"
        "github.com/anacrolix/torrent/mse"
        pp "github.com/anacrolix/torrent/peer_protocol"
-       utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch"
        request_strategy "github.com/anacrolix/torrent/request-strategy"
        "github.com/anacrolix/torrent/storage"
        "github.com/anacrolix/torrent/tracker"
@@ -90,6 +89,8 @@ type Client struct {
        httpClient            *http.Client
 
        clientHolepunchAddrSets
+
+       defaultLocalLtepProtocolMap LocalLtepProtocolMap
 }
 
 type ipStr string
@@ -214,6 +215,7 @@ func (cl *Client) init(cfg *ClientConfig) {
                        MaxConnsPerHost: 10,
                }
        }
+       cl.defaultLocalLtepProtocolMap = makeBuiltinLtepProtocols(!cfg.DisablePEX)
 }
 
 func NewClient(cfg *ClientConfig) (cl *Client, err error) {
@@ -221,9 +223,8 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                cfg = NewDefaultClientConfig()
                cfg.ListenPort = 0
        }
-       var client Client
-       client.init(cfg)
-       cl = &client
+       cl = &Client{}
+       cl.init(cfg)
        go cl.acceptLimitClearer()
        cl.initLogger()
        defer func() {
@@ -1089,6 +1090,10 @@ func (t *Torrent) runHandshookConn(pc *PeerConn) error {
                return fmt.Errorf("adding connection: %w", err)
        }
        defer t.dropConnection(pc)
+       pc.addBuiltinLtepProtocols(!cl.config.DisablePEX)
+       for _, cb := range pc.callbacks.PeerConnAdded {
+               cb(pc)
+       }
        pc.startMessageWriter()
        pc.sendInitialMessages()
        pc.initUpdateRequestsTimer()
@@ -1146,10 +1151,6 @@ func (pc *PeerConn) sendInitialMessages() {
                        ExtendedID: pp.HandshakeExtendedID,
                        ExtendedPayload: func() []byte {
                                msg := pp.ExtendedHandshakeMessage{
-                                       M: map[pp.ExtensionName]pp.ExtensionNumber{
-                                               pp.ExtensionNameMetadata:  metadataExtendedId,
-                                               utHolepunch.ExtensionName: utHolepunchExtendedId,
-                                       },
                                        V:            cl.config.ExtendedHandshakeClientVersion,
                                        Reqq:         localClientReqq,
                                        YourIp:       pp.CompactIp(pc.remoteIp()),
@@ -1160,9 +1161,7 @@ func (pc *PeerConn) sendInitialMessages() {
                                        Ipv4: pp.CompactIp(cl.config.PublicIp4.To4()),
                                        Ipv6: cl.config.PublicIp6.To16(),
                                }
-                               if !cl.config.DisablePEX {
-                                       msg.M[pp.ExtensionNamePex] = pexExtendedId
-                               }
+                               msg.M = pc.LocalLtepProtocolMap.toSupportedExtensionDict()
                                return bencode.MustMarshal(msg)
                        }(),
                })
index 5a5bddba0a6b3acfebea35d8f21d908a38ce3be0..585bbeafaa576e7663e1604f47ab7657ae45a8e3 100644 (file)
--- a/global.go
+++ b/global.go
@@ -19,14 +19,6 @@ const (
        maxMetadataSize uint32 = 16 * 1024 * 1024
 )
 
-// These are our extended message IDs. Peers will use these values to
-// select which extension a message is intended for.
-const (
-       metadataExtendedId = iota + 1 // 0 is reserved for deleting keys
-       pexExtendedId
-       utHolepunchExtendedId
-)
-
 func defaultPeerExtensionBytes() PeerExtensionBits {
        return pp.NewPeerExtensionBytes(pp.ExtensionBitDht, pp.ExtensionBitLtep, pp.ExtensionBitFast)
 }
diff --git a/ltep.go b/ltep.go
new file mode 100644 (file)
index 0000000..747d808
--- /dev/null
+++ b/ltep.go
@@ -0,0 +1,69 @@
+package torrent
+
+import (
+       "fmt"
+       "slices"
+
+       g "github.com/anacrolix/generics"
+       pp "github.com/anacrolix/torrent/peer_protocol"
+)
+
+type LocalLtepProtocolMap struct {
+       // 1-based mapping from extension number to extension name (subtract one from the extension ID
+       // to find the corresponding protocol name). The first LocalLtepProtocolBuiltinCount of these
+       // are use builtin handlers. If you want to handle builtin protocols yourself, you would move
+       // them above the threshold. You can disable them by removing them entirely, and add your own.
+       // These changes should be done in the PeerConnAdded callback.
+       Index []pp.ExtensionName
+       // How many of the protocols are using the builtin handlers.
+       NumBuiltin int
+}
+
+func (me *LocalLtepProtocolMap) toSupportedExtensionDict() (m map[pp.ExtensionName]pp.ExtensionNumber) {
+       g.MakeMapWithCap(&m, len(me.Index))
+       for i, name := range me.Index {
+               old := g.MapInsert(m, name, pp.ExtensionNumber(i+1))
+               if old.Ok {
+                       panic(fmt.Sprintf("extension %q already defined with id %v", name, old.Value))
+               }
+       }
+       return
+}
+
+// Returns the local extension name for the given ID. If builtin is true, the implementation intends
+// to handle it itself. For incoming messages with extension ID 0, the message is a handshake, and
+// should be treated specially.
+func (me *LocalLtepProtocolMap) LookupId(id pp.ExtensionNumber) (name pp.ExtensionName, builtin bool, err error) {
+       if id == 0 {
+               err = fmt.Errorf("extension ID 0 is handshake")
+               builtin = true
+               return
+       }
+       protocolIndex := int(id - 1)
+       if protocolIndex >= len(me.Index) {
+               err = fmt.Errorf("unexpected extended message ID: %v", id)
+               return
+       }
+       builtin = protocolIndex < me.NumBuiltin
+       name = me.Index[protocolIndex]
+       return
+}
+
+func (me *LocalLtepProtocolMap) builtin() []pp.ExtensionName {
+       return me.Index[:me.NumBuiltin]
+}
+
+func (me *LocalLtepProtocolMap) user() []pp.ExtensionName {
+       return me.Index[me.NumBuiltin:]
+}
+
+func (me *LocalLtepProtocolMap) AddUserProtocol(name pp.ExtensionName) {
+       builtin := slices.DeleteFunc(me.builtin(), func(delName pp.ExtensionName) bool {
+               return delName == name
+       })
+       user := slices.DeleteFunc(me.user(), func(delName pp.ExtensionName) bool {
+               return delName == name
+       })
+       me.Index = append(append(builtin, user...), name)
+       me.NumBuiltin = len(builtin)
+}
diff --git a/ltep_test.go b/ltep_test.go
new file mode 100644 (file)
index 0000000..c3c874c
--- /dev/null
@@ -0,0 +1,130 @@
+package torrent_test
+
+import (
+       "strconv"
+       "testing"
+
+       pp "github.com/anacrolix/torrent/peer_protocol"
+
+       qt "github.com/frankban/quicktest"
+
+       "github.com/anacrolix/torrent/internal/testutil"
+
+       "github.com/anacrolix/sync"
+
+       . "github.com/anacrolix/torrent"
+)
+
+const (
+       testRepliesToOddsExtensionName  = "pm_me_odds"
+       testRepliesToEvensExtensionName = "pm_me_evens"
+)
+
+func countHandler(
+       c *qt.C,
+       wg *sync.WaitGroup,
+       // Name of the endpoint that this handler is for, for logging.
+       handlerName string,
+       // Whether we expect evens or odds
+       expectedMod2 uint,
+       // Extension name of messages we expect to handle.
+       answerToName pp.ExtensionName,
+       // Extension name of messages we expect to send.
+       replyToName pp.ExtensionName,
+       // Signal done when this value is seen.
+       doneValue uint,
+) func(event PeerConnReadExtensionMessageEvent) {
+       return func(event PeerConnReadExtensionMessageEvent) {
+               // Read handshake, don't look it up.
+               if event.ExtensionNumber == 0 {
+                       return
+               }
+               name, builtin, err := event.PeerConn.LocalLtepProtocolMap.LookupId(event.ExtensionNumber)
+               c.Assert(err, qt.IsNil)
+               // Not a user protocol.
+               if builtin {
+                       return
+               }
+               switch name {
+               case answerToName:
+                       u64, err := strconv.ParseUint(string(event.Payload), 10, 0)
+                       c.Assert(err, qt.IsNil)
+                       i := uint(u64)
+                       c.Logf("%v got %d", handlerName, i)
+                       if i == doneValue {
+                               wg.Done()
+                               return
+                       }
+                       c.Assert(i%2, qt.Equals, expectedMod2)
+                       go func() {
+                               c.Assert(
+                                       event.PeerConn.WriteExtendedMessage(
+                                               replyToName,
+                                               []byte(strconv.FormatUint(uint64(i+1), 10))),
+                                       qt.IsNil)
+                       }()
+               default:
+                       c.Fatalf("got unexpected extension name %q", name)
+               }
+       }
+}
+
+func TestUserLtep(t *testing.T) {
+       c := qt.New(t)
+       var wg sync.WaitGroup
+
+       makeCfg := func() *ClientConfig {
+               cfg := TestingConfig(t)
+               // Only want a single connection to between the clients.
+               cfg.DisableUTP = true
+               cfg.DisableIPv6 = true
+               return cfg
+       }
+
+       evensCfg := makeCfg()
+       evensCfg.Callbacks.ReadExtendedHandshake = func(pc *PeerConn, msg *pp.ExtendedHandshakeMessage) {
+               // The client lock is held while handling this event, so we have to do synchronous work in a
+               // separate goroutine.
+               go func() {
+                       // Check sending an extended message for a protocol the peer doesn't support is an error.
+                       c.Check(pc.WriteExtendedMessage("pm_me_floats", []byte("3.142")), qt.IsNotNil)
+                       // Kick things off by sending a 1.
+                       c.Check(pc.WriteExtendedMessage(testRepliesToOddsExtensionName, []byte("1")), qt.IsNil)
+               }()
+       }
+       evensCfg.Callbacks.PeerConnReadExtensionMessage = append(
+               evensCfg.Callbacks.PeerConnReadExtensionMessage,
+               countHandler(c, &wg, "evens", 0, testRepliesToEvensExtensionName, testRepliesToOddsExtensionName, 100))
+       evensCfg.Callbacks.PeerConnAdded = append(evensCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
+               conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToEvensExtensionName)
+               c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
+       })
+
+       oddsCfg := makeCfg()
+       oddsCfg.Callbacks.PeerConnAdded = append(oddsCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
+               conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToOddsExtensionName)
+               c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
+       })
+       oddsCfg.Callbacks.PeerConnReadExtensionMessage = append(
+               oddsCfg.Callbacks.PeerConnReadExtensionMessage,
+               countHandler(c, &wg, "odds", 1, testRepliesToOddsExtensionName, testRepliesToEvensExtensionName, 100))
+
+       cl1, err := NewClient(oddsCfg)
+       c.Assert(err, qt.IsNil)
+       defer cl1.Close()
+       cl2, err := NewClient(evensCfg)
+       c.Assert(err, qt.IsNil)
+       defer cl2.Close()
+       addOpts := AddTorrentOpts{}
+       t1, _ := cl1.AddTorrentOpt(addOpts)
+       t2, _ := cl2.AddTorrentOpt(addOpts)
+       defer testutil.ExportStatusWriter(cl1, "cl1", t)()
+       defer testutil.ExportStatusWriter(cl2, "cl2", t)()
+       // Expect one PeerConn to see the value.
+       wg.Add(1)
+       added := t1.AddClientPeer(cl2)
+       // Ensure some addresses for the other client were added.
+       c.Assert(added, qt.Not(qt.Equals), 0)
+       wg.Wait()
+       _ = t2
+}
diff --git a/peer.go b/peer.go
index d3ea15161ab8261d5c0df48a6eddd4a05f9a2df0..37d7ca57f02914ccf59059ecff8a0fe570a91b2d 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -334,6 +334,13 @@ func (p *Peer) close() {
        }
 }
 
+func (p *Peer) Close() error {
+       p.locker().Lock()
+       defer p.locker().Unlock()
+       p.close()
+       return nil
+}
+
 // Peer definitely has a piece, for purposes of requesting. So it's not sufficient that we think
 // they do (known=true).
 func (cn *Peer) peerHasPiece(piece pieceIndex) bool {
index 8bc518163394ec8bbf619b135edf055c6f715b87..019590e40af215cb7d9938d6b6da03c5087756ea 100644 (file)
@@ -24,7 +24,7 @@ type (
        }
 
        ExtensionName   string
-       ExtensionNumber int
+       ExtensionNumber uint8
 )
 
 const (
index 00fffc93ed95ce9c90cad2d0b28bf47bb7d12b3b..a1e98b1282ac040780c8df05aadaf3bb42226ccc 100644 (file)
@@ -45,6 +45,12 @@ type PeerConn struct {
        PeerExtensionBytes pp.PeerExtensionBits
        PeerListenPort     int
 
+       // The local extended protocols to advertise in the extended handshake, and to support receiving
+       // from the peer. This will point to the Client default when the PeerConnAdded callback is
+       // invoked. Do not modify this, point it to your own instance. Do not modify the destination
+       // after returning from the callback.
+       LocalLtepProtocolMap *LocalLtepProtocolMap
+
        // The actual Conn, used for closing, and setting socket options. Do not use methods on this
        // while holding any mutexes.
        conn net.Conn
@@ -55,6 +61,7 @@ type PeerConn struct {
 
        messageWriter peerConnMsgWriter
 
+       // The peer's extension map, as sent in their extended handshake.
        PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber
        PeerClientName   atomic.Value
        uploadTimer      *time.Timer
@@ -877,17 +884,27 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
        }()
        t := c.t
        cl := t.cl
-       switch id {
-       case pp.HandshakeExtendedID:
+       {
+               event := PeerConnReadExtensionMessageEvent{
+                       PeerConn:        c,
+                       ExtensionNumber: id,
+                       Payload:         payload,
+               }
+               for _, cb := range c.callbacks.PeerConnReadExtensionMessage {
+                       cb(event)
+               }
+       }
+       if id == pp.HandshakeExtendedID {
                var d pp.ExtendedHandshakeMessage
                if err := bencode.Unmarshal(payload, &d); err != nil {
                        c.logger.Printf("error parsing extended handshake message %q: %s", payload, err)
                        return fmt.Errorf("unmarshalling extended handshake payload: %w", err)
                }
+               // Trigger this callback after it's been processed. If you want to handle it yourself, you
+               // should hook PeerConnReadExtensionMessage.
                if cb := c.callbacks.ReadExtendedHandshake; cb != nil {
                        cb(c, &d)
                }
-               // c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d))
                if d.Reqq != 0 {
                        c.PeerMaxRequests = d.Reqq
                }
@@ -919,13 +936,23 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
                        c.pex.Init(c)
                }
                return nil
-       case metadataExtendedId:
+       }
+       extensionName, builtin, err := c.LocalLtepProtocolMap.LookupId(id)
+       if err != nil {
+               return
+       }
+       if !builtin {
+               // User should have taken care of this in PeerConnReadExtensionMessage callback.
+               return nil
+       }
+       switch extensionName {
+       case pp.ExtensionNameMetadata:
                err := cl.gotMetadataExtensionMsg(payload, t, c)
                if err != nil {
                        return fmt.Errorf("handling metadata extension message: %w", err)
                }
                return nil
-       case pexExtendedId:
+       case pp.ExtensionNamePex:
                if !c.pex.IsEnabled() {
                        return nil // or hang-up maybe?
                }
@@ -934,7 +961,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
                        err = fmt.Errorf("receiving pex message: %w", err)
                }
                return
-       case utHolepunchExtendedId:
+       case utHolepunch.ExtensionName:
                var msg utHolepunch.Msg
                err = msg.UnmarshalBinary(payload)
                if err != nil {
@@ -944,7 +971,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
                err = c.t.handleReceivedUtHolepunchMsg(msg, c)
                return
        default:
-               return fmt.Errorf("unexpected extended message ID: %v", id)
+               panic(fmt.Sprintf("unhandled builtin extension protocol %q", extensionName))
        }
 }
 
@@ -1152,3 +1179,33 @@ func (c *PeerConn) useful() bool {
        }
        return false
 }
+
+func makeBuiltinLtepProtocols(pex bool) LocalLtepProtocolMap {
+       ps := []pp.ExtensionName{pp.ExtensionNameMetadata, utHolepunch.ExtensionName}
+       if pex {
+               ps = append(ps, pp.ExtensionNamePex)
+       }
+       return LocalLtepProtocolMap{
+               Index:      ps,
+               NumBuiltin: len(ps),
+       }
+}
+
+func (c *PeerConn) addBuiltinLtepProtocols(pex bool) {
+       c.LocalLtepProtocolMap = &c.t.cl.defaultLocalLtepProtocolMap
+}
+
+func (pc *PeerConn) WriteExtendedMessage(extName pp.ExtensionName, payload []byte) error {
+       pc.locker().Lock()
+       defer pc.locker().Unlock()
+       id := pc.PeerExtensionIDs[extName]
+       if id == 0 {
+               return fmt.Errorf("peer does not support or has disabled extension %q", extName)
+       }
+       pc.write(pp.Message{
+               Type:            pp.Extended,
+               ExtendedID:      id,
+               ExtendedPayload: payload,
+       })
+       return nil
+}
index f8b9c9e07ae162a7e384855fd96669976f4fa446..b8be73e8877676e3f8be78602921dbfb748b3f42 100644 (file)
@@ -22,7 +22,7 @@ func TestPexConnState(t *testing.T) {
                network:    addr.Network(),
        })
        c.PeerExtensionIDs = make(map[pp.ExtensionName]pp.ExtensionNumber)
-       c.PeerExtensionIDs[pp.ExtensionNamePex] = pexExtendedId
+       c.PeerExtensionIDs[pp.ExtensionNamePex] = 1
        c.messageWriter.mu.Lock()
        c.setTorrent(torrent)
        if err := torrent.addPeerConn(c); err != nil {
@@ -45,7 +45,8 @@ func TestPexConnState(t *testing.T) {
        c.pex.Share(testWriter)
        require.True(t, writerCalled)
        require.EqualValues(t, pp.Extended, out.Type)
-       require.EqualValues(t, pexExtendedId, out.ExtendedID)
+       require.NotEqualValues(t, out.ExtendedID, 0)
+       require.EqualValues(t, c.PeerExtensionIDs[pp.ExtensionNamePex], out.ExtendedID)
 
        x, err := pp.LoadPexMsg(out.ExtendedPayload)
        require.NoError(t, err)