type Callbacks struct {
// Called after a peer connection completes the BitTorrent handshake. The Client lock is not
// held.
- CompletedHandshake func(*PeerConn, InfoHash)
- ReadMessage func(*PeerConn, *pp.Message)
+ CompletedHandshake func(*PeerConn, InfoHash)
+ ReadMessage func(*PeerConn, *pp.Message)
+ // This can be folded into the general case below.
ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage)
PeerConnClosed func(*PeerConn)
+ // BEP 10 message. Not sure if I should call this Ltep universally. Each handler here is called
+ // in order.
+ PeerConnReadExtensionMessage []func(PeerConnReadExtensionMessageEvent)
// Provides secret keys to be tried against incoming encrypted connections.
ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter
SentRequest []func(PeerRequestEvent)
PeerClosed []func(*Peer)
NewPeer []func(*Peer)
+ // Called when a PeerConn has been added to a Torrent. It's finished all BitTorrent protocol
+ // handshakes, and is about to start sending and receiving BitTorrent messages. The extended
+ // handshake has not yet occurred. This is a good time to alter the supported extension
+ // protocols.
+ PeerConnAdded []func(*PeerConn)
}
type ReceivedUsefulDataEvent = PeerMessageEvent
Peer *Peer
Request
}
+
+type PeerConnReadExtensionMessageEvent struct {
+ PeerConn *PeerConn
+ // You can look up what protocol this corresponds to using the PeerConn.LocalLtepProtocolMap.
+ ExtensionNumber pp.ExtensionNumber
+ Payload []byte
+}
"github.com/anacrolix/torrent/metainfo"
"github.com/anacrolix/torrent/mse"
pp "github.com/anacrolix/torrent/peer_protocol"
- utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch"
request_strategy "github.com/anacrolix/torrent/request-strategy"
"github.com/anacrolix/torrent/storage"
"github.com/anacrolix/torrent/tracker"
httpClient *http.Client
clientHolepunchAddrSets
+
+ defaultLocalLtepProtocolMap LocalLtepProtocolMap
}
type ipStr string
MaxConnsPerHost: 10,
}
}
+ cl.defaultLocalLtepProtocolMap = makeBuiltinLtepProtocols(!cfg.DisablePEX)
}
func NewClient(cfg *ClientConfig) (cl *Client, err error) {
cfg = NewDefaultClientConfig()
cfg.ListenPort = 0
}
- var client Client
- client.init(cfg)
- cl = &client
+ cl = &Client{}
+ cl.init(cfg)
go cl.acceptLimitClearer()
cl.initLogger()
defer func() {
return fmt.Errorf("adding connection: %w", err)
}
defer t.dropConnection(pc)
+ pc.addBuiltinLtepProtocols(!cl.config.DisablePEX)
+ for _, cb := range pc.callbacks.PeerConnAdded {
+ cb(pc)
+ }
pc.startMessageWriter()
pc.sendInitialMessages()
pc.initUpdateRequestsTimer()
ExtendedID: pp.HandshakeExtendedID,
ExtendedPayload: func() []byte {
msg := pp.ExtendedHandshakeMessage{
- M: map[pp.ExtensionName]pp.ExtensionNumber{
- pp.ExtensionNameMetadata: metadataExtendedId,
- utHolepunch.ExtensionName: utHolepunchExtendedId,
- },
V: cl.config.ExtendedHandshakeClientVersion,
Reqq: localClientReqq,
YourIp: pp.CompactIp(pc.remoteIp()),
Ipv4: pp.CompactIp(cl.config.PublicIp4.To4()),
Ipv6: cl.config.PublicIp6.To16(),
}
- if !cl.config.DisablePEX {
- msg.M[pp.ExtensionNamePex] = pexExtendedId
- }
+ msg.M = pc.LocalLtepProtocolMap.toSupportedExtensionDict()
return bencode.MustMarshal(msg)
}(),
})
maxMetadataSize uint32 = 16 * 1024 * 1024
)
-// These are our extended message IDs. Peers will use these values to
-// select which extension a message is intended for.
-const (
- metadataExtendedId = iota + 1 // 0 is reserved for deleting keys
- pexExtendedId
- utHolepunchExtendedId
-)
-
func defaultPeerExtensionBytes() PeerExtensionBits {
return pp.NewPeerExtensionBytes(pp.ExtensionBitDht, pp.ExtensionBitLtep, pp.ExtensionBitFast)
}
--- /dev/null
+package torrent
+
+import (
+ "fmt"
+ "slices"
+
+ g "github.com/anacrolix/generics"
+ pp "github.com/anacrolix/torrent/peer_protocol"
+)
+
+type LocalLtepProtocolMap struct {
+ // 1-based mapping from extension number to extension name (subtract one from the extension ID
+ // to find the corresponding protocol name). The first LocalLtepProtocolBuiltinCount of these
+ // are use builtin handlers. If you want to handle builtin protocols yourself, you would move
+ // them above the threshold. You can disable them by removing them entirely, and add your own.
+ // These changes should be done in the PeerConnAdded callback.
+ Index []pp.ExtensionName
+ // How many of the protocols are using the builtin handlers.
+ NumBuiltin int
+}
+
+func (me *LocalLtepProtocolMap) toSupportedExtensionDict() (m map[pp.ExtensionName]pp.ExtensionNumber) {
+ g.MakeMapWithCap(&m, len(me.Index))
+ for i, name := range me.Index {
+ old := g.MapInsert(m, name, pp.ExtensionNumber(i+1))
+ if old.Ok {
+ panic(fmt.Sprintf("extension %q already defined with id %v", name, old.Value))
+ }
+ }
+ return
+}
+
+// Returns the local extension name for the given ID. If builtin is true, the implementation intends
+// to handle it itself. For incoming messages with extension ID 0, the message is a handshake, and
+// should be treated specially.
+func (me *LocalLtepProtocolMap) LookupId(id pp.ExtensionNumber) (name pp.ExtensionName, builtin bool, err error) {
+ if id == 0 {
+ err = fmt.Errorf("extension ID 0 is handshake")
+ builtin = true
+ return
+ }
+ protocolIndex := int(id - 1)
+ if protocolIndex >= len(me.Index) {
+ err = fmt.Errorf("unexpected extended message ID: %v", id)
+ return
+ }
+ builtin = protocolIndex < me.NumBuiltin
+ name = me.Index[protocolIndex]
+ return
+}
+
+func (me *LocalLtepProtocolMap) builtin() []pp.ExtensionName {
+ return me.Index[:me.NumBuiltin]
+}
+
+func (me *LocalLtepProtocolMap) user() []pp.ExtensionName {
+ return me.Index[me.NumBuiltin:]
+}
+
+func (me *LocalLtepProtocolMap) AddUserProtocol(name pp.ExtensionName) {
+ builtin := slices.DeleteFunc(me.builtin(), func(delName pp.ExtensionName) bool {
+ return delName == name
+ })
+ user := slices.DeleteFunc(me.user(), func(delName pp.ExtensionName) bool {
+ return delName == name
+ })
+ me.Index = append(append(builtin, user...), name)
+ me.NumBuiltin = len(builtin)
+}
--- /dev/null
+package torrent_test
+
+import (
+ "strconv"
+ "testing"
+
+ pp "github.com/anacrolix/torrent/peer_protocol"
+
+ qt "github.com/frankban/quicktest"
+
+ "github.com/anacrolix/torrent/internal/testutil"
+
+ "github.com/anacrolix/sync"
+
+ . "github.com/anacrolix/torrent"
+)
+
+const (
+ testRepliesToOddsExtensionName = "pm_me_odds"
+ testRepliesToEvensExtensionName = "pm_me_evens"
+)
+
+func countHandler(
+ c *qt.C,
+ wg *sync.WaitGroup,
+ // Name of the endpoint that this handler is for, for logging.
+ handlerName string,
+ // Whether we expect evens or odds
+ expectedMod2 uint,
+ // Extension name of messages we expect to handle.
+ answerToName pp.ExtensionName,
+ // Extension name of messages we expect to send.
+ replyToName pp.ExtensionName,
+ // Signal done when this value is seen.
+ doneValue uint,
+) func(event PeerConnReadExtensionMessageEvent) {
+ return func(event PeerConnReadExtensionMessageEvent) {
+ // Read handshake, don't look it up.
+ if event.ExtensionNumber == 0 {
+ return
+ }
+ name, builtin, err := event.PeerConn.LocalLtepProtocolMap.LookupId(event.ExtensionNumber)
+ c.Assert(err, qt.IsNil)
+ // Not a user protocol.
+ if builtin {
+ return
+ }
+ switch name {
+ case answerToName:
+ u64, err := strconv.ParseUint(string(event.Payload), 10, 0)
+ c.Assert(err, qt.IsNil)
+ i := uint(u64)
+ c.Logf("%v got %d", handlerName, i)
+ if i == doneValue {
+ wg.Done()
+ return
+ }
+ c.Assert(i%2, qt.Equals, expectedMod2)
+ go func() {
+ c.Assert(
+ event.PeerConn.WriteExtendedMessage(
+ replyToName,
+ []byte(strconv.FormatUint(uint64(i+1), 10))),
+ qt.IsNil)
+ }()
+ default:
+ c.Fatalf("got unexpected extension name %q", name)
+ }
+ }
+}
+
+func TestUserLtep(t *testing.T) {
+ c := qt.New(t)
+ var wg sync.WaitGroup
+
+ makeCfg := func() *ClientConfig {
+ cfg := TestingConfig(t)
+ // Only want a single connection to between the clients.
+ cfg.DisableUTP = true
+ cfg.DisableIPv6 = true
+ return cfg
+ }
+
+ evensCfg := makeCfg()
+ evensCfg.Callbacks.ReadExtendedHandshake = func(pc *PeerConn, msg *pp.ExtendedHandshakeMessage) {
+ // The client lock is held while handling this event, so we have to do synchronous work in a
+ // separate goroutine.
+ go func() {
+ // Check sending an extended message for a protocol the peer doesn't support is an error.
+ c.Check(pc.WriteExtendedMessage("pm_me_floats", []byte("3.142")), qt.IsNotNil)
+ // Kick things off by sending a 1.
+ c.Check(pc.WriteExtendedMessage(testRepliesToOddsExtensionName, []byte("1")), qt.IsNil)
+ }()
+ }
+ evensCfg.Callbacks.PeerConnReadExtensionMessage = append(
+ evensCfg.Callbacks.PeerConnReadExtensionMessage,
+ countHandler(c, &wg, "evens", 0, testRepliesToEvensExtensionName, testRepliesToOddsExtensionName, 100))
+ evensCfg.Callbacks.PeerConnAdded = append(evensCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
+ conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToEvensExtensionName)
+ c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
+ })
+
+ oddsCfg := makeCfg()
+ oddsCfg.Callbacks.PeerConnAdded = append(oddsCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
+ conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToOddsExtensionName)
+ c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
+ })
+ oddsCfg.Callbacks.PeerConnReadExtensionMessage = append(
+ oddsCfg.Callbacks.PeerConnReadExtensionMessage,
+ countHandler(c, &wg, "odds", 1, testRepliesToOddsExtensionName, testRepliesToEvensExtensionName, 100))
+
+ cl1, err := NewClient(oddsCfg)
+ c.Assert(err, qt.IsNil)
+ defer cl1.Close()
+ cl2, err := NewClient(evensCfg)
+ c.Assert(err, qt.IsNil)
+ defer cl2.Close()
+ addOpts := AddTorrentOpts{}
+ t1, _ := cl1.AddTorrentOpt(addOpts)
+ t2, _ := cl2.AddTorrentOpt(addOpts)
+ defer testutil.ExportStatusWriter(cl1, "cl1", t)()
+ defer testutil.ExportStatusWriter(cl2, "cl2", t)()
+ // Expect one PeerConn to see the value.
+ wg.Add(1)
+ added := t1.AddClientPeer(cl2)
+ // Ensure some addresses for the other client were added.
+ c.Assert(added, qt.Not(qt.Equals), 0)
+ wg.Wait()
+ _ = t2
+}
}
}
+func (p *Peer) Close() error {
+ p.locker().Lock()
+ defer p.locker().Unlock()
+ p.close()
+ return nil
+}
+
// Peer definitely has a piece, for purposes of requesting. So it's not sufficient that we think
// they do (known=true).
func (cn *Peer) peerHasPiece(piece pieceIndex) bool {
}
ExtensionName string
- ExtensionNumber int
+ ExtensionNumber uint8
)
const (
PeerExtensionBytes pp.PeerExtensionBits
PeerListenPort int
+ // The local extended protocols to advertise in the extended handshake, and to support receiving
+ // from the peer. This will point to the Client default when the PeerConnAdded callback is
+ // invoked. Do not modify this, point it to your own instance. Do not modify the destination
+ // after returning from the callback.
+ LocalLtepProtocolMap *LocalLtepProtocolMap
+
// The actual Conn, used for closing, and setting socket options. Do not use methods on this
// while holding any mutexes.
conn net.Conn
messageWriter peerConnMsgWriter
+ // The peer's extension map, as sent in their extended handshake.
PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber
PeerClientName atomic.Value
uploadTimer *time.Timer
}()
t := c.t
cl := t.cl
- switch id {
- case pp.HandshakeExtendedID:
+ {
+ event := PeerConnReadExtensionMessageEvent{
+ PeerConn: c,
+ ExtensionNumber: id,
+ Payload: payload,
+ }
+ for _, cb := range c.callbacks.PeerConnReadExtensionMessage {
+ cb(event)
+ }
+ }
+ if id == pp.HandshakeExtendedID {
var d pp.ExtendedHandshakeMessage
if err := bencode.Unmarshal(payload, &d); err != nil {
c.logger.Printf("error parsing extended handshake message %q: %s", payload, err)
return fmt.Errorf("unmarshalling extended handshake payload: %w", err)
}
+ // Trigger this callback after it's been processed. If you want to handle it yourself, you
+ // should hook PeerConnReadExtensionMessage.
if cb := c.callbacks.ReadExtendedHandshake; cb != nil {
cb(c, &d)
}
- // c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d))
if d.Reqq != 0 {
c.PeerMaxRequests = d.Reqq
}
c.pex.Init(c)
}
return nil
- case metadataExtendedId:
+ }
+ extensionName, builtin, err := c.LocalLtepProtocolMap.LookupId(id)
+ if err != nil {
+ return
+ }
+ if !builtin {
+ // User should have taken care of this in PeerConnReadExtensionMessage callback.
+ return nil
+ }
+ switch extensionName {
+ case pp.ExtensionNameMetadata:
err := cl.gotMetadataExtensionMsg(payload, t, c)
if err != nil {
return fmt.Errorf("handling metadata extension message: %w", err)
}
return nil
- case pexExtendedId:
+ case pp.ExtensionNamePex:
if !c.pex.IsEnabled() {
return nil // or hang-up maybe?
}
err = fmt.Errorf("receiving pex message: %w", err)
}
return
- case utHolepunchExtendedId:
+ case utHolepunch.ExtensionName:
var msg utHolepunch.Msg
err = msg.UnmarshalBinary(payload)
if err != nil {
err = c.t.handleReceivedUtHolepunchMsg(msg, c)
return
default:
- return fmt.Errorf("unexpected extended message ID: %v", id)
+ panic(fmt.Sprintf("unhandled builtin extension protocol %q", extensionName))
}
}
}
return false
}
+
+func makeBuiltinLtepProtocols(pex bool) LocalLtepProtocolMap {
+ ps := []pp.ExtensionName{pp.ExtensionNameMetadata, utHolepunch.ExtensionName}
+ if pex {
+ ps = append(ps, pp.ExtensionNamePex)
+ }
+ return LocalLtepProtocolMap{
+ Index: ps,
+ NumBuiltin: len(ps),
+ }
+}
+
+func (c *PeerConn) addBuiltinLtepProtocols(pex bool) {
+ c.LocalLtepProtocolMap = &c.t.cl.defaultLocalLtepProtocolMap
+}
+
+func (pc *PeerConn) WriteExtendedMessage(extName pp.ExtensionName, payload []byte) error {
+ pc.locker().Lock()
+ defer pc.locker().Unlock()
+ id := pc.PeerExtensionIDs[extName]
+ if id == 0 {
+ return fmt.Errorf("peer does not support or has disabled extension %q", extName)
+ }
+ pc.write(pp.Message{
+ Type: pp.Extended,
+ ExtendedID: id,
+ ExtendedPayload: payload,
+ })
+ return nil
+}
network: addr.Network(),
})
c.PeerExtensionIDs = make(map[pp.ExtensionName]pp.ExtensionNumber)
- c.PeerExtensionIDs[pp.ExtensionNamePex] = pexExtendedId
+ c.PeerExtensionIDs[pp.ExtensionNamePex] = 1
c.messageWriter.mu.Lock()
c.setTorrent(torrent)
if err := torrent.addPeerConn(c); err != nil {
c.pex.Share(testWriter)
require.True(t, writerCalled)
require.EqualValues(t, pp.Extended, out.Type)
- require.EqualValues(t, pexExtendedId, out.ExtendedID)
+ require.NotEqualValues(t, out.ExtendedID, 0)
+ require.EqualValues(t, c.PeerExtensionIDs[pp.ExtensionNamePex], out.ExtendedID)
x, err := pp.LoadPexMsg(out.ExtendedPayload)
require.NoError(t, err)