From: Matt Joiner Date: Thu, 15 Feb 2018 23:59:56 +0000 (+1100) Subject: Set the connection.cryptoMethod X-Git-Tag: v1.0.0~167 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=fc03dcb859f426053431c993b59a2b345e3a0bf0;p=btrtrc.git Set the connection.cryptoMethod It was unwittingly dropped from received connections, and may never have been set for initiated connections. --- diff --git a/client.go b/client.go index 18ff016c..a05a754d 100644 --- a/client.go +++ b/client.go @@ -748,7 +748,7 @@ func (cl *Client) incomingPeerPort() int { func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) { if c.headerEncrypted { var rw io.ReadWriter - rw, err = mse.InitiateHandshake( + rw, c.cryptoMethod, err = mse.InitiateHandshake( struct { io.Reader io.Writer diff --git a/handshake.go b/handshake.go index b950a9b1..c3a7cc5a 100644 --- a/handshake.go +++ b/handshake.go @@ -187,7 +187,7 @@ func handleEncryption( } } headerEncrypted = true - ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod { + ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod { switch { case policy.ForceEncryption: return mse.CryptoMethodRC4 diff --git a/mse/mse.go b/mse/mse.go index d843d8f9..83454b82 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher { return newEncrypt(initer, h.s[:], h.skey) } -func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { +func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) { h.postWrite(hash(req1, h.s[:])) h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:]))) buf := &bytes.Buffer{} @@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { if err != nil { return } - switch method & h.cryptoProvides { + selected = method & h.cryptoProvides + switch selected { case CryptoMethodRC4: ret = readWriter{r, &cipherWriter{e, h.conn, nil}} case CryptoMethodPlaintext: @@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { var ErrNoSecretKeyMatch = errors.New("no skey matched") -func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { +func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) { // There is up to 512 bytes of padding, then the 20 byte hash. err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:])) if err != nil { @@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { return } cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1) - chosen := h.chooseMethod(provides) + chosen = h.chooseMethod(provides) _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) if err != nil { return @@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { return } -func (h *handshake) Do() (ret io.ReadWriter, err error) { +func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) { h.writeCond.L = &h.writeMu h.writerCond.L = &h.writerMu go h.writer() @@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) { return } if h.initer { - ret, err = h.initerSteps() + ret, method, err = h.initerSteps() } else { - ret, err = h.receiverSteps() + ret, method, err = h.receiverSteps() } return } -func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) { +func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) { h := handshake{ conn: rw, initer: true, @@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry return h.Do() } -func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) { +func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) { h := handshake{ conn: rw, initer: false, diff --git a/mse/mse_test.go b/mse/mse_test.go index 51913504..3efaf98d 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -12,6 +12,7 @@ import ( _ "github.com/anacrolix/envpprof" "github.com/bradfitz/iter" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,11 +65,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides wg.Add(2) go func() { defer wg.Done() - a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides) - if err != nil { - t.Fatal(err) - return - } + a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides) + require.NoError(t, err) + assert.Equal(t, cryptoSelect(cryptoProvides), cm) go a.Write([]byte(aData)) var msg [20]byte @@ -80,11 +79,9 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides }() go func() { defer wg.Done() - b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) - if err != nil { - t.Fatal(err) - return - } + b, cm, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) + require.NoError(t, err) + assert.Equal(t, cryptoSelect(cryptoProvides), cm) go b.Write([]byte(bData)) // Need to be exact here, as there are several reads, and net.Pipe is // most synchronous. @@ -134,7 +131,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) { func TestReceiveRandomData(t *testing.T) { tr := trackReader{rand.Reader, 0} - _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector) + _, _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector) // No skey matches require.Error(t, err) // Establishing S, and then reading the maximum padding for giving up on @@ -183,13 +180,13 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) { go func() { defer ac.Close() defer wg.Done() - rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto) + rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto) require.NoError(t, err) require.NoError(t, readAndWrite(rw, ar, a)) }() func() { defer bc.Close() - rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto }) + rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto }) require.NoError(t, err) require.NoError(t, readAndWrite(rw, br, b)) }()