]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Set the connection.cryptoMethod
authorMatt Joiner <anacrolix@gmail.com>
Thu, 15 Feb 2018 23:59:56 +0000 (10:59 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 15 Feb 2018 23:59:56 +0000 (10:59 +1100)
It was unwittingly dropped from received connections, and may never have been set for initiated connections.

client.go
handshake.go
mse/mse.go
mse/mse_test.go

index 18ff016c7848bb98fd6f9fafa65222e88dac6273..a05a754d0d2a8e9008e9e14fffb9299bb974363f 100644 (file)
--- a/client.go
+++ b/client.go
@@ -748,7 +748,7 @@ func (cl *Client) incomingPeerPort() int {
 func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
        if c.headerEncrypted {
                var rw io.ReadWriter
-               rw, err = mse.InitiateHandshake(
+               rw, c.cryptoMethod, err = mse.InitiateHandshake(
                        struct {
                                io.Reader
                                io.Writer
index b950a9b19a02064e46cff53d082af8a925c013dc..c3a7cc5a1fe173586af500f5daad03e943e829d2 100644 (file)
@@ -187,7 +187,7 @@ func handleEncryption(
                }
        }
        headerEncrypted = true
-       ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
+       ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod {
                switch {
                case policy.ForceEncryption:
                        return mse.CryptoMethodRC4
index d843d8f93f12b89f87e492d38caef03a4d9d0e64..83454b82e510cb22a03245330393eb94f2c9ed09 100644 (file)
@@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
        return newEncrypt(initer, h.s[:], h.skey)
 }
 
-func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
        h.postWrite(hash(req1, h.s[:]))
        h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
        buf := &bytes.Buffer{}
@@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       switch method & h.cryptoProvides {
+       selected = method & h.cryptoProvides
+       switch selected {
        case CryptoMethodRC4:
                ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
        case CryptoMethodPlaintext:
@@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
 
 var ErrNoSecretKeyMatch = errors.New("no skey matched")
 
-func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
        // There is up to 512 bytes of padding, then the 20 byte hash.
        err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
        if err != nil {
@@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
                return
        }
        cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
-       chosen := h.chooseMethod(provides)
+       chosen = h.chooseMethod(provides)
        _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
        if err != nil {
                return
@@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        return
 }
 
-func (h *handshake) Do() (ret io.ReadWriter, err error) {
+func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
        h.writeCond.L = &h.writeMu
        h.writerCond.L = &h.writerMu
        go h.writer()
@@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
                return
        }
        if h.initer {
-               ret, err = h.initerSteps()
+               ret, method, err = h.initerSteps()
        } else {
-               ret, err = h.receiverSteps()
+               ret, method, err = h.receiverSteps()
        }
        return
 }
 
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
        h := handshake{
                conn:           rw,
                initer:         true,
@@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
        return h.Do()
 }
 
-func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
        h := handshake{
                conn:         rw,
                initer:       false,
index 519135043a17b7993f88c8cf49d4502c4ff4d6cd..3efaf98d7f59be189f81a9e34c65ad4a651a22ad 100644 (file)
@@ -12,6 +12,7 @@ import (
 
        _ "github.com/anacrolix/envpprof"
        "github.com/bradfitz/iter"
+       "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
 )
 
@@ -64,11 +65,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
        wg.Add(2)
        go func() {
                defer wg.Done()
-               a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
-               if err != nil {
-                       t.Fatal(err)
-                       return
-               }
+               a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
+               require.NoError(t, err)
+               assert.Equal(t, cryptoSelect(cryptoProvides), cm)
                go a.Write([]byte(aData))
 
                var msg [20]byte
@@ -80,11 +79,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
        }()
        go func() {
                defer wg.Done()
-               b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
-               if err != nil {
-                       t.Fatal(err)
-                       return
-               }
+               b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
+               require.NoError(t, err)
+               assert.Equal(t, cryptoSelect(cryptoProvides), cm)
                go b.Write([]byte(bData))
                // Need to be exact here, as there are several reads, and net.Pipe is
                // most synchronous.
@@ -134,7 +131,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
 
 func TestReceiveRandomData(t *testing.T) {
        tr := trackReader{rand.Reader, 0}
-       _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
+       _, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
        // No skey matches
        require.Error(t, err)
        // Establishing S, and then reading the maximum padding for giving up on
@@ -183,13 +180,13 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) {
                go func() {
                        defer ac.Close()
                        defer wg.Done()
-                       rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
+                       rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
                        require.NoError(t, err)
                        require.NoError(t, readAndWrite(rw, ar, a))
                }()
                func() {
                        defer bc.Close()
-                       rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
+                       rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
                        require.NoError(t, err)
                        require.NoError(t, readAndWrite(rw, br, b))
                }()