From: Matt Joiner Date: Thu, 12 Mar 2015 09:29:48 +0000 (+1100) Subject: msg: Return usable object after handshake X-Git-Tag: v1.0.0~1282 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=13a5b8b2790a8648974a72ad80e41b0b0ee41780;p=btrtrc.git msg: Return usable object after handshake --- diff --git a/mse/mse.go b/mse/mse.go index 1d53ad21..a452228b 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -339,7 +339,12 @@ func readUntil(r io.Reader, b []byte) error { return nil } -func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { +type readWriter struct { + io.Reader + io.Writer +} + +func (h *handshake) Do() (ret io.ReadWriter, err error) { err = h.establishS() if err != nil { return @@ -383,6 +388,7 @@ func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { err = fmt.Errorf("error reading crypto negotiation: %s", err) return } + ret = readWriter{r, &cipherWriter{bC, h.conn}} } else { err = readUntil(h.conn, hash(req1, h.s.Bytes())) if err != nil { @@ -422,6 +428,7 @@ func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { if err != nil { return } + ret = readWriter{r, w} } err = h.finishWriting() if err != nil { @@ -431,7 +438,7 @@ func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { return } -func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriteCloser, err error) { +func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: initer, diff --git a/mse/mse_test.go b/mse/mse_test.go index 6c20fdf7..d26d90be 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -3,6 +3,7 @@ package mse import ( "bytes" "io" + "log" "net" "sync" @@ -53,7 +54,10 @@ func TestHandshake(t *testing.T) { t.Fatal(err) return } - a.Close() + a.Write([]byte("hello world")) + var msg [20]byte + n, _ := a.Read(msg[:]) + log.Print(string(msg[:n])) }() go func() { defer wg.Done() @@ -62,7 +66,10 @@ func TestHandshake(t *testing.T) { t.Fatal(err) return } - b.Close() + var msg [20]byte + n, _ := b.Read(msg[:]) + log.Print(string(msg[:n])) + b.Write([]byte("yo dawg")) }() wg.Wait() }