]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse_test.go
fixed code quality issues using DeepSource
[btrtrc.git] / mse / mse_test.go
1 package mse
2
3 import (
4         "bytes"
5         "crypto/rand"
6         "crypto/rc4"
7         "io"
8         "io/ioutil"
9         "net"
10         "sync"
11         "testing"
12
13         _ "github.com/anacrolix/envpprof"
14         "github.com/bradfitz/iter"
15         "github.com/stretchr/testify/assert"
16         "github.com/stretchr/testify/require"
17 )
18
19 func sliceIter(skeys [][]byte) SecretKeyIter {
20         return func(callback func([]byte) bool) {
21                 for _, sk := range skeys {
22                         if !callback(sk) {
23                                 break
24                         }
25                 }
26         }
27 }
28
29 func TestReadUntil(t *testing.T) {
30         test := func(data, until string, leftover int, expectedErr error) {
31                 r := bytes.NewReader([]byte(data))
32                 err := readUntil(r, []byte(until))
33                 if err != expectedErr {
34                         t.Fatal(err)
35                 }
36                 if r.Len() != leftover {
37                         t.Fatal(r.Len())
38                 }
39         }
40         test("feakjfeafeafegbaabc00", "abc", 2, nil)
41         test("feakjfeafeafegbaadc00", "abc", 0, io.EOF)
42 }
43
44 func TestSuffixMatchLen(t *testing.T) {
45         test := func(a, b string, expected int) {
46                 actual := suffixMatchLen([]byte(a), []byte(b))
47                 if actual != expected {
48                         t.Fatalf("expected %d, got %d for %q and %q", expected, actual, a, b)
49                 }
50         }
51         test("hello", "world", 0)
52         test("hello", "lo", 2)
53         test("hello", "llo", 3)
54         test("hello", "hell", 0)
55         test("hello", "helloooo!", 5)
56         test("hello", "lol!", 2)
57         test("hello", "mondo", 0)
58         test("mongo", "webscale", 0)
59         test("sup", "person", 1)
60 }
61
62 func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) {
63         a, b := net.Pipe()
64         wg := sync.WaitGroup{}
65         wg.Add(2)
66         go func() {
67                 defer wg.Done()
68                 a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
69                 require.NoError(t, err)
70                 assert.Equal(t, cryptoSelect(cryptoProvides), cm)
71                 go a.Write([]byte(aData))
72
73                 var msg [20]byte
74                 n, _ := a.Read(msg[:])
75                 if n != len(bData) {
76                         t.FailNow()
77                 }
78                 // t.Log(string(msg[:n]))
79         }()
80         go func() {
81                 defer wg.Done()
82                 res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
83                 require.NoError(t, res.error)
84                 assert.EqualValues(t, "yep", res.SecretKey)
85                 b := res.ReadWriter
86                 assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod)
87                 go b.Write([]byte(bData))
88                 // Need to be exact here, as there are several reads, and net.Pipe is most synchronous.
89                 msg := make([]byte, len(ia)+len(aData))
90                 n, _ := io.ReadFull(b, msg)
91                 if n != len(msg) {
92                         t.FailNow()
93                 }
94                 // t.Log(string(msg[:n]))
95         }()
96         wg.Wait()
97         a.Close()
98         b.Close()
99 }
100
101 func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) {
102         handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
103         handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
104         handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
105 }
106
107 func TestHandshakeDefault(t *testing.T) {
108         allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
109         t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
110 }
111
112 func TestHandshakeSelectPlaintext(t *testing.T) {
113         allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext })
114 }
115
116 func BenchmarkHandshakeDefault(b *testing.B) {
117         for range iter.N(b.N) {
118                 allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
119         }
120 }
121
122 type trackReader struct {
123         r io.Reader
124         n int64
125 }
126
127 func (tr *trackReader) Read(b []byte) (n int, err error) {
128         n, err = tr.r.Read(b)
129         tr.n += int64(n)
130         return
131 }
132
133 func TestReceiveRandomData(t *testing.T) {
134         tr := trackReader{rand.Reader, 0}
135         _, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
136         // No skey matches
137         require.Error(t, err)
138         // Establishing S, and then reading the maximum padding for giving up on
139         // synchronizing.
140         require.EqualValues(t, 96+532, tr.n)
141 }
142
143 func fillRand(t testing.TB, bs ...[]byte) {
144         for _, b := range bs {
145                 _, err := rand.Read(b)
146                 require.NoError(t, err)
147         }
148 }
149
150 func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
151         var wg sync.WaitGroup
152         wg.Add(1)
153         var wErr error
154         go func() {
155                 defer wg.Done()
156                 _, wErr = rw.Write(w)
157         }()
158         _, err := io.ReadFull(rw, r)
159         if err != nil {
160                 return err
161         }
162         wg.Wait()
163         return wErr
164 }
165
166 func benchmarkStream(t *testing.B, crypto CryptoMethod) {
167         ia := make([]byte, 0x1000)
168         a := make([]byte, 1<<20)
169         b := make([]byte, 1<<20)
170         fillRand(t, ia, a, b)
171         t.StopTimer()
172         t.SetBytes(int64(len(ia) + len(a) + len(b)))
173         t.ResetTimer()
174         for range iter.N(t.N) {
175                 ac, bc := net.Pipe()
176                 ar := make([]byte, len(b))
177                 br := make([]byte, len(ia)+len(a))
178                 t.StartTimer()
179                 var wg sync.WaitGroup
180                 wg.Add(1)
181                 go func() {
182                         defer ac.Close()
183                         defer wg.Done()
184                         rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
185                         require.NoError(t, err)
186                         require.NoError(t, readAndWrite(rw, ar, a))
187                 }()
188                 func() {
189                         defer bc.Close()
190                         rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
191                         require.NoError(t, err)
192                         require.NoError(t, readAndWrite(rw, br, b))
193                 }()
194                 wg.Wait()
195                 t.StopTimer()
196                 if !bytes.Equal(ar, b) {
197                         t.Fatalf("A read the wrong bytes")
198                 }
199                 if !bytes.Equal(br[:len(ia)], ia) {
200                         t.Fatalf("B read the wrong IA")
201                 }
202                 if !bytes.Equal(br[len(ia):], a) {
203                         t.Fatalf("B read the wrong A")
204                 }
205                 // require.Equal(t, b, ar)
206                 // require.Equal(t, ia, br[:len(ia)])
207                 // require.Equal(t, a, br[len(ia):])
208         }
209 }
210
211 func BenchmarkStreamRC4(t *testing.B) {
212         benchmarkStream(t, CryptoMethodRC4)
213 }
214
215 func BenchmarkStreamPlaintext(t *testing.B) {
216         benchmarkStream(t, CryptoMethodPlaintext)
217 }
218
219 func BenchmarkPipeRC4(t *testing.B) {
220         key := make([]byte, 20)
221         n, _ := rand.Read(key)
222         require.Equal(t, len(key), n)
223         var buf bytes.Buffer
224         c, err := rc4.NewCipher(key)
225         require.NoError(t, err)
226         r := cipherReader{
227                 c: c,
228                 r: &buf,
229         }
230         c, err = rc4.NewCipher(key)
231         require.NoError(t, err)
232         w := cipherWriter{
233                 c: c,
234                 w: &buf,
235         }
236         a := make([]byte, 0x1000)
237         n, _ = io.ReadFull(rand.Reader, a)
238         require.Equal(t, len(a), n)
239         b := make([]byte, len(a))
240         t.SetBytes(int64(len(a)))
241         t.ResetTimer()
242         for range iter.N(t.N) {
243                 n, _ = w.Write(a)
244                 if n != len(a) {
245                         t.FailNow()
246                 }
247                 n, _ = r.Read(b)
248                 if n != len(b) {
249                         t.FailNow()
250                 }
251                 if !bytes.Equal(a, b) {
252                         t.FailNow()
253                 }
254         }
255 }
256
257 func BenchmarkSkeysReceive(b *testing.B) {
258         var skeys [][]byte
259         for range iter.N(100000) {
260                 skeys = append(skeys, make([]byte, 20))
261         }
262         fillRand(b, skeys...)
263         initSkey := skeys[len(skeys)/2]
264         //c := qt.New(b)
265         b.ReportAllocs()
266         b.ResetTimer()
267         for range iter.N(b.N) {
268                 initiator, receiver := net.Pipe()
269                 go func() {
270                         _, _, err := InitiateHandshake(initiator, initSkey, nil, AllSupportedCrypto)
271                         if err != nil {
272                                 panic(err)
273                         }
274                 }()
275                 res := ReceiveHandshakeEx(receiver, sliceIter(skeys), DefaultCryptoSelector)
276                 if res.error != nil {
277                         panic(res.error)
278                 }
279         }
280 }