mse/mse.go | 18 ++++++++++++++++-- mse/mse_test.go | 11 ++++++----- diff --git a/mse/mse.go b/mse/mse.go index b51a5aaf45d1500bb02cbbc715cf9b4687f47ab6..85d55a7d0a50e5f35d7635c0536e47f97a3b18ae 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -542,14 +542,28 @@ defer perf.ScopeTimerErr(&err)() return h.Do() } -func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) { +type HandshakeResult struct { + io.ReadWriter + CryptoMethod + error + SecretKey []byte +} + +func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) { + res := ReceiveHandshakeEx(rw, skeys, selectCrypto) + return res.ReadWriter, res.CryptoMethod, res.error +} + +func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) { h := handshake{ conn: rw, initer: false, skeys: skeys, chooseMethod: selectCrypto, } - return h.Do() + ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do() + ret.SecretKey = h.skey + return } // A function that given a function, calls it with secret keys until it diff --git a/mse/mse_test.go b/mse/mse_test.go index 5eeb73af65836a7e3ce7c8d263e38ca3f72ed9fc..ff2cfc547e27f3411cb714b3aeab5fa373417481 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -79,12 +79,13 @@ // t.Log(string(msg[:n])) }() go func() { defer wg.Done() - b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) - require.NoError(t, err) - assert.Equal(t, cryptoSelect(cryptoProvides), cm) + res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) + require.NoError(t, res.error) + assert.EqualValues(t, "yep", res.SecretKey) + b := res.ReadWriter + assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod) go b.Write([]byte(bData)) - // Need to be exact here, as there are several reads, and net.Pipe is - // most synchronous. + // Need to be exact here, as there are several reads, and net.Pipe is most synchronous. msg := make([]byte, len(ia)+len(aData)) n, _ := io.ReadFull(b, msg[:]) if n != len(msg) {