]> Sergey Matveev's repositories - btrtrc.git/blobdiff - client.go
Add the ReceiveEncryptedHandshakeSkeys callback
[btrtrc.git] / client.go
index 111812c3a07bbefd1141f1bd7126a6335a839383..c4e3e6b0ee09eb73b4d0e91ebb2ae809bd181434 100644 (file)
--- a/client.go
+++ b/client.go
@@ -798,10 +798,11 @@ func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) error {
        return nil
 }
 
-// Calls f with any secret keys.
+// Calls f with any secret keys. Note that it takes the Client lock, and so must be used from code
+// that won't also try to take the lock. This saves us copying all the infohashes everytime.
 func (cl *Client) forSkeys(f func([]byte) bool) {
-       cl.lock()
-       defer cl.unlock()
+       cl.rLock()
+       defer cl.rUnlock()
        if false { // Emulate the bug from #114
                var firstIh InfoHash
                for ih := range cl.torrents {
@@ -822,11 +823,18 @@ func (cl *Client) forSkeys(f func([]byte) bool) {
        }
 }
 
+func (cl *Client) handshakeReceiverSecretKeys() mse.SecretKeyIter {
+       if ret := cl.config.Callbacks.ReceiveEncryptedHandshakeSkeys; ret != nil {
+               return ret
+       }
+       return cl.forSkeys
+}
+
 // Do encryption and bittorrent handshakes as receiver.
 func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
        defer perf.ScopeTimerErr(&err)()
        var rw io.ReadWriter
-       rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.forSkeys, cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
+       rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.handshakeReceiverSecretKeys(), cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
        c.setRW(rw)
        if err == nil || err == mse.ErrNoSecretKeyMatch {
                if c.headerEncrypted {
@@ -844,7 +852,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
                return
        }
        if cl.config.HeaderObfuscationPolicy.RequirePreferred && c.headerEncrypted != cl.config.HeaderObfuscationPolicy.Preferred {
-               err = errors.New("connection not have required header obfuscation")
+               err = errors.New("connection does not have required header obfuscation")
                return
        }
        ih, err := cl.connBtHandshake(c, nil)