"crypto/rand"
"crypto/rc4"
"io"
- "io/ioutil"
"net"
"sync"
"testing"
_ "github.com/anacrolix/envpprof"
-
- "github.com/bradfitz/iter"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
test("sup", "person", 1)
}
-func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
+func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) {
a, b := net.Pipe()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
- a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
- if err != nil {
- t.Fatal(err)
- return
- }
+ a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
+ require.NoError(t, err)
+ assert.Equal(t, cryptoSelect(cryptoProvides), cm)
go a.Write([]byte(aData))
var msg [20]byte
}()
go func() {
defer wg.Done()
- b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
- if err != nil {
- t.Fatal(err)
- return
- }
+ res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
+ require.NoError(t, res.error)
+ assert.EqualValues(t, "yep", res.SecretKey)
+ b := res.ReadWriter
+ assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod)
go b.Write([]byte(bData))
- // Need to be exact here, as there are several reads, and net.Pipe is
- // most synchronous.
+ // Need to be exact here, as there are several reads, and net.Pipe is most synchronous.
msg := make([]byte, len(ia)+len(aData))
- n, _ := io.ReadFull(b, msg[:])
+ n, _ := io.ReadFull(b, msg)
if n != len(msg) {
t.FailNow()
}
b.Close()
}
-func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
+func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) {
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
}
func TestHandshakeSelectPlaintext(t *testing.T) {
- allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return CryptoMethodPlaintext })
+ allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext })
}
func BenchmarkHandshakeDefault(b *testing.B) {
- for range iter.N(b.N) {
+ for i := 0; i < b.N; i += 1 {
allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
}
}
func TestReceiveRandomData(t *testing.T) {
tr := trackReader{rand.Reader, 0}
- _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
+ _, _, err := ReceiveHandshake(readWriter{&tr, io.Discard}, nil, DefaultCryptoSelector)
// No skey matches
require.Error(t, err)
// Establishing S, and then reading the maximum padding for giving up on
}
}
-func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
+func readAndWrite(rw io.ReadWriter, r, w []byte) error {
var wg sync.WaitGroup
wg.Add(1)
var wErr error
return wErr
}
-func benchmarkStream(t *testing.B, crypto uint32) {
+func benchmarkStream(t *testing.B, crypto CryptoMethod) {
ia := make([]byte, 0x1000)
a := make([]byte, 1<<20)
b := make([]byte, 1<<20)
t.StopTimer()
t.SetBytes(int64(len(ia) + len(a) + len(b)))
t.ResetTimer()
- for range iter.N(t.N) {
+ for i := 0; i < t.N; i += 1 {
ac, bc := net.Pipe()
ar := make([]byte, len(b))
br := make([]byte, len(ia)+len(a))
go func() {
defer ac.Close()
defer wg.Done()
- rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
+ rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, ar, a))
}()
func() {
defer bc.Close()
- rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(uint32) uint32 { return crypto })
+ rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b))
}()
+ wg.Wait()
t.StopTimer()
if !bytes.Equal(ar, b) {
t.Fatalf("A read the wrong bytes")
b := make([]byte, len(a))
t.SetBytes(int64(len(a)))
t.ResetTimer()
- for range iter.N(t.N) {
+ for i := 0; i < t.N; i += 1 {
n, _ = w.Write(a)
if n != len(a) {
t.FailNow()
}
}
}
+
+func BenchmarkSkeysReceive(b *testing.B) {
+ var skeys [][]byte
+ for i := 0; i < 100000; i += 1 {
+ skeys = append(skeys, make([]byte, 20))
+ }
+ fillRand(b, skeys...)
+ initSkey := skeys[len(skeys)/2]
+ // c := qt.New(b)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i += 1 {
+ initiator, receiver := net.Pipe()
+ go func() {
+ _, _, err := InitiateHandshake(initiator, initSkey, nil, AllSupportedCrypto)
+ if err != nil {
+ panic(err)
+ }
+ }()
+ res := ReceiveHandshakeEx(receiver, sliceIter(skeys), DefaultCryptoSelector)
+ if res.error != nil {
+ panic(res.error)
+ }
+ }
+}