]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse_test.go
Support plaintext crypto method for protocol header encryption
[btrtrc.git] / mse / mse_test.go
1 package mse
2
3 import (
4         "bytes"
5         "crypto/rand"
6         "crypto/rc4"
7         "io"
8         "io/ioutil"
9         "net"
10         "sync"
11         "testing"
12
13         _ "github.com/anacrolix/envpprof"
14
15         "github.com/bradfitz/iter"
16         "github.com/stretchr/testify/require"
17 )
18
19 func TestReadUntil(t *testing.T) {
20         test := func(data, until string, leftover int, expectedErr error) {
21                 r := bytes.NewReader([]byte(data))
22                 err := readUntil(r, []byte(until))
23                 if err != expectedErr {
24                         t.Fatal(err)
25                 }
26                 if r.Len() != leftover {
27                         t.Fatal(r.Len())
28                 }
29         }
30         test("feakjfeafeafegbaabc00", "abc", 2, nil)
31         test("feakjfeafeafegbaadc00", "abc", 0, io.EOF)
32 }
33
34 func TestSuffixMatchLen(t *testing.T) {
35         test := func(a, b string, expected int) {
36                 actual := suffixMatchLen([]byte(a), []byte(b))
37                 if actual != expected {
38                         t.Fatalf("expected %d, got %d for %q and %q", expected, actual, a, b)
39                 }
40         }
41         test("hello", "world", 0)
42         test("hello", "lo", 2)
43         test("hello", "llo", 3)
44         test("hello", "hell", 0)
45         test("hello", "helloooo!", 5)
46         test("hello", "lol!", 2)
47         test("hello", "mondo", 0)
48         test("mongo", "webscale", 0)
49         test("sup", "person", 1)
50 }
51
52 func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
53         a, b := net.Pipe()
54         wg := sync.WaitGroup{}
55         wg.Add(2)
56         go func() {
57                 defer wg.Done()
58                 a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
59                 if err != nil {
60                         t.Fatal(err)
61                         return
62                 }
63                 go a.Write([]byte(aData))
64
65                 var msg [20]byte
66                 n, _ := a.Read(msg[:])
67                 if n != len(bData) {
68                         t.FailNow()
69                 }
70                 // t.Log(string(msg[:n]))
71         }()
72         go func() {
73                 defer wg.Done()
74                 b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
75                 if err != nil {
76                         t.Fatal(err)
77                         return
78                 }
79                 go b.Write([]byte(bData))
80                 // Need to be exact here, as there are several reads, and net.Pipe is
81                 // most synchronous.
82                 msg := make([]byte, len(ia)+len(aData))
83                 n, _ := io.ReadFull(b, msg[:])
84                 if n != len(msg) {
85                         t.FailNow()
86                 }
87                 // t.Log(string(msg[:n]))
88         }()
89         wg.Wait()
90         a.Close()
91         b.Close()
92 }
93
94 func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
95         handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
96         handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
97         handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
98 }
99
100 func TestHandshakeDefault(t *testing.T) {
101         allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
102         t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
103 }
104
105 func TestHandshakeSelectPlaintext(t *testing.T) {
106         allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return CryptoMethodPlaintext })
107 }
108
109 func BenchmarkHandshakeDefault(b *testing.B) {
110         for range iter.N(b.N) {
111                 allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
112         }
113 }
114
115 type trackReader struct {
116         r io.Reader
117         n int64
118 }
119
120 func (tr *trackReader) Read(b []byte) (n int, err error) {
121         n, err = tr.r.Read(b)
122         tr.n += int64(n)
123         return
124 }
125
126 func TestReceiveRandomData(t *testing.T) {
127         tr := trackReader{rand.Reader, 0}
128         _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
129         // No skey matches
130         require.Error(t, err)
131         // Establishing S, and then reading the maximum padding for giving up on
132         // synchronizing.
133         require.EqualValues(t, 96+532, tr.n)
134 }
135
136 func fillRand(t testing.TB, bs ...[]byte) {
137         for _, b := range bs {
138                 _, err := rand.Read(b)
139                 require.NoError(t, err)
140         }
141 }
142
143 func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
144         var wg sync.WaitGroup
145         wg.Add(1)
146         var wErr error
147         go func() {
148                 defer wg.Done()
149                 _, wErr = rw.Write(w)
150         }()
151         _, err := io.ReadFull(rw, r)
152         if err != nil {
153                 return err
154         }
155         wg.Wait()
156         return wErr
157 }
158
159 func benchmarkStream(t *testing.B, crypto uint32) {
160         ia := make([]byte, 0x1000)
161         a := make([]byte, 1<<20)
162         b := make([]byte, 1<<20)
163         fillRand(t, ia, a, b)
164         t.StopTimer()
165         t.SetBytes(int64(len(ia) + len(a) + len(b)))
166         t.ResetTimer()
167         for range iter.N(t.N) {
168                 ac, bc := net.Pipe()
169                 ar := make([]byte, len(b))
170                 br := make([]byte, len(ia)+len(a))
171                 t.StartTimer()
172                 var wg sync.WaitGroup
173                 wg.Add(1)
174                 go func() {
175                         defer ac.Close()
176                         defer wg.Done()
177                         rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
178                         require.NoError(t, err)
179                         require.NoError(t, readAndWrite(rw, ar, a))
180                 }()
181                 func() {
182                         defer bc.Close()
183                         rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(uint32) uint32 { return crypto })
184                         require.NoError(t, err)
185                         require.NoError(t, readAndWrite(rw, br, b))
186                 }()
187                 t.StopTimer()
188                 if !bytes.Equal(ar, b) {
189                         t.Fatalf("A read the wrong bytes")
190                 }
191                 if !bytes.Equal(br[:len(ia)], ia) {
192                         t.Fatalf("B read the wrong IA")
193                 }
194                 if !bytes.Equal(br[len(ia):], a) {
195                         t.Fatalf("B read the wrong A")
196                 }
197                 // require.Equal(t, b, ar)
198                 // require.Equal(t, ia, br[:len(ia)])
199                 // require.Equal(t, a, br[len(ia):])
200         }
201 }
202
203 func BenchmarkStreamRC4(t *testing.B) {
204         benchmarkStream(t, CryptoMethodRC4)
205 }
206
207 func BenchmarkStreamPlaintext(t *testing.B) {
208         benchmarkStream(t, CryptoMethodPlaintext)
209 }
210
211 func BenchmarkPipeRC4(t *testing.B) {
212         key := make([]byte, 20)
213         n, _ := rand.Read(key)
214         require.Equal(t, len(key), n)
215         var buf bytes.Buffer
216         c, err := rc4.NewCipher(key)
217         require.NoError(t, err)
218         r := cipherReader{
219                 c: c,
220                 r: &buf,
221         }
222         c, err = rc4.NewCipher(key)
223         require.NoError(t, err)
224         w := cipherWriter{
225                 c: c,
226                 w: &buf,
227         }
228         a := make([]byte, 0x1000)
229         n, _ = io.ReadFull(rand.Reader, a)
230         require.Equal(t, len(a), n)
231         b := make([]byte, len(a))
232         t.SetBytes(int64(len(a)))
233         t.ResetTimer()
234         for range iter.N(t.N) {
235                 n, _ = w.Write(a)
236                 if n != len(a) {
237                         t.FailNow()
238                 }
239                 n, _ = r.Read(b)
240                 if n != len(b) {
241                         t.FailNow()
242                 }
243                 if !bytes.Equal(a, b) {
244                         t.FailNow()
245                 }
246         }
247 }