From: Matt Joiner Date: Sun, 3 Jan 2021 23:49:28 +0000 (+1100) Subject: Add mse.ReceiveHandshakeEx X-Git-Tag: v1.20.0~4 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=73823ee61d34ba8846644535f83e9510a394914f;p=btrtrc.git Add mse.ReceiveHandshakeEx --- diff --git a/mse/mse.go b/mse/mse.go index b51a5aaf..85d55a7d 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -542,14 +542,28 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry return h.Do() } -func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) { +type HandshakeResult struct { + io.ReadWriter + CryptoMethod + error + SecretKey []byte +} + +func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) { + res := ReceiveHandshakeEx(rw, skeys, selectCrypto) + return res.ReadWriter, res.CryptoMethod, res.error +} + +func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) { h := handshake{ conn: rw, initer: false, skeys: skeys, chooseMethod: selectCrypto, } - return h.Do() + ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do() + ret.SecretKey = h.skey + return } // A function that given a function, calls it with secret keys until it diff --git a/mse/mse_test.go b/mse/mse_test.go index 5eeb73af..ff2cfc54 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -79,12 +79,13 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides }() go func() { defer wg.Done() - b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) - require.NoError(t, err) - assert.Equal(t, cryptoSelect(cryptoProvides), cm) + res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) + require.NoError(t, res.error) + assert.EqualValues(t, "yep", res.SecretKey) + b := res.ReadWriter + assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod) go b.Write([]byte(bData)) - // Need to be exact here, as there are several reads, and net.Pipe is - // most synchronous. + // Need to be exact here, as there are several reads, and net.Pipe is most synchronous. msg := make([]byte, len(ia)+len(aData)) n, _ := io.ReadFull(b, msg[:]) if n != len(msg) {