]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Avoid allocating skeys for receiving encrypted handshakes
authorMatt Joiner <anacrolix@gmail.com>
Tue, 4 Apr 2017 08:41:08 +0000 (18:41 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 4 Apr 2017 08:41:08 +0000 (18:41 +1000)
client.go
mse/mse.go

index 37637244408bcc6314ad203eaca4217a562b0cf1..e49f39bb5573ee91b814dee39f0aab9edf61cf61 100644 (file)
--- a/client.go
+++ b/client.go
@@ -797,7 +797,7 @@ func (r deadlineReader) Read(b []byte) (n int, err error) {
        return
 }
 
-func maybeReceiveEncryptedHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, encrypted bool, err error) {
+func maybeReceiveEncryptedHandshake(rw io.ReadWriter, skeys mse.SecretKeyIter) (ret io.ReadWriter, encrypted bool, err error) {
        var protocol [len(pp.Protocol)]byte
        _, err = io.ReadFull(rw, protocol[:])
        if err != nil {
@@ -814,14 +814,7 @@ func maybeReceiveEncryptedHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.Re
                return
        }
        encrypted = true
-       ret, err = mse.ReceiveHandshake(ret, skeys)
-       return
-}
-
-func (cl *Client) receiveSkeys() (ret [][]byte) {
-       for ih := range cl.torrents {
-               ret = append(ret, append([]byte(nil), ih[:]...))
-       }
+       ret, err = mse.ReceiveHandshakeLazy(ret, skeys)
        return
 }
 
@@ -844,14 +837,22 @@ func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err er
        return
 }
 
+// Calls f with any secret keys.
+func (cl *Client) forSkeys(f func([]byte) bool) {
+       cl.mu.Lock()
+       defer cl.mu.Unlock()
+       for ih := range cl.torrents {
+               if !f(ih[:]) {
+                       break
+               }
+       }
+}
+
 // Do encryption and bittorrent handshakes as receiver.
 func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
-       cl.mu.Lock()
-       skeys := cl.receiveSkeys()
-       cl.mu.Unlock()
        if !cl.config.DisableEncryption {
                var rw io.ReadWriter
-               rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw(), skeys)
+               rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw(), cl.forSkeys)
                c.setRW(rw)
                if err != nil {
                        if err == mse.ErrNoSecretKeyMatch {
index 93a80a6bf5a0ab37298af5bf6e83bb2dd371801f..fbc3aa8e89da1b5b692477d606c965de0758d1ce 100644 (file)
@@ -87,6 +87,7 @@ type cipherReader struct {
 }
 
 func (cr *cipherReader) Read(b []byte) (n int, err error) {
+       // inefficient to allocate here
        be := make([]byte, len(b))
        n, err = cr.r.Read(be)
        cr.c.XORKeyStream(b[:n], be[:n])
@@ -187,10 +188,10 @@ func newPadLen() int64 {
 type handshake struct {
        conn   io.ReadWriter
        s      [96]byte
-       initer bool     // Whether we're initiating or receiving.
-       skeys  [][]byte // Skeys we'll accept if receiving.
-       skey   []byte   // Skey we're initiating with.
-       ia     []byte   // Initial payload. Only used by the initiator.
+       initer bool          // Whether we're initiating or receiving.
+       skeys  SecretKeyIter // Skeys we'll accept if receiving.
+       skey   []byte        // Skey we're initiating with.
+       ia     []byte        // Initial payload. Only used by the initiator.
 
        writeMu    sync.Mutex
        writes     [][]byte
@@ -405,13 +406,14 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
                return
        }
        err = ErrNoSecretKeyMatch
-       for _, skey := range h.skeys {
+       h.skeys(func(skey []byte) bool {
                if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
                        h.skey = skey
                        err = nil
-                       break
+                       return false
                }
-       }
+               return true
+       })
        if err != nil {
                return
        }
@@ -494,7 +496,33 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (re
        }
        return h.Do()
 }
+
 func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
+       h := handshake{
+               conn:   rw,
+               initer: false,
+               skeys:  sliceIter(skeys),
+       }
+       return h.Do()
+}
+
+func sliceIter(skeys [][]byte) SecretKeyIter {
+       return func(callback func([]byte) bool) {
+               for _, sk := range skeys {
+                       if !callback(sk) {
+                               break
+                       }
+               }
+       }
+}
+
+// A function that given a function, calls it with secret keys until it
+// returns false or exhausted.
+type SecretKeyIter func(callback func(skey []byte) (more bool))
+
+// Doesn't unpack the secret keys until it needs to, and through the passed
+// function.
+func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWriter, err error) {
        h := handshake{
                conn:   rw,
                initer: false,