"expvar"
"fmt"
"io"
- "io/ioutil"
+ "math"
"math/big"
"strconv"
"sync"
- "github.com/bradfitz/iter"
+ "github.com/anacrolix/missinggo/perf"
)
const (
maxPadLen = 512
- cryptoMethodPlaintext = 1
- cryptoMethodRC4 = 2
+ CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
+ CryptoMethodRC4 CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
+ AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4
)
+type CryptoMethod uint32
+
var (
// Prime P according to the spec, and G, the generator.
p, g big.Int
return h.Sum(nil)
}
-func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
+func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
c, err := rc4.NewCipher(hash([]byte(func() string {
if initer {
return "keyA"
return ret
}
}()
- cr.c.XORKeyStream(be[:], b)
+ cr.c.XORKeyStream(be, b)
n, err = cr.w.Write(be[:len(b)])
if n != len(b) {
// The cipher will have advanced beyond the callers stream position.
return h.postWrite(paddedLeft(y.Bytes(), 96))
}
-func (h *handshake) establishS() (err error) {
+func (h *handshake) establishS() error {
x := newX()
h.postY(&x)
var b [96]byte
- _, err = io.ReadFull(h.conn, b[:])
+ _, err := io.ReadFull(h.conn, b[:])
if err != nil {
- return
+ return fmt.Errorf("error reading Y: %w", err)
}
var Y, S big.Int
Y.SetBytes(b[:])
S.Exp(&Y, &x, &p)
sBytes := S.Bytes()
copy(h.s[96-len(sBytes):96], sBytes)
- return
+ return nil
}
func newPadLen() int64 {
skeys SecretKeyIter // Skeys we'll accept if receiving.
skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator.
+ // Return the bit for the crypto method the receiver wants to use.
+ chooseMethod CryptoSelector
+ // Sent to the receiver.
+ cryptoProvides CryptoMethod
writeMu sync.Mutex
writes [][]byte
h.writerCond.Wait()
}
h.writerMu.Unlock()
- return
}
func (h *handshake) writer() {
return nil
}
-func xor(dst, src []byte) (ret []byte) {
- max := len(dst)
- if max > len(src) {
- max = len(src)
- }
- ret = make([]byte, 0, max)
- for i := range iter.N(max) {
- ret = append(ret, dst[i]^src[i])
+func xor(a, b []byte) (ret []byte) {
+ max := len(a)
+ if max > len(b) {
+ max = len(b)
}
+ ret = make([]byte, max)
+ xorInPlace(ret, a, b)
return
}
+func xorInPlace(dst, a, b []byte) {
+ for i := range dst {
+ dst[i] = a[i] ^ b[i]
+ }
+}
+
func marshal(w io.Writer, data ...interface{}) (err error) {
for _, data := range data {
err = binary.Write(w, binary.BigEndian, data)
return newEncrypt(initer, h.s[:], h.skey)
}
-func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{}
padLen := uint16(newPadLen())
- err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
+ if len(h.ia) > math.MaxUint16 {
+ err = errors.New("initial payload too large")
+ return
+ }
+ err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
if err != nil {
return
}
return
}
r := newCipherReader(bC, h.conn)
- var method uint32
+ var method CryptoMethod
err = unmarshal(r, &method, &padLen)
if err != nil {
return
}
- if method != cryptoMethodRC4 {
- err = fmt.Errorf("receiver chose unsupported method: %x", method)
- return
- }
- _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+ _, err = io.CopyN(io.Discard, r, int64(padLen))
if err != nil {
return
}
- ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+ selected = method & h.cryptoProvides
+ switch selected {
+ case CryptoMethodRC4:
+ ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+ case CryptoMethodPlaintext:
+ ret = h.conn
+ default:
+ err = fmt.Errorf("receiver chose unsupported method: %x", method)
+ }
return
}
var ErrNoSecretKeyMatch = errors.New("no skey matched")
-func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
// There is up to 512 bytes of padding, then the 20 byte hash.
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
if err != nil {
if err != nil {
return
}
+ expectedHash := hash(req3, h.s[:])
+ eachHash := sha1.New()
+ var sum, xored [sha1.Size]byte
err = ErrNoSecretKeyMatch
h.skeys(func(skey []byte) bool {
- if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
+ eachHash.Reset()
+ eachHash.Write(req2)
+ eachHash.Write(skey)
+ eachHash.Sum(sum[:0])
+ xorInPlace(xored[:], sum[:], expectedHash)
+ if bytes.Equal(xored[:], b[:]) {
h.skey = skey
err = nil
return false
}
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
var (
- vc [8]byte
- method uint32
- padLen uint16
+ vc [8]byte
+ provides CryptoMethod
+ padLen uint16
)
- err = unmarshal(r, vc[:], &method, &padLen)
+ err = unmarshal(r, vc[:], &provides, &padLen)
if err != nil {
return
}
- cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
- if method&cryptoMethodRC4 == 0 {
- err = errors.New("no supported crypto methods were provided")
- return
- }
- _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
+ cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
+ chosen = h.chooseMethod(provides)
+ _, err = io.CopyN(io.Discard, r, int64(padLen))
if err != nil {
return
}
buf := &bytes.Buffer{}
w := cipherWriter{h.newEncrypt(false), buf, nil}
padLen = uint16(newPadLen())
- err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
+ err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
if err != nil {
return
}
if err != nil {
return
}
- ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}}
+ switch chosen {
+ case CryptoMethodRC4:
+ ret = readWriter{
+ io.MultiReader(bytes.NewReader(h.ia), r),
+ &cipherWriter{w.c, h.conn, nil},
+ }
+ case CryptoMethodPlaintext:
+ ret = readWriter{
+ io.MultiReader(bytes.NewReader(h.ia), h.conn),
+ h.conn,
+ }
+ default:
+ err = errors.New("chosen crypto method is not supported")
+ }
return
}
-func (h *handshake) Do() (ret io.ReadWriter, err error) {
+func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
}()
err = h.establishS()
if err != nil {
- err = fmt.Errorf("error while establishing secret: %s", err)
+ err = fmt.Errorf("error while establishing secret: %w", err)
return
}
pad := make([]byte, newPadLen())
return
}
if h.initer {
- ret, err = h.initerSteps()
+ ret, method, err = h.initerSteps()
} else {
- ret, err = h.receiverSteps()
+ ret, method, err = h.receiverSteps()
}
return
}
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(
+ rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+) (
+ ret io.ReadWriter, method CryptoMethod, err error,
+) {
h := handshake{
- conn: rw,
- initer: true,
- skey: skey,
- ia: initialPayload,
+ conn: rw,
+ initer: true,
+ skey: skey,
+ ia: initialPayload,
+ cryptoProvides: cryptoProvides,
}
+ defer perf.ScopeTimerErr(&err)()
return h.Do()
}
-func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
- h := handshake{
- conn: rw,
- initer: false,
- skeys: sliceIter(skeys),
- }
- return h.Do()
+type HandshakeResult struct {
+ io.ReadWriter
+ CryptoMethod
+ error
+ SecretKey []byte
}
-func sliceIter(skeys [][]byte) SecretKeyIter {
- return func(callback func([]byte) bool) {
- for _, sk := range skeys {
- if !callback(sk) {
- break
- }
- }
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
+ res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+ return res.ReadWriter, res.CryptoMethod, res.error
+}
+
+func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
+ h := handshake{
+ conn: rw,
+ initer: false,
+ skeys: skeys,
+ chooseMethod: selectCrypto,
}
+ ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+ ret.SecretKey = h.skey
+ return
}
// A function that given a function, calls it with secret keys until it
// returns false or exhausted.
type SecretKeyIter func(callback func(skey []byte) (more bool))
-// Doesn't unpack the secret keys until it needs to, and through the passed
-// function.
-func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWriter, err error) {
- h := handshake{
- conn: rw,
- initer: false,
- skeys: skeys,
+func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
+ // We prefer plaintext for performance reasons.
+ if provided&CryptoMethodPlaintext != 0 {
+ return CryptoMethodPlaintext
}
- return h.Do()
+ return CryptoMethodRC4
}
+
+type CryptoSelector func(CryptoMethod) CryptoMethod