From: Matt Joiner Date: Thu, 22 Feb 2024 03:36:47 +0000 (+1100) Subject: Add low level support for BEP 10 user protocols X-Git-Tag: v1.55.0~19 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=8605abc771421206a4cae6bd39f6ca049030d19a;p=btrtrc.git Add low level support for BEP 10 user protocols --- diff --git a/callbacks.go b/callbacks.go index f9ba131b..0c66bc50 100644 --- a/callbacks.go +++ b/callbacks.go @@ -11,10 +11,14 @@ import ( type Callbacks struct { // Called after a peer connection completes the BitTorrent handshake. The Client lock is not // held. - CompletedHandshake func(*PeerConn, InfoHash) - ReadMessage func(*PeerConn, *pp.Message) + CompletedHandshake func(*PeerConn, InfoHash) + ReadMessage func(*PeerConn, *pp.Message) + // This can be folded into the general case below. ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage) PeerConnClosed func(*PeerConn) + // BEP 10 message. Not sure if I should call this Ltep universally. Each handler here is called + // in order. + PeerConnReadExtensionMessage []func(PeerConnReadExtensionMessageEvent) // Provides secret keys to be tried against incoming encrypted connections. ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter @@ -25,6 +29,11 @@ type Callbacks struct { SentRequest []func(PeerRequestEvent) PeerClosed []func(*Peer) NewPeer []func(*Peer) + // Called when a PeerConn has been added to a Torrent. It's finished all BitTorrent protocol + // handshakes, and is about to start sending and receiving BitTorrent messages. The extended + // handshake has not yet occurred. This is a good time to alter the supported extension + // protocols. + PeerConnAdded []func(*PeerConn) } type ReceivedUsefulDataEvent = PeerMessageEvent @@ -38,3 +47,10 @@ type PeerRequestEvent struct { Peer *Peer Request } + +type PeerConnReadExtensionMessageEvent struct { + PeerConn *PeerConn + // You can look up what protocol this corresponds to using the PeerConn.LocalLtepProtocolMap. + ExtensionNumber pp.ExtensionNumber + Payload []byte +} diff --git a/client.go b/client.go index 7e3fa0f1..e2f1de31 100644 --- a/client.go +++ b/client.go @@ -43,7 +43,6 @@ import ( "github.com/anacrolix/torrent/metainfo" "github.com/anacrolix/torrent/mse" pp "github.com/anacrolix/torrent/peer_protocol" - utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch" request_strategy "github.com/anacrolix/torrent/request-strategy" "github.com/anacrolix/torrent/storage" "github.com/anacrolix/torrent/tracker" @@ -90,6 +89,8 @@ type Client struct { httpClient *http.Client clientHolepunchAddrSets + + defaultLocalLtepProtocolMap LocalLtepProtocolMap } type ipStr string @@ -214,6 +215,7 @@ func (cl *Client) init(cfg *ClientConfig) { MaxConnsPerHost: 10, } } + cl.defaultLocalLtepProtocolMap = makeBuiltinLtepProtocols(!cfg.DisablePEX) } func NewClient(cfg *ClientConfig) (cl *Client, err error) { @@ -221,9 +223,8 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) { cfg = NewDefaultClientConfig() cfg.ListenPort = 0 } - var client Client - client.init(cfg) - cl = &client + cl = &Client{} + cl.init(cfg) go cl.acceptLimitClearer() cl.initLogger() defer func() { @@ -1089,6 +1090,10 @@ func (t *Torrent) runHandshookConn(pc *PeerConn) error { return fmt.Errorf("adding connection: %w", err) } defer t.dropConnection(pc) + pc.addBuiltinLtepProtocols(!cl.config.DisablePEX) + for _, cb := range pc.callbacks.PeerConnAdded { + cb(pc) + } pc.startMessageWriter() pc.sendInitialMessages() pc.initUpdateRequestsTimer() @@ -1146,10 +1151,6 @@ func (pc *PeerConn) sendInitialMessages() { ExtendedID: pp.HandshakeExtendedID, ExtendedPayload: func() []byte { msg := pp.ExtendedHandshakeMessage{ - M: map[pp.ExtensionName]pp.ExtensionNumber{ - pp.ExtensionNameMetadata: metadataExtendedId, - utHolepunch.ExtensionName: utHolepunchExtendedId, - }, V: cl.config.ExtendedHandshakeClientVersion, Reqq: localClientReqq, YourIp: pp.CompactIp(pc.remoteIp()), @@ -1160,9 +1161,7 @@ func (pc *PeerConn) sendInitialMessages() { Ipv4: pp.CompactIp(cl.config.PublicIp4.To4()), Ipv6: cl.config.PublicIp6.To16(), } - if !cl.config.DisablePEX { - msg.M[pp.ExtensionNamePex] = pexExtendedId - } + msg.M = pc.LocalLtepProtocolMap.toSupportedExtensionDict() return bencode.MustMarshal(msg) }(), }) diff --git a/global.go b/global.go index 5a5bddba..585bbeaf 100644 --- a/global.go +++ b/global.go @@ -19,14 +19,6 @@ const ( maxMetadataSize uint32 = 16 * 1024 * 1024 ) -// These are our extended message IDs. Peers will use these values to -// select which extension a message is intended for. -const ( - metadataExtendedId = iota + 1 // 0 is reserved for deleting keys - pexExtendedId - utHolepunchExtendedId -) - func defaultPeerExtensionBytes() PeerExtensionBits { return pp.NewPeerExtensionBytes(pp.ExtensionBitDht, pp.ExtensionBitLtep, pp.ExtensionBitFast) } diff --git a/ltep.go b/ltep.go new file mode 100644 index 00000000..747d8086 --- /dev/null +++ b/ltep.go @@ -0,0 +1,69 @@ +package torrent + +import ( + "fmt" + "slices" + + g "github.com/anacrolix/generics" + pp "github.com/anacrolix/torrent/peer_protocol" +) + +type LocalLtepProtocolMap struct { + // 1-based mapping from extension number to extension name (subtract one from the extension ID + // to find the corresponding protocol name). The first LocalLtepProtocolBuiltinCount of these + // are use builtin handlers. If you want to handle builtin protocols yourself, you would move + // them above the threshold. You can disable them by removing them entirely, and add your own. + // These changes should be done in the PeerConnAdded callback. + Index []pp.ExtensionName + // How many of the protocols are using the builtin handlers. + NumBuiltin int +} + +func (me *LocalLtepProtocolMap) toSupportedExtensionDict() (m map[pp.ExtensionName]pp.ExtensionNumber) { + g.MakeMapWithCap(&m, len(me.Index)) + for i, name := range me.Index { + old := g.MapInsert(m, name, pp.ExtensionNumber(i+1)) + if old.Ok { + panic(fmt.Sprintf("extension %q already defined with id %v", name, old.Value)) + } + } + return +} + +// Returns the local extension name for the given ID. If builtin is true, the implementation intends +// to handle it itself. For incoming messages with extension ID 0, the message is a handshake, and +// should be treated specially. +func (me *LocalLtepProtocolMap) LookupId(id pp.ExtensionNumber) (name pp.ExtensionName, builtin bool, err error) { + if id == 0 { + err = fmt.Errorf("extension ID 0 is handshake") + builtin = true + return + } + protocolIndex := int(id - 1) + if protocolIndex >= len(me.Index) { + err = fmt.Errorf("unexpected extended message ID: %v", id) + return + } + builtin = protocolIndex < me.NumBuiltin + name = me.Index[protocolIndex] + return +} + +func (me *LocalLtepProtocolMap) builtin() []pp.ExtensionName { + return me.Index[:me.NumBuiltin] +} + +func (me *LocalLtepProtocolMap) user() []pp.ExtensionName { + return me.Index[me.NumBuiltin:] +} + +func (me *LocalLtepProtocolMap) AddUserProtocol(name pp.ExtensionName) { + builtin := slices.DeleteFunc(me.builtin(), func(delName pp.ExtensionName) bool { + return delName == name + }) + user := slices.DeleteFunc(me.user(), func(delName pp.ExtensionName) bool { + return delName == name + }) + me.Index = append(append(builtin, user...), name) + me.NumBuiltin = len(builtin) +} diff --git a/ltep_test.go b/ltep_test.go new file mode 100644 index 00000000..c3c874ce --- /dev/null +++ b/ltep_test.go @@ -0,0 +1,130 @@ +package torrent_test + +import ( + "strconv" + "testing" + + pp "github.com/anacrolix/torrent/peer_protocol" + + qt "github.com/frankban/quicktest" + + "github.com/anacrolix/torrent/internal/testutil" + + "github.com/anacrolix/sync" + + . "github.com/anacrolix/torrent" +) + +const ( + testRepliesToOddsExtensionName = "pm_me_odds" + testRepliesToEvensExtensionName = "pm_me_evens" +) + +func countHandler( + c *qt.C, + wg *sync.WaitGroup, + // Name of the endpoint that this handler is for, for logging. + handlerName string, + // Whether we expect evens or odds + expectedMod2 uint, + // Extension name of messages we expect to handle. + answerToName pp.ExtensionName, + // Extension name of messages we expect to send. + replyToName pp.ExtensionName, + // Signal done when this value is seen. + doneValue uint, +) func(event PeerConnReadExtensionMessageEvent) { + return func(event PeerConnReadExtensionMessageEvent) { + // Read handshake, don't look it up. + if event.ExtensionNumber == 0 { + return + } + name, builtin, err := event.PeerConn.LocalLtepProtocolMap.LookupId(event.ExtensionNumber) + c.Assert(err, qt.IsNil) + // Not a user protocol. + if builtin { + return + } + switch name { + case answerToName: + u64, err := strconv.ParseUint(string(event.Payload), 10, 0) + c.Assert(err, qt.IsNil) + i := uint(u64) + c.Logf("%v got %d", handlerName, i) + if i == doneValue { + wg.Done() + return + } + c.Assert(i%2, qt.Equals, expectedMod2) + go func() { + c.Assert( + event.PeerConn.WriteExtendedMessage( + replyToName, + []byte(strconv.FormatUint(uint64(i+1), 10))), + qt.IsNil) + }() + default: + c.Fatalf("got unexpected extension name %q", name) + } + } +} + +func TestUserLtep(t *testing.T) { + c := qt.New(t) + var wg sync.WaitGroup + + makeCfg := func() *ClientConfig { + cfg := TestingConfig(t) + // Only want a single connection to between the clients. + cfg.DisableUTP = true + cfg.DisableIPv6 = true + return cfg + } + + evensCfg := makeCfg() + evensCfg.Callbacks.ReadExtendedHandshake = func(pc *PeerConn, msg *pp.ExtendedHandshakeMessage) { + // The client lock is held while handling this event, so we have to do synchronous work in a + // separate goroutine. + go func() { + // Check sending an extended message for a protocol the peer doesn't support is an error. + c.Check(pc.WriteExtendedMessage("pm_me_floats", []byte("3.142")), qt.IsNotNil) + // Kick things off by sending a 1. + c.Check(pc.WriteExtendedMessage(testRepliesToOddsExtensionName, []byte("1")), qt.IsNil) + }() + } + evensCfg.Callbacks.PeerConnReadExtensionMessage = append( + evensCfg.Callbacks.PeerConnReadExtensionMessage, + countHandler(c, &wg, "evens", 0, testRepliesToEvensExtensionName, testRepliesToOddsExtensionName, 100)) + evensCfg.Callbacks.PeerConnAdded = append(evensCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) { + conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToEvensExtensionName) + c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1) + }) + + oddsCfg := makeCfg() + oddsCfg.Callbacks.PeerConnAdded = append(oddsCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) { + conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToOddsExtensionName) + c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1) + }) + oddsCfg.Callbacks.PeerConnReadExtensionMessage = append( + oddsCfg.Callbacks.PeerConnReadExtensionMessage, + countHandler(c, &wg, "odds", 1, testRepliesToOddsExtensionName, testRepliesToEvensExtensionName, 100)) + + cl1, err := NewClient(oddsCfg) + c.Assert(err, qt.IsNil) + defer cl1.Close() + cl2, err := NewClient(evensCfg) + c.Assert(err, qt.IsNil) + defer cl2.Close() + addOpts := AddTorrentOpts{} + t1, _ := cl1.AddTorrentOpt(addOpts) + t2, _ := cl2.AddTorrentOpt(addOpts) + defer testutil.ExportStatusWriter(cl1, "cl1", t)() + defer testutil.ExportStatusWriter(cl2, "cl2", t)() + // Expect one PeerConn to see the value. + wg.Add(1) + added := t1.AddClientPeer(cl2) + // Ensure some addresses for the other client were added. + c.Assert(added, qt.Not(qt.Equals), 0) + wg.Wait() + _ = t2 +} diff --git a/peer.go b/peer.go index d3ea1516..37d7ca57 100644 --- a/peer.go +++ b/peer.go @@ -334,6 +334,13 @@ func (p *Peer) close() { } } +func (p *Peer) Close() error { + p.locker().Lock() + defer p.locker().Unlock() + p.close() + return nil +} + // Peer definitely has a piece, for purposes of requesting. So it's not sufficient that we think // they do (known=true). func (cn *Peer) peerHasPiece(piece pieceIndex) bool { diff --git a/peer_protocol/extended.go b/peer_protocol/extended.go index 8bc51816..019590e4 100644 --- a/peer_protocol/extended.go +++ b/peer_protocol/extended.go @@ -24,7 +24,7 @@ type ( } ExtensionName string - ExtensionNumber int + ExtensionNumber uint8 ) const ( diff --git a/peerconn.go b/peerconn.go index 00fffc93..a1e98b12 100644 --- a/peerconn.go +++ b/peerconn.go @@ -45,6 +45,12 @@ type PeerConn struct { PeerExtensionBytes pp.PeerExtensionBits PeerListenPort int + // The local extended protocols to advertise in the extended handshake, and to support receiving + // from the peer. This will point to the Client default when the PeerConnAdded callback is + // invoked. Do not modify this, point it to your own instance. Do not modify the destination + // after returning from the callback. + LocalLtepProtocolMap *LocalLtepProtocolMap + // The actual Conn, used for closing, and setting socket options. Do not use methods on this // while holding any mutexes. conn net.Conn @@ -55,6 +61,7 @@ type PeerConn struct { messageWriter peerConnMsgWriter + // The peer's extension map, as sent in their extended handshake. PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber PeerClientName atomic.Value uploadTimer *time.Timer @@ -877,17 +884,27 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err }() t := c.t cl := t.cl - switch id { - case pp.HandshakeExtendedID: + { + event := PeerConnReadExtensionMessageEvent{ + PeerConn: c, + ExtensionNumber: id, + Payload: payload, + } + for _, cb := range c.callbacks.PeerConnReadExtensionMessage { + cb(event) + } + } + if id == pp.HandshakeExtendedID { var d pp.ExtendedHandshakeMessage if err := bencode.Unmarshal(payload, &d); err != nil { c.logger.Printf("error parsing extended handshake message %q: %s", payload, err) return fmt.Errorf("unmarshalling extended handshake payload: %w", err) } + // Trigger this callback after it's been processed. If you want to handle it yourself, you + // should hook PeerConnReadExtensionMessage. if cb := c.callbacks.ReadExtendedHandshake; cb != nil { cb(c, &d) } - // c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d)) if d.Reqq != 0 { c.PeerMaxRequests = d.Reqq } @@ -919,13 +936,23 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err c.pex.Init(c) } return nil - case metadataExtendedId: + } + extensionName, builtin, err := c.LocalLtepProtocolMap.LookupId(id) + if err != nil { + return + } + if !builtin { + // User should have taken care of this in PeerConnReadExtensionMessage callback. + return nil + } + switch extensionName { + case pp.ExtensionNameMetadata: err := cl.gotMetadataExtensionMsg(payload, t, c) if err != nil { return fmt.Errorf("handling metadata extension message: %w", err) } return nil - case pexExtendedId: + case pp.ExtensionNamePex: if !c.pex.IsEnabled() { return nil // or hang-up maybe? } @@ -934,7 +961,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err err = fmt.Errorf("receiving pex message: %w", err) } return - case utHolepunchExtendedId: + case utHolepunch.ExtensionName: var msg utHolepunch.Msg err = msg.UnmarshalBinary(payload) if err != nil { @@ -944,7 +971,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err err = c.t.handleReceivedUtHolepunchMsg(msg, c) return default: - return fmt.Errorf("unexpected extended message ID: %v", id) + panic(fmt.Sprintf("unhandled builtin extension protocol %q", extensionName)) } } @@ -1152,3 +1179,33 @@ func (c *PeerConn) useful() bool { } return false } + +func makeBuiltinLtepProtocols(pex bool) LocalLtepProtocolMap { + ps := []pp.ExtensionName{pp.ExtensionNameMetadata, utHolepunch.ExtensionName} + if pex { + ps = append(ps, pp.ExtensionNamePex) + } + return LocalLtepProtocolMap{ + Index: ps, + NumBuiltin: len(ps), + } +} + +func (c *PeerConn) addBuiltinLtepProtocols(pex bool) { + c.LocalLtepProtocolMap = &c.t.cl.defaultLocalLtepProtocolMap +} + +func (pc *PeerConn) WriteExtendedMessage(extName pp.ExtensionName, payload []byte) error { + pc.locker().Lock() + defer pc.locker().Unlock() + id := pc.PeerExtensionIDs[extName] + if id == 0 { + return fmt.Errorf("peer does not support or has disabled extension %q", extName) + } + pc.write(pp.Message{ + Type: pp.Extended, + ExtendedID: id, + ExtendedPayload: payload, + }) + return nil +} diff --git a/pexconn_test.go b/pexconn_test.go index f8b9c9e0..b8be73e8 100644 --- a/pexconn_test.go +++ b/pexconn_test.go @@ -22,7 +22,7 @@ func TestPexConnState(t *testing.T) { network: addr.Network(), }) c.PeerExtensionIDs = make(map[pp.ExtensionName]pp.ExtensionNumber) - c.PeerExtensionIDs[pp.ExtensionNamePex] = pexExtendedId + c.PeerExtensionIDs[pp.ExtensionNamePex] = 1 c.messageWriter.mu.Lock() c.setTorrent(torrent) if err := torrent.addPeerConn(c); err != nil { @@ -45,7 +45,8 @@ func TestPexConnState(t *testing.T) { c.pex.Share(testWriter) require.True(t, writerCalled) require.EqualValues(t, pp.Extended, out.Type) - require.EqualValues(t, pexExtendedId, out.ExtendedID) + require.NotEqualValues(t, out.ExtendedID, 0) + require.EqualValues(t, c.PeerExtensionIDs[pp.ExtensionNamePex], out.ExtendedID) x, err := pp.LoadPexMsg(out.ExtendedPayload) require.NoError(t, err)