From 73823ee61d34ba8846644535f83e9510a394914f Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Mon, 4 Jan 2021 10:49:28 +1100 Subject: [PATCH] Add mse.ReceiveHandshakeEx --- mse/mse.go | 18 ++++++++++++++++-- mse/mse_test.go | 11 ++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mse/mse.go b/mse/mse.go index b51a5aaf..85d55a7d 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -542,14 +542,28 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry return h.Do() } -func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) { +type HandshakeResult struct { + io.ReadWriter + CryptoMethod + error + SecretKey []byte +} + +func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) { + res := ReceiveHandshakeEx(rw, skeys, selectCrypto) + return res.ReadWriter, res.CryptoMethod, res.error +} + +func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) { h := handshake{ conn: rw, initer: false, skeys: skeys, chooseMethod: selectCrypto, } - return h.Do() + ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do() + ret.SecretKey = h.skey + return } // A function that given a function, calls it with secret keys until it diff --git a/mse/mse_test.go b/mse/mse_test.go index 5eeb73af..ff2cfc54 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -79,12 +79,13 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides }() go func() { defer wg.Done() - b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) - require.NoError(t, err) - assert.Equal(t, cryptoSelect(cryptoProvides), cm) + res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) + require.NoError(t, res.error) + assert.EqualValues(t, "yep", res.SecretKey) + b := res.ReadWriter + assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod) go b.Write([]byte(bData)) - // Need to be exact here, as there are several reads, and net.Pipe is - // most synchronous. + // Need to be exact here, as there are several reads, and net.Pipe is most synchronous. msg := make([]byte, len(ia)+len(aData)) n, _ := io.ReadFull(b, msg[:]) if n != len(msg) { -- 2.48.1