From 13a5b8b2790a8648974a72ad80e41b0b0ee41780 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 12 Mar 2015 20:29:48 +1100 Subject: [PATCH] msg: Return usable object after handshake --- mse/mse.go | 11 +++++++++-- mse/mse_test.go | 11 +++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mse/mse.go b/mse/mse.go index 1d53ad21..a452228b 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -339,7 +339,12 @@ func readUntil(r io.Reader, b []byte) error { return nil } -func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { +type readWriter struct { + io.Reader + io.Writer +} + +func (h *handshake) Do() (ret io.ReadWriter, err error) { err = h.establishS() if err != nil { return @@ -383,6 +388,7 @@ func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { err = fmt.Errorf("error reading crypto negotiation: %s", err) return } + ret = readWriter{r, &cipherWriter{bC, h.conn}} } else { err = readUntil(h.conn, hash(req1, h.s.Bytes())) if err != nil { @@ -422,6 +428,7 @@ func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { if err != nil { return } + ret = readWriter{r, w} } err = h.finishWriting() if err != nil { @@ -431,7 +438,7 @@ func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { return } -func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriteCloser, err error) { +func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: initer, diff --git a/mse/mse_test.go b/mse/mse_test.go index 6c20fdf7..d26d90be 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -3,6 +3,7 @@ package mse import ( "bytes" "io" + "log" "net" "sync" @@ -53,7 +54,10 @@ func TestHandshake(t *testing.T) { t.Fatal(err) return } - a.Close() + a.Write([]byte("hello world")) + var msg [20]byte + n, _ := a.Read(msg[:]) + log.Print(string(msg[:n])) }() go func() { defer wg.Done() @@ -62,7 +66,10 @@ func TestHandshake(t *testing.T) { t.Fatal(err) return } - b.Close() + var msg [20]byte + n, _ := b.Read(msg[:]) + log.Print(string(msg[:n])) + b.Write([]byte("yo dawg")) }() wg.Wait() } -- 2.44.0