]> Sergey Matveev's repositories - btrtrc.git/commitdiff
mse: Tons of fixes and improvements
authorMatt Joiner <anacrolix@gmail.com>
Thu, 12 Mar 2015 19:16:49 +0000 (06:16 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 12 Mar 2015 19:16:49 +0000 (06:16 +1100)
mse/mse.go
mse/mse_test.go

index a452228b47ade929516b09d704c1578cde744cff..e9ea9a6772d4aad213c415f5e3b880288f460146 100644 (file)
@@ -106,10 +106,6 @@ func (me *cipherWriter) Write(b []byte) (n int, err error) {
        return
 }
 
-func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
-       return &cipherWriter{c, w}
-}
-
 func readY(r io.Reader) (y big.Int, err error) {
        var b [96]byte
        _, err = io.ReadFull(r, b[:])
@@ -133,12 +129,17 @@ func newX() big.Int {
        return X
 }
 
+// Calculate, and send Y, our public key.
 func (h *handshake) postY(x *big.Int) error {
        var y big.Int
        y.Exp(&g, x, &p)
        b := y.Bytes()
        if len(b) != 96 {
-               panic(len(b))
+               b1 := make([]byte, 96)
+               if n := copy(b1[96-len(b):], b); n != len(b) {
+                       panic(n)
+               }
+               b = b1
        }
        return h.postWrite(b)
 }
@@ -173,6 +174,7 @@ type handshake struct {
        conn   io.ReadWriteCloser
        s      big.Int
        initer bool
+       skeys  [][]byte
        skey   []byte
 
        writeMu    sync.Mutex
@@ -257,6 +259,26 @@ func xor(dst, src []byte) (ret []byte) {
        return
 }
 
+func marshal(w io.Writer, data ...interface{}) (err error) {
+       for _, data := range data {
+               err = binary.Write(w, binary.BigEndian, data)
+               if err != nil {
+                       break
+               }
+       }
+       return
+}
+
+func unmarshal(r io.Reader, data ...interface{}) (err error) {
+       for _, data := range data {
+               err = binary.Read(r, binary.BigEndian, data)
+               if err != nil {
+                       break
+               }
+       }
+       return
+}
+
 type cryptoNegotiation struct {
        VC     [8]byte
        Method uint32
@@ -265,7 +287,8 @@ type cryptoNegotiation struct {
 }
 
 func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
-       _, err = io.ReadFull(r, me.VC[:])
+       err = binary.Read(r, binary.BigEndian, me.VC[:])
+       // _, err = io.ReadFull(r, me.VC[:])
        if err != nil {
                return
        }
@@ -283,7 +306,8 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
 }
 
 func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
-       _, err = w.Write(me.VC[:])
+       // _, err = w.Write(me.VC[:])
+       err = binary.Write(w, binary.BigEndian, me.VC[:])
        if err != nil {
                return
        }
@@ -344,9 +368,101 @@ type readWriter struct {
        io.Writer
 }
 
+func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
+       h.postWrite(hash(req1, h.s.Bytes()))
+       h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
+       buf := &bytes.Buffer{}
+       err = (&cryptoNegotiation{
+               Method: cryptoMethodRC4,
+               PadLen: uint16(newPadLen()),
+       }).MarshalWriter(buf)
+       if err != nil {
+               return
+       }
+       err = marshal(buf, uint16(0))
+       if err != nil {
+               return
+       }
+       e := newEncrypt(true, h.s.Bytes(), h.skey)
+       be := make([]byte, buf.Len())
+       e.XORKeyStream(be, buf.Bytes())
+       h.postWrite(be)
+       bC := newEncrypt(false, h.s.Bytes(), h.skey)
+       var eVC [8]byte
+       bC.XORKeyStream(eVC[:], make([]byte, 8))
+       log.Print(eVC)
+       // Read until the all zero VC.
+       err = readUntil(h.conn, eVC[:])
+       if err != nil {
+               err = fmt.Errorf("error reading until VC: %s", err)
+               return
+       }
+       var cn cryptoNegotiation
+       r := &cipherReader{bC, h.conn}
+       err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
+       log.Printf("initer got %v", cn)
+       if err != nil {
+               err = fmt.Errorf("error reading crypto negotiation: %s", err)
+               return
+       }
+       ret = readWriter{r, &cipherWriter{e, h.conn}}
+       return
+}
+
+func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
+       err = readUntil(h.conn, hash(req1, h.s.Bytes()))
+       if err != nil {
+               return
+       }
+       var b [20]byte
+       _, err = io.ReadFull(h.conn, b[:])
+       if err != nil {
+               return
+       }
+       err = errors.New("skey doesn't match")
+       for _, skey := range h.skeys {
+               if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) {
+                       h.skey = skey
+                       err = nil
+                       break
+               }
+       }
+       if err != nil {
+               return
+       }
+       var cn cryptoNegotiation
+       r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
+       err = cn.UnmarshalReader(r)
+       if err != nil {
+               return
+       }
+       log.Printf("receiver got %v", cn)
+       if cn.Method&cryptoMethodRC4 == 0 {
+               err = errors.New("no supported crypto methods were provided")
+               return
+       }
+       unmarshal(r, new(uint16))
+       buf := &bytes.Buffer{}
+       w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
+       err = (&cryptoNegotiation{
+               Method: cryptoMethodRC4,
+               PadLen: uint16(newPadLen()),
+       }).MarshalWriter(&w)
+       if err != nil {
+               return
+       }
+       err = h.postWrite(buf.Bytes())
+       if err != nil {
+               return
+       }
+       ret = readWriter{r, &cipherWriter{w.c, h.conn}}
+       return
+}
+
 func (h *handshake) Do() (ret io.ReadWriter, err error) {
        err = h.establishS()
        if err != nil {
+               err = fmt.Errorf("error while establishing secret: %s", err)
                return
        }
        pad := make([]byte, newPadLen())
@@ -356,92 +472,25 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
                return
        }
        if h.initer {
-               h.postWrite(hash(req1, h.s.Bytes()))
-               h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
-               buf := &bytes.Buffer{}
-               err = (&cryptoNegotiation{
-                       Method: cryptoMethodRC4,
-                       PadLen: uint16(newPadLen()),
-               }).MarshalWriter(buf)
-               if err != nil {
-                       return
-               }
-               e := newEncrypt(true, h.s.Bytes(), h.skey)
-               be := make([]byte, buf.Len())
-               e.XORKeyStream(be, buf.Bytes())
-               h.postWrite(be)
-               bC := newEncrypt(false, h.s.Bytes(), h.skey)
-               var eVC [8]byte
-               bC.XORKeyStream(eVC[:], make([]byte, 8))
-               log.Print(eVC)
-               // Read until the all zero VC.
-               err = readUntil(h.conn, eVC[:])
-               if err != nil {
-                       err = fmt.Errorf("error reading until VC: %s", err)
-                       return
-               }
-               var cn cryptoNegotiation
-               r := &cipherReader{bC, h.conn}
-               err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
-               log.Printf("initer got %v", cn)
-               if err != nil {
-                       err = fmt.Errorf("error reading crypto negotiation: %s", err)
-                       return
-               }
-               ret = readWriter{r, &cipherWriter{bC, h.conn}}
+               ret, err = h.initerSteps()
        } else {
-               err = readUntil(h.conn, hash(req1, h.s.Bytes()))
-               if err != nil {
-                       return
-               }
-               var b [20]byte
-               _, err = io.ReadFull(h.conn, b[:])
-               if err != nil {
-                       return
-               }
-               if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
-                       err = errors.New("skey doesn't match")
-                       return
-               }
-               var cn cryptoNegotiation
-               r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
-               err = cn.UnmarshalReader(r)
-               if err != nil {
-                       return
-               }
-               log.Printf("receiver got %v", cn)
-               if cn.Method&cryptoMethodRC4 == 0 {
-                       err = errors.New("no supported crypto methods were provided")
-                       return
-               }
-               buf := &bytes.Buffer{}
-               w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
-               err = (&cryptoNegotiation{
-                       Method: cryptoMethodRC4,
-                       PadLen: uint16(newPadLen()),
-               }).MarshalWriter(w)
-               if err != nil {
-                       return
-               }
-               log.Println("encrypted VC", buf.Bytes()[:8])
-               err = h.postWrite(buf.Bytes())
-               if err != nil {
-                       return
-               }
-               ret = readWriter{r, w}
+               ret, err = h.receiverSteps()
+       }
+       if err != nil {
+               return
        }
        err = h.finishWriting()
        if err != nil {
                return
        }
-       ret = h.conn
+       log.Print("ermahgerd, finished MSE handshake")
        return
 }
 
-func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
        h := handshake{
                conn:   rw,
-               initer: initer,
+               initer: true,
                skey:   skey,
        }
        h.writeCond.L = &h.writeMu
@@ -449,3 +498,14 @@ func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWrit
        go h.writer()
        return h.Do()
 }
+func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) {
+       h := handshake{
+               conn:   rw,
+               initer: false,
+               skeys:  skeys,
+       }
+       h.writeCond.L = &h.writeMu
+       h.writerCond.L = &h.writerMu
+       go h.writer()
+       return h.Do()
+}
index d26d90be2916fc1f1d674424c8381bf3c8dc16b3..644a4cdfffab7efcc1859d7b6d5e3264eb550cf6 100644 (file)
@@ -49,7 +49,7 @@ func TestHandshake(t *testing.T) {
        wg.Add(2)
        go func() {
                defer wg.Done()
-               a, err := Handshake(a, true, []byte("yep"))
+               a, err := InitiateHandshake(a, []byte("yep"))
                if err != nil {
                        t.Fatal(err)
                        return
@@ -61,7 +61,7 @@ func TestHandshake(t *testing.T) {
        }()
        go func() {
                defer wg.Done()
-               b, err := Handshake(b, false, []byte("yep"))
+               b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
                if err != nil {
                        t.Fatal(err)
                        return