]> Sergey Matveev's repositories - btrtrc.git/blobdiff - mse/mse.go
chore: remove refs to deprecated io/ioutil
[btrtrc.git] / mse / mse.go
index 85d55a7d0a50e5f35d7635c0536e47f97a3b18ae..c3a9f3d3ae840c77eed1c37a881111194ae71c9f 100644 (file)
@@ -12,14 +12,12 @@ import (
        "expvar"
        "fmt"
        "io"
-       "io/ioutil"
        "math"
        "math/big"
        "strconv"
        "sync"
 
        "github.com/anacrolix/missinggo/perf"
-       "github.com/bradfitz/iter"
 )
 
 const (
@@ -70,7 +68,7 @@ func hash(parts ...[]byte) []byte {
        return h.Sum(nil)
 }
 
-func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
+func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
        c, err := rc4.NewCipher(hash([]byte(func() string {
                if initer {
                        return "keyA"
@@ -134,7 +132,7 @@ func (cr *cipherWriter) Write(b []byte) (n int, err error) {
                        return ret
                }
        }()
-       cr.c.XORKeyStream(be[:], b)
+       cr.c.XORKeyStream(be, b)
        n, err = cr.w.Write(be[:len(b)])
        if n != len(b) {
                // The cipher will have advanced beyond the callers stream position.
@@ -184,7 +182,7 @@ func (h *handshake) establishS() error {
        var b [96]byte
        _, err := io.ReadFull(h.conn, b[:])
        if err != nil {
-               return fmt.Errorf("error reading Y: %s", err)
+               return fmt.Errorf("error reading Y: %w", err)
        }
        var Y, S big.Int
        Y.SetBytes(b[:])
@@ -286,18 +284,22 @@ func (h *handshake) postWrite(b []byte) error {
        return nil
 }
 
-func xor(dst, src []byte) (ret []byte) {
-       max := len(dst)
-       if max > len(src) {
-               max = len(src)
-       }
-       ret = make([]byte, 0, max)
-       for i := range iter.N(max) {
-               ret = append(ret, dst[i]^src[i])
+func xor(a, b []byte) (ret []byte) {
+       max := len(a)
+       if max > len(b) {
+               max = len(b)
        }
+       ret = make([]byte, max)
+       xorInPlace(ret, a, b)
        return
 }
 
+func xorInPlace(dst, a, b []byte) {
+       for i := range dst {
+               dst[i] = a[i] ^ b[i]
+       }
+}
+
 func marshal(w io.Writer, data ...interface{}) (err error) {
        for _, data := range data {
                err = binary.Write(w, binary.BigEndian, data)
@@ -406,7 +408,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err
        if err != nil {
                return
        }
-       _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+       _, err = io.CopyN(io.Discard, r, int64(padLen))
        if err != nil {
                return
        }
@@ -438,9 +440,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
        if err != nil {
                return
        }
+       expectedHash := hash(req3, h.s[:])
+       eachHash := sha1.New()
+       var sum, xored [sha1.Size]byte
        err = ErrNoSecretKeyMatch
        h.skeys(func(skey []byte) bool {
-               if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
+               eachHash.Reset()
+               eachHash.Write(req2)
+               eachHash.Write(skey)
+               eachHash.Sum(sum[:0])
+               xorInPlace(xored[:], sum[:], expectedHash)
+               if bytes.Equal(xored[:], b[:]) {
                        h.skey = skey
                        err = nil
                        return false
@@ -463,7 +473,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
        }
        cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
        chosen = h.chooseMethod(provides)
-       _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+       _, err = io.CopyN(io.Discard, r, int64(padLen))
        if err != nil {
                return
        }
@@ -513,7 +523,7 @@ func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
        }()
        err = h.establishS()
        if err != nil {
-               err = fmt.Errorf("error while establishing secret: %s", err)
+               err = fmt.Errorf("error while establishing secret: %w", err)
                return
        }
        pad := make([]byte, newPadLen())
@@ -530,7 +540,11 @@ func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
        return
 }
 
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
+func InitiateHandshake(
+       rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+) (
+       ret io.ReadWriter, method CryptoMethod, err error,
+) {
        h := handshake{
                conn:           rw,
                initer:         true,