From 29e06fb83c8a82632ce8ad181ceec99da7181a92 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 13 Sep 2017 16:19:14 +1000 Subject: [PATCH] mse: Support plaintext crypto mode --- mse/mse.go | 82 +++++++++++++++++++++++++------------ mse/mse_test.go | 107 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 150 insertions(+), 39 deletions(-) diff --git a/mse/mse.go b/mse/mse.go index eb2480aa..8127b815 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "math/big" "strconv" "sync" @@ -25,6 +26,7 @@ const ( cryptoMethodPlaintext = 1 cryptoMethodRC4 = 2 + AllSupportedCrypto = cryptoMethodPlaintext | cryptoMethodRC4 ) var ( @@ -209,6 +211,10 @@ type handshake struct { skeys SecretKeyIter // Skeys we'll accept if receiving. skey []byte // Skey we're initiating with. ia []byte // Initial payload. Only used by the initiator. + // Return the bit for the crypto method the receiver wants to use. + chooseMethod func(supported uint32) uint32 + // Sent to the receiver. + cryptoProvides uint32 writeMu sync.Mutex writes [][]byte @@ -365,11 +371,11 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:]))) buf := &bytes.Buffer{} padLen := uint16(newPadLen()) - err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia) if len(h.ia) > math.MaxUint16 { err = errors.New("initial payload too large") return } + err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia) if err != nil { return } @@ -398,15 +404,18 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { if err != nil { return } - if method != cryptoMethodRC4 { - err = fmt.Errorf("receiver chose unsupported method: %x", method) - return - } _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) if err != nil { return } - ret = readWriter{r, &cipherWriter{e, h.conn, nil}} + switch method & h.cryptoProvides { + case cryptoMethodRC4: + ret = readWriter{r, &cipherWriter{e, h.conn, nil}} + case cryptoMethodPlaintext: + ret = h.conn + default: + err = fmt.Errorf("receiver chose unsupported method: %x", method) + } return } @@ -440,20 +449,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { } r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn) var ( - vc [8]byte - method uint32 - padLen uint16 + vc [8]byte + provides uint32 + padLen uint16 ) - err = unmarshal(r, vc[:], &method, &padLen) + err = unmarshal(r, vc[:], &provides, &padLen) if err != nil { return } - cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1) - if method&cryptoMethodRC4 == 0 { - err = errors.New("no supported crypto methods were provided") - return - } + cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1) + chosen := h.chooseMethod(provides) _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) if err != nil { return @@ -467,7 +473,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { buf := &bytes.Buffer{} w := cipherWriter{h.newEncrypt(false), buf, nil} padLen = uint16(newPadLen()) - err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen]) + err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen]) if err != nil { return } @@ -475,7 +481,20 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { if err != nil { return } - ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}} + switch chosen { + case cryptoMethodRC4: + ret = readWriter{ + io.MultiReader(bytes.NewReader(h.ia), r), + &cipherWriter{w.c, h.conn, nil}, + } + case cryptoMethodPlaintext: + ret = readWriter{ + io.MultiReader(bytes.NewReader(h.ia), h.conn), + h.conn, + } + default: + err = errors.New("chosen crypto method is not supported") + } return } @@ -508,21 +527,23 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) { return } -func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) { +func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) { h := handshake{ - conn: rw, - initer: true, - skey: skey, - ia: initialPayload, + conn: rw, + initer: true, + skey: skey, + ia: initialPayload, + cryptoProvides: cryptoProvides, } return h.Do() } -func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) { +func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) { h := handshake{ - conn: rw, - initer: false, - skeys: sliceIter(skeys), + conn: rw, + initer: false, + skeys: sliceIter(skeys), + chooseMethod: selectCrypto, } return h.Do() } @@ -551,3 +572,12 @@ func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWri } return h.Do() } + +func DefaultCryptoSelector(provided uint32) uint32 { + if provided&cryptoMethodRC4 != 0 { + return cryptoMethodRC4 + } + return cryptoMethodPlaintext +} + +type CryptoSelector func(uint32) uint32 diff --git a/mse/mse_test.go b/mse/mse_test.go index 8c283e56..d9ef7b68 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -10,6 +10,8 @@ import ( "sync" "testing" + _ "github.com/anacrolix/envpprof" + "github.com/bradfitz/iter" "github.com/stretchr/testify/require" ) @@ -47,13 +49,13 @@ func TestSuffixMatchLen(t *testing.T) { test("sup", "person", 1) } -func handshakeTest(t testing.TB, ia []byte, aData, bData string) { +func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) { a, b := net.Pipe() wg := sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() - a, err := InitiateHandshake(a, []byte("yep"), ia) + a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides) if err != nil { t.Fatal(err) return @@ -69,7 +71,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) { }() go func() { defer wg.Done() - b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}) + b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}, cryptoSelect) if err != nil { t.Fatal(err) return @@ -89,20 +91,24 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) { b.Close() } -func allHandshakeTests(t testing.TB) { - handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg") - handshakeTest(t, nil, "hello world", "yo dawg") - handshakeTest(t, []byte{}, "hello world", "yo dawg") +func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) { + handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector) + handshakeTest(t, nil, "hello world", "yo dawg", provides, selector) + handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector) } -func TestHandshake(t *testing.T) { - allHandshakeTests(t) +func TestHandshakeDefault(t *testing.T) { + allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector) t.Logf("crypto provides encountered: %s", cryptoProvidesCount) } -func BenchmarkHandshake(b *testing.B) { +func TestHandshakeSelectPlaintext(t *testing.T) { + allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return cryptoMethodPlaintext }) +} + +func BenchmarkHandshakeDefault(b *testing.B) { for range iter.N(b.N) { - allHandshakeTests(b) + allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector) } } @@ -119,7 +125,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) { func TestReceiveRandomData(t *testing.T) { tr := trackReader{rand.Reader, 0} - _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil) + _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector) // No skey matches require.Error(t, err) // Establishing S, and then reading the maximum padding for giving up on @@ -127,7 +133,82 @@ func TestReceiveRandomData(t *testing.T) { require.EqualValues(t, 96+532, tr.n) } -func BenchmarkPipe(t *testing.B) { +func fillRand(t testing.TB, bs ...[]byte) { + for _, b := range bs { + _, err := rand.Read(b) + require.NoError(t, err) + } +} + +func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error { + var wg sync.WaitGroup + wg.Add(1) + var wErr error + go func() { + defer wg.Done() + _, wErr = rw.Write(w) + }() + _, err := io.ReadFull(rw, r) + if err != nil { + return err + } + wg.Wait() + return wErr +} + +func benchmarkStream(t *testing.B, crypto uint32) { + ia := make([]byte, 0x1000) + a := make([]byte, 1<<20) + b := make([]byte, 1<<20) + fillRand(t, ia, a, b) + t.StopTimer() + t.SetBytes(int64(len(ia) + len(a) + len(b))) + t.ResetTimer() + for range iter.N(t.N) { + ac, bc := net.Pipe() + ar := make([]byte, len(b)) + br := make([]byte, len(ia)+len(a)) + t.StartTimer() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer ac.Close() + defer wg.Done() + rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto) + require.NoError(t, err) + require.NoError(t, readAndWrite(rw, ar, a)) + }() + func() { + defer bc.Close() + rw, err := ReceiveHandshake(bc, [][]byte{[]byte("cats")}, func(uint32) uint32 { return crypto }) + require.NoError(t, err) + require.NoError(t, readAndWrite(rw, br, b)) + }() + t.StopTimer() + if !bytes.Equal(ar, b) { + t.Fatalf("A read the wrong bytes") + } + if !bytes.Equal(br[:len(ia)], ia) { + t.Fatalf("B read the wrong IA") + } + if !bytes.Equal(br[len(ia):], a) { + t.Fatalf("B read the wrong A") + } + // require.Equal(t, b, ar) + // require.Equal(t, ia, br[:len(ia)]) + // require.Equal(t, a, br[len(ia):]) + } +} + +func BenchmarkStreamRC4(t *testing.B) { + benchmarkStream(t, cryptoMethodRC4) +} + +func BenchmarkStreamPlaintext(t *testing.B) { + benchmarkStream(t, cryptoMethodPlaintext) +} + +func BenchmarkPipeRC4(t *testing.B) { key := make([]byte, 20) n, _ := rand.Read(key) require.Equal(t, len(key), n) -- 2.48.1