From: Matt Joiner <anacrolix@gmail.com>
Date: Wed, 13 Sep 2017 06:19:14 +0000 (+1000)
Subject: mse: Support plaintext crypto mode
X-Git-Tag: v1.0.0~401
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=29e06fb83c8a82632ce8ad181ceec99da7181a92;p=btrtrc.git

mse: Support plaintext crypto mode
---

diff --git a/mse/mse.go b/mse/mse.go
index eb2480aa..8127b815 100644
--- a/mse/mse.go
+++ b/mse/mse.go
@@ -13,6 +13,7 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"math"
 	"math/big"
 	"strconv"
 	"sync"
@@ -25,6 +26,7 @@ const (
 
 	cryptoMethodPlaintext = 1
 	cryptoMethodRC4       = 2
+	AllSupportedCrypto    = cryptoMethodPlaintext | cryptoMethodRC4
 )
 
 var (
@@ -209,6 +211,10 @@ type handshake struct {
 	skeys  SecretKeyIter // Skeys we'll accept if receiving.
 	skey   []byte        // Skey we're initiating with.
 	ia     []byte        // Initial payload. Only used by the initiator.
+	// Return the bit for the crypto method the receiver wants to use.
+	chooseMethod func(supported uint32) uint32
+	// Sent to the receiver.
+	cryptoProvides uint32
 
 	writeMu    sync.Mutex
 	writes     [][]byte
@@ -365,11 +371,11 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
 	h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
 	buf := &bytes.Buffer{}
 	padLen := uint16(newPadLen())
-	err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
 	if len(h.ia) > math.MaxUint16 {
 		err = errors.New("initial payload too large")
 		return
 	}
+	err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
 	if err != nil {
 		return
 	}
@@ -398,15 +404,18 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
 	if err != nil {
 		return
 	}
-	if method != cryptoMethodRC4 {
-		err = fmt.Errorf("receiver chose unsupported method: %x", method)
-		return
-	}
 	_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
 	if err != nil {
 		return
 	}
-	ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+	switch method & h.cryptoProvides {
+	case cryptoMethodRC4:
+		ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+	case cryptoMethodPlaintext:
+		ret = h.conn
+	default:
+		err = fmt.Errorf("receiver chose unsupported method: %x", method)
+	}
 	return
 }
 
@@ -440,20 +449,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
 	}
 	r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
 	var (
-		vc     [8]byte
-		method uint32
-		padLen uint16
+		vc       [8]byte
+		provides uint32
+		padLen   uint16
 	)
 
-	err = unmarshal(r, vc[:], &method, &padLen)
+	err = unmarshal(r, vc[:], &provides, &padLen)
 	if err != nil {
 		return
 	}
-	cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
-	if method&cryptoMethodRC4 == 0 {
-		err = errors.New("no supported crypto methods were provided")
-		return
-	}
+	cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
+	chosen := h.chooseMethod(provides)
 	_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
 	if err != nil {
 		return
@@ -467,7 +473,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
 	buf := &bytes.Buffer{}
 	w := cipherWriter{h.newEncrypt(false), buf, nil}
 	padLen = uint16(newPadLen())
-	err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
+	err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
 	if err != nil {
 		return
 	}
@@ -475,7 +481,20 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
 	if err != nil {
 		return
 	}
-	ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}}
+	switch chosen {
+	case cryptoMethodRC4:
+		ret = readWriter{
+			io.MultiReader(bytes.NewReader(h.ia), r),
+			&cipherWriter{w.c, h.conn, nil},
+		}
+	case cryptoMethodPlaintext:
+		ret = readWriter{
+			io.MultiReader(bytes.NewReader(h.ia), h.conn),
+			h.conn,
+		}
+	default:
+		err = errors.New("chosen crypto method is not supported")
+	}
 	return
 }
 
@@ -508,21 +527,23 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
 	return
 }
 
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
 	h := handshake{
-		conn:   rw,
-		initer: true,
-		skey:   skey,
-		ia:     initialPayload,
+		conn:           rw,
+		initer:         true,
+		skey:           skey,
+		ia:             initialPayload,
+		cryptoProvides: cryptoProvides,
 	}
 	return h.Do()
 }
 
-func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
+func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
 	h := handshake{
-		conn:   rw,
-		initer: false,
-		skeys:  sliceIter(skeys),
+		conn:         rw,
+		initer:       false,
+		skeys:        sliceIter(skeys),
+		chooseMethod: selectCrypto,
 	}
 	return h.Do()
 }
@@ -551,3 +572,12 @@ func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWri
 	}
 	return h.Do()
 }
+
+func DefaultCryptoSelector(provided uint32) uint32 {
+	if provided&cryptoMethodRC4 != 0 {
+		return cryptoMethodRC4
+	}
+	return cryptoMethodPlaintext
+}
+
+type CryptoSelector func(uint32) uint32
diff --git a/mse/mse_test.go b/mse/mse_test.go
index 8c283e56..d9ef7b68 100644
--- a/mse/mse_test.go
+++ b/mse/mse_test.go
@@ -10,6 +10,8 @@ import (
 	"sync"
 	"testing"
 
+	_ "github.com/anacrolix/envpprof"
+
 	"github.com/bradfitz/iter"
 	"github.com/stretchr/testify/require"
 )
@@ -47,13 +49,13 @@ func TestSuffixMatchLen(t *testing.T) {
 	test("sup", "person", 1)
 }
 
-func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
+func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
 	a, b := net.Pipe()
 	wg := sync.WaitGroup{}
 	wg.Add(2)
 	go func() {
 		defer wg.Done()
-		a, err := InitiateHandshake(a, []byte("yep"), ia)
+		a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
 		if err != nil {
 			t.Fatal(err)
 			return
@@ -69,7 +71,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
 	}()
 	go func() {
 		defer wg.Done()
-		b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
+		b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}, cryptoSelect)
 		if err != nil {
 			t.Fatal(err)
 			return
@@ -89,20 +91,24 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
 	b.Close()
 }
 
-func allHandshakeTests(t testing.TB) {
-	handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
-	handshakeTest(t, nil, "hello world", "yo dawg")
-	handshakeTest(t, []byte{}, "hello world", "yo dawg")
+func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
+	handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
+	handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
+	handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
 }
 
-func TestHandshake(t *testing.T) {
-	allHandshakeTests(t)
+func TestHandshakeDefault(t *testing.T) {
+	allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
 	t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
 }
 
-func BenchmarkHandshake(b *testing.B) {
+func TestHandshakeSelectPlaintext(t *testing.T) {
+	allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return cryptoMethodPlaintext })
+}
+
+func BenchmarkHandshakeDefault(b *testing.B) {
 	for range iter.N(b.N) {
-		allHandshakeTests(b)
+		allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
 	}
 }
 
@@ -119,7 +125,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
 
 func TestReceiveRandomData(t *testing.T) {
 	tr := trackReader{rand.Reader, 0}
-	_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil)
+	_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
 	// No skey matches
 	require.Error(t, err)
 	// Establishing S, and then reading the maximum padding for giving up on
@@ -127,7 +133,82 @@ func TestReceiveRandomData(t *testing.T) {
 	require.EqualValues(t, 96+532, tr.n)
 }
 
-func BenchmarkPipe(t *testing.B) {
+func fillRand(t testing.TB, bs ...[]byte) {
+	for _, b := range bs {
+		_, err := rand.Read(b)
+		require.NoError(t, err)
+	}
+}
+
+func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
+	var wg sync.WaitGroup
+	wg.Add(1)
+	var wErr error
+	go func() {
+		defer wg.Done()
+		_, wErr = rw.Write(w)
+	}()
+	_, err := io.ReadFull(rw, r)
+	if err != nil {
+		return err
+	}
+	wg.Wait()
+	return wErr
+}
+
+func benchmarkStream(t *testing.B, crypto uint32) {
+	ia := make([]byte, 0x1000)
+	a := make([]byte, 1<<20)
+	b := make([]byte, 1<<20)
+	fillRand(t, ia, a, b)
+	t.StopTimer()
+	t.SetBytes(int64(len(ia) + len(a) + len(b)))
+	t.ResetTimer()
+	for range iter.N(t.N) {
+		ac, bc := net.Pipe()
+		ar := make([]byte, len(b))
+		br := make([]byte, len(ia)+len(a))
+		t.StartTimer()
+		var wg sync.WaitGroup
+		wg.Add(1)
+		go func() {
+			defer ac.Close()
+			defer wg.Done()
+			rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
+			require.NoError(t, err)
+			require.NoError(t, readAndWrite(rw, ar, a))
+		}()
+		func() {
+			defer bc.Close()
+			rw, err := ReceiveHandshake(bc, [][]byte{[]byte("cats")}, func(uint32) uint32 { return crypto })
+			require.NoError(t, err)
+			require.NoError(t, readAndWrite(rw, br, b))
+		}()
+		t.StopTimer()
+		if !bytes.Equal(ar, b) {
+			t.Fatalf("A read the wrong bytes")
+		}
+		if !bytes.Equal(br[:len(ia)], ia) {
+			t.Fatalf("B read the wrong IA")
+		}
+		if !bytes.Equal(br[len(ia):], a) {
+			t.Fatalf("B read the wrong A")
+		}
+		// require.Equal(t, b, ar)
+		// require.Equal(t, ia, br[:len(ia)])
+		// require.Equal(t, a, br[len(ia):])
+	}
+}
+
+func BenchmarkStreamRC4(t *testing.B) {
+	benchmarkStream(t, cryptoMethodRC4)
+}
+
+func BenchmarkStreamPlaintext(t *testing.B) {
+	benchmarkStream(t, cryptoMethodPlaintext)
+}
+
+func BenchmarkPipeRC4(t *testing.B) {
 	key := make([]byte, 20)
 	n, _ := rand.Read(key)
 	require.Equal(t, len(key), n)