]> Sergey Matveev's repositories - btrtrc.git/blobdiff - mse/mse.go
chore: remove refs to deprecated io/ioutil
[btrtrc.git] / mse / mse.go
index 8127b815052734f6600dca0cc7d87e611c0cbcb1..c3a9f3d3ae840c77eed1c37a881111194ae71c9f 100644 (file)
@@ -12,23 +12,24 @@ import (
        "expvar"
        "fmt"
        "io"
-       "io/ioutil"
        "math"
        "math/big"
        "strconv"
        "sync"
 
-       "github.com/bradfitz/iter"
+       "github.com/anacrolix/missinggo/perf"
 )
 
 const (
        maxPadLen = 512
 
-       cryptoMethodPlaintext = 1
-       cryptoMethodRC4       = 2
-       AllSupportedCrypto    = cryptoMethodPlaintext | cryptoMethodRC4
+       CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
+       CryptoMethodRC4       CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
+       AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
 )
 
+type CryptoMethod uint32
+
 var (
        // Prime P according to the spec, and G, the generator.
        p, g big.Int
@@ -67,7 +68,7 @@ func hash(parts ...[]byte) []byte {
        return h.Sum(nil)
 }
 
-func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
+func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
        c, err := rc4.NewCipher(hash([]byte(func() string {
                if initer {
                        return "keyA"
@@ -131,7 +132,7 @@ func (cr *cipherWriter) Write(b []byte) (n int, err error) {
                        return ret
                }
        }()
-       cr.c.XORKeyStream(be[:], b)
+       cr.c.XORKeyStream(be, b)
        n, err = cr.w.Write(be[:len(b)])
        if n != len(b) {
                // The cipher will have advanced beyond the callers stream position.
@@ -175,20 +176,20 @@ func (h *handshake) postY(x *big.Int) error {
        return h.postWrite(paddedLeft(y.Bytes(), 96))
 }
 
-func (h *handshake) establishS() (err error) {
+func (h *handshake) establishS() error {
        x := newX()
        h.postY(&x)
        var b [96]byte
-       _, err = io.ReadFull(h.conn, b[:])
+       _, err := io.ReadFull(h.conn, b[:])
        if err != nil {
-               return
+               return fmt.Errorf("error reading Y: %w", err)
        }
        var Y, S big.Int
        Y.SetBytes(b[:])
        S.Exp(&Y, &x, &p)
        sBytes := S.Bytes()
        copy(h.s[96-len(sBytes):96], sBytes)
-       return
+       return nil
 }
 
 func newPadLen() int64 {
@@ -212,9 +213,9 @@ type handshake struct {
        skey   []byte        // Skey we're initiating with.
        ia     []byte        // Initial payload. Only used by the initiator.
        // Return the bit for the crypto method the receiver wants to use.
-       chooseMethod func(supported uint32) uint32
+       chooseMethod CryptoSelector
        // Sent to the receiver.
-       cryptoProvides uint32
+       cryptoProvides CryptoMethod
 
        writeMu    sync.Mutex
        writes     [][]byte
@@ -238,7 +239,6 @@ func (h *handshake) finishWriting() {
                h.writerCond.Wait()
        }
        h.writerMu.Unlock()
-       return
 }
 
 func (h *handshake) writer() {
@@ -284,18 +284,22 @@ func (h *handshake) postWrite(b []byte) error {
        return nil
 }
 
-func xor(dst, src []byte) (ret []byte) {
-       max := len(dst)
-       if max > len(src) {
-               max = len(src)
-       }
-       ret = make([]byte, 0, max)
-       for i := range iter.N(max) {
-               ret = append(ret, dst[i]^src[i])
+func xor(a, b []byte) (ret []byte) {
+       max := len(a)
+       if max > len(b) {
+               max = len(b)
        }
+       ret = make([]byte, max)
+       xorInPlace(ret, a, b)
        return
 }
 
+func xorInPlace(dst, a, b []byte) {
+       for i := range dst {
+               dst[i] = a[i] ^ b[i]
+       }
+}
+
 func marshal(w io.Writer, data ...interface{}) (err error) {
        for _, data := range data {
                err = binary.Write(w, binary.BigEndian, data)
@@ -366,7 +370,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
        return newEncrypt(initer, h.s[:], h.skey)
 }
 
-func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
        h.postWrite(hash(req1, h.s[:]))
        h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
        buf := &bytes.Buffer{}
@@ -399,19 +403,20 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
                return
        }
        r := newCipherReader(bC, h.conn)
-       var method uint32
+       var method CryptoMethod
        err = unmarshal(r, &method, &padLen)
        if err != nil {
                return
        }
-       _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+       _, err = io.CopyN(io.Discard, r, int64(padLen))
        if err != nil {
                return
        }
-       switch method & h.cryptoProvides {
-       case cryptoMethodRC4:
+       selected = method & h.cryptoProvides
+       switch selected {
+       case CryptoMethodRC4:
                ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
-       case cryptoMethodPlaintext:
+       case CryptoMethodPlaintext:
                ret = h.conn
        default:
                err = fmt.Errorf("receiver chose unsupported method: %x", method)
@@ -421,7 +426,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
 
 var ErrNoSecretKeyMatch = errors.New("no skey matched")
 
-func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
        // There is up to 512 bytes of padding, then the 20 byte hash.
        err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
        if err != nil {
@@ -435,9 +440,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
+       expectedHash := hash(req3, h.s[:])
+       eachHash := sha1.New()
+       var sum, xored [sha1.Size]byte
        err = ErrNoSecretKeyMatch
        h.skeys(func(skey []byte) bool {
-               if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
+               eachHash.Reset()
+               eachHash.Write(req2)
+               eachHash.Write(skey)
+               eachHash.Sum(sum[:0])
+               xorInPlace(xored[:], sum[:], expectedHash)
+               if bytes.Equal(xored[:], b[:]) {
                        h.skey = skey
                        err = nil
                        return false
@@ -450,7 +463,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
        var (
                vc       [8]byte
-               provides uint32
+               provides CryptoMethod
                padLen   uint16
        )
 
@@ -459,8 +472,8 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
                return
        }
        cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
-       chosen := h.chooseMethod(provides)
-       _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+       chosen = h.chooseMethod(provides)
+       _, err = io.CopyN(io.Discard, r, int64(padLen))
        if err != nil {
                return
        }
@@ -482,12 +495,12 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
                return
        }
        switch chosen {
-       case cryptoMethodRC4:
+       case CryptoMethodRC4:
                ret = readWriter{
                        io.MultiReader(bytes.NewReader(h.ia), r),
                        &cipherWriter{w.c, h.conn, nil},
                }
-       case cryptoMethodPlaintext:
+       case CryptoMethodPlaintext:
                ret = readWriter{
                        io.MultiReader(bytes.NewReader(h.ia), h.conn),
                        h.conn,
@@ -498,7 +511,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        return
 }
 
-func (h *handshake) Do() (ret io.ReadWriter, err error) {
+func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
        h.writeCond.L = &h.writeMu
        h.writerCond.L = &h.writerMu
        go h.writer()
@@ -510,7 +523,7 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
        }()
        err = h.establishS()
        if err != nil {
-               err = fmt.Errorf("error while establishing secret: %s", err)
+               err = fmt.Errorf("error while establishing secret: %w", err)
                return
        }
        pad := make([]byte, newPadLen())
@@ -520,14 +533,18 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
                return
        }
        if h.initer {
-               ret, err = h.initerSteps()
+               ret, method, err = h.initerSteps()
        } else {
-               ret, err = h.receiverSteps()
+               ret, method, err = h.receiverSteps()
        }
        return
 }
 
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
+func InitiateHandshake(
+       rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+) (
+       ret io.ReadWriter, method CryptoMethod, err error,
+) {
        h := handshake{
                conn:           rw,
                initer:         true,
@@ -535,49 +552,44 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
                ia:             initialPayload,
                cryptoProvides: cryptoProvides,
        }
+       defer perf.ScopeTimerErr(&err)()
        return h.Do()
 }
 
-func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
+type HandshakeResult struct {
+       io.ReadWriter
+       CryptoMethod
+       error
+       SecretKey []byte
+}
+
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
+       res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+       return res.ReadWriter, res.CryptoMethod, res.error
+}
+
+func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
        h := handshake{
                conn:         rw,
                initer:       false,
-               skeys:        sliceIter(skeys),
+               skeys:        skeys,
                chooseMethod: selectCrypto,
        }
-       return h.Do()
-}
-
-func sliceIter(skeys [][]byte) SecretKeyIter {
-       return func(callback func([]byte) bool) {
-               for _, sk := range skeys {
-                       if !callback(sk) {
-                               break
-                       }
-               }
-       }
+       ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+       ret.SecretKey = h.skey
+       return
 }
 
 // A function that given a function, calls it with secret keys until it
 // returns false or exhausted.
 type SecretKeyIter func(callback func(skey []byte) (more bool))
 
-// Doesn't unpack the secret keys until it needs to, and through the passed
-// function.
-func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWriter, err error) {
-       h := handshake{
-               conn:   rw,
-               initer: false,
-               skeys:  skeys,
-       }
-       return h.Do()
-}
-
-func DefaultCryptoSelector(provided uint32) uint32 {
-       if provided&cryptoMethodRC4 != 0 {
-               return cryptoMethodRC4
+func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
+       // We prefer plaintext for performance reasons.
+       if provided&CryptoMethodPlaintext != 0 {
+               return CryptoMethodPlaintext
        }
-       return cryptoMethodPlaintext
+       return CryptoMethodRC4
 }
 
-type CryptoSelector func(uint32) uint32
+type CryptoSelector func(CryptoMethod) CryptoMethod