From: Matt Joiner <anacrolix@gmail.com>
Date: Sun, 3 Jan 2021 23:49:28 +0000 (+1100)
Subject: Add mse.ReceiveHandshakeEx
X-Git-Tag: v1.20.0~4
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=73823ee61d34ba8846644535f83e9510a394914f;p=btrtrc.git

Add mse.ReceiveHandshakeEx
---

diff --git a/mse/mse.go b/mse/mse.go
index b51a5aaf..85d55a7d 100644
--- a/mse/mse.go
+++ b/mse/mse.go
@@ -542,14 +542,28 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
 	return h.Do()
 }
 
-func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
+type HandshakeResult struct {
+	io.ReadWriter
+	CryptoMethod
+	error
+	SecretKey []byte
+}
+
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
+	res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+	return res.ReadWriter, res.CryptoMethod, res.error
+}
+
+func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
 	h := handshake{
 		conn:         rw,
 		initer:       false,
 		skeys:        skeys,
 		chooseMethod: selectCrypto,
 	}
-	return h.Do()
+	ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+	ret.SecretKey = h.skey
+	return
 }
 
 // A function that given a function, calls it with secret keys until it
diff --git a/mse/mse_test.go b/mse/mse_test.go
index 5eeb73af..ff2cfc54 100644
--- a/mse/mse_test.go
+++ b/mse/mse_test.go
@@ -79,12 +79,13 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
 	}()
 	go func() {
 		defer wg.Done()
-		b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
-		require.NoError(t, err)
-		assert.Equal(t, cryptoSelect(cryptoProvides), cm)
+		res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
+		require.NoError(t, res.error)
+		assert.EqualValues(t, "yep", res.SecretKey)
+		b := res.ReadWriter
+		assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod)
 		go b.Write([]byte(bData))
-		// Need to be exact here, as there are several reads, and net.Pipe is
-		// most synchronous.
+		// Need to be exact here, as there are several reads, and net.Pipe is most synchronous.
 		msg := make([]byte, len(ia)+len(aData))
 		n, _ := io.ReadFull(b, msg[:])
 		if n != len(msg) {