import (
"bytes"
+ "context"
"crypto/rand"
"crypto/rc4"
"crypto/sha1"
var (
// Prime P according to the spec, and G, the generator.
- p, g big.Int
+ p, specG big.Int
// The rand.Int max arg for use in newPadLen()
newPadLenMax big.Int
// For use in initer's hashes
func init() {
p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
- g.SetInt64(2)
+ specG.SetInt64(2)
newPadLenMax.SetInt64(maxPadLen + 1)
}
// Calculate, and send Y, our public key.
func (h *handshake) postY(x *big.Int) error {
var y big.Int
- y.Exp(&g, x, &p)
+ y.Exp(&specG, x, &p)
return h.postWrite(paddedLeft(y.Bytes(), 96))
}
x := newX()
h.postY(&x)
var b [96]byte
- _, err := io.ReadFull(h.conn, b[:])
+ _, err := io.ReadFull(h.ctxConn, b[:])
if err != nil {
return fmt.Errorf("error reading Y: %w", err)
}
// Manages state for both initiating and receiving handshakes.
type handshake struct {
- conn io.ReadWriter
- s [96]byte
- initer bool // Whether we're initiating or receiving.
- skeys SecretKeyIter // Skeys we'll accept if receiving.
- skey []byte // Skey we're initiating with.
- ia []byte // Initial payload. Only used by the initiator.
+ conn io.ReadWriter
+ // The conn with Reads and Writes wrapped to the context given in handshake.Do.
+ ctxConn io.ReadWriter
+ s [96]byte
+ initer bool // Whether we're initiating or receiving.
+ skeys SecretKeyIter // Skeys we'll accept if receiving.
+ skey []byte // Skey we're initiating with.
+ ia []byte // Initial payload. Only used by the initiator.
// Return the bit for the crypto method the receiver wants to use.
chooseMethod CryptoSelector
// Sent to the receiver.
b := h.writes[0]
h.writes = h.writes[1:]
h.writeMu.Unlock()
- _, err := h.conn.Write(b)
+ _, err := h.ctxConn.Write(b)
if err != nil {
h.writeMu.Lock()
h.writeErr = err
return newEncrypt(initer, h.s[:], h.skey)
}
-func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
+func (h *handshake) initerSteps(ctx context.Context) (ret io.ReadWriter, selected CryptoMethod, err error) {
h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{}
// Read until the all zero VC. At this point we've only read the 96 byte
// public key, Y. There is potentially 512 byte padding, between us and
// the 8 byte verification constant.
- err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
+ err = readUntil(io.LimitReader(h.ctxConn, 520), eVC[:])
if err != nil {
if err == io.EOF {
err = errors.New("failed to synchronize on VC")
} else {
- err = fmt.Errorf("error reading until VC: %s", err)
+ err = fmt.Errorf("error reading until VC: %w", err)
}
return
}
- r := newCipherReader(bC, h.conn)
+ ctxReader := newCipherReader(bC, h.ctxConn)
var method CryptoMethod
- err = unmarshal(r, &method, &padLen)
+ err = unmarshal(ctxReader, &method, &padLen)
if err != nil {
return
}
- _, err = io.CopyN(io.Discard, r, int64(padLen))
+ _, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
if err != nil {
return
}
selected = method & h.cryptoProvides
switch selected {
case CryptoMethodRC4:
- ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+ ret = readWriter{
+ newCipherReader(bC, h.conn),
+ &cipherWriter{e, h.conn, nil},
+ }
case CryptoMethodPlaintext:
ret = h.conn
default:
var ErrNoSecretKeyMatch = errors.New("no skey matched")
-func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
+func (h *handshake) receiverSteps(ctx context.Context) (ret io.ReadWriter, chosen CryptoMethod, err error) {
// There is up to 512 bytes of padding, then the 20 byte hash.
- err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
+ err = readUntil(io.LimitReader(h.ctxConn, 532), hash(req1, h.s[:]))
if err != nil {
if err == io.EOF {
err = errors.New("failed to synchronize on S hash")
return
}
var b [20]byte
- _, err = io.ReadFull(h.conn, b[:])
+ _, err = io.ReadFull(h.ctxConn, b[:])
if err != nil {
return
}
if err != nil {
return
}
- r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
+ cipher := newEncrypt(true, h.s[:], h.skey)
+ ctxReader := newCipherReader(cipher, h.ctxConn)
var (
vc [8]byte
provides CryptoMethod
padLen uint16
)
- err = unmarshal(r, vc[:], &provides, &padLen)
+ err = unmarshal(ctxReader, vc[:], &provides, &padLen)
if err != nil {
return
}
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
chosen = h.chooseMethod(provides)
- _, err = io.CopyN(io.Discard, r, int64(padLen))
+ _, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
if err != nil {
return
}
var lenIA uint16
- unmarshal(r, &lenIA)
+ unmarshal(ctxReader, &lenIA)
if lenIA != 0 {
h.ia = make([]byte, lenIA)
- unmarshal(r, h.ia)
+ unmarshal(ctxReader, h.ia)
}
buf := &bytes.Buffer{}
w := cipherWriter{h.newEncrypt(false), buf, nil}
switch chosen {
case CryptoMethodRC4:
ret = readWriter{
- io.MultiReader(bytes.NewReader(h.ia), r),
+ io.MultiReader(bytes.NewReader(h.ia), newCipherReader(cipher, h.conn)),
&cipherWriter{w.c, h.conn, nil},
}
case CryptoMethodPlaintext:
return
}
-func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
+func (h *handshake) Do(ctx context.Context) (ret io.ReadWriter, method CryptoMethod, err error) {
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
return
}
if h.initer {
- ret, method, err = h.initerSteps()
+ ret, method, err = h.initerSteps(ctx)
} else {
- ret, method, err = h.receiverSteps()
+ ret, method, err = h.receiverSteps(ctx)
}
return
}
func InitiateHandshake(
- rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+ rw io.ReadWriter,
+ skey, initialPayload []byte,
+ cryptoProvides CryptoMethod,
+) (
+ ret io.ReadWriter, method CryptoMethod, err error,
+) {
+ return InitiateHandshakeContext(context.TODO(), rw, skey, initialPayload, cryptoProvides)
+}
+
+func InitiateHandshakeContext(
+ ctx context.Context,
+ rw io.ReadWriter,
+ skey, initialPayload []byte,
+ cryptoProvides CryptoMethod,
) (
ret io.ReadWriter, method CryptoMethod, err error,
) {
h := handshake{
conn: rw,
+ ctxConn: contextedReadWriter(ctx, rw),
initer: true,
skey: skey,
ia: initialPayload,
cryptoProvides: cryptoProvides,
}
defer perf.ScopeTimerErr(&err)()
- return h.Do()
+ return h.Do(ctx)
}
type HandshakeResult struct {
SecretKey []byte
}
-func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
- res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+func ReceiveHandshake(
+ ctx context.Context,
+ rw io.ReadWriter,
+ skeys SecretKeyIter,
+ selectCrypto CryptoSelector,
+) (io.ReadWriter, CryptoMethod, error) {
+ res := ReceiveHandshakeEx(ctx, rw, skeys, selectCrypto)
return res.ReadWriter, res.CryptoMethod, res.error
}
-func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
+func ReceiveHandshakeEx(
+ ctx context.Context,
+ rw io.ReadWriter,
+ skeys SecretKeyIter,
+ selectCrypto CryptoSelector,
+) (ret HandshakeResult) {
h := handshake{
conn: rw,
+ ctxConn: contextedReadWriter(ctx, rw),
initer: false,
skeys: skeys,
chooseMethod: selectCrypto,
}
- ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+ ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do(ctx)
ret.SecretKey = h.skey
return
}