]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Add mse.ReceiveHandshakeEx
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math"
17         "math/big"
18         "strconv"
19         "sync"
20
21         "github.com/anacrolix/missinggo/perf"
22         "github.com/bradfitz/iter"
23 )
24
25 const (
26         maxPadLen = 512
27
28         CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
29         CryptoMethodRC4       CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
30         AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
31 )
32
33 type CryptoMethod uint32
34
35 var (
36         // Prime P according to the spec, and G, the generator.
37         p, g big.Int
38         // The rand.Int max arg for use in newPadLen()
39         newPadLenMax big.Int
40         // For use in initer's hashes
41         req1 = []byte("req1")
42         req2 = []byte("req2")
43         req3 = []byte("req3")
44         // Verification constant "VC" which is all zeroes in the bittorrent
45         // implementation.
46         vc [8]byte
47         // Zero padding
48         zeroPad [512]byte
49         // Tracks counts of received crypto_provides
50         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
51 )
52
53 func init() {
54         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
55         g.SetInt64(2)
56         newPadLenMax.SetInt64(maxPadLen + 1)
57 }
58
59 func hash(parts ...[]byte) []byte {
60         h := sha1.New()
61         for _, p := range parts {
62                 n, err := h.Write(p)
63                 if err != nil {
64                         panic(err)
65                 }
66                 if n != len(p) {
67                         panic(n)
68                 }
69         }
70         return h.Sum(nil)
71 }
72
73 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
74         c, err := rc4.NewCipher(hash([]byte(func() string {
75                 if initer {
76                         return "keyA"
77                 } else {
78                         return "keyB"
79                 }
80         }()), s, skey))
81         if err != nil {
82                 panic(err)
83         }
84         var burnSrc, burnDst [1024]byte
85         c.XORKeyStream(burnDst[:], burnSrc[:])
86         return
87 }
88
89 type cipherReader struct {
90         c  *rc4.Cipher
91         r  io.Reader
92         mu sync.Mutex
93         be []byte
94 }
95
96 func (cr *cipherReader) Read(b []byte) (n int, err error) {
97         var be []byte
98         cr.mu.Lock()
99         if len(cr.be) >= len(b) {
100                 be = cr.be
101                 cr.be = nil
102                 cr.mu.Unlock()
103         } else {
104                 cr.mu.Unlock()
105                 be = make([]byte, len(b))
106         }
107         n, err = cr.r.Read(be[:len(b)])
108         cr.c.XORKeyStream(b[:n], be[:n])
109         cr.mu.Lock()
110         if len(be) > len(cr.be) {
111                 cr.be = be
112         }
113         cr.mu.Unlock()
114         return
115 }
116
117 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
118         return &cipherReader{c: c, r: r}
119 }
120
121 type cipherWriter struct {
122         c *rc4.Cipher
123         w io.Writer
124         b []byte
125 }
126
127 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
128         be := func() []byte {
129                 if len(cr.b) < len(b) {
130                         return make([]byte, len(b))
131                 } else {
132                         ret := cr.b
133                         cr.b = nil
134                         return ret
135                 }
136         }()
137         cr.c.XORKeyStream(be[:], b)
138         n, err = cr.w.Write(be[:len(b)])
139         if n != len(b) {
140                 // The cipher will have advanced beyond the callers stream position.
141                 // We can't use the cipher anymore.
142                 cr.c = nil
143         }
144         if len(be) > len(cr.b) {
145                 cr.b = be
146         }
147         return
148 }
149
150 func newX() big.Int {
151         var X big.Int
152         X.SetBytes(func() []byte {
153                 var b [20]byte
154                 _, err := rand.Read(b[:])
155                 if err != nil {
156                         panic(err)
157                 }
158                 return b[:]
159         }())
160         return X
161 }
162
163 func paddedLeft(b []byte, _len int) []byte {
164         if len(b) == _len {
165                 return b
166         }
167         ret := make([]byte, _len)
168         if n := copy(ret[_len-len(b):], b); n != len(b) {
169                 panic(n)
170         }
171         return ret
172 }
173
174 // Calculate, and send Y, our public key.
175 func (h *handshake) postY(x *big.Int) error {
176         var y big.Int
177         y.Exp(&g, x, &p)
178         return h.postWrite(paddedLeft(y.Bytes(), 96))
179 }
180
181 func (h *handshake) establishS() error {
182         x := newX()
183         h.postY(&x)
184         var b [96]byte
185         _, err := io.ReadFull(h.conn, b[:])
186         if err != nil {
187                 return fmt.Errorf("error reading Y: %s", err)
188         }
189         var Y, S big.Int
190         Y.SetBytes(b[:])
191         S.Exp(&Y, &x, &p)
192         sBytes := S.Bytes()
193         copy(h.s[96-len(sBytes):96], sBytes)
194         return nil
195 }
196
197 func newPadLen() int64 {
198         i, err := rand.Int(rand.Reader, &newPadLenMax)
199         if err != nil {
200                 panic(err)
201         }
202         ret := i.Int64()
203         if ret < 0 || ret > maxPadLen {
204                 panic(ret)
205         }
206         return ret
207 }
208
209 // Manages state for both initiating and receiving handshakes.
210 type handshake struct {
211         conn   io.ReadWriter
212         s      [96]byte
213         initer bool          // Whether we're initiating or receiving.
214         skeys  SecretKeyIter // Skeys we'll accept if receiving.
215         skey   []byte        // Skey we're initiating with.
216         ia     []byte        // Initial payload. Only used by the initiator.
217         // Return the bit for the crypto method the receiver wants to use.
218         chooseMethod CryptoSelector
219         // Sent to the receiver.
220         cryptoProvides CryptoMethod
221
222         writeMu    sync.Mutex
223         writes     [][]byte
224         writeErr   error
225         writeCond  sync.Cond
226         writeClose bool
227
228         writerMu   sync.Mutex
229         writerCond sync.Cond
230         writerDone bool
231 }
232
233 func (h *handshake) finishWriting() {
234         h.writeMu.Lock()
235         h.writeClose = true
236         h.writeCond.Broadcast()
237         h.writeMu.Unlock()
238
239         h.writerMu.Lock()
240         for !h.writerDone {
241                 h.writerCond.Wait()
242         }
243         h.writerMu.Unlock()
244 }
245
246 func (h *handshake) writer() {
247         defer func() {
248                 h.writerMu.Lock()
249                 h.writerDone = true
250                 h.writerCond.Broadcast()
251                 h.writerMu.Unlock()
252         }()
253         for {
254                 h.writeMu.Lock()
255                 for {
256                         if len(h.writes) != 0 {
257                                 break
258                         }
259                         if h.writeClose {
260                                 h.writeMu.Unlock()
261                                 return
262                         }
263                         h.writeCond.Wait()
264                 }
265                 b := h.writes[0]
266                 h.writes = h.writes[1:]
267                 h.writeMu.Unlock()
268                 _, err := h.conn.Write(b)
269                 if err != nil {
270                         h.writeMu.Lock()
271                         h.writeErr = err
272                         h.writeMu.Unlock()
273                         return
274                 }
275         }
276 }
277
278 func (h *handshake) postWrite(b []byte) error {
279         h.writeMu.Lock()
280         defer h.writeMu.Unlock()
281         if h.writeErr != nil {
282                 return h.writeErr
283         }
284         h.writes = append(h.writes, b)
285         h.writeCond.Signal()
286         return nil
287 }
288
289 func xor(dst, src []byte) (ret []byte) {
290         max := len(dst)
291         if max > len(src) {
292                 max = len(src)
293         }
294         ret = make([]byte, 0, max)
295         for i := range iter.N(max) {
296                 ret = append(ret, dst[i]^src[i])
297         }
298         return
299 }
300
301 func marshal(w io.Writer, data ...interface{}) (err error) {
302         for _, data := range data {
303                 err = binary.Write(w, binary.BigEndian, data)
304                 if err != nil {
305                         break
306                 }
307         }
308         return
309 }
310
311 func unmarshal(r io.Reader, data ...interface{}) (err error) {
312         for _, data := range data {
313                 err = binary.Read(r, binary.BigEndian, data)
314                 if err != nil {
315                         break
316                 }
317         }
318         return
319 }
320
321 // Looking for b at the end of a.
322 func suffixMatchLen(a, b []byte) int {
323         if len(b) > len(a) {
324                 b = b[:len(a)]
325         }
326         // i is how much of b to try to match
327         for i := len(b); i > 0; i-- {
328                 // j is how many chars we've compared
329                 j := 0
330                 for ; j < i; j++ {
331                         if b[i-1-j] != a[len(a)-1-j] {
332                                 goto shorter
333                         }
334                 }
335                 return j
336         shorter:
337         }
338         return 0
339 }
340
341 // Reads from r until b has been seen. Keeps the minimum amount of data in
342 // memory.
343 func readUntil(r io.Reader, b []byte) error {
344         b1 := make([]byte, len(b))
345         i := 0
346         for {
347                 _, err := io.ReadFull(r, b1[i:])
348                 if err != nil {
349                         return err
350                 }
351                 i = suffixMatchLen(b1, b)
352                 if i == len(b) {
353                         break
354                 }
355                 if copy(b1, b1[len(b1)-i:]) != i {
356                         panic("wat")
357                 }
358         }
359         return nil
360 }
361
362 type readWriter struct {
363         io.Reader
364         io.Writer
365 }
366
367 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
368         return newEncrypt(initer, h.s[:], h.skey)
369 }
370
371 func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
372         h.postWrite(hash(req1, h.s[:]))
373         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
374         buf := &bytes.Buffer{}
375         padLen := uint16(newPadLen())
376         if len(h.ia) > math.MaxUint16 {
377                 err = errors.New("initial payload too large")
378                 return
379         }
380         err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
381         if err != nil {
382                 return
383         }
384         e := h.newEncrypt(true)
385         be := make([]byte, buf.Len())
386         e.XORKeyStream(be, buf.Bytes())
387         h.postWrite(be)
388         bC := h.newEncrypt(false)
389         var eVC [8]byte
390         bC.XORKeyStream(eVC[:], vc[:])
391         // Read until the all zero VC. At this point we've only read the 96 byte
392         // public key, Y. There is potentially 512 byte padding, between us and
393         // the 8 byte verification constant.
394         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
395         if err != nil {
396                 if err == io.EOF {
397                         err = errors.New("failed to synchronize on VC")
398                 } else {
399                         err = fmt.Errorf("error reading until VC: %s", err)
400                 }
401                 return
402         }
403         r := newCipherReader(bC, h.conn)
404         var method CryptoMethod
405         err = unmarshal(r, &method, &padLen)
406         if err != nil {
407                 return
408         }
409         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
410         if err != nil {
411                 return
412         }
413         selected = method & h.cryptoProvides
414         switch selected {
415         case CryptoMethodRC4:
416                 ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
417         case CryptoMethodPlaintext:
418                 ret = h.conn
419         default:
420                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
421         }
422         return
423 }
424
425 var ErrNoSecretKeyMatch = errors.New("no skey matched")
426
427 func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
428         // There is up to 512 bytes of padding, then the 20 byte hash.
429         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
430         if err != nil {
431                 if err == io.EOF {
432                         err = errors.New("failed to synchronize on S hash")
433                 }
434                 return
435         }
436         var b [20]byte
437         _, err = io.ReadFull(h.conn, b[:])
438         if err != nil {
439                 return
440         }
441         err = ErrNoSecretKeyMatch
442         h.skeys(func(skey []byte) bool {
443                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
444                         h.skey = skey
445                         err = nil
446                         return false
447                 }
448                 return true
449         })
450         if err != nil {
451                 return
452         }
453         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
454         var (
455                 vc       [8]byte
456                 provides CryptoMethod
457                 padLen   uint16
458         )
459
460         err = unmarshal(r, vc[:], &provides, &padLen)
461         if err != nil {
462                 return
463         }
464         cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
465         chosen = h.chooseMethod(provides)
466         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
467         if err != nil {
468                 return
469         }
470         var lenIA uint16
471         unmarshal(r, &lenIA)
472         if lenIA != 0 {
473                 h.ia = make([]byte, lenIA)
474                 unmarshal(r, h.ia)
475         }
476         buf := &bytes.Buffer{}
477         w := cipherWriter{h.newEncrypt(false), buf, nil}
478         padLen = uint16(newPadLen())
479         err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
480         if err != nil {
481                 return
482         }
483         err = h.postWrite(buf.Bytes())
484         if err != nil {
485                 return
486         }
487         switch chosen {
488         case CryptoMethodRC4:
489                 ret = readWriter{
490                         io.MultiReader(bytes.NewReader(h.ia), r),
491                         &cipherWriter{w.c, h.conn, nil},
492                 }
493         case CryptoMethodPlaintext:
494                 ret = readWriter{
495                         io.MultiReader(bytes.NewReader(h.ia), h.conn),
496                         h.conn,
497                 }
498         default:
499                 err = errors.New("chosen crypto method is not supported")
500         }
501         return
502 }
503
504 func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
505         h.writeCond.L = &h.writeMu
506         h.writerCond.L = &h.writerMu
507         go h.writer()
508         defer func() {
509                 h.finishWriting()
510                 if err == nil {
511                         err = h.writeErr
512                 }
513         }()
514         err = h.establishS()
515         if err != nil {
516                 err = fmt.Errorf("error while establishing secret: %s", err)
517                 return
518         }
519         pad := make([]byte, newPadLen())
520         io.ReadFull(rand.Reader, pad)
521         err = h.postWrite(pad)
522         if err != nil {
523                 return
524         }
525         if h.initer {
526                 ret, method, err = h.initerSteps()
527         } else {
528                 ret, method, err = h.receiverSteps()
529         }
530         return
531 }
532
533 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
534         h := handshake{
535                 conn:           rw,
536                 initer:         true,
537                 skey:           skey,
538                 ia:             initialPayload,
539                 cryptoProvides: cryptoProvides,
540         }
541         defer perf.ScopeTimerErr(&err)()
542         return h.Do()
543 }
544
545 type HandshakeResult struct {
546         io.ReadWriter
547         CryptoMethod
548         error
549         SecretKey []byte
550 }
551
552 func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
553         res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
554         return res.ReadWriter, res.CryptoMethod, res.error
555 }
556
557 func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
558         h := handshake{
559                 conn:         rw,
560                 initer:       false,
561                 skeys:        skeys,
562                 chooseMethod: selectCrypto,
563         }
564         ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
565         ret.SecretKey = h.skey
566         return
567 }
568
569 // A function that given a function, calls it with secret keys until it
570 // returns false or exhausted.
571 type SecretKeyIter func(callback func(skey []byte) (more bool))
572
573 func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
574         // We prefer plaintext for performance reasons.
575         if provided&CryptoMethodPlaintext != 0 {
576                 return CryptoMethodPlaintext
577         }
578         return CryptoMethodRC4
579 }
580
581 type CryptoSelector func(CryptoMethod) CryptoMethod