]> Sergey Matveev's repositories - btrtrc.git/blobdiff - mse/mse.go
chore: remove refs to deprecated io/ioutil
[btrtrc.git] / mse / mse.go
index 1d53ad21ebb201b71b590b2b04e6640f4d6ef200..c3a9f3d3ae840c77eed1c37a881111194ae71c9f 100644 (file)
@@ -9,23 +9,27 @@ import (
        "crypto/sha1"
        "encoding/binary"
        "errors"
+       "expvar"
        "fmt"
        "io"
-       "io/ioutil"
-       "log"
+       "math"
        "math/big"
+       "strconv"
        "sync"
 
-       "github.com/bradfitz/iter"
+       "github.com/anacrolix/missinggo/perf"
 )
 
 const (
        maxPadLen = 512
 
-       cryptoMethodPlaintext = 1
-       cryptoMethodRC4       = 2
+       CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
+       CryptoMethodRC4       CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
+       AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
 )
 
+type CryptoMethod uint32
+
 var (
        // Prime P according to the spec, and G, the generator.
        p, g big.Int
@@ -35,6 +39,13 @@ var (
        req1 = []byte("req1")
        req2 = []byte("req2")
        req3 = []byte("req3")
+       // Verification constant "VC" which is all zeroes in the bittorrent
+       // implementation.
+       vc [8]byte
+       // Zero padding
+       zeroPad [512]byte
+       // Tracks counts of received crypto_provides
+       cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
 )
 
 func init() {
@@ -57,7 +68,7 @@ func hash(parts ...[]byte) []byte {
        return h.Sum(nil)
 }
 
-func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
+func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
        c, err := rc4.NewCipher(hash([]byte(func() string {
                if initer {
                        return "keyA"
@@ -74,49 +85,63 @@ func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
 }
 
 type cipherReader struct {
-       c *rc4.Cipher
-       r io.Reader
+       c  *rc4.Cipher
+       r  io.Reader
+       mu sync.Mutex
+       be []byte
 }
 
-func (me *cipherReader) Read(b []byte) (n int, err error) {
-       be := make([]byte, len(b))
-       n, err = me.r.Read(be)
-       me.c.XORKeyStream(b[:n], be[:n])
+func (cr *cipherReader) Read(b []byte) (n int, err error) {
+       var be []byte
+       cr.mu.Lock()
+       if len(cr.be) >= len(b) {
+               be = cr.be
+               cr.be = nil
+               cr.mu.Unlock()
+       } else {
+               cr.mu.Unlock()
+               be = make([]byte, len(b))
+       }
+       n, err = cr.r.Read(be[:len(b)])
+       cr.c.XORKeyStream(b[:n], be[:n])
+       cr.mu.Lock()
+       if len(be) > len(cr.be) {
+               cr.be = be
+       }
+       cr.mu.Unlock()
        return
 }
 
 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
-       return &cipherReader{c, r}
+       return &cipherReader{c: c, r: r}
 }
 
 type cipherWriter struct {
        c *rc4.Cipher
        w io.Writer
+       b []byte
 }
 
-func (me *cipherWriter) Write(b []byte) (n int, err error) {
-       be := make([]byte, len(b))
-       me.c.XORKeyStream(be, b)
-       n, err = me.w.Write(be)
-       if n != len(be) {
+func (cr *cipherWriter) Write(b []byte) (n int, err error) {
+       be := func() []byte {
+               if len(cr.b) < len(b) {
+                       return make([]byte, len(b))
+               } else {
+                       ret := cr.b
+                       cr.b = nil
+                       return ret
+               }
+       }()
+       cr.c.XORKeyStream(be, b)
+       n, err = cr.w.Write(be[:len(b)])
+       if n != len(b) {
                // The cipher will have advanced beyond the callers stream position.
                // We can't use the cipher anymore.
-               me.c = nil
+               cr.c = nil
        }
-       return
-}
-
-func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
-       return &cipherWriter{c, w}
-}
-
-func readY(r io.Reader) (y big.Int, err error) {
-       var b [96]byte
-       _, err = io.ReadFull(r, b[:])
-       if err != nil {
-               return
+       if len(be) > len(cr.b) {
+               cr.b = be
        }
-       y.SetBytes(b[:])
        return
 }
 
@@ -133,28 +158,38 @@ func newX() big.Int {
        return X
 }
 
+func paddedLeft(b []byte, _len int) []byte {
+       if len(b) == _len {
+               return b
+       }
+       ret := make([]byte, _len)
+       if n := copy(ret[_len-len(b):], b); n != len(b) {
+               panic(n)
+       }
+       return ret
+}
+
+// Calculate, and send Y, our public key.
 func (h *handshake) postY(x *big.Int) error {
        var y big.Int
        y.Exp(&g, x, &p)
-       b := y.Bytes()
-       if len(b) != 96 {
-               panic(len(b))
-       }
-       return h.postWrite(b)
+       return h.postWrite(paddedLeft(y.Bytes(), 96))
 }
 
-func (h *handshake) establishS() (err error) {
+func (h *handshake) establishS() error {
        x := newX()
        h.postY(&x)
        var b [96]byte
-       _, err = io.ReadFull(h.conn, b[:])
+       _, err := io.ReadFull(h.conn, b[:])
        if err != nil {
-               return
+               return fmt.Errorf("error reading Y: %w", err)
        }
-       var Y big.Int
+       var Y, S big.Int
        Y.SetBytes(b[:])
-       h.s.Exp(&Y, &x, &p)
-       return
+       S.Exp(&Y, &x, &p)
+       sBytes := S.Bytes()
+       copy(h.s[96-len(sBytes):96], sBytes)
+       return nil
 }
 
 func newPadLen() int64 {
@@ -169,11 +204,18 @@ func newPadLen() int64 {
        return ret
 }
 
+// Manages state for both initiating and receiving handshakes.
 type handshake struct {
-       conn   io.ReadWriteCloser
-       s      big.Int
-       initer bool
-       skey   []byte
+       conn   io.ReadWriter
+       s      [96]byte
+       initer bool          // Whether we're initiating or receiving.
+       skeys  SecretKeyIter // Skeys we'll accept if receiving.
+       skey   []byte        // Skey we're initiating with.
+       ia     []byte        // Initial payload. Only used by the initiator.
+       // Return the bit for the crypto method the receiver wants to use.
+       chooseMethod CryptoSelector
+       // Sent to the receiver.
+       cryptoProvides CryptoMethod
 
        writeMu    sync.Mutex
        writes     [][]byte
@@ -186,11 +228,10 @@ type handshake struct {
        writerDone bool
 }
 
-func (h *handshake) finishWriting() (err error) {
+func (h *handshake) finishWriting() {
        h.writeMu.Lock()
        h.writeClose = true
        h.writeCond.Broadcast()
-       err = h.writeErr
        h.writeMu.Unlock()
 
        h.writerMu.Lock()
@@ -198,8 +239,6 @@ func (h *handshake) finishWriting() (err error) {
                h.writerCond.Wait()
        }
        h.writerMu.Unlock()
-       return
-
 }
 
 func (h *handshake) writer() {
@@ -245,57 +284,39 @@ func (h *handshake) postWrite(b []byte) error {
        return nil
 }
 
-func xor(dst, src []byte) (ret []byte) {
-       max := len(dst)
-       if max > len(src) {
-               max = len(src)
-       }
-       ret = make([]byte, 0, max)
-       for i := range iter.N(max) {
-               ret = append(ret, dst[i]^src[i])
+func xor(a, b []byte) (ret []byte) {
+       max := len(a)
+       if max > len(b) {
+               max = len(b)
        }
+       ret = make([]byte, max)
+       xorInPlace(ret, a, b)
        return
 }
 
-type cryptoNegotiation struct {
-       VC     [8]byte
-       Method uint32
-       PadLen uint16
-       IA     []byte
+func xorInPlace(dst, a, b []byte) {
+       for i := range dst {
+               dst[i] = a[i] ^ b[i]
+       }
 }
 
-func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
-       _, err = io.ReadFull(r, me.VC[:])
-       if err != nil {
-               return
-       }
-       err = binary.Read(r, binary.BigEndian, &me.Method)
-       if err != nil {
-               return
-       }
-       err = binary.Read(r, binary.BigEndian, &me.PadLen)
-       if err != nil {
-               return
+func marshal(w io.Writer, data ...interface{}) (err error) {
+       for _, data := range data {
+               err = binary.Write(w, binary.BigEndian, data)
+               if err != nil {
+                       break
+               }
        }
-       log.Print(me.PadLen)
-       _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
        return
 }
 
-func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
-       _, err = w.Write(me.VC[:])
-       if err != nil {
-               return
-       }
-       err = binary.Write(w, binary.BigEndian, me.Method)
-       if err != nil {
-               return
-       }
-       err = binary.Write(w, binary.BigEndian, me.PadLen)
-       if err != nil {
-               return
+func unmarshal(r io.Reader, data ...interface{}) (err error) {
+       for _, data := range data {
+               err = binary.Read(r, binary.BigEndian, data)
+               if err != nil {
+                       break
+               }
        }
-       _, err = w.Write(make([]byte, me.PadLen))
        return
 }
 
@@ -319,8 +340,9 @@ func suffixMatchLen(a, b []byte) int {
        return 0
 }
 
+// Reads from r until b has been seen. Keeps the minimum amount of data in
+// memory.
 func readUntil(r io.Reader, b []byte) error {
-       log.Println("read until", b)
        b1 := make([]byte, len(b))
        i := 0
        for {
@@ -339,106 +361,235 @@ func readUntil(r io.Reader, b []byte) error {
        return nil
 }
 
-func (h *handshake) Do() (ret io.ReadWriteCloser, err error) {
-       err = h.establishS()
-       if err != nil {
+type readWriter struct {
+       io.Reader
+       io.Writer
+}
+
+func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
+       return newEncrypt(initer, h.s[:], h.skey)
+}
+
+func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
+       h.postWrite(hash(req1, h.s[:]))
+       h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
+       buf := &bytes.Buffer{}
+       padLen := uint16(newPadLen())
+       if len(h.ia) > math.MaxUint16 {
+               err = errors.New("initial payload too large")
                return
        }
-       pad := make([]byte, newPadLen())
-       io.ReadFull(rand.Reader, pad)
-       err = h.postWrite(pad)
+       err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
        if err != nil {
                return
        }
-       if h.initer {
-               h.postWrite(hash(req1, h.s.Bytes()))
-               h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
-               buf := &bytes.Buffer{}
-               err = (&cryptoNegotiation{
-                       Method: cryptoMethodRC4,
-                       PadLen: uint16(newPadLen()),
-               }).MarshalWriter(buf)
-               if err != nil {
-                       return
-               }
-               e := newEncrypt(true, h.s.Bytes(), h.skey)
-               be := make([]byte, buf.Len())
-               e.XORKeyStream(be, buf.Bytes())
-               h.postWrite(be)
-               bC := newEncrypt(false, h.s.Bytes(), h.skey)
-               var eVC [8]byte
-               bC.XORKeyStream(eVC[:], make([]byte, 8))
-               log.Print(eVC)
-               // Read until the all zero VC.
-               err = readUntil(h.conn, eVC[:])
-               if err != nil {
+       e := h.newEncrypt(true)
+       be := make([]byte, buf.Len())
+       e.XORKeyStream(be, buf.Bytes())
+       h.postWrite(be)
+       bC := h.newEncrypt(false)
+       var eVC [8]byte
+       bC.XORKeyStream(eVC[:], vc[:])
+       // Read until the all zero VC. At this point we've only read the 96 byte
+       // public key, Y. There is potentially 512 byte padding, between us and
+       // the 8 byte verification constant.
+       err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
+       if err != nil {
+               if err == io.EOF {
+                       err = errors.New("failed to synchronize on VC")
+               } else {
                        err = fmt.Errorf("error reading until VC: %s", err)
-                       return
-               }
-               var cn cryptoNegotiation
-               r := &cipherReader{bC, h.conn}
-               err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
-               log.Printf("initer got %v", cn)
-               if err != nil {
-                       err = fmt.Errorf("error reading crypto negotiation: %s", err)
-                       return
-               }
-       } else {
-               err = readUntil(h.conn, hash(req1, h.s.Bytes()))
-               if err != nil {
-                       return
-               }
-               var b [20]byte
-               _, err = io.ReadFull(h.conn, b[:])
-               if err != nil {
-                       return
                }
-               if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
-                       err = errors.New("skey doesn't match")
-                       return
+               return
+       }
+       r := newCipherReader(bC, h.conn)
+       var method CryptoMethod
+       err = unmarshal(r, &method, &padLen)
+       if err != nil {
+               return
+       }
+       _, err = io.CopyN(io.Discard, r, int64(padLen))
+       if err != nil {
+               return
+       }
+       selected = method & h.cryptoProvides
+       switch selected {
+       case CryptoMethodRC4:
+               ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+       case CryptoMethodPlaintext:
+               ret = h.conn
+       default:
+               err = fmt.Errorf("receiver chose unsupported method: %x", method)
+       }
+       return
+}
+
+var ErrNoSecretKeyMatch = errors.New("no skey matched")
+
+func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
+       // There is up to 512 bytes of padding, then the 20 byte hash.
+       err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
+       if err != nil {
+               if err == io.EOF {
+                       err = errors.New("failed to synchronize on S hash")
                }
-               var cn cryptoNegotiation
-               r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
-               err = cn.UnmarshalReader(r)
-               if err != nil {
-                       return
+               return
+       }
+       var b [20]byte
+       _, err = io.ReadFull(h.conn, b[:])
+       if err != nil {
+               return
+       }
+       expectedHash := hash(req3, h.s[:])
+       eachHash := sha1.New()
+       var sum, xored [sha1.Size]byte
+       err = ErrNoSecretKeyMatch
+       h.skeys(func(skey []byte) bool {
+               eachHash.Reset()
+               eachHash.Write(req2)
+               eachHash.Write(skey)
+               eachHash.Sum(sum[:0])
+               xorInPlace(xored[:], sum[:], expectedHash)
+               if bytes.Equal(xored[:], b[:]) {
+                       h.skey = skey
+                       err = nil
+                       return false
                }
-               log.Printf("receiver got %v", cn)
-               if cn.Method&cryptoMethodRC4 == 0 {
-                       err = errors.New("no supported crypto methods were provided")
-                       return
+               return true
+       })
+       if err != nil {
+               return
+       }
+       r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
+       var (
+               vc       [8]byte
+               provides CryptoMethod
+               padLen   uint16
+       )
+
+       err = unmarshal(r, vc[:], &provides, &padLen)
+       if err != nil {
+               return
+       }
+       cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
+       chosen = h.chooseMethod(provides)
+       _, err = io.CopyN(io.Discard, r, int64(padLen))
+       if err != nil {
+               return
+       }
+       var lenIA uint16
+       unmarshal(r, &lenIA)
+       if lenIA != 0 {
+               h.ia = make([]byte, lenIA)
+               unmarshal(r, h.ia)
+       }
+       buf := &bytes.Buffer{}
+       w := cipherWriter{h.newEncrypt(false), buf, nil}
+       padLen = uint16(newPadLen())
+       err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
+       if err != nil {
+               return
+       }
+       err = h.postWrite(buf.Bytes())
+       if err != nil {
+               return
+       }
+       switch chosen {
+       case CryptoMethodRC4:
+               ret = readWriter{
+                       io.MultiReader(bytes.NewReader(h.ia), r),
+                       &cipherWriter{w.c, h.conn, nil},
                }
-               buf := &bytes.Buffer{}
-               w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
-               err = (&cryptoNegotiation{
-                       Method: cryptoMethodRC4,
-                       PadLen: uint16(newPadLen()),
-               }).MarshalWriter(w)
-               if err != nil {
-                       return
+       case CryptoMethodPlaintext:
+               ret = readWriter{
+                       io.MultiReader(bytes.NewReader(h.ia), h.conn),
+                       h.conn,
                }
-               log.Println("encrypted VC", buf.Bytes()[:8])
-               err = h.postWrite(buf.Bytes())
-               if err != nil {
-                       return
+       default:
+               err = errors.New("chosen crypto method is not supported")
+       }
+       return
+}
+
+func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
+       h.writeCond.L = &h.writeMu
+       h.writerCond.L = &h.writerMu
+       go h.writer()
+       defer func() {
+               h.finishWriting()
+               if err == nil {
+                       err = h.writeErr
                }
+       }()
+       err = h.establishS()
+       if err != nil {
+               err = fmt.Errorf("error while establishing secret: %w", err)
+               return
        }
-       err = h.finishWriting()
+       pad := make([]byte, newPadLen())
+       io.ReadFull(rand.Reader, pad)
+       err = h.postWrite(pad)
        if err != nil {
                return
        }
-       ret = h.conn
+       if h.initer {
+               ret, method, err = h.initerSteps()
+       } else {
+               ret, method, err = h.receiverSteps()
+       }
        return
 }
 
-func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriteCloser, err error) {
+func InitiateHandshake(
+       rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+) (
+       ret io.ReadWriter, method CryptoMethod, err error,
+) {
        h := handshake{
-               conn:   rw,
-               initer: initer,
-               skey:   skey,
+               conn:           rw,
+               initer:         true,
+               skey:           skey,
+               ia:             initialPayload,
+               cryptoProvides: cryptoProvides,
        }
-       h.writeCond.L = &h.writeMu
-       h.writerCond.L = &h.writerMu
-       go h.writer()
+       defer perf.ScopeTimerErr(&err)()
        return h.Do()
 }
+
+type HandshakeResult struct {
+       io.ReadWriter
+       CryptoMethod
+       error
+       SecretKey []byte
+}
+
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
+       res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+       return res.ReadWriter, res.CryptoMethod, res.error
+}
+
+func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
+       h := handshake{
+               conn:         rw,
+               initer:       false,
+               skeys:        skeys,
+               chooseMethod: selectCrypto,
+       }
+       ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+       ret.SecretKey = h.skey
+       return
+}
+
+// A function that given a function, calls it with secret keys until it
+// returns false or exhausted.
+type SecretKeyIter func(callback func(skey []byte) (more bool))
+
+func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
+       // We prefer plaintext for performance reasons.
+       if provided&CryptoMethodPlaintext != 0 {
+               return CryptoMethodPlaintext
+       }
+       return CryptoMethodRC4
+}
+
+type CryptoSelector func(CryptoMethod) CryptoMethod