]> Sergey Matveev's repositories - btrtrc.git/commitdiff
mse: Clean-up
authorMatt Joiner <anacrolix@gmail.com>
Wed, 18 Mar 2015 07:14:57 +0000 (18:14 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 18 Mar 2015 07:14:57 +0000 (18:14 +1100)
mse/mse.go

index f675b40be6a2b262047bb2e8214fca415dcd7269..91f9fd888c0a6d6211f0a19b9a16c6330ac7fd5b 100644 (file)
@@ -17,6 +17,8 @@ import (
        "strconv"
        "sync"
 
+       "bitbucket.org/anacrolix/go.torrent/util"
+
        "github.com/bradfitz/iter"
 )
 
@@ -36,7 +38,12 @@ var (
        req1 = []byte("req1")
        req2 = []byte("req2")
        req3 = []byte("req3")
-
+       // Verification constant "VC" which is all zeroes in the bittorrent
+       // implementation.
+       vc [8]byte
+       // Zero padding
+       zeroPad [512]byte
+       // Tracks counts of received crypto_provides
        cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
 )
 
@@ -132,19 +139,22 @@ func newX() big.Int {
        return X
 }
 
+func paddedLeft(b []byte, _len int) []byte {
+       if len(b) == _len {
+               return b
+       }
+       ret := make([]byte, _len)
+       if n := copy(ret[_len-len(b):], b); n != len(b) {
+               panic(n)
+       }
+       return ret
+}
+
 // Calculate, and send Y, our public key.
 func (h *handshake) postY(x *big.Int) error {
        var y big.Int
        y.Exp(&g, x, &p)
-       b := y.Bytes()
-       if len(b) != 96 {
-               b1 := make([]byte, 96)
-               if n := copy(b1[96-len(b):], b); n != len(b) {
-                       panic(n)
-               }
-               b = b1
-       }
-       return h.postWrite(b)
+       return h.postWrite(paddedLeft(y.Bytes(), 96))
 }
 
 func (h *handshake) establishS() (err error) {
@@ -155,9 +165,10 @@ func (h *handshake) establishS() (err error) {
        if err != nil {
                return
        }
-       var Y big.Int
+       var Y, S big.Int
        Y.SetBytes(b[:])
-       h.s.Exp(&Y, &x, &p)
+       S.Exp(&Y, &x, &p)
+       util.CopyExact(&h.s, paddedLeft(S.Bytes(), 96))
        return
 }
 
@@ -174,8 +185,8 @@ func newPadLen() int64 {
 }
 
 type handshake struct {
-       conn   io.ReadWriteCloser
-       s      big.Int
+       conn   io.ReadWriter
+       s      [96]byte
        initer bool
        skeys  [][]byte
        skey   []byte
@@ -192,11 +203,10 @@ type handshake struct {
        writerDone bool
 }
 
-func (h *handshake) finishWriting() (err error) {
+func (h *handshake) finishWriting() {
        h.writeMu.Lock()
        h.writeClose = true
        h.writeCond.Broadcast()
-       err = h.writeErr
        h.writeMu.Unlock()
 
        h.writerMu.Lock()
@@ -205,7 +215,6 @@ func (h *handshake) finishWriting() (err error) {
        }
        h.writerMu.Unlock()
        return
-
 }
 
 func (h *handshake) writer() {
@@ -283,48 +292,6 @@ func unmarshal(r io.Reader, data ...interface{}) (err error) {
        return
 }
 
-type cryptoNegotiation struct {
-       VC     [8]byte
-       Method uint32
-       PadLen uint16
-       IA     []byte
-}
-
-func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
-       err = binary.Read(r, binary.BigEndian, me.VC[:])
-       if err != nil {
-               return
-       }
-       err = binary.Read(r, binary.BigEndian, &me.Method)
-       if err != nil {
-               return
-       }
-       err = binary.Read(r, binary.BigEndian, &me.PadLen)
-       if err != nil {
-               return
-       }
-       _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
-       return
-}
-
-func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
-       // _, err = w.Write(me.VC[:])
-       err = binary.Write(w, binary.BigEndian, me.VC[:])
-       if err != nil {
-               return
-       }
-       err = binary.Write(w, binary.BigEndian, me.Method)
-       if err != nil {
-               return
-       }
-       err = binary.Write(w, binary.BigEndian, me.PadLen)
-       if err != nil {
-               return
-       }
-       _, err = w.Write(make([]byte, me.PadLen))
-       return
-}
-
 // Looking for b at the end of a.
 func suffixMatchLen(a, b []byte) int {
        if len(b) > len(a) {
@@ -369,48 +336,65 @@ type readWriter struct {
        io.Writer
 }
 
+func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
+       return newEncrypt(initer, h.s[:], h.skey)
+}
+
 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
-       h.postWrite(hash(req1, h.s.Bytes()))
-       h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
+       h.postWrite(hash(req1, h.s[:]))
+       h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
        buf := &bytes.Buffer{}
-       err = (&cryptoNegotiation{
-               Method: cryptoMethodRC4,
-               PadLen: uint16(newPadLen()),
-       }).MarshalWriter(buf)
-       if err != nil {
-               return
-       }
-       err = marshal(buf, uint16(len(h.ia)), h.ia)
+       padLen := uint16(newPadLen())
+       err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
        if err != nil {
                return
        }
-       e := newEncrypt(true, h.s.Bytes(), h.skey)
+       e := h.newEncrypt(true)
        be := make([]byte, buf.Len())
        e.XORKeyStream(be, buf.Bytes())
        h.postWrite(be)
-       bC := newEncrypt(false, h.s.Bytes(), h.skey)
+       bC := h.newEncrypt(false)
        var eVC [8]byte
-       bC.XORKeyStream(eVC[:], make([]byte, 8))
-       // Read until the all zero VC.
-       err = readUntil(h.conn, eVC[:])
+       bC.XORKeyStream(eVC[:], vc[:])
+       // Read until the all zero VC. At this point we've only read the 96 byte
+       // public key, Y. There is potentially 512 byte padding, between us and
+       // the 8 byte verification constant.
+       err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
        if err != nil {
-               err = fmt.Errorf("error reading until VC: %s", err)
+               if err == io.EOF {
+                       err = errors.New("failed to synchronize on VC")
+               } else {
+                       err = fmt.Errorf("error reading until VC: %s", err)
+               }
                return
        }
-       var cn cryptoNegotiation
        r := &cipherReader{bC, h.conn}
-       err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
+       var method uint32
+       err = unmarshal(r, &method, &padLen)
+       if err != nil {
+               return
+       }
+       if method != cryptoMethodRC4 {
+               err = fmt.Errorf("receiver chose unsupported method: %x", method)
+               return
+       }
+       _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
        if err != nil {
-               err = fmt.Errorf("error reading crypto negotiation: %s", err)
                return
        }
        ret = readWriter{r, &cipherWriter{e, h.conn}}
        return
 }
 
+var ErrNoSecretKeyMatch = errors.New("no skey matched")
+
 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
-       err = readUntil(h.conn, hash(req1, h.s.Bytes()))
+       // There is up to 512 bytes of padding, then the 20 byte hash.
+       err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
        if err != nil {
+               if err == io.EOF {
+                       err = errors.New("failed to synchronize on S hash")
+               }
                return
        }
        var b [20]byte
@@ -418,9 +402,9 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       err = errors.New("skey doesn't match")
+       err = ErrNoSecretKeyMatch
        for _, skey := range h.skeys {
-               if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) {
+               if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
                        h.skey = skey
                        err = nil
                        break
@@ -429,17 +413,26 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       var cn cryptoNegotiation
-       r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
-       err = cn.UnmarshalReader(r)
+       r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
+       var (
+               vc     [8]byte
+               method uint32
+               padLen uint16
+       )
+
+       err = unmarshal(r, vc[:], &method, &padLen)
        if err != nil {
                return
        }
-       cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1)
-       if cn.Method&cryptoMethodRC4 == 0 {
+       cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
+       if method&cryptoMethodRC4 == 0 {
                err = errors.New("no supported crypto methods were provided")
                return
        }
+       _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+       if err != nil {
+               return
+       }
        var lenIA uint16
        unmarshal(r, &lenIA)
        if lenIA != 0 {
@@ -447,11 +440,9 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
                unmarshal(r, h.ia)
        }
        buf := &bytes.Buffer{}
-       w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
-       err = (&cryptoNegotiation{
-               Method: cryptoMethodRC4,
-               PadLen: uint16(newPadLen()),
-       }).MarshalWriter(&w)
+       w := cipherWriter{h.newEncrypt(false), buf}
+       padLen = uint16(newPadLen())
+       err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
        if err != nil {
                return
        }
@@ -464,6 +455,15 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
 }
 
 func (h *handshake) Do() (ret io.ReadWriter, err error) {
+       h.writeCond.L = &h.writeMu
+       h.writerCond.L = &h.writerMu
+       go h.writer()
+       defer func() {
+               h.finishWriting()
+               if err == nil {
+                       err = h.writeErr
+               }
+       }()
        err = h.establishS()
        if err != nil {
                err = fmt.Errorf("error while establishing secret: %s", err)
@@ -480,36 +480,23 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
        } else {
                ret, err = h.receiverSteps()
        }
-       if err != nil {
-               return
-       }
-       err = h.finishWriting()
-       if err != nil {
-               return
-       }
        return
 }
 
-func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
        h := handshake{
                conn:   rw,
                initer: true,
                skey:   skey,
                ia:     initialPayload,
        }
-       h.writeCond.L = &h.writeMu
-       h.writerCond.L = &h.writerMu
-       go h.writer()
        return h.Do()
 }
-func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) {
+func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
        h := handshake{
                conn:   rw,
                initer: false,
                skeys:  skeys,
        }
-       h.writeCond.L = &h.writeMu
-       h.writerCond.L = &h.writerMu
-       go h.writer()
        return h.Do()
 }