]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add the ReceiveEncryptedHandshakeSkeys callback
authorMatt Joiner <anacrolix@gmail.com>
Thu, 5 Nov 2020 02:28:45 +0000 (13:28 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 5 Nov 2020 02:28:45 +0000 (13:28 +1100)
callbacks.go
client.go
handshake.go

index 83e248154a9d4c7878dc96c6e38f0ae0f9d8860d..fa9fea5bea3d3c824ee04467da51c1ee781453b8 100644 (file)
@@ -1,14 +1,21 @@
 package torrent
 
 import (
+       "github.com/anacrolix/torrent/mse"
        pp "github.com/anacrolix/torrent/peer_protocol"
 )
 
-// These are called synchronously, and do not pass ownership. The Client and other locks may still
-// be held. nil functions are not called.
+// These are called synchronously, and do not pass ownership of arguments (do not expect to retain
+// data after returning from the callback). The Client and other locks may still be held. nil
+// functions are not called.
 type Callbacks struct {
-       CompletedHandshake    func(_ *PeerConn, infoHash InfoHash)
+       // Called after a peer connection completes the BitTorrent handshake. The Client lock is not
+       // held.
+       CompletedHandshake    func(*PeerConn, InfoHash)
        ReadMessage           func(*PeerConn, *pp.Message)
        ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage)
        PeerConnClosed        func(*PeerConn)
+
+       // Provides secret keys to be tried against incoming encrypted connections.
+       ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter
 }
index 111812c3a07bbefd1141f1bd7126a6335a839383..c4e3e6b0ee09eb73b4d0e91ebb2ae809bd181434 100644 (file)
--- a/client.go
+++ b/client.go
@@ -798,10 +798,11 @@ func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) error {
        return nil
 }
 
-// Calls f with any secret keys.
+// Calls f with any secret keys. Note that it takes the Client lock, and so must be used from code
+// that won't also try to take the lock. This saves us copying all the infohashes everytime.
 func (cl *Client) forSkeys(f func([]byte) bool) {
-       cl.lock()
-       defer cl.unlock()
+       cl.rLock()
+       defer cl.rUnlock()
        if false { // Emulate the bug from #114
                var firstIh InfoHash
                for ih := range cl.torrents {
@@ -822,11 +823,18 @@ func (cl *Client) forSkeys(f func([]byte) bool) {
        }
 }
 
+func (cl *Client) handshakeReceiverSecretKeys() mse.SecretKeyIter {
+       if ret := cl.config.Callbacks.ReceiveEncryptedHandshakeSkeys; ret != nil {
+               return ret
+       }
+       return cl.forSkeys
+}
+
 // Do encryption and bittorrent handshakes as receiver.
 func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
        defer perf.ScopeTimerErr(&err)()
        var rw io.ReadWriter
-       rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.forSkeys, cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
+       rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.handshakeReceiverSecretKeys(), cl.config.HeaderObfuscationPolicy, cl.config.CryptoSelector)
        c.setRW(rw)
        if err == nil || err == mse.ErrNoSecretKeyMatch {
                if c.headerEncrypted {
@@ -844,7 +852,7 @@ func (cl *Client) receiveHandshakes(c *PeerConn) (t *Torrent, err error) {
                return
        }
        if cl.config.HeaderObfuscationPolicy.RequirePreferred && c.headerEncrypted != cl.config.HeaderObfuscationPolicy.Preferred {
-               err = errors.New("connection not have required header obfuscation")
+               err = errors.New("connection does not have required header obfuscation")
                return
        }
        ih, err := cl.connBtHandshake(c, nil)
index 83d322bf40a67de9cbd615306e5aef5cde8fac2e..b38a7086e83227003b7c5f7b5ef0af120dc76986 100644 (file)
@@ -27,6 +27,7 @@ func (r deadlineReader) Read(b []byte) (int, error) {
        return r.r.Read(b)
 }
 
+// Handles stream encryption for inbound connections.
 func handleEncryption(
        rw io.ReadWriter,
        skeys mse.SecretKeyIter,
@@ -38,12 +39,14 @@ func handleEncryption(
        cryptoMethod mse.CryptoMethod,
        err error,
 ) {
+       // Tries to start an unencrypted stream.
        if !policy.RequirePreferred || !policy.Preferred {
                var protocol [len(pp.Protocol)]byte
                _, err = io.ReadFull(rw, protocol[:])
                if err != nil {
                        return
                }
+               // Put the protocol back into the stream.
                rw = struct {
                        io.Reader
                        io.Writer
@@ -56,6 +59,7 @@ func handleEncryption(
                        return
                }
                if policy.RequirePreferred {
+                       // We are here because we require unencrypted connections.
                        err = fmt.Errorf("unexpected protocol string %q and header obfuscation disabled", protocol)
                        return
                }