]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Note that torrent.Reader is not concurrent-safe
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "math"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "github.com/anacrolix/missinggo/perf"
21 )
22
23 const (
24         maxPadLen = 512
25
26         CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
27         CryptoMethodRC4       CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
28         AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
29 )
30
31 type CryptoMethod uint32
32
33 var (
34         // Prime P according to the spec, and G, the generator.
35         p, g big.Int
36         // The rand.Int max arg for use in newPadLen()
37         newPadLenMax big.Int
38         // For use in initer's hashes
39         req1 = []byte("req1")
40         req2 = []byte("req2")
41         req3 = []byte("req3")
42         // Verification constant "VC" which is all zeroes in the bittorrent
43         // implementation.
44         vc [8]byte
45         // Zero padding
46         zeroPad [512]byte
47         // Tracks counts of received crypto_provides
48         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
49 )
50
51 func init() {
52         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
53         g.SetInt64(2)
54         newPadLenMax.SetInt64(maxPadLen + 1)
55 }
56
57 func hash(parts ...[]byte) []byte {
58         h := sha1.New()
59         for _, p := range parts {
60                 n, err := h.Write(p)
61                 if err != nil {
62                         panic(err)
63                 }
64                 if n != len(p) {
65                         panic(n)
66                 }
67         }
68         return h.Sum(nil)
69 }
70
71 func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
72         c, err := rc4.NewCipher(hash([]byte(func() string {
73                 if initer {
74                         return "keyA"
75                 } else {
76                         return "keyB"
77                 }
78         }()), s, skey))
79         if err != nil {
80                 panic(err)
81         }
82         var burnSrc, burnDst [1024]byte
83         c.XORKeyStream(burnDst[:], burnSrc[:])
84         return
85 }
86
87 type cipherReader struct {
88         c  *rc4.Cipher
89         r  io.Reader
90         mu sync.Mutex
91         be []byte
92 }
93
94 func (cr *cipherReader) Read(b []byte) (n int, err error) {
95         var be []byte
96         cr.mu.Lock()
97         if len(cr.be) >= len(b) {
98                 be = cr.be
99                 cr.be = nil
100                 cr.mu.Unlock()
101         } else {
102                 cr.mu.Unlock()
103                 be = make([]byte, len(b))
104         }
105         n, err = cr.r.Read(be[:len(b)])
106         cr.c.XORKeyStream(b[:n], be[:n])
107         cr.mu.Lock()
108         if len(be) > len(cr.be) {
109                 cr.be = be
110         }
111         cr.mu.Unlock()
112         return
113 }
114
115 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
116         return &cipherReader{c: c, r: r}
117 }
118
119 type cipherWriter struct {
120         c *rc4.Cipher
121         w io.Writer
122         b []byte
123 }
124
125 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
126         be := func() []byte {
127                 if len(cr.b) < len(b) {
128                         return make([]byte, len(b))
129                 } else {
130                         ret := cr.b
131                         cr.b = nil
132                         return ret
133                 }
134         }()
135         cr.c.XORKeyStream(be, b)
136         n, err = cr.w.Write(be[:len(b)])
137         if n != len(b) {
138                 // The cipher will have advanced beyond the callers stream position.
139                 // We can't use the cipher anymore.
140                 cr.c = nil
141         }
142         if len(be) > len(cr.b) {
143                 cr.b = be
144         }
145         return
146 }
147
148 func newX() big.Int {
149         var X big.Int
150         X.SetBytes(func() []byte {
151                 var b [20]byte
152                 _, err := rand.Read(b[:])
153                 if err != nil {
154                         panic(err)
155                 }
156                 return b[:]
157         }())
158         return X
159 }
160
161 func paddedLeft(b []byte, _len int) []byte {
162         if len(b) == _len {
163                 return b
164         }
165         ret := make([]byte, _len)
166         if n := copy(ret[_len-len(b):], b); n != len(b) {
167                 panic(n)
168         }
169         return ret
170 }
171
172 // Calculate, and send Y, our public key.
173 func (h *handshake) postY(x *big.Int) error {
174         var y big.Int
175         y.Exp(&g, x, &p)
176         return h.postWrite(paddedLeft(y.Bytes(), 96))
177 }
178
179 func (h *handshake) establishS() error {
180         x := newX()
181         h.postY(&x)
182         var b [96]byte
183         _, err := io.ReadFull(h.conn, b[:])
184         if err != nil {
185                 return fmt.Errorf("error reading Y: %w", err)
186         }
187         var Y, S big.Int
188         Y.SetBytes(b[:])
189         S.Exp(&Y, &x, &p)
190         sBytes := S.Bytes()
191         copy(h.s[96-len(sBytes):96], sBytes)
192         return nil
193 }
194
195 func newPadLen() int64 {
196         i, err := rand.Int(rand.Reader, &newPadLenMax)
197         if err != nil {
198                 panic(err)
199         }
200         ret := i.Int64()
201         if ret < 0 || ret > maxPadLen {
202                 panic(ret)
203         }
204         return ret
205 }
206
207 // Manages state for both initiating and receiving handshakes.
208 type handshake struct {
209         conn   io.ReadWriter
210         s      [96]byte
211         initer bool          // Whether we're initiating or receiving.
212         skeys  SecretKeyIter // Skeys we'll accept if receiving.
213         skey   []byte        // Skey we're initiating with.
214         ia     []byte        // Initial payload. Only used by the initiator.
215         // Return the bit for the crypto method the receiver wants to use.
216         chooseMethod CryptoSelector
217         // Sent to the receiver.
218         cryptoProvides CryptoMethod
219
220         writeMu    sync.Mutex
221         writes     [][]byte
222         writeErr   error
223         writeCond  sync.Cond
224         writeClose bool
225
226         writerMu   sync.Mutex
227         writerCond sync.Cond
228         writerDone bool
229 }
230
231 func (h *handshake) finishWriting() {
232         h.writeMu.Lock()
233         h.writeClose = true
234         h.writeCond.Broadcast()
235         h.writeMu.Unlock()
236
237         h.writerMu.Lock()
238         for !h.writerDone {
239                 h.writerCond.Wait()
240         }
241         h.writerMu.Unlock()
242 }
243
244 func (h *handshake) writer() {
245         defer func() {
246                 h.writerMu.Lock()
247                 h.writerDone = true
248                 h.writerCond.Broadcast()
249                 h.writerMu.Unlock()
250         }()
251         for {
252                 h.writeMu.Lock()
253                 for {
254                         if len(h.writes) != 0 {
255                                 break
256                         }
257                         if h.writeClose {
258                                 h.writeMu.Unlock()
259                                 return
260                         }
261                         h.writeCond.Wait()
262                 }
263                 b := h.writes[0]
264                 h.writes = h.writes[1:]
265                 h.writeMu.Unlock()
266                 _, err := h.conn.Write(b)
267                 if err != nil {
268                         h.writeMu.Lock()
269                         h.writeErr = err
270                         h.writeMu.Unlock()
271                         return
272                 }
273         }
274 }
275
276 func (h *handshake) postWrite(b []byte) error {
277         h.writeMu.Lock()
278         defer h.writeMu.Unlock()
279         if h.writeErr != nil {
280                 return h.writeErr
281         }
282         h.writes = append(h.writes, b)
283         h.writeCond.Signal()
284         return nil
285 }
286
287 func xor(a, b []byte) (ret []byte) {
288         max := len(a)
289         if max > len(b) {
290                 max = len(b)
291         }
292         ret = make([]byte, max)
293         xorInPlace(ret, a, b)
294         return
295 }
296
297 func xorInPlace(dst, a, b []byte) {
298         for i := range dst {
299                 dst[i] = a[i] ^ b[i]
300         }
301 }
302
303 func marshal(w io.Writer, data ...interface{}) (err error) {
304         for _, data := range data {
305                 err = binary.Write(w, binary.BigEndian, data)
306                 if err != nil {
307                         break
308                 }
309         }
310         return
311 }
312
313 func unmarshal(r io.Reader, data ...interface{}) (err error) {
314         for _, data := range data {
315                 err = binary.Read(r, binary.BigEndian, data)
316                 if err != nil {
317                         break
318                 }
319         }
320         return
321 }
322
323 // Looking for b at the end of a.
324 func suffixMatchLen(a, b []byte) int {
325         if len(b) > len(a) {
326                 b = b[:len(a)]
327         }
328         // i is how much of b to try to match
329         for i := len(b); i > 0; i-- {
330                 // j is how many chars we've compared
331                 j := 0
332                 for ; j < i; j++ {
333                         if b[i-1-j] != a[len(a)-1-j] {
334                                 goto shorter
335                         }
336                 }
337                 return j
338         shorter:
339         }
340         return 0
341 }
342
343 // Reads from r until b has been seen. Keeps the minimum amount of data in
344 // memory.
345 func readUntil(r io.Reader, b []byte) error {
346         b1 := make([]byte, len(b))
347         i := 0
348         for {
349                 _, err := io.ReadFull(r, b1[i:])
350                 if err != nil {
351                         return err
352                 }
353                 i = suffixMatchLen(b1, b)
354                 if i == len(b) {
355                         break
356                 }
357                 if copy(b1, b1[len(b1)-i:]) != i {
358                         panic("wat")
359                 }
360         }
361         return nil
362 }
363
364 type readWriter struct {
365         io.Reader
366         io.Writer
367 }
368
369 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
370         return newEncrypt(initer, h.s[:], h.skey)
371 }
372
373 func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
374         h.postWrite(hash(req1, h.s[:]))
375         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
376         buf := &bytes.Buffer{}
377         padLen := uint16(newPadLen())
378         if len(h.ia) > math.MaxUint16 {
379                 err = errors.New("initial payload too large")
380                 return
381         }
382         err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
383         if err != nil {
384                 return
385         }
386         e := h.newEncrypt(true)
387         be := make([]byte, buf.Len())
388         e.XORKeyStream(be, buf.Bytes())
389         h.postWrite(be)
390         bC := h.newEncrypt(false)
391         var eVC [8]byte
392         bC.XORKeyStream(eVC[:], vc[:])
393         // Read until the all zero VC. At this point we've only read the 96 byte
394         // public key, Y. There is potentially 512 byte padding, between us and
395         // the 8 byte verification constant.
396         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
397         if err != nil {
398                 if err == io.EOF {
399                         err = errors.New("failed to synchronize on VC")
400                 } else {
401                         err = fmt.Errorf("error reading until VC: %s", err)
402                 }
403                 return
404         }
405         r := newCipherReader(bC, h.conn)
406         var method CryptoMethod
407         err = unmarshal(r, &method, &padLen)
408         if err != nil {
409                 return
410         }
411         _, err = io.CopyN(io.Discard, r, int64(padLen))
412         if err != nil {
413                 return
414         }
415         selected = method & h.cryptoProvides
416         switch selected {
417         case CryptoMethodRC4:
418                 ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
419         case CryptoMethodPlaintext:
420                 ret = h.conn
421         default:
422                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
423         }
424         return
425 }
426
427 var ErrNoSecretKeyMatch = errors.New("no skey matched")
428
429 func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
430         // There is up to 512 bytes of padding, then the 20 byte hash.
431         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
432         if err != nil {
433                 if err == io.EOF {
434                         err = errors.New("failed to synchronize on S hash")
435                 }
436                 return
437         }
438         var b [20]byte
439         _, err = io.ReadFull(h.conn, b[:])
440         if err != nil {
441                 return
442         }
443         expectedHash := hash(req3, h.s[:])
444         eachHash := sha1.New()
445         var sum, xored [sha1.Size]byte
446         err = ErrNoSecretKeyMatch
447         h.skeys(func(skey []byte) bool {
448                 eachHash.Reset()
449                 eachHash.Write(req2)
450                 eachHash.Write(skey)
451                 eachHash.Sum(sum[:0])
452                 xorInPlace(xored[:], sum[:], expectedHash)
453                 if bytes.Equal(xored[:], b[:]) {
454                         h.skey = skey
455                         err = nil
456                         return false
457                 }
458                 return true
459         })
460         if err != nil {
461                 return
462         }
463         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
464         var (
465                 vc       [8]byte
466                 provides CryptoMethod
467                 padLen   uint16
468         )
469
470         err = unmarshal(r, vc[:], &provides, &padLen)
471         if err != nil {
472                 return
473         }
474         cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
475         chosen = h.chooseMethod(provides)
476         _, err = io.CopyN(io.Discard, r, int64(padLen))
477         if err != nil {
478                 return
479         }
480         var lenIA uint16
481         unmarshal(r, &lenIA)
482         if lenIA != 0 {
483                 h.ia = make([]byte, lenIA)
484                 unmarshal(r, h.ia)
485         }
486         buf := &bytes.Buffer{}
487         w := cipherWriter{h.newEncrypt(false), buf, nil}
488         padLen = uint16(newPadLen())
489         err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
490         if err != nil {
491                 return
492         }
493         err = h.postWrite(buf.Bytes())
494         if err != nil {
495                 return
496         }
497         switch chosen {
498         case CryptoMethodRC4:
499                 ret = readWriter{
500                         io.MultiReader(bytes.NewReader(h.ia), r),
501                         &cipherWriter{w.c, h.conn, nil},
502                 }
503         case CryptoMethodPlaintext:
504                 ret = readWriter{
505                         io.MultiReader(bytes.NewReader(h.ia), h.conn),
506                         h.conn,
507                 }
508         default:
509                 err = errors.New("chosen crypto method is not supported")
510         }
511         return
512 }
513
514 func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
515         h.writeCond.L = &h.writeMu
516         h.writerCond.L = &h.writerMu
517         go h.writer()
518         defer func() {
519                 h.finishWriting()
520                 if err == nil {
521                         err = h.writeErr
522                 }
523         }()
524         err = h.establishS()
525         if err != nil {
526                 err = fmt.Errorf("error while establishing secret: %w", err)
527                 return
528         }
529         pad := make([]byte, newPadLen())
530         io.ReadFull(rand.Reader, pad)
531         err = h.postWrite(pad)
532         if err != nil {
533                 return
534         }
535         if h.initer {
536                 ret, method, err = h.initerSteps()
537         } else {
538                 ret, method, err = h.receiverSteps()
539         }
540         return
541 }
542
543 func InitiateHandshake(
544         rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
545 ) (
546         ret io.ReadWriter, method CryptoMethod, err error,
547 ) {
548         h := handshake{
549                 conn:           rw,
550                 initer:         true,
551                 skey:           skey,
552                 ia:             initialPayload,
553                 cryptoProvides: cryptoProvides,
554         }
555         defer perf.ScopeTimerErr(&err)()
556         return h.Do()
557 }
558
559 type HandshakeResult struct {
560         io.ReadWriter
561         CryptoMethod
562         error
563         SecretKey []byte
564 }
565
566 func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
567         res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
568         return res.ReadWriter, res.CryptoMethod, res.error
569 }
570
571 func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
572         h := handshake{
573                 conn:         rw,
574                 initer:       false,
575                 skeys:        skeys,
576                 chooseMethod: selectCrypto,
577         }
578         ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
579         ret.SecretKey = h.skey
580         return
581 }
582
583 // A function that given a function, calls it with secret keys until it
584 // returns false or exhausted.
585 type SecretKeyIter func(callback func(skey []byte) (more bool))
586
587 func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
588         // We prefer plaintext for performance reasons.
589         if provided&CryptoMethodPlaintext != 0 {
590                 return CryptoMethodPlaintext
591         }
592         return CryptoMethodRC4
593 }
594
595 type CryptoSelector func(CryptoMethod) CryptoMethod