From: Matt Joiner Date: Thu, 12 Mar 2015 19:16:49 +0000 (+1100) Subject: mse: Tons of fixes and improvements X-Git-Tag: v1.0.0~1281 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=d57f5896d403030b64631be36fff209272610962;p=btrtrc.git mse: Tons of fixes and improvements --- diff --git a/mse/mse.go b/mse/mse.go index a452228b..e9ea9a67 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -106,10 +106,6 @@ func (me *cipherWriter) Write(b []byte) (n int, err error) { return } -func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer { - return &cipherWriter{c, w} -} - func readY(r io.Reader) (y big.Int, err error) { var b [96]byte _, err = io.ReadFull(r, b[:]) @@ -133,12 +129,17 @@ func newX() big.Int { return X } +// Calculate, and send Y, our public key. func (h *handshake) postY(x *big.Int) error { var y big.Int y.Exp(&g, x, &p) b := y.Bytes() if len(b) != 96 { - panic(len(b)) + b1 := make([]byte, 96) + if n := copy(b1[96-len(b):], b); n != len(b) { + panic(n) + } + b = b1 } return h.postWrite(b) } @@ -173,6 +174,7 @@ type handshake struct { conn io.ReadWriteCloser s big.Int initer bool + skeys [][]byte skey []byte writeMu sync.Mutex @@ -257,6 +259,26 @@ func xor(dst, src []byte) (ret []byte) { return } +func marshal(w io.Writer, data ...interface{}) (err error) { + for _, data := range data { + err = binary.Write(w, binary.BigEndian, data) + if err != nil { + break + } + } + return +} + +func unmarshal(r io.Reader, data ...interface{}) (err error) { + for _, data := range data { + err = binary.Read(r, binary.BigEndian, data) + if err != nil { + break + } + } + return +} + type cryptoNegotiation struct { VC [8]byte Method uint32 @@ -265,7 +287,8 @@ type cryptoNegotiation struct { } func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) { - _, err = io.ReadFull(r, me.VC[:]) + err = binary.Read(r, binary.BigEndian, me.VC[:]) + // _, err = io.ReadFull(r, me.VC[:]) if err != nil { return } @@ -283,7 +306,8 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) { } func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) { - _, err = w.Write(me.VC[:]) + // _, err = w.Write(me.VC[:]) + err = binary.Write(w, binary.BigEndian, me.VC[:]) if err != nil { return } @@ -344,9 +368,101 @@ type readWriter struct { io.Writer } +func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { + h.postWrite(hash(req1, h.s.Bytes())) + h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes()))) + buf := &bytes.Buffer{} + err = (&cryptoNegotiation{ + Method: cryptoMethodRC4, + PadLen: uint16(newPadLen()), + }).MarshalWriter(buf) + if err != nil { + return + } + err = marshal(buf, uint16(0)) + if err != nil { + return + } + e := newEncrypt(true, h.s.Bytes(), h.skey) + be := make([]byte, buf.Len()) + e.XORKeyStream(be, buf.Bytes()) + h.postWrite(be) + bC := newEncrypt(false, h.s.Bytes(), h.skey) + var eVC [8]byte + bC.XORKeyStream(eVC[:], make([]byte, 8)) + log.Print(eVC) + // Read until the all zero VC. + err = readUntil(h.conn, eVC[:]) + if err != nil { + err = fmt.Errorf("error reading until VC: %s", err) + return + } + var cn cryptoNegotiation + r := &cipherReader{bC, h.conn} + err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r)) + log.Printf("initer got %v", cn) + if err != nil { + err = fmt.Errorf("error reading crypto negotiation: %s", err) + return + } + ret = readWriter{r, &cipherWriter{e, h.conn}} + return +} + +func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { + err = readUntil(h.conn, hash(req1, h.s.Bytes())) + if err != nil { + return + } + var b [20]byte + _, err = io.ReadFull(h.conn, b[:]) + if err != nil { + return + } + err = errors.New("skey doesn't match") + for _, skey := range h.skeys { + if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) { + h.skey = skey + err = nil + break + } + } + if err != nil { + return + } + var cn cryptoNegotiation + r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn) + err = cn.UnmarshalReader(r) + if err != nil { + return + } + log.Printf("receiver got %v", cn) + if cn.Method&cryptoMethodRC4 == 0 { + err = errors.New("no supported crypto methods were provided") + return + } + unmarshal(r, new(uint16)) + buf := &bytes.Buffer{} + w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf} + err = (&cryptoNegotiation{ + Method: cryptoMethodRC4, + PadLen: uint16(newPadLen()), + }).MarshalWriter(&w) + if err != nil { + return + } + err = h.postWrite(buf.Bytes()) + if err != nil { + return + } + ret = readWriter{r, &cipherWriter{w.c, h.conn}} + return +} + func (h *handshake) Do() (ret io.ReadWriter, err error) { err = h.establishS() if err != nil { + err = fmt.Errorf("error while establishing secret: %s", err) return } pad := make([]byte, newPadLen()) @@ -356,92 +472,25 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) { return } if h.initer { - h.postWrite(hash(req1, h.s.Bytes())) - h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes()))) - buf := &bytes.Buffer{} - err = (&cryptoNegotiation{ - Method: cryptoMethodRC4, - PadLen: uint16(newPadLen()), - }).MarshalWriter(buf) - if err != nil { - return - } - e := newEncrypt(true, h.s.Bytes(), h.skey) - be := make([]byte, buf.Len()) - e.XORKeyStream(be, buf.Bytes()) - h.postWrite(be) - bC := newEncrypt(false, h.s.Bytes(), h.skey) - var eVC [8]byte - bC.XORKeyStream(eVC[:], make([]byte, 8)) - log.Print(eVC) - // Read until the all zero VC. - err = readUntil(h.conn, eVC[:]) - if err != nil { - err = fmt.Errorf("error reading until VC: %s", err) - return - } - var cn cryptoNegotiation - r := &cipherReader{bC, h.conn} - err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r)) - log.Printf("initer got %v", cn) - if err != nil { - err = fmt.Errorf("error reading crypto negotiation: %s", err) - return - } - ret = readWriter{r, &cipherWriter{bC, h.conn}} + ret, err = h.initerSteps() } else { - err = readUntil(h.conn, hash(req1, h.s.Bytes())) - if err != nil { - return - } - var b [20]byte - _, err = io.ReadFull(h.conn, b[:]) - if err != nil { - return - } - if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) { - err = errors.New("skey doesn't match") - return - } - var cn cryptoNegotiation - r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn) - err = cn.UnmarshalReader(r) - if err != nil { - return - } - log.Printf("receiver got %v", cn) - if cn.Method&cryptoMethodRC4 == 0 { - err = errors.New("no supported crypto methods were provided") - return - } - buf := &bytes.Buffer{} - w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf) - err = (&cryptoNegotiation{ - Method: cryptoMethodRC4, - PadLen: uint16(newPadLen()), - }).MarshalWriter(w) - if err != nil { - return - } - log.Println("encrypted VC", buf.Bytes()[:8]) - err = h.postWrite(buf.Bytes()) - if err != nil { - return - } - ret = readWriter{r, w} + ret, err = h.receiverSteps() + } + if err != nil { + return } err = h.finishWriting() if err != nil { return } - ret = h.conn + log.Print("ermahgerd, finished MSE handshake") return } -func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) { +func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, - initer: initer, + initer: true, skey: skey, } h.writeCond.L = &h.writeMu @@ -449,3 +498,14 @@ func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWrit go h.writer() return h.Do() } +func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) { + h := handshake{ + conn: rw, + initer: false, + skeys: skeys, + } + h.writeCond.L = &h.writeMu + h.writerCond.L = &h.writerMu + go h.writer() + return h.Do() +} diff --git a/mse/mse_test.go b/mse/mse_test.go index d26d90be..644a4cdf 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -49,7 +49,7 @@ func TestHandshake(t *testing.T) { wg.Add(2) go func() { defer wg.Done() - a, err := Handshake(a, true, []byte("yep")) + a, err := InitiateHandshake(a, []byte("yep")) if err != nil { t.Fatal(err) return @@ -61,7 +61,7 @@ func TestHandshake(t *testing.T) { }() go func() { defer wg.Done() - b, err := Handshake(b, false, []byte("yep")) + b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}) if err != nil { t.Fatal(err) return