]> Sergey Matveev's repositories - btrtrc.git/blob - ltep_test.go
cmd/btrtrc client
[btrtrc.git] / ltep_test.go
1 package torrent_test
2
3 import (
4         "math/rand"
5         "strconv"
6         "testing"
7
8         "github.com/anacrolix/sync"
9         qt "github.com/frankban/quicktest"
10
11         . "github.com/anacrolix/torrent"
12         "github.com/anacrolix/torrent/internal/testutil"
13         pp "github.com/anacrolix/torrent/peer_protocol"
14 )
15
16 const (
17         testRepliesToOddsExtensionName  = "pm_me_odds"
18         testRepliesToEvensExtensionName = "pm_me_evens"
19 )
20
21 func countHandler(
22         c *qt.C,
23         wg *sync.WaitGroup,
24         // Name of the endpoint that this handler is for, for logging.
25         handlerName string,
26         // Whether we expect evens or odds
27         expectedMod2 uint,
28         // Extension name of messages we expect to handle.
29         answerToName pp.ExtensionName,
30         // Extension name of messages we expect to send.
31         replyToName pp.ExtensionName,
32         // Signal done when this value is seen.
33         doneValue uint,
34 ) func(event PeerConnReadExtensionMessageEvent) {
35         return func(event PeerConnReadExtensionMessageEvent) {
36                 // Read handshake, don't look it up.
37                 if event.ExtensionNumber == 0 {
38                         return
39                 }
40                 name, builtin, err := event.PeerConn.LocalLtepProtocolMap.LookupId(event.ExtensionNumber)
41                 c.Assert(err, qt.IsNil)
42                 // Not a user protocol.
43                 if builtin {
44                         return
45                 }
46                 switch name {
47                 case answerToName:
48                         u64, err := strconv.ParseUint(string(event.Payload), 10, 0)
49                         c.Assert(err, qt.IsNil)
50                         i := uint(u64)
51                         c.Logf("%v got %d", handlerName, i)
52                         if i == doneValue {
53                                 wg.Done()
54                                 return
55                         }
56                         c.Assert(i%2, qt.Equals, expectedMod2)
57                         go func() {
58                                 c.Assert(
59                                         event.PeerConn.WriteExtendedMessage(
60                                                 replyToName,
61                                                 []byte(strconv.FormatUint(uint64(i+1), 10))),
62                                         qt.IsNil)
63                         }()
64                 default:
65                         c.Fatalf("got unexpected extension name %q", name)
66                 }
67         }
68 }
69
70 func TestUserLtep(t *testing.T) {
71         c := qt.New(t)
72         var wg sync.WaitGroup
73
74         makeCfg := func() *ClientConfig {
75                 cfg := TestingConfig(t)
76                 // Only want a single connection to between the clients.
77                 cfg.DisableUTP = true
78                 cfg.DisableIPv6 = true
79                 return cfg
80         }
81
82         evensCfg := makeCfg()
83         evensCfg.Callbacks.ReadExtendedHandshake = func(pc *PeerConn, msg *pp.ExtendedHandshakeMessage) {
84                 // The client lock is held while handling this event, so we have to do synchronous work in a
85                 // separate goroutine.
86                 go func() {
87                         // Check sending an extended message for a protocol the peer doesn't support is an error.
88                         c.Check(pc.WriteExtendedMessage("pm_me_floats", []byte("3.142")), qt.IsNotNil)
89                         // Kick things off by sending a 1.
90                         c.Check(pc.WriteExtendedMessage(testRepliesToOddsExtensionName, []byte("1")), qt.IsNil)
91                 }()
92         }
93         evensCfg.Callbacks.PeerConnReadExtensionMessage = append(
94                 evensCfg.Callbacks.PeerConnReadExtensionMessage,
95                 countHandler(c, &wg, "evens", 0, testRepliesToEvensExtensionName, testRepliesToOddsExtensionName, 100))
96         evensCfg.Callbacks.PeerConnAdded = append(evensCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
97                 conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToEvensExtensionName)
98                 c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
99         })
100
101         oddsCfg := makeCfg()
102         oddsCfg.Callbacks.PeerConnAdded = append(oddsCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
103                 conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToOddsExtensionName)
104                 c.Assert(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], qt.HasLen, 1)
105         })
106         oddsCfg.Callbacks.PeerConnReadExtensionMessage = append(
107                 oddsCfg.Callbacks.PeerConnReadExtensionMessage,
108                 countHandler(c, &wg, "odds", 1, testRepliesToOddsExtensionName, testRepliesToEvensExtensionName, 100))
109
110         cl1, err := NewClient(oddsCfg)
111         c.Assert(err, qt.IsNil)
112         defer cl1.Close()
113         cl2, err := NewClient(evensCfg)
114         c.Assert(err, qt.IsNil)
115         defer cl2.Close()
116         addOpts := AddTorrentOpts{}
117         rand.Read(addOpts.InfoHash[:])
118         t1, _ := cl1.AddTorrentOpt(addOpts)
119         t2, _ := cl2.AddTorrentOpt(addOpts)
120         defer testutil.ExportStatusWriter(cl1, "cl1", t)()
121         defer testutil.ExportStatusWriter(cl2, "cl2", t)()
122         // Expect one PeerConn to see the value.
123         wg.Add(1)
124         added := t1.AddClientPeer(cl2)
125         // Ensure some addresses for the other client were added.
126         c.Assert(added, qt.Not(qt.Equals), 0)
127         wg.Wait()
128         _ = t2
129 }