]> Sergey Matveev's repositories - btrtrc.git/blob - ltep_test.go
Add low level support for BEP 10 user protocols
[btrtrc.git] / ltep_test.go
1 package torrent_test
2
3 import (
4         "strconv"
5         "testing"
6
7         pp "github.com/anacrolix/torrent/peer_protocol"
8
9         qt "github.com/frankban/quicktest"
10
11         "github.com/anacrolix/torrent/internal/testutil"
12
13         "github.com/anacrolix/sync"
14
15         . "github.com/anacrolix/torrent"
16 )
17
18 const (
19         testRepliesToOddsExtensionName  = "pm_me_odds"
20         testRepliesToEvensExtensionName = "pm_me_evens"
21 )
22
23 func countHandler(
24         c *qt.C,
25         wg *sync.WaitGroup,
26         // Name of the endpoint that this handler is for, for logging.
27         handlerName string,
28         // Whether we expect evens or odds
29         expectedMod2 uint,
30         // Extension name of messages we expect to handle.
31         answerToName pp.ExtensionName,
32         // Extension name of messages we expect to send.
33         replyToName pp.ExtensionName,
34         // Signal done when this value is seen.
35         doneValue uint,
36 ) func(event PeerConnReadExtensionMessageEvent) {
37         return func(event PeerConnReadExtensionMessageEvent) {
38                 // Read handshake, don't look it up.
39                 if event.ExtensionNumber == 0 {
40                         return
41                 }
42                 name, builtin, err := event.PeerConn.LocalLtepProtocolMap.LookupId(event.ExtensionNumber)
43                 c.Assert(err, qt.IsNil)
44                 // Not a user protocol.
45                 if builtin {
46                         return
47                 }
48                 switch name {
49                 case answerToName:
50                         u64, err := strconv.ParseUint(string(event.Payload), 10, 0)
51                         c.Assert(err, qt.IsNil)
52                         i := uint(u64)
53                         c.Logf("%v got %d", handlerName, i)
54                         if i == doneValue {
55                                 wg.Done()
56                                 return
57                         }
58                         c.Assert(i%2, qt.Equals, expectedMod2)
59                         go func() {
60                                 c.Assert(
61                                         event.PeerConn.WriteExtendedMessage(
62                                                 replyToName,
63                                                 []byte(strconv.FormatUint(uint64(i+1), 10))),
64                                         qt.IsNil)
65                         }()
66                 default:
67                         c.Fatalf("got unexpected extension name %q", name)
68                 }
69         }
70 }
71
72 func TestUserLtep(t *testing.T) {
73         c := qt.New(t)
74         var wg sync.WaitGroup
75
76         makeCfg := func() *ClientConfig {
77                 cfg := TestingConfig(t)
78                 // Only want a single connection to between the clients.
79                 cfg.DisableUTP = true
80                 cfg.DisableIPv6 = true
81                 return cfg
82         }
83
84         evensCfg := makeCfg()
85         evensCfg.Callbacks.ReadExtendedHandshake = func(pc *PeerConn, msg *pp.ExtendedHandshakeMessage) {
86                 // The client lock is held while handling this event, so we have to do synchronous work in a
87                 // separate goroutine.
88                 go func() {
89                         // Check sending an extended message for a protocol the peer doesn't support is an error.
90                         c.Check(pc.WriteExtendedMessage("pm_me_floats", []byte("3.142")), qt.IsNotNil)
91                         // Kick things off by sending a 1.
92                         c.Check(pc.WriteExtendedMessage(testRepliesToOddsExtensionName, []byte("1")), qt.IsNil)
93                 }()
94         }
95         evensCfg.Callbacks.PeerConnReadExtensionMessage = append(
96                 evensCfg.Callbacks.PeerConnReadExtensionMessage,
97                 countHandler(c, &wg, "evens", 0, testRepliesToEvensExtensionName, testRepliesToOddsExtensionName, 100))
98         evensCfg.Callbacks.PeerConnAdded = append(evensCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
99                 conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToEvensExtensionName)
100                 c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
101         })
102
103         oddsCfg := makeCfg()
104         oddsCfg.Callbacks.PeerConnAdded = append(oddsCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
105                 conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToOddsExtensionName)
106                 c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
107         })
108         oddsCfg.Callbacks.PeerConnReadExtensionMessage = append(
109                 oddsCfg.Callbacks.PeerConnReadExtensionMessage,
110                 countHandler(c, &wg, "odds", 1, testRepliesToOddsExtensionName, testRepliesToEvensExtensionName, 100))
111
112         cl1, err := NewClient(oddsCfg)
113         c.Assert(err, qt.IsNil)
114         defer cl1.Close()
115         cl2, err := NewClient(evensCfg)
116         c.Assert(err, qt.IsNil)
117         defer cl2.Close()
118         addOpts := AddTorrentOpts{}
119         t1, _ := cl1.AddTorrentOpt(addOpts)
120         t2, _ := cl2.AddTorrentOpt(addOpts)
121         defer testutil.ExportStatusWriter(cl1, "cl1", t)()
122         defer testutil.ExportStatusWriter(cl2, "cl2", t)()
123         // Expect one PeerConn to see the value.
124         wg.Add(1)
125         added := t1.AddClientPeer(cl2)
126         // Ensure some addresses for the other client were added.
127         c.Assert(added, qt.Not(qt.Equals), 0)
128         wg.Wait()
129         _ = t2
130 }