From: Matt Joiner Date: Wed, 18 Mar 2015 07:14:57 +0000 (+1100) Subject: mse: Clean-up X-Git-Tag: v1.0.0~1276 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=eb29dcec80393b048ee4ce92cff2756e5af92c42;p=btrtrc.git mse: Clean-up --- diff --git a/mse/mse.go b/mse/mse.go index f675b40b..91f9fd88 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -17,6 +17,8 @@ import ( "strconv" "sync" + "bitbucket.org/anacrolix/go.torrent/util" + "github.com/bradfitz/iter" ) @@ -36,7 +38,12 @@ var ( req1 = []byte("req1") req2 = []byte("req2") req3 = []byte("req3") - + // Verification constant "VC" which is all zeroes in the bittorrent + // implementation. + vc [8]byte + // Zero padding + zeroPad [512]byte + // Tracks counts of received crypto_provides cryptoProvidesCount = expvar.NewMap("mseCryptoProvides") ) @@ -132,19 +139,22 @@ func newX() big.Int { return X } +func paddedLeft(b []byte, _len int) []byte { + if len(b) == _len { + return b + } + ret := make([]byte, _len) + if n := copy(ret[_len-len(b):], b); n != len(b) { + panic(n) + } + return ret +} + // Calculate, and send Y, our public key. func (h *handshake) postY(x *big.Int) error { var y big.Int y.Exp(&g, x, &p) - b := y.Bytes() - if len(b) != 96 { - b1 := make([]byte, 96) - if n := copy(b1[96-len(b):], b); n != len(b) { - panic(n) - } - b = b1 - } - return h.postWrite(b) + return h.postWrite(paddedLeft(y.Bytes(), 96)) } func (h *handshake) establishS() (err error) { @@ -155,9 +165,10 @@ func (h *handshake) establishS() (err error) { if err != nil { return } - var Y big.Int + var Y, S big.Int Y.SetBytes(b[:]) - h.s.Exp(&Y, &x, &p) + S.Exp(&Y, &x, &p) + util.CopyExact(&h.s, paddedLeft(S.Bytes(), 96)) return } @@ -174,8 +185,8 @@ func newPadLen() int64 { } type handshake struct { - conn io.ReadWriteCloser - s big.Int + conn io.ReadWriter + s [96]byte initer bool skeys [][]byte skey []byte @@ -192,11 +203,10 @@ type handshake struct { writerDone bool } -func (h *handshake) finishWriting() (err error) { +func (h *handshake) finishWriting() { h.writeMu.Lock() h.writeClose = true h.writeCond.Broadcast() - err = h.writeErr h.writeMu.Unlock() h.writerMu.Lock() @@ -205,7 +215,6 @@ func (h *handshake) finishWriting() (err error) { } h.writerMu.Unlock() return - } func (h *handshake) writer() { @@ -283,48 +292,6 @@ func unmarshal(r io.Reader, data ...interface{}) (err error) { return } -type cryptoNegotiation struct { - VC [8]byte - Method uint32 - PadLen uint16 - IA []byte -} - -func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) { - err = binary.Read(r, binary.BigEndian, me.VC[:]) - if err != nil { - return - } - err = binary.Read(r, binary.BigEndian, &me.Method) - if err != nil { - return - } - err = binary.Read(r, binary.BigEndian, &me.PadLen) - if err != nil { - return - } - _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen)) - return -} - -func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) { - // _, err = w.Write(me.VC[:]) - err = binary.Write(w, binary.BigEndian, me.VC[:]) - if err != nil { - return - } - err = binary.Write(w, binary.BigEndian, me.Method) - if err != nil { - return - } - err = binary.Write(w, binary.BigEndian, me.PadLen) - if err != nil { - return - } - _, err = w.Write(make([]byte, me.PadLen)) - return -} - // Looking for b at the end of a. func suffixMatchLen(a, b []byte) int { if len(b) > len(a) { @@ -369,48 +336,65 @@ type readWriter struct { io.Writer } +func (h *handshake) newEncrypt(initer bool) *rc4.Cipher { + return newEncrypt(initer, h.s[:], h.skey) +} + func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { - h.postWrite(hash(req1, h.s.Bytes())) - h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes()))) + h.postWrite(hash(req1, h.s[:])) + h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:]))) buf := &bytes.Buffer{} - err = (&cryptoNegotiation{ - Method: cryptoMethodRC4, - PadLen: uint16(newPadLen()), - }).MarshalWriter(buf) - if err != nil { - return - } - err = marshal(buf, uint16(len(h.ia)), h.ia) + padLen := uint16(newPadLen()) + err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia) if err != nil { return } - e := newEncrypt(true, h.s.Bytes(), h.skey) + e := h.newEncrypt(true) be := make([]byte, buf.Len()) e.XORKeyStream(be, buf.Bytes()) h.postWrite(be) - bC := newEncrypt(false, h.s.Bytes(), h.skey) + bC := h.newEncrypt(false) var eVC [8]byte - bC.XORKeyStream(eVC[:], make([]byte, 8)) - // Read until the all zero VC. - err = readUntil(h.conn, eVC[:]) + bC.XORKeyStream(eVC[:], vc[:]) + // Read until the all zero VC. At this point we've only read the 96 byte + // public key, Y. There is potentially 512 byte padding, between us and + // the 8 byte verification constant. + err = readUntil(io.LimitReader(h.conn, 520), eVC[:]) if err != nil { - err = fmt.Errorf("error reading until VC: %s", err) + if err == io.EOF { + err = errors.New("failed to synchronize on VC") + } else { + err = fmt.Errorf("error reading until VC: %s", err) + } return } - var cn cryptoNegotiation r := &cipherReader{bC, h.conn} - err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r)) + var method uint32 + err = unmarshal(r, &method, &padLen) + if err != nil { + return + } + if method != cryptoMethodRC4 { + err = fmt.Errorf("receiver chose unsupported method: %x", method) + return + } + _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) if err != nil { - err = fmt.Errorf("error reading crypto negotiation: %s", err) return } ret = readWriter{r, &cipherWriter{e, h.conn}} return } +var ErrNoSecretKeyMatch = errors.New("no skey matched") + func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { - err = readUntil(h.conn, hash(req1, h.s.Bytes())) + // There is up to 512 bytes of padding, then the 20 byte hash. + err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:])) if err != nil { + if err == io.EOF { + err = errors.New("failed to synchronize on S hash") + } return } var b [20]byte @@ -418,9 +402,9 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { if err != nil { return } - err = errors.New("skey doesn't match") + err = ErrNoSecretKeyMatch for _, skey := range h.skeys { - if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) { + if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) { h.skey = skey err = nil break @@ -429,17 +413,26 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { if err != nil { return } - var cn cryptoNegotiation - r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn) - err = cn.UnmarshalReader(r) + r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn) + var ( + vc [8]byte + method uint32 + padLen uint16 + ) + + err = unmarshal(r, vc[:], &method, &padLen) if err != nil { return } - cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1) - if cn.Method&cryptoMethodRC4 == 0 { + cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1) + if method&cryptoMethodRC4 == 0 { err = errors.New("no supported crypto methods were provided") return } + _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) + if err != nil { + return + } var lenIA uint16 unmarshal(r, &lenIA) if lenIA != 0 { @@ -447,11 +440,9 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { unmarshal(r, h.ia) } buf := &bytes.Buffer{} - w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf} - err = (&cryptoNegotiation{ - Method: cryptoMethodRC4, - PadLen: uint16(newPadLen()), - }).MarshalWriter(&w) + w := cipherWriter{h.newEncrypt(false), buf} + padLen = uint16(newPadLen()) + err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen]) if err != nil { return } @@ -464,6 +455,15 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { } func (h *handshake) Do() (ret io.ReadWriter, err error) { + h.writeCond.L = &h.writeMu + h.writerCond.L = &h.writerMu + go h.writer() + defer func() { + h.finishWriting() + if err == nil { + err = h.writeErr + } + }() err = h.establishS() if err != nil { err = fmt.Errorf("error while establishing secret: %s", err) @@ -480,36 +480,23 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) { } else { ret, err = h.receiverSteps() } - if err != nil { - return - } - err = h.finishWriting() - if err != nil { - return - } return } -func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) { +func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: true, skey: skey, ia: initialPayload, } - h.writeCond.L = &h.writeMu - h.writerCond.L = &h.writerMu - go h.writer() return h.Do() } -func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) { +func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: false, skeys: skeys, } - h.writeCond.L = &h.writeMu - h.writerCond.L = &h.writerMu - go h.writer() return h.Do() }