From: Matt Joiner Date: Sat, 10 Aug 2024 04:41:02 +0000 (+1000) Subject: Add context to mse handshakes X-Git-Tag: v1.57.0~14 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=b7b97a6666eb2450b09d735718f3d6721d2d8d30;p=btrtrc.git Add context to mse handshakes --- diff --git a/client.go b/client.go index c87defa8..a65a6613 100644 --- a/client.go +++ b/client.go @@ -728,7 +728,7 @@ func (cl *Client) initiateProtocolHandshakes( if err != nil { panic(err) } - err = cl.initiateHandshakes(c, t) + err = cl.initiateHandshakes(ctx, c, t) return } @@ -914,10 +914,11 @@ func (cl *Client) incomingPeerPort() int { return cl.LocalPort() } -func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) (err error) { +func (cl *Client) initiateHandshakes(ctx context.Context, c *PeerConn, t *Torrent) (err error) { if c.headerEncrypted { var rw io.ReadWriter - rw, c.cryptoMethod, err = mse.InitiateHandshake( + rw, c.cryptoMethod, err = mse.InitiateHandshakeContext( + ctx, struct { io.Reader io.Writer diff --git a/handshake.go b/handshake.go index b38a7086..95239665 100644 --- a/handshake.go +++ b/handshake.go @@ -2,6 +2,7 @@ package torrent import ( "bytes" + "context" "fmt" "io" "net" @@ -65,7 +66,7 @@ func handleEncryption( } } headerEncrypted = true - ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, selector) + ret, cryptoMethod, err = mse.ReceiveHandshake(context.TODO(), rw, skeys, selector) return } diff --git a/mse/cmd/mse/main.go b/mse/cmd/mse/main.go index 7d10a26d..c96af097 100644 --- a/mse/cmd/mse/main.go +++ b/mse/cmd/mse/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "log" @@ -62,7 +63,7 @@ func mainErr() error { return fmt.Errorf("accepting: %w", err) } defer cn.Close() - rw, _, err := mse.ReceiveHandshake(cn, func(f func([]byte) bool) { + rw, _, err := mse.ReceiveHandshake(context.TODO(), cn, func(f func([]byte) bool) { for _, sk := range args.Listen.SecretKeys { f([]byte(sk)) } diff --git a/mse/ctxrw.go b/mse/ctxrw.go new file mode 100644 index 00000000..933c1871 --- /dev/null +++ b/mse/ctxrw.go @@ -0,0 +1,58 @@ +package mse + +import ( + "context" + g "github.com/anacrolix/generics" + "io" +) + +type contextedReader struct { + ctx context.Context + r io.Reader +} + +func (me contextedReader) Read(p []byte) (n int, err error) { + return contextedReadOrWrite(me.ctx, me.r.Read, p) +} + +type contextedWriter struct { + ctx context.Context + w io.Writer +} + +// This is problematic. If you return with a context error, a read or write is still pending, and +// could mess up the stream. +func contextedReadOrWrite(ctx context.Context, method func(b []byte) (int, error), b []byte) (_ int, err error) { + asyncCh := make(chan g.Result[int], 1) + go func() { + asyncCh <- g.ResultFromTuple(method(b)) + }() + select { + case <-ctx.Done(): + err = context.Cause(ctx) + return + case res := <-asyncCh: + return res.AsTuple() + } + +} + +func (me contextedWriter) Write(p []byte) (n int, err error) { + return contextedReadOrWrite(me.ctx, me.w.Write, p) +} + +func contextedReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter { + return struct { + io.Reader + io.Writer + }{ + contextedReader{ + ctx: ctx, + r: rw, + }, + contextedWriter{ + ctx: ctx, + w: rw, + }, + } +} diff --git a/mse/mse.go b/mse/mse.go index 582a451f..6ab6f223 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -4,6 +4,7 @@ package mse import ( "bytes" + "context" "crypto/rand" "crypto/rc4" "crypto/sha1" @@ -32,7 +33,7 @@ type CryptoMethod uint32 var ( // Prime P according to the spec, and G, the generator. - p, g big.Int + p, specG big.Int // The rand.Int max arg for use in newPadLen() newPadLenMax big.Int // For use in initer's hashes @@ -50,7 +51,7 @@ var ( func init() { p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0) - g.SetInt64(2) + specG.SetInt64(2) newPadLenMax.SetInt64(maxPadLen + 1) } @@ -159,7 +160,7 @@ func paddedLeft(b []byte, _len int) []byte { // Calculate, and send Y, our public key. func (h *handshake) postY(x *big.Int) error { var y big.Int - y.Exp(&g, x, &p) + y.Exp(&specG, x, &p) return h.postWrite(paddedLeft(y.Bytes(), 96)) } @@ -167,7 +168,7 @@ func (h *handshake) establishS() error { x := newX() h.postY(&x) var b [96]byte - _, err := io.ReadFull(h.conn, b[:]) + _, err := io.ReadFull(h.ctxConn, b[:]) if err != nil { return fmt.Errorf("error reading Y: %w", err) } @@ -193,12 +194,14 @@ func newPadLen() int64 { // Manages state for both initiating and receiving handshakes. type handshake struct { - conn io.ReadWriter - s [96]byte - initer bool // Whether we're initiating or receiving. - skeys SecretKeyIter // Skeys we'll accept if receiving. - skey []byte // Skey we're initiating with. - ia []byte // Initial payload. Only used by the initiator. + conn io.ReadWriter + // The conn with Reads and Writes wrapped to the context given in handshake.Do. + ctxConn io.ReadWriter + s [96]byte + initer bool // Whether we're initiating or receiving. + skeys SecretKeyIter // Skeys we'll accept if receiving. + skey []byte // Skey we're initiating with. + ia []byte // Initial payload. Only used by the initiator. // Return the bit for the crypto method the receiver wants to use. chooseMethod CryptoSelector // Sent to the receiver. @@ -250,7 +253,7 @@ func (h *handshake) writer() { b := h.writes[0] h.writes = h.writes[1:] h.writeMu.Unlock() - _, err := h.conn.Write(b) + _, err := h.ctxConn.Write(b) if err != nil { h.writeMu.Lock() h.writeErr = err @@ -357,7 +360,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher { return newEncrypt(initer, h.s[:], h.skey) } -func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) { +func (h *handshake) initerSteps(ctx context.Context) (ret io.ReadWriter, selected CryptoMethod, err error) { h.postWrite(hash(req1, h.s[:])) h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:]))) buf := &bytes.Buffer{} @@ -380,29 +383,32 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err // Read until the all zero VC. At this point we've only read the 96 byte // public key, Y. There is potentially 512 byte padding, between us and // the 8 byte verification constant. - err = readUntil(io.LimitReader(h.conn, 520), eVC[:]) + err = readUntil(io.LimitReader(h.ctxConn, 520), eVC[:]) if err != nil { if err == io.EOF { err = errors.New("failed to synchronize on VC") } else { - err = fmt.Errorf("error reading until VC: %s", err) + err = fmt.Errorf("error reading until VC: %w", err) } return } - r := newCipherReader(bC, h.conn) + ctxReader := newCipherReader(bC, h.ctxConn) var method CryptoMethod - err = unmarshal(r, &method, &padLen) + err = unmarshal(ctxReader, &method, &padLen) if err != nil { return } - _, err = io.CopyN(io.Discard, r, int64(padLen)) + _, err = io.CopyN(io.Discard, ctxReader, int64(padLen)) if err != nil { return } selected = method & h.cryptoProvides switch selected { case CryptoMethodRC4: - ret = readWriter{r, &cipherWriter{e, h.conn, nil}} + ret = readWriter{ + newCipherReader(bC, h.conn), + &cipherWriter{e, h.conn, nil}, + } case CryptoMethodPlaintext: ret = h.conn default: @@ -413,9 +419,9 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err var ErrNoSecretKeyMatch = errors.New("no skey matched") -func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) { +func (h *handshake) receiverSteps(ctx context.Context) (ret io.ReadWriter, chosen CryptoMethod, err error) { // There is up to 512 bytes of padding, then the 20 byte hash. - err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:])) + err = readUntil(io.LimitReader(h.ctxConn, 532), hash(req1, h.s[:])) if err != nil { if err == io.EOF { err = errors.New("failed to synchronize on S hash") @@ -423,7 +429,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err return } var b [20]byte - _, err = io.ReadFull(h.conn, b[:]) + _, err = io.ReadFull(h.ctxConn, b[:]) if err != nil { return } @@ -447,28 +453,29 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err if err != nil { return } - r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn) + cipher := newEncrypt(true, h.s[:], h.skey) + ctxReader := newCipherReader(cipher, h.ctxConn) var ( vc [8]byte provides CryptoMethod padLen uint16 ) - err = unmarshal(r, vc[:], &provides, &padLen) + err = unmarshal(ctxReader, vc[:], &provides, &padLen) if err != nil { return } cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1) chosen = h.chooseMethod(provides) - _, err = io.CopyN(io.Discard, r, int64(padLen)) + _, err = io.CopyN(io.Discard, ctxReader, int64(padLen)) if err != nil { return } var lenIA uint16 - unmarshal(r, &lenIA) + unmarshal(ctxReader, &lenIA) if lenIA != 0 { h.ia = make([]byte, lenIA) - unmarshal(r, h.ia) + unmarshal(ctxReader, h.ia) } buf := &bytes.Buffer{} w := cipherWriter{h.newEncrypt(false), buf, nil} @@ -484,7 +491,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err switch chosen { case CryptoMethodRC4: ret = readWriter{ - io.MultiReader(bytes.NewReader(h.ia), r), + io.MultiReader(bytes.NewReader(h.ia), newCipherReader(cipher, h.conn)), &cipherWriter{w.c, h.conn, nil}, } case CryptoMethodPlaintext: @@ -498,7 +505,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err return } -func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) { +func (h *handshake) Do(ctx context.Context) (ret io.ReadWriter, method CryptoMethod, err error) { h.writeCond.L = &h.writeMu h.writerCond.L = &h.writerMu go h.writer() @@ -520,27 +527,41 @@ func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) { return } if h.initer { - ret, method, err = h.initerSteps() + ret, method, err = h.initerSteps(ctx) } else { - ret, method, err = h.receiverSteps() + ret, method, err = h.receiverSteps(ctx) } return } func InitiateHandshake( - rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod, + rw io.ReadWriter, + skey, initialPayload []byte, + cryptoProvides CryptoMethod, +) ( + ret io.ReadWriter, method CryptoMethod, err error, +) { + return InitiateHandshakeContext(context.TODO(), rw, skey, initialPayload, cryptoProvides) +} + +func InitiateHandshakeContext( + ctx context.Context, + rw io.ReadWriter, + skey, initialPayload []byte, + cryptoProvides CryptoMethod, ) ( ret io.ReadWriter, method CryptoMethod, err error, ) { h := handshake{ conn: rw, + ctxConn: contextedReadWriter(ctx, rw), initer: true, skey: skey, ia: initialPayload, cryptoProvides: cryptoProvides, } defer perf.ScopeTimerErr(&err)() - return h.Do() + return h.Do(ctx) } type HandshakeResult struct { @@ -550,19 +571,30 @@ type HandshakeResult struct { SecretKey []byte } -func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) { - res := ReceiveHandshakeEx(rw, skeys, selectCrypto) +func ReceiveHandshake( + ctx context.Context, + rw io.ReadWriter, + skeys SecretKeyIter, + selectCrypto CryptoSelector, +) (io.ReadWriter, CryptoMethod, error) { + res := ReceiveHandshakeEx(ctx, rw, skeys, selectCrypto) return res.ReadWriter, res.CryptoMethod, res.error } -func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) { +func ReceiveHandshakeEx( + ctx context.Context, + rw io.ReadWriter, + skeys SecretKeyIter, + selectCrypto CryptoSelector, +) (ret HandshakeResult) { h := handshake{ conn: rw, + ctxConn: contextedReadWriter(ctx, rw), initer: false, skeys: skeys, chooseMethod: selectCrypto, } - ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do() + ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do(ctx) ret.SecretKey = h.skey return } diff --git a/mse/mse_test.go b/mse/mse_test.go index f7f7fe7a..d97323f1 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -2,6 +2,7 @@ package mse import ( "bytes" + "context" "crypto/rand" "crypto/rc4" "io" @@ -77,7 +78,12 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides }() go func() { defer wg.Done() - res := ReceiveHandshakeEx(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect) + res := ReceiveHandshakeEx( + context.Background(), + b, + sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), + cryptoSelect, + ) require.NoError(t, res.error) assert.EqualValues(t, "yep", res.SecretKey) b := res.ReadWriter @@ -130,7 +136,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) { func TestReceiveRandomData(t *testing.T) { tr := trackReader{rand.Reader, 0} - _, _, err := ReceiveHandshake(readWriter{&tr, io.Discard}, nil, DefaultCryptoSelector) + _, _, err := ReceiveHandshake(context.Background(), readWriter{&tr, io.Discard}, nil, DefaultCryptoSelector) // No skey matches require.Error(t, err) // Establishing S, and then reading the maximum padding for giving up on @@ -185,7 +191,12 @@ func benchmarkStream(t *testing.B, crypto CryptoMethod) { }() func() { defer bc.Close() - rw, _, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto }) + rw, _, err := ReceiveHandshake( + context.Background(), + bc, + sliceIter([][]byte{[]byte("cats")}), + func(CryptoMethod) CryptoMethod { return crypto }, + ) require.NoError(t, err) require.NoError(t, readAndWrite(rw, br, b)) }() @@ -270,7 +281,7 @@ func BenchmarkSkeysReceive(b *testing.B) { panic(err) } }() - res := ReceiveHandshakeEx(receiver, sliceIter(skeys), DefaultCryptoSelector) + res := ReceiveHandshakeEx(context.Background(), receiver, sliceIter(skeys), DefaultCryptoSelector) if res.error != nil { panic(res.error) }