]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Fix UseSources panicking when sqlite storage is closed
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math"
17         "math/big"
18         "strconv"
19         "sync"
20
21         "github.com/anacrolix/missinggo/perf"
22 )
23
24 const (
25         maxPadLen = 512
26
27         CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
28         CryptoMethodRC4       CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
29         AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
30 )
31
32 type CryptoMethod uint32
33
34 var (
35         // Prime P according to the spec, and G, the generator.
36         p, g big.Int
37         // The rand.Int max arg for use in newPadLen()
38         newPadLenMax big.Int
39         // For use in initer's hashes
40         req1 = []byte("req1")
41         req2 = []byte("req2")
42         req3 = []byte("req3")
43         // Verification constant "VC" which is all zeroes in the bittorrent
44         // implementation.
45         vc [8]byte
46         // Zero padding
47         zeroPad [512]byte
48         // Tracks counts of received crypto_provides
49         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
50 )
51
52 func init() {
53         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
54         g.SetInt64(2)
55         newPadLenMax.SetInt64(maxPadLen + 1)
56 }
57
58 func hash(parts ...[]byte) []byte {
59         h := sha1.New()
60         for _, p := range parts {
61                 n, err := h.Write(p)
62                 if err != nil {
63                         panic(err)
64                 }
65                 if n != len(p) {
66                         panic(n)
67                 }
68         }
69         return h.Sum(nil)
70 }
71
72 func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
73         c, err := rc4.NewCipher(hash([]byte(func() string {
74                 if initer {
75                         return "keyA"
76                 } else {
77                         return "keyB"
78                 }
79         }()), s, skey))
80         if err != nil {
81                 panic(err)
82         }
83         var burnSrc, burnDst [1024]byte
84         c.XORKeyStream(burnDst[:], burnSrc[:])
85         return
86 }
87
88 type cipherReader struct {
89         c  *rc4.Cipher
90         r  io.Reader
91         mu sync.Mutex
92         be []byte
93 }
94
95 func (cr *cipherReader) Read(b []byte) (n int, err error) {
96         var be []byte
97         cr.mu.Lock()
98         if len(cr.be) >= len(b) {
99                 be = cr.be
100                 cr.be = nil
101                 cr.mu.Unlock()
102         } else {
103                 cr.mu.Unlock()
104                 be = make([]byte, len(b))
105         }
106         n, err = cr.r.Read(be[:len(b)])
107         cr.c.XORKeyStream(b[:n], be[:n])
108         cr.mu.Lock()
109         if len(be) > len(cr.be) {
110                 cr.be = be
111         }
112         cr.mu.Unlock()
113         return
114 }
115
116 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
117         return &cipherReader{c: c, r: r}
118 }
119
120 type cipherWriter struct {
121         c *rc4.Cipher
122         w io.Writer
123         b []byte
124 }
125
126 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
127         be := func() []byte {
128                 if len(cr.b) < len(b) {
129                         return make([]byte, len(b))
130                 } else {
131                         ret := cr.b
132                         cr.b = nil
133                         return ret
134                 }
135         }()
136         cr.c.XORKeyStream(be, b)
137         n, err = cr.w.Write(be[:len(b)])
138         if n != len(b) {
139                 // The cipher will have advanced beyond the callers stream position.
140                 // We can't use the cipher anymore.
141                 cr.c = nil
142         }
143         if len(be) > len(cr.b) {
144                 cr.b = be
145         }
146         return
147 }
148
149 func newX() big.Int {
150         var X big.Int
151         X.SetBytes(func() []byte {
152                 var b [20]byte
153                 _, err := rand.Read(b[:])
154                 if err != nil {
155                         panic(err)
156                 }
157                 return b[:]
158         }())
159         return X
160 }
161
162 func paddedLeft(b []byte, _len int) []byte {
163         if len(b) == _len {
164                 return b
165         }
166         ret := make([]byte, _len)
167         if n := copy(ret[_len-len(b):], b); n != len(b) {
168                 panic(n)
169         }
170         return ret
171 }
172
173 // Calculate, and send Y, our public key.
174 func (h *handshake) postY(x *big.Int) error {
175         var y big.Int
176         y.Exp(&g, x, &p)
177         return h.postWrite(paddedLeft(y.Bytes(), 96))
178 }
179
180 func (h *handshake) establishS() error {
181         x := newX()
182         h.postY(&x)
183         var b [96]byte
184         _, err := io.ReadFull(h.conn, b[:])
185         if err != nil {
186                 return fmt.Errorf("error reading Y: %w", err)
187         }
188         var Y, S big.Int
189         Y.SetBytes(b[:])
190         S.Exp(&Y, &x, &p)
191         sBytes := S.Bytes()
192         copy(h.s[96-len(sBytes):96], sBytes)
193         return nil
194 }
195
196 func newPadLen() int64 {
197         i, err := rand.Int(rand.Reader, &newPadLenMax)
198         if err != nil {
199                 panic(err)
200         }
201         ret := i.Int64()
202         if ret < 0 || ret > maxPadLen {
203                 panic(ret)
204         }
205         return ret
206 }
207
208 // Manages state for both initiating and receiving handshakes.
209 type handshake struct {
210         conn   io.ReadWriter
211         s      [96]byte
212         initer bool          // Whether we're initiating or receiving.
213         skeys  SecretKeyIter // Skeys we'll accept if receiving.
214         skey   []byte        // Skey we're initiating with.
215         ia     []byte        // Initial payload. Only used by the initiator.
216         // Return the bit for the crypto method the receiver wants to use.
217         chooseMethod CryptoSelector
218         // Sent to the receiver.
219         cryptoProvides CryptoMethod
220
221         writeMu    sync.Mutex
222         writes     [][]byte
223         writeErr   error
224         writeCond  sync.Cond
225         writeClose bool
226
227         writerMu   sync.Mutex
228         writerCond sync.Cond
229         writerDone bool
230 }
231
232 func (h *handshake) finishWriting() {
233         h.writeMu.Lock()
234         h.writeClose = true
235         h.writeCond.Broadcast()
236         h.writeMu.Unlock()
237
238         h.writerMu.Lock()
239         for !h.writerDone {
240                 h.writerCond.Wait()
241         }
242         h.writerMu.Unlock()
243 }
244
245 func (h *handshake) writer() {
246         defer func() {
247                 h.writerMu.Lock()
248                 h.writerDone = true
249                 h.writerCond.Broadcast()
250                 h.writerMu.Unlock()
251         }()
252         for {
253                 h.writeMu.Lock()
254                 for {
255                         if len(h.writes) != 0 {
256                                 break
257                         }
258                         if h.writeClose {
259                                 h.writeMu.Unlock()
260                                 return
261                         }
262                         h.writeCond.Wait()
263                 }
264                 b := h.writes[0]
265                 h.writes = h.writes[1:]
266                 h.writeMu.Unlock()
267                 _, err := h.conn.Write(b)
268                 if err != nil {
269                         h.writeMu.Lock()
270                         h.writeErr = err
271                         h.writeMu.Unlock()
272                         return
273                 }
274         }
275 }
276
277 func (h *handshake) postWrite(b []byte) error {
278         h.writeMu.Lock()
279         defer h.writeMu.Unlock()
280         if h.writeErr != nil {
281                 return h.writeErr
282         }
283         h.writes = append(h.writes, b)
284         h.writeCond.Signal()
285         return nil
286 }
287
288 func xor(a, b []byte) (ret []byte) {
289         max := len(a)
290         if max > len(b) {
291                 max = len(b)
292         }
293         ret = make([]byte, max)
294         xorInPlace(ret, a, b)
295         return
296 }
297
298 func xorInPlace(dst, a, b []byte) {
299         for i := range dst {
300                 dst[i] = a[i] ^ b[i]
301         }
302 }
303
304 func marshal(w io.Writer, data ...interface{}) (err error) {
305         for _, data := range data {
306                 err = binary.Write(w, binary.BigEndian, data)
307                 if err != nil {
308                         break
309                 }
310         }
311         return
312 }
313
314 func unmarshal(r io.Reader, data ...interface{}) (err error) {
315         for _, data := range data {
316                 err = binary.Read(r, binary.BigEndian, data)
317                 if err != nil {
318                         break
319                 }
320         }
321         return
322 }
323
324 // Looking for b at the end of a.
325 func suffixMatchLen(a, b []byte) int {
326         if len(b) > len(a) {
327                 b = b[:len(a)]
328         }
329         // i is how much of b to try to match
330         for i := len(b); i > 0; i-- {
331                 // j is how many chars we've compared
332                 j := 0
333                 for ; j < i; j++ {
334                         if b[i-1-j] != a[len(a)-1-j] {
335                                 goto shorter
336                         }
337                 }
338                 return j
339         shorter:
340         }
341         return 0
342 }
343
344 // Reads from r until b has been seen. Keeps the minimum amount of data in
345 // memory.
346 func readUntil(r io.Reader, b []byte) error {
347         b1 := make([]byte, len(b))
348         i := 0
349         for {
350                 _, err := io.ReadFull(r, b1[i:])
351                 if err != nil {
352                         return err
353                 }
354                 i = suffixMatchLen(b1, b)
355                 if i == len(b) {
356                         break
357                 }
358                 if copy(b1, b1[len(b1)-i:]) != i {
359                         panic("wat")
360                 }
361         }
362         return nil
363 }
364
365 type readWriter struct {
366         io.Reader
367         io.Writer
368 }
369
370 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
371         return newEncrypt(initer, h.s[:], h.skey)
372 }
373
374 func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
375         h.postWrite(hash(req1, h.s[:]))
376         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
377         buf := &bytes.Buffer{}
378         padLen := uint16(newPadLen())
379         if len(h.ia) > math.MaxUint16 {
380                 err = errors.New("initial payload too large")
381                 return
382         }
383         err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
384         if err != nil {
385                 return
386         }
387         e := h.newEncrypt(true)
388         be := make([]byte, buf.Len())
389         e.XORKeyStream(be, buf.Bytes())
390         h.postWrite(be)
391         bC := h.newEncrypt(false)
392         var eVC [8]byte
393         bC.XORKeyStream(eVC[:], vc[:])
394         // Read until the all zero VC. At this point we've only read the 96 byte
395         // public key, Y. There is potentially 512 byte padding, between us and
396         // the 8 byte verification constant.
397         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
398         if err != nil {
399                 if err == io.EOF {
400                         err = errors.New("failed to synchronize on VC")
401                 } else {
402                         err = fmt.Errorf("error reading until VC: %s", err)
403                 }
404                 return
405         }
406         r := newCipherReader(bC, h.conn)
407         var method CryptoMethod
408         err = unmarshal(r, &method, &padLen)
409         if err != nil {
410                 return
411         }
412         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
413         if err != nil {
414                 return
415         }
416         selected = method & h.cryptoProvides
417         switch selected {
418         case CryptoMethodRC4:
419                 ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
420         case CryptoMethodPlaintext:
421                 ret = h.conn
422         default:
423                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
424         }
425         return
426 }
427
428 var ErrNoSecretKeyMatch = errors.New("no skey matched")
429
430 func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
431         // There is up to 512 bytes of padding, then the 20 byte hash.
432         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
433         if err != nil {
434                 if err == io.EOF {
435                         err = errors.New("failed to synchronize on S hash")
436                 }
437                 return
438         }
439         var b [20]byte
440         _, err = io.ReadFull(h.conn, b[:])
441         if err != nil {
442                 return
443         }
444         expectedHash := hash(req3, h.s[:])
445         eachHash := sha1.New()
446         var sum, xored [sha1.Size]byte
447         err = ErrNoSecretKeyMatch
448         h.skeys(func(skey []byte) bool {
449                 eachHash.Reset()
450                 eachHash.Write(req2)
451                 eachHash.Write(skey)
452                 eachHash.Sum(sum[:0])
453                 xorInPlace(xored[:], sum[:], expectedHash)
454                 if bytes.Equal(xored[:], b[:]) {
455                         h.skey = skey
456                         err = nil
457                         return false
458                 }
459                 return true
460         })
461         if err != nil {
462                 return
463         }
464         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
465         var (
466                 vc       [8]byte
467                 provides CryptoMethod
468                 padLen   uint16
469         )
470
471         err = unmarshal(r, vc[:], &provides, &padLen)
472         if err != nil {
473                 return
474         }
475         cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
476         chosen = h.chooseMethod(provides)
477         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
478         if err != nil {
479                 return
480         }
481         var lenIA uint16
482         unmarshal(r, &lenIA)
483         if lenIA != 0 {
484                 h.ia = make([]byte, lenIA)
485                 unmarshal(r, h.ia)
486         }
487         buf := &bytes.Buffer{}
488         w := cipherWriter{h.newEncrypt(false), buf, nil}
489         padLen = uint16(newPadLen())
490         err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
491         if err != nil {
492                 return
493         }
494         err = h.postWrite(buf.Bytes())
495         if err != nil {
496                 return
497         }
498         switch chosen {
499         case CryptoMethodRC4:
500                 ret = readWriter{
501                         io.MultiReader(bytes.NewReader(h.ia), r),
502                         &cipherWriter{w.c, h.conn, nil},
503                 }
504         case CryptoMethodPlaintext:
505                 ret = readWriter{
506                         io.MultiReader(bytes.NewReader(h.ia), h.conn),
507                         h.conn,
508                 }
509         default:
510                 err = errors.New("chosen crypto method is not supported")
511         }
512         return
513 }
514
515 func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
516         h.writeCond.L = &h.writeMu
517         h.writerCond.L = &h.writerMu
518         go h.writer()
519         defer func() {
520                 h.finishWriting()
521                 if err == nil {
522                         err = h.writeErr
523                 }
524         }()
525         err = h.establishS()
526         if err != nil {
527                 err = fmt.Errorf("error while establishing secret: %w", err)
528                 return
529         }
530         pad := make([]byte, newPadLen())
531         io.ReadFull(rand.Reader, pad)
532         err = h.postWrite(pad)
533         if err != nil {
534                 return
535         }
536         if h.initer {
537                 ret, method, err = h.initerSteps()
538         } else {
539                 ret, method, err = h.receiverSteps()
540         }
541         return
542 }
543
544 func InitiateHandshake(
545         rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
546 ) (
547         ret io.ReadWriter, method CryptoMethod, err error,
548 ) {
549         h := handshake{
550                 conn:           rw,
551                 initer:         true,
552                 skey:           skey,
553                 ia:             initialPayload,
554                 cryptoProvides: cryptoProvides,
555         }
556         defer perf.ScopeTimerErr(&err)()
557         return h.Do()
558 }
559
560 type HandshakeResult struct {
561         io.ReadWriter
562         CryptoMethod
563         error
564         SecretKey []byte
565 }
566
567 func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
568         res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
569         return res.ReadWriter, res.CryptoMethod, res.error
570 }
571
572 func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
573         h := handshake{
574                 conn:         rw,
575                 initer:       false,
576                 skeys:        skeys,
577                 chooseMethod: selectCrypto,
578         }
579         ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
580         ret.SecretKey = h.skey
581         return
582 }
583
584 // A function that given a function, calls it with secret keys until it
585 // returns false or exhausted.
586 type SecretKeyIter func(callback func(skey []byte) (more bool))
587
588 func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
589         // We prefer plaintext for performance reasons.
590         if provided&CryptoMethodPlaintext != 0 {
591                 return CryptoMethodPlaintext
592         }
593         return CryptoMethodRC4
594 }
595
596 type CryptoSelector func(CryptoMethod) CryptoMethod