]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse_test.go
519135043a17b7993f88c8cf49d4502c4ff4d6cd
[btrtrc.git] / mse / mse_test.go
1 package mse
2
3 import (
4         "bytes"
5         "crypto/rand"
6         "crypto/rc4"
7         "io"
8         "io/ioutil"
9         "net"
10         "sync"
11         "testing"
12
13         _ "github.com/anacrolix/envpprof"
14         "github.com/bradfitz/iter"
15         "github.com/stretchr/testify/require"
16 )
17
18 func sliceIter(skeys [][]byte) SecretKeyIter {
19         return func(callback func([]byte) bool) {
20                 for _, sk := range skeys {
21                         if !callback(sk) {
22                                 break
23                         }
24                 }
25         }
26 }
27
28 func TestReadUntil(t *testing.T) {
29         test := func(data, until string, leftover int, expectedErr error) {
30                 r := bytes.NewReader([]byte(data))
31                 err := readUntil(r, []byte(until))
32                 if err != expectedErr {
33                         t.Fatal(err)
34                 }
35                 if r.Len() != leftover {
36                         t.Fatal(r.Len())
37                 }
38         }
39         test("feakjfeafeafegbaabc00", "abc", 2, nil)
40         test("feakjfeafeafegbaadc00", "abc", 0, io.EOF)
41 }
42
43 func TestSuffixMatchLen(t *testing.T) {
44         test := func(a, b string, expected int) {
45                 actual := suffixMatchLen([]byte(a), []byte(b))
46                 if actual != expected {
47                         t.Fatalf("expected %d, got %d for %q and %q", expected, actual, a, b)
48                 }
49         }
50         test("hello", "world", 0)
51         test("hello", "lo", 2)
52         test("hello", "llo", 3)
53         test("hello", "hell", 0)
54         test("hello", "helloooo!", 5)
55         test("hello", "lol!", 2)
56         test("hello", "mondo", 0)
57         test("mongo", "webscale", 0)
58         test("sup", "person", 1)
59 }
60
61 func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) {
62         a, b := net.Pipe()
63         wg := sync.WaitGroup{}
64         wg.Add(2)
65         go func() {
66                 defer wg.Done()
67                 a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
68                 if err != nil {
69                         t.Fatal(err)
70                         return
71                 }
72                 go a.Write([]byte(aData))
73
74                 var msg [20]byte
75                 n, _ := a.Read(msg[:])
76                 if n != len(bData) {
77                         t.FailNow()
78                 }
79                 // t.Log(string(msg[:n]))
80         }()
81         go func() {
82                 defer wg.Done()
83                 b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
84                 if err != nil {
85                         t.Fatal(err)
86                         return
87                 }
88                 go b.Write([]byte(bData))
89                 // Need to be exact here, as there are several reads, and net.Pipe is
90                 // most synchronous.
91                 msg := make([]byte, len(ia)+len(aData))
92                 n, _ := io.ReadFull(b, msg[:])
93                 if n != len(msg) {
94                         t.FailNow()
95                 }
96                 // t.Log(string(msg[:n]))
97         }()
98         wg.Wait()
99         a.Close()
100         b.Close()
101 }
102
103 func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) {
104         handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
105         handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
106         handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
107 }
108
109 func TestHandshakeDefault(t *testing.T) {
110         allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
111         t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
112 }
113
114 func TestHandshakeSelectPlaintext(t *testing.T) {
115         allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext })
116 }
117
118 func BenchmarkHandshakeDefault(b *testing.B) {
119         for range iter.N(b.N) {
120                 allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
121         }
122 }
123
124 type trackReader struct {
125         r io.Reader
126         n int64
127 }
128
129 func (tr *trackReader) Read(b []byte) (n int, err error) {
130         n, err = tr.r.Read(b)
131         tr.n += int64(n)
132         return
133 }
134
135 func TestReceiveRandomData(t *testing.T) {
136         tr := trackReader{rand.Reader, 0}
137         _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
138         // No skey matches
139         require.Error(t, err)
140         // Establishing S, and then reading the maximum padding for giving up on
141         // synchronizing.
142         require.EqualValues(t, 96+532, tr.n)
143 }
144
145 func fillRand(t testing.TB, bs ...[]byte) {
146         for _, b := range bs {
147                 _, err := rand.Read(b)
148                 require.NoError(t, err)
149         }
150 }
151
152 func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
153         var wg sync.WaitGroup
154         wg.Add(1)
155         var wErr error
156         go func() {
157                 defer wg.Done()
158                 _, wErr = rw.Write(w)
159         }()
160         _, err := io.ReadFull(rw, r)
161         if err != nil {
162                 return err
163         }
164         wg.Wait()
165         return wErr
166 }
167
168 func benchmarkStream(t *testing.B, crypto CryptoMethod) {
169         ia := make([]byte, 0x1000)
170         a := make([]byte, 1<<20)
171         b := make([]byte, 1<<20)
172         fillRand(t, ia, a, b)
173         t.StopTimer()
174         t.SetBytes(int64(len(ia) + len(a) + len(b)))
175         t.ResetTimer()
176         for range iter.N(t.N) {
177                 ac, bc := net.Pipe()
178                 ar := make([]byte, len(b))
179                 br := make([]byte, len(ia)+len(a))
180                 t.StartTimer()
181                 var wg sync.WaitGroup
182                 wg.Add(1)
183                 go func() {
184                         defer ac.Close()
185                         defer wg.Done()
186                         rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
187                         require.NoError(t, err)
188                         require.NoError(t, readAndWrite(rw, ar, a))
189                 }()
190                 func() {
191                         defer bc.Close()
192                         rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
193                         require.NoError(t, err)
194                         require.NoError(t, readAndWrite(rw, br, b))
195                 }()
196                 t.StopTimer()
197                 if !bytes.Equal(ar, b) {
198                         t.Fatalf("A read the wrong bytes")
199                 }
200                 if !bytes.Equal(br[:len(ia)], ia) {
201                         t.Fatalf("B read the wrong IA")
202                 }
203                 if !bytes.Equal(br[len(ia):], a) {
204                         t.Fatalf("B read the wrong A")
205                 }
206                 // require.Equal(t, b, ar)
207                 // require.Equal(t, ia, br[:len(ia)])
208                 // require.Equal(t, a, br[len(ia):])
209         }
210 }
211
212 func BenchmarkStreamRC4(t *testing.B) {
213         benchmarkStream(t, CryptoMethodRC4)
214 }
215
216 func BenchmarkStreamPlaintext(t *testing.B) {
217         benchmarkStream(t, CryptoMethodPlaintext)
218 }
219
220 func BenchmarkPipeRC4(t *testing.B) {
221         key := make([]byte, 20)
222         n, _ := rand.Read(key)
223         require.Equal(t, len(key), n)
224         var buf bytes.Buffer
225         c, err := rc4.NewCipher(key)
226         require.NoError(t, err)
227         r := cipherReader{
228                 c: c,
229                 r: &buf,
230         }
231         c, err = rc4.NewCipher(key)
232         require.NoError(t, err)
233         w := cipherWriter{
234                 c: c,
235                 w: &buf,
236         }
237         a := make([]byte, 0x1000)
238         n, _ = io.ReadFull(rand.Reader, a)
239         require.Equal(t, len(a), n)
240         b := make([]byte, len(a))
241         t.SetBytes(int64(len(a)))
242         t.ResetTimer()
243         for range iter.N(t.N) {
244                 n, _ = w.Write(a)
245                 if n != len(a) {
246                         t.FailNow()
247                 }
248                 n, _ = r.Read(b)
249                 if n != len(b) {
250                         t.FailNow()
251                 }
252                 if !bytes.Equal(a, b) {
253                         t.FailNow()
254                 }
255         }
256 }