"crypto/sha1"
"encoding/binary"
"errors"
+ "expvar"
"fmt"
"io"
- "io/ioutil"
- "log"
+ "math"
"math/big"
+ "strconv"
"sync"
- "github.com/bradfitz/iter"
+ "github.com/anacrolix/missinggo/perf"
)
const (
maxPadLen = 512
- cryptoMethodPlaintext = 1
- cryptoMethodRC4 = 2
+ CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
+ CryptoMethodRC4 CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
+ AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4
)
+type CryptoMethod uint32
+
var (
// Prime P according to the spec, and G, the generator.
p, g big.Int
req1 = []byte("req1")
req2 = []byte("req2")
req3 = []byte("req3")
+ // Verification constant "VC" which is all zeroes in the bittorrent
+ // implementation.
+ vc [8]byte
+ // Zero padding
+ zeroPad [512]byte
+ // Tracks counts of received crypto_provides
+ cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
)
func init() {
return h.Sum(nil)
}
-func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
+func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
c, err := rc4.NewCipher(hash([]byte(func() string {
if initer {
return "keyA"
}
type cipherReader struct {
- c *rc4.Cipher
- r io.Reader
+ c *rc4.Cipher
+ r io.Reader
+ mu sync.Mutex
+ be []byte
}
-func (me *cipherReader) Read(b []byte) (n int, err error) {
- be := make([]byte, len(b))
- n, err = me.r.Read(be)
- me.c.XORKeyStream(b[:n], be[:n])
+func (cr *cipherReader) Read(b []byte) (n int, err error) {
+ var be []byte
+ cr.mu.Lock()
+ if len(cr.be) >= len(b) {
+ be = cr.be
+ cr.be = nil
+ cr.mu.Unlock()
+ } else {
+ cr.mu.Unlock()
+ be = make([]byte, len(b))
+ }
+ n, err = cr.r.Read(be[:len(b)])
+ cr.c.XORKeyStream(b[:n], be[:n])
+ cr.mu.Lock()
+ if len(be) > len(cr.be) {
+ cr.be = be
+ }
+ cr.mu.Unlock()
return
}
func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
- return &cipherReader{c, r}
+ return &cipherReader{c: c, r: r}
}
type cipherWriter struct {
c *rc4.Cipher
w io.Writer
+ b []byte
}
-func (me *cipherWriter) Write(b []byte) (n int, err error) {
- be := make([]byte, len(b))
- me.c.XORKeyStream(be, b)
- n, err = me.w.Write(be)
- if n != len(be) {
+func (cr *cipherWriter) Write(b []byte) (n int, err error) {
+ be := func() []byte {
+ if len(cr.b) < len(b) {
+ return make([]byte, len(b))
+ } else {
+ ret := cr.b
+ cr.b = nil
+ return ret
+ }
+ }()
+ cr.c.XORKeyStream(be, b)
+ n, err = cr.w.Write(be[:len(b)])
+ if n != len(b) {
// The cipher will have advanced beyond the callers stream position.
// We can't use the cipher anymore.
- me.c = nil
+ cr.c = nil
}
- return
-}
-
-func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
- return &cipherWriter{c, w}
-}
-
-func readY(r io.Reader) (y big.Int, err error) {
- var b [96]byte
- _, err = io.ReadFull(r, b[:])
- if err != nil {
- return
+ if len(be) > len(cr.b) {
+ cr.b = be
}
- y.SetBytes(b[:])
return
}
return X
}
+func paddedLeft(b []byte, _len int) []byte {
+ if len(b) == _len {
+ return b
+ }
+ ret := make([]byte, _len)
+ if n := copy(ret[_len-len(b):], b); n != len(b) {
+ panic(n)
+ }
+ return ret
+}
+
+// Calculate, and send Y, our public key.
func (h *handshake) postY(x *big.Int) error {
var y big.Int
y.Exp(&g, x, &p)
- b := y.Bytes()
- if len(b) != 96 {
- panic(len(b))
- }
- return h.postWrite(b)
+ return h.postWrite(paddedLeft(y.Bytes(), 96))
}
-func (h *handshake) establishS() (err error) {
+func (h *handshake) establishS() error {
x := newX()
h.postY(&x)
var b [96]byte
- _, err = io.ReadFull(h.conn, b[:])
+ _, err := io.ReadFull(h.conn, b[:])
if err != nil {
- return
+ return fmt.Errorf("error reading Y: %w", err)
}
- var Y big.Int
+ var Y, S big.Int
Y.SetBytes(b[:])
- h.s.Exp(&Y, &x, &p)
- return
+ S.Exp(&Y, &x, &p)
+ sBytes := S.Bytes()
+ copy(h.s[96-len(sBytes):96], sBytes)
+ return nil
}
func newPadLen() int64 {
return ret
}
+// Manages state for both initiating and receiving handshakes.
type handshake struct {
- conn io.ReadWriteCloser
- s big.Int
- initer bool
- skey []byte
+ conn io.ReadWriter
+ s [96]byte
+ initer bool // Whether we're initiating or receiving.
+ skeys SecretKeyIter // Skeys we'll accept if receiving.
+ skey []byte // Skey we're initiating with.
+ ia []byte // Initial payload. Only used by the initiator.
+ // Return the bit for the crypto method the receiver wants to use.
+ chooseMethod CryptoSelector
+ // Sent to the receiver.
+ cryptoProvides CryptoMethod
writeMu sync.Mutex
writes [][]byte
writerDone bool
}
-func (h *handshake) finishWriting() (err error) {
+func (h *handshake) finishWriting() {
h.writeMu.Lock()
h.writeClose = true
h.writeCond.Broadcast()
- err = h.writeErr
h.writeMu.Unlock()
h.writerMu.Lock()
h.writerCond.Wait()
}
h.writerMu.Unlock()
- return
-
}
func (h *handshake) writer() {
return nil
}
-func xor(dst, src []byte) (ret []byte) {
- max := len(dst)
- if max > len(src) {
- max = len(src)
- }
- ret = make([]byte, 0, max)
- for i := range iter.N(max) {
- ret = append(ret, dst[i]^src[i])
+func xor(a, b []byte) (ret []byte) {
+ max := len(a)
+ if max > len(b) {
+ max = len(b)
}
+ ret = make([]byte, max)
+ xorInPlace(ret, a, b)
return
}
-type cryptoNegotiation struct {
- VC [8]byte
- Method uint32
- PadLen uint16
- IA []byte
+func xorInPlace(dst, a, b []byte) {
+ for i := range dst {
+ dst[i] = a[i] ^ b[i]
+ }
}
-func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
- _, err = io.ReadFull(r, me.VC[:])
- if err != nil {
- return
- }
- err = binary.Read(r, binary.BigEndian, &me.Method)
- if err != nil {
- return
- }
- err = binary.Read(r, binary.BigEndian, &me.PadLen)
- if err != nil {
- return
+func marshal(w io.Writer, data ...interface{}) (err error) {
+ for _, data := range data {
+ err = binary.Write(w, binary.BigEndian, data)
+ if err != nil {
+ break
+ }
}
- log.Print(me.PadLen)
- _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
return
}
-func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
- _, err = w.Write(me.VC[:])
- if err != nil {
- return
- }
- err = binary.Write(w, binary.BigEndian, me.Method)
- if err != nil {
- return
- }
- err = binary.Write(w, binary.BigEndian, me.PadLen)
- if err != nil {
- return
+func unmarshal(r io.Reader, data ...interface{}) (err error) {
+ for _, data := range data {
+ err = binary.Read(r, binary.BigEndian, data)
+ if err != nil {
+ break
+ }
}
- _, err = w.Write(make([]byte, me.PadLen))
return
}
return 0
}
+// Reads from r until b has been seen. Keeps the minimum amount of data in
+// memory.
func readUntil(r io.Reader, b []byte) error {
- log.Println("read until", b)
b1 := make([]byte, len(b))
i := 0
for {
io.Writer
}
-func (h *handshake) Do() (ret io.ReadWriter, err error) {
- err = h.establishS()
- if err != nil {
+func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
+ return newEncrypt(initer, h.s[:], h.skey)
+}
+
+func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
+ h.postWrite(hash(req1, h.s[:]))
+ h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
+ buf := &bytes.Buffer{}
+ padLen := uint16(newPadLen())
+ if len(h.ia) > math.MaxUint16 {
+ err = errors.New("initial payload too large")
return
}
- pad := make([]byte, newPadLen())
- io.ReadFull(rand.Reader, pad)
- err = h.postWrite(pad)
+ err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
if err != nil {
return
}
- if h.initer {
- h.postWrite(hash(req1, h.s.Bytes()))
- h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
- buf := &bytes.Buffer{}
- err = (&cryptoNegotiation{
- Method: cryptoMethodRC4,
- PadLen: uint16(newPadLen()),
- }).MarshalWriter(buf)
- if err != nil {
- return
- }
- e := newEncrypt(true, h.s.Bytes(), h.skey)
- be := make([]byte, buf.Len())
- e.XORKeyStream(be, buf.Bytes())
- h.postWrite(be)
- bC := newEncrypt(false, h.s.Bytes(), h.skey)
- var eVC [8]byte
- bC.XORKeyStream(eVC[:], make([]byte, 8))
- log.Print(eVC)
- // Read until the all zero VC.
- err = readUntil(h.conn, eVC[:])
- if err != nil {
+ e := h.newEncrypt(true)
+ be := make([]byte, buf.Len())
+ e.XORKeyStream(be, buf.Bytes())
+ h.postWrite(be)
+ bC := h.newEncrypt(false)
+ var eVC [8]byte
+ bC.XORKeyStream(eVC[:], vc[:])
+ // Read until the all zero VC. At this point we've only read the 96 byte
+ // public key, Y. There is potentially 512 byte padding, between us and
+ // the 8 byte verification constant.
+ err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
+ if err != nil {
+ if err == io.EOF {
+ err = errors.New("failed to synchronize on VC")
+ } else {
err = fmt.Errorf("error reading until VC: %s", err)
- return
- }
- var cn cryptoNegotiation
- r := &cipherReader{bC, h.conn}
- err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
- log.Printf("initer got %v", cn)
- if err != nil {
- err = fmt.Errorf("error reading crypto negotiation: %s", err)
- return
}
- ret = readWriter{r, &cipherWriter{bC, h.conn}}
- } else {
- err = readUntil(h.conn, hash(req1, h.s.Bytes()))
- if err != nil {
- return
- }
- var b [20]byte
- _, err = io.ReadFull(h.conn, b[:])
- if err != nil {
- return
- }
- if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
- err = errors.New("skey doesn't match")
- return
+ return
+ }
+ r := newCipherReader(bC, h.conn)
+ var method CryptoMethod
+ err = unmarshal(r, &method, &padLen)
+ if err != nil {
+ return
+ }
+ _, err = io.CopyN(io.Discard, r, int64(padLen))
+ if err != nil {
+ return
+ }
+ selected = method & h.cryptoProvides
+ switch selected {
+ case CryptoMethodRC4:
+ ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+ case CryptoMethodPlaintext:
+ ret = h.conn
+ default:
+ err = fmt.Errorf("receiver chose unsupported method: %x", method)
+ }
+ return
+}
+
+var ErrNoSecretKeyMatch = errors.New("no skey matched")
+
+func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
+ // There is up to 512 bytes of padding, then the 20 byte hash.
+ err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
+ if err != nil {
+ if err == io.EOF {
+ err = errors.New("failed to synchronize on S hash")
}
- var cn cryptoNegotiation
- r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
- err = cn.UnmarshalReader(r)
- if err != nil {
- return
+ return
+ }
+ var b [20]byte
+ _, err = io.ReadFull(h.conn, b[:])
+ if err != nil {
+ return
+ }
+ expectedHash := hash(req3, h.s[:])
+ eachHash := sha1.New()
+ var sum, xored [sha1.Size]byte
+ err = ErrNoSecretKeyMatch
+ h.skeys(func(skey []byte) bool {
+ eachHash.Reset()
+ eachHash.Write(req2)
+ eachHash.Write(skey)
+ eachHash.Sum(sum[:0])
+ xorInPlace(xored[:], sum[:], expectedHash)
+ if bytes.Equal(xored[:], b[:]) {
+ h.skey = skey
+ err = nil
+ return false
}
- log.Printf("receiver got %v", cn)
- if cn.Method&cryptoMethodRC4 == 0 {
- err = errors.New("no supported crypto methods were provided")
- return
+ return true
+ })
+ if err != nil {
+ return
+ }
+ r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
+ var (
+ vc [8]byte
+ provides CryptoMethod
+ padLen uint16
+ )
+
+ err = unmarshal(r, vc[:], &provides, &padLen)
+ if err != nil {
+ return
+ }
+ cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
+ chosen = h.chooseMethod(provides)
+ _, err = io.CopyN(io.Discard, r, int64(padLen))
+ if err != nil {
+ return
+ }
+ var lenIA uint16
+ unmarshal(r, &lenIA)
+ if lenIA != 0 {
+ h.ia = make([]byte, lenIA)
+ unmarshal(r, h.ia)
+ }
+ buf := &bytes.Buffer{}
+ w := cipherWriter{h.newEncrypt(false), buf, nil}
+ padLen = uint16(newPadLen())
+ err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
+ if err != nil {
+ return
+ }
+ err = h.postWrite(buf.Bytes())
+ if err != nil {
+ return
+ }
+ switch chosen {
+ case CryptoMethodRC4:
+ ret = readWriter{
+ io.MultiReader(bytes.NewReader(h.ia), r),
+ &cipherWriter{w.c, h.conn, nil},
}
- buf := &bytes.Buffer{}
- w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
- err = (&cryptoNegotiation{
- Method: cryptoMethodRC4,
- PadLen: uint16(newPadLen()),
- }).MarshalWriter(w)
- if err != nil {
- return
+ case CryptoMethodPlaintext:
+ ret = readWriter{
+ io.MultiReader(bytes.NewReader(h.ia), h.conn),
+ h.conn,
}
- log.Println("encrypted VC", buf.Bytes()[:8])
- err = h.postWrite(buf.Bytes())
- if err != nil {
- return
+ default:
+ err = errors.New("chosen crypto method is not supported")
+ }
+ return
+}
+
+func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
+ h.writeCond.L = &h.writeMu
+ h.writerCond.L = &h.writerMu
+ go h.writer()
+ defer func() {
+ h.finishWriting()
+ if err == nil {
+ err = h.writeErr
}
- ret = readWriter{r, w}
+ }()
+ err = h.establishS()
+ if err != nil {
+ err = fmt.Errorf("error while establishing secret: %w", err)
+ return
}
- err = h.finishWriting()
+ pad := make([]byte, newPadLen())
+ io.ReadFull(rand.Reader, pad)
+ err = h.postWrite(pad)
if err != nil {
return
}
- ret = h.conn
+ if h.initer {
+ ret, method, err = h.initerSteps()
+ } else {
+ ret, method, err = h.receiverSteps()
+ }
return
}
-func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(
+ rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
+) (
+ ret io.ReadWriter, method CryptoMethod, err error,
+) {
h := handshake{
- conn: rw,
- initer: initer,
- skey: skey,
+ conn: rw,
+ initer: true,
+ skey: skey,
+ ia: initialPayload,
+ cryptoProvides: cryptoProvides,
}
- h.writeCond.L = &h.writeMu
- h.writerCond.L = &h.writerMu
- go h.writer()
+ defer perf.ScopeTimerErr(&err)()
return h.Do()
}
+
+type HandshakeResult struct {
+ io.ReadWriter
+ CryptoMethod
+ error
+ SecretKey []byte
+}
+
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
+ res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
+ return res.ReadWriter, res.CryptoMethod, res.error
+}
+
+func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
+ h := handshake{
+ conn: rw,
+ initer: false,
+ skeys: skeys,
+ chooseMethod: selectCrypto,
+ }
+ ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
+ ret.SecretKey = h.skey
+ return
+}
+
+// A function that given a function, calls it with secret keys until it
+// returns false or exhausted.
+type SecretKeyIter func(callback func(skey []byte) (more bool))
+
+func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
+ // We prefer plaintext for performance reasons.
+ if provided&CryptoMethodPlaintext != 0 {
+ return CryptoMethodPlaintext
+ }
+ return CryptoMethodRC4
+}
+
+type CryptoSelector func(CryptoMethod) CryptoMethod