]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add context to mse handshakes
authorMatt Joiner <anacrolix@gmail.com>
Sat, 10 Aug 2024 04:41:02 +0000 (14:41 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 10 Aug 2024 04:41:31 +0000 (14:41 +1000)
client.go
handshake.go
mse/cmd/mse/main.go
mse/ctxrw.go [new file with mode: 0644]
mse/mse.go
mse/mse_test.go

index c87defa8b24651540c533ff9d248badbba0d6eb1..a65a6613e95d1b9e657b749134f37faf39104587 100644 (file)
--- a/client.go
+++ b/client.go
@@ -728,7 +728,7 @@ func (cl *Client) initiateProtocolHandshakes(
        if err != nil {
                panic(err)
        }
-       err = cl.initiateHandshakes(c, t)
+       err = cl.initiateHandshakes(ctx, c, t)
        return
 }
 
@@ -914,10 +914,11 @@ func (cl *Client) incomingPeerPort() int {
        return cl.LocalPort()
 }
 
-func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) (err error) {
+func (cl *Client) initiateHandshakes(ctx context.Context, c *PeerConn, t *Torrent) (err error) {
        if c.headerEncrypted {
                var rw io.ReadWriter
-               rw, c.cryptoMethod, err = mse.InitiateHandshake(
+               rw, c.cryptoMethod, err = mse.InitiateHandshakeContext(
+                       ctx,
                        struct {
                                io.Reader
                                io.Writer
index b38a7086e83227003b7c5f7b5ef0af120dc76986..95239665e748d85f7aa1d9890069d71246ad3616 100644 (file)
@@ -2,6 +2,7 @@ package torrent
 
 import (
        "bytes"
+       "context"
        "fmt"
        "io"
        "net"
@@ -65,7 +66,7 @@ func handleEncryption(
                }
        }
        headerEncrypted = true
-       ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, selector)
+       ret, cryptoMethod, err = mse.ReceiveHandshake(context.TODO(), rw, skeys, selector)
        return
 }
 
index 7d10a26d9d386914435cce1846df18ce5faec899..c96af097175bc9062c49dad442e5e3064c6d0395 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "context"
        "fmt"
        "io"
        "log"
@@ -62,7 +63,7 @@ func mainErr() error {
                        return fmt.Errorf("accepting: %w", err)
                }
                defer cn.Close()
-               rw, _, err := mse.ReceiveHandshake(cn, func(f func([]byte) bool) {
+               rw, _, err := mse.ReceiveHandshake(context.TODO(), cn, func(f func([]byte) bool) {
                        for _, sk := range args.Listen.SecretKeys {
                                f([]byte(sk))
                        }
diff --git a/mse/ctxrw.go b/mse/ctxrw.go
new file mode 100644 (file)
index 0000000..933c187
--- /dev/null
@@ -0,0 +1,58 @@
+package mse
+
+import (
+       "context"
+       g "github.com/anacrolix/generics"
+       "io"
+)
+
+type contextedReader struct {
+       ctx context.Context
+       r   io.Reader
+}
+
+func (me contextedReader) Read(p []byte) (n int, err error) {
+       return contextedReadOrWrite(me.ctx, me.r.Read, p)
+}
+
+type contextedWriter struct {
+       ctx context.Context
+       w   io.Writer
+}
+
+// This is problematic. If you return with a context error, a read or write is still pending, and
+// could mess up the stream.
+func contextedReadOrWrite(ctx context.Context, method func(b []byte) (int, error), b []byte) (_ int, err error) {
+       asyncCh := make(chan g.Result[int], 1)
+       go func() {
+               asyncCh <- g.ResultFromTuple(method(b))
+       }()
+       select {
+       case <-ctx.Done():
+               err = context.Cause(ctx)
+               return
+       case res := <-asyncCh:
+               return res.AsTuple()
+       }
+
+}
+
+func (me contextedWriter) Write(p []byte) (n int, err error) {
+       return contextedReadOrWrite(me.ctx, me.w.Write, p)
+}
+
+func contextedReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter {
+       return struct {
+               io.Reader
+               io.Writer
+       }{
+               contextedReader{
+                       ctx: ctx,
+                       r:   rw,
+               },
+               contextedWriter{
+                       ctx: ctx,
+                       w:   rw,
+               },
+       }
+}
index 582a451f98137ce5ed43e1bf5d0217ba09ff2535..6ab6f2236f68ea261c8cd0ef14d9fed13abd4e1a 100644 (file)
@@ -4,6 +4,7 @@ package mse
 
 import (
        "bytes"
+       "context"
        "crypto/rand"
        "crypto/rc4"
        "crypto/sha1"
@@ -32,7 +33,7 @@ type CryptoMethod uint32
 
 var (
        // Prime P according to the spec, and G, the generator.
-       p, g big.Int
+       p, specG big.Int
        // The rand.Int max arg for use in newPadLen()
        newPadLenMax big.Int
        // For use in initer's hashes
@@ -50,7 +51,7 @@ var (
 
 func init() {
        p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
-       g.SetInt64(2)
+       specG.SetInt64(2)
        newPadLenMax.SetInt64(maxPadLen + 1)
 }
 
@@ -159,7 +160,7 @@ func paddedLeft(b []byte, _len int) []byte {
 // Calculate, and send Y, our public key.
 func (h *handshake) postY(x *big.Int) error {
        var y big.Int
-       y.Exp(&g, x, &p)
+       y.Exp(&specG, x, &p)
        return h.postWrite(paddedLeft(y.Bytes(), 96))
 }
 
@@ -167,7 +168,7 @@ func (h *handshake) establishS() error {
        x := newX()
        h.postY(&x)
        var b [96]byte
-       _, err := io.ReadFull(h.conn, b[:])
+       _, err := io.ReadFull(h.ctxConn, b[:])
        if err != nil {
                return fmt.Errorf("error reading Y: %w", err)
        }
@@ -193,12 +194,14 @@ func newPadLen() int64 {
 
 // Manages state for both initiating and receiving handshakes.
 type handshake struct {
-       conn   io.ReadWriter
-       s      [96]byte
-       initer bool          // Whether we're initiating or receiving.
-       skeys  SecretKeyIter // Skeys we'll accept if receiving.
-       skey   []byte        // Skey we're initiating with.
-       ia     []byte        // Initial payload. Only used by the initiator.
+       conn io.ReadWriter
+       // The conn with Reads and Writes wrapped to the context given in handshake.Do.
+       ctxConn io.ReadWriter
+       s       [96]byte
+       initer  bool          // Whether we're initiating or receiving.
+       skeys   SecretKeyIter // Skeys we'll accept if receiving.
+       skey    []byte        // Skey we're initiating with.
+       ia      []byte        // Initial payload. Only used by the initiator.
        // Return the bit for the crypto method the receiver wants to use.
        chooseMethod CryptoSelector
        // Sent to the receiver.
@@ -250,7 +253,7 @@ func (h *handshake) writer() {
                b := h.writes[0]
                h.writes = h.writes[1:]
                h.writeMu.Unlock()
-               _, err := h.conn.Write(b)
+               _, err := h.ctxConn.Write(b)
                if err != nil {
                        h.writeMu.Lock()
                        h.writeErr = err
@@ -357,7 +360,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
        return newEncrypt(initer, h.s[:], h.skey)
 }
 
-func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
+func (h *handshake) initerSteps(ctx context.Context) (ret io.ReadWriter, selected CryptoMethod, err error) {
        h.postWrite(hash(req1, h.s[:]))
        h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
        buf := &bytes.Buffer{}
@@ -380,29 +383,32 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err
        // Read until the all zero VC. At this point we've only read the 96 byte
        // public key, Y. There is potentially 512 byte padding, between us and
        // the 8 byte verification constant.
-       err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
+       err = readUntil(io.LimitReader(h.ctxConn, 520), eVC[:])
        if err != nil {
                if err == io.EOF {
                        err = errors.New("failed to synchronize on VC")
                } else {
-                       err = fmt.Errorf("error reading until VC: %s", err)
+                       err = fmt.Errorf("error reading until VC: %w", err)
                }
                return
        }
-       r := newCipherReader(bC, h.conn)
+       ctxReader := newCipherReader(bC, h.ctxConn)
        var method CryptoMethod
-       err = unmarshal(r, &method, &padLen)
+       err = unmarshal(ctxReader, &method, &padLen)
        if err != nil {
                return
        }
-       _, err = io.CopyN(io.Discard, r, int64(padLen))
+       _, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
        if err != nil {
                return
        }
        selected = method & h.cryptoProvides
        switch selected {
        case CryptoMethodRC4:
-               ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+               ret = readWriter{
+                       newCipherReader(bC, h.conn),
+                       &cipherWriter{e, h.conn, nil},
+               }
        case CryptoMethodPlaintext:
                ret = h.conn
        default:
@@ -413,9 +419,9 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err
 
 var ErrNoSecretKeyMatch = errors.New("no skey matched")
 
-func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
+func (h *handshake) receiverSteps(ctx context.Context) (ret io.ReadWriter, chosen CryptoMethod, err error) {
        // There is up to 512 bytes of padding, then the 20 byte hash.
-       err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
+       err = readUntil(io.LimitReader(h.ctxConn, 532), hash(req1, h.s[:]))
        if err != nil {
                if err == io.EOF {
                        err = errors.New("failed to synchronize on S hash")
@@ -423,7 +429,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
                return
        }
        var b [20]byte
-       _, err = io.ReadFull(h.conn, b[:])
+       _, err = io.ReadFull(h.ctxConn, b[:])
        if err != nil {
                return
        }
@@ -447,28 +453,29 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
        if err != nil {
                return
        }
-       r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
+       cipher := newEncrypt(true, h.s[:], h.skey)
+       ctxReader := newCipherReader(cipher, h.ctxConn)
        var (
                vc       [8]byte
                provides CryptoMethod
                padLen   uint16
        )
 
-       err = unmarshal(r, vc[:], &provides, &padLen)
+       err = unmarshal(ctxReader, vc[:], &provides, &padLen)
        if err != nil {
                return
        }
        cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
        chosen = h.chooseMethod(provides)
-       _, err = io.CopyN(io.Discard, r, int64(padLen))
+       _, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
        if err != nil {
                return
        }
        var lenIA uint16
-       unmarshal(r, &lenIA)
+       unmarshal(ctxReader, &lenIA)
        if lenIA != 0 {
                h.ia = make([]byte, lenIA)
-               unmarshal(r, h.ia)
+               unmarshal(ctxReader, h.ia)
        }
        buf := &bytes.Buffer{}
        w := cipherWriter{h.newEncrypt(false), buf, nil}
@@ -484,7 +491,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
        switch chosen {
        case CryptoMethodRC4:
                ret = readWriter{
-                       io.MultiReader(bytes.NewReader(h.ia), r),
+                       io.MultiReader(bytes.NewReader(h.ia), newCipherReader(cipher, h.conn)),
                        &cipherWriter{w.c, h.conn, nil},
                }
        case CryptoMethodPlaintext:
@@ -498,7 +505,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
        return
 }
 
-func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
+func (h *handshake) Do(ctx context.Context) (ret io.ReadWriter, method CryptoMethod, err error) {
        h.writeCond.L = &h.writeMu
        h.writerCond.L = &h.writerMu
        go h.writer()
@@ -520,27 +527,41 @@ func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
                return
        }
        if h.initer {
-               ret, method, err = h.initerSteps()
+               ret, method, err = h.initerSteps(ctx)
        } else {
-               ret, method, err = h.receiverSteps()
+               ret, method, err = h.receiverSteps(ctx)
        }
        return
 }
 
 func InitiateHandshake(
-       rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+       rw io.ReadWriter,
+       skey, initialPayload []byte,
+       cryptoProvides CryptoMethod,
+) (
+       ret io.ReadWriter, method CryptoMethod, err error,
+) {
+       return InitiateHandshakeContext(context.TODO(), rw, skey, initialPayload, cryptoProvides)
+}
+
+func InitiateHandshakeContext(
+       ctx context.Context,
+       rw io.ReadWriter,
+       skey, initialPayload []byte,
+       cryptoProvides CryptoMethod,
 ) (
        ret io.ReadWriter, method CryptoMethod, err error,
 ) {
        h := handshake{
                conn:           rw,
+               ctxConn:        contextedReadWriter(ctx, rw),
                initer:         true,
                skey:           skey,
                ia:             initialPayload,
                cryptoProvides: cryptoProvides,
        }
        defer perf.ScopeTimerErr(&err)()
-       return h.Do()
+       return h.Do(ctx)
 }
 
 type HandshakeResult struct {
@@ -550,19 +571,30 @@ type HandshakeResult struct {
        SecretKey []byte
 }
 
-func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
-       res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+func ReceiveHandshake(
+       ctx context.Context,
+       rw io.ReadWriter,
+       skeys SecretKeyIter,
+       selectCrypto CryptoSelector,
+) (io.ReadWriter, CryptoMethod, error) {
+       res := ReceiveHandshakeEx(ctx, rw, skeys, selectCrypto)
        return res.ReadWriter, res.CryptoMethod, res.error
 }
 
-func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
+func ReceiveHandshakeEx(
+       ctx context.Context,
+       rw io.ReadWriter,
+       skeys SecretKeyIter,
+       selectCrypto CryptoSelector,
+) (ret HandshakeResult) {
        h := handshake{
                conn:         rw,
+               ctxConn:      contextedReadWriter(ctx, rw),
                initer:       false,
                skeys:        skeys,
                chooseMethod: selectCrypto,
        }
-       ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+       ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do(ctx)
        ret.SecretKey = h.skey
        return
 }
index f7f7fe7ab9c7f2090f2a30467cc5635414dbdc8d..d97323f1c83b328093302ed6c5e99880b15ac776 100644 (file)
@@ -2,6 +2,7 @@ package mse
 
 import (
        "bytes"
+       "context"
        "crypto/rand"
        "crypto/rc4"
        "io"
@@ -77,7 +78,12 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides
        }()
        go func() {
                defer wg.Done()
-               res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
+               res := ReceiveHandshakeEx(
+                       context.Background(),
+                       b,
+                       sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}),
+                       cryptoSelect,
+               )
                require.NoError(t, res.error)
                assert.EqualValues(t, "yep", res.SecretKey)
                b := res.ReadWriter
@@ -130,7 +136,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
 
 func TestReceiveRandomData(t *testing.T) {
        tr := trackReader{rand.Reader, 0}
-       _, _, err := ReceiveHandshake(readWriter{&tr, io.Discard}, nil, DefaultCryptoSelector)
+       _, _, err := ReceiveHandshake(context.Background(), readWriter{&tr, io.Discard}, nil, DefaultCryptoSelector)
        // No skey matches
        require.Error(t, err)
        // Establishing S, and then reading the maximum padding for giving up on
@@ -185,7 +191,12 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) {
                }()
                func() {
                        defer bc.Close()
-                       rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto })
+                       rw, _, err := ReceiveHandshake(
+                               context.Background(),
+                               bc,
+                               sliceIter([][]byte{[]byte("cats")}),
+                               func(CryptoMethod) CryptoMethod { return crypto },
+                       )
                        require.NoError(t, err)
                        require.NoError(t, readAndWrite(rw, br, b))
                }()
@@ -270,7 +281,7 @@ func BenchmarkSkeysReceive(b *testing.B) {
                                panic(err)
                        }
                }()
-               res := ReceiveHandshakeEx(receiver, sliceIter(skeys), DefaultCryptoSelector)
+               res := ReceiveHandshakeEx(context.Background(), receiver, sliceIter(skeys), DefaultCryptoSelector)
                if res.error != nil {
                        panic(res.error)
                }