]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Merge branch 'dev'
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math"
17         "math/big"
18         "strconv"
19         "sync"
20
21         "github.com/anacrolix/missinggo/perf"
22
23         "github.com/bradfitz/iter"
24 )
25
26 const (
27         maxPadLen = 512
28
29         CryptoMethodPlaintext CryptoMethod = 1
30         CryptoMethodRC4       CryptoMethod = 2
31         AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
32 )
33
34 type CryptoMethod uint32
35
36 var (
37         // Prime P according to the spec, and G, the generator.
38         p, g big.Int
39         // The rand.Int max arg for use in newPadLen()
40         newPadLenMax big.Int
41         // For use in initer's hashes
42         req1 = []byte("req1")
43         req2 = []byte("req2")
44         req3 = []byte("req3")
45         // Verification constant "VC" which is all zeroes in the bittorrent
46         // implementation.
47         vc [8]byte
48         // Zero padding
49         zeroPad [512]byte
50         // Tracks counts of received crypto_provides
51         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
52 )
53
54 func init() {
55         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
56         g.SetInt64(2)
57         newPadLenMax.SetInt64(maxPadLen + 1)
58 }
59
60 func hash(parts ...[]byte) []byte {
61         h := sha1.New()
62         for _, p := range parts {
63                 n, err := h.Write(p)
64                 if err != nil {
65                         panic(err)
66                 }
67                 if n != len(p) {
68                         panic(n)
69                 }
70         }
71         return h.Sum(nil)
72 }
73
74 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
75         c, err := rc4.NewCipher(hash([]byte(func() string {
76                 if initer {
77                         return "keyA"
78                 } else {
79                         return "keyB"
80                 }
81         }()), s, skey))
82         if err != nil {
83                 panic(err)
84         }
85         var burnSrc, burnDst [1024]byte
86         c.XORKeyStream(burnDst[:], burnSrc[:])
87         return
88 }
89
90 type cipherReader struct {
91         c  *rc4.Cipher
92         r  io.Reader
93         mu sync.Mutex
94         be []byte
95 }
96
97 func (cr *cipherReader) Read(b []byte) (n int, err error) {
98         var be []byte
99         cr.mu.Lock()
100         if len(cr.be) >= len(b) {
101                 be = cr.be
102                 cr.be = nil
103                 cr.mu.Unlock()
104         } else {
105                 cr.mu.Unlock()
106                 be = make([]byte, len(b))
107         }
108         n, err = cr.r.Read(be[:len(b)])
109         cr.c.XORKeyStream(b[:n], be[:n])
110         cr.mu.Lock()
111         if len(be) > len(cr.be) {
112                 cr.be = be
113         }
114         cr.mu.Unlock()
115         return
116 }
117
118 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
119         return &cipherReader{c: c, r: r}
120 }
121
122 type cipherWriter struct {
123         c *rc4.Cipher
124         w io.Writer
125         b []byte
126 }
127
128 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
129         be := func() []byte {
130                 if len(cr.b) < len(b) {
131                         return make([]byte, len(b))
132                 } else {
133                         ret := cr.b
134                         cr.b = nil
135                         return ret
136                 }
137         }()
138         cr.c.XORKeyStream(be[:], b)
139         n, err = cr.w.Write(be[:len(b)])
140         if n != len(b) {
141                 // The cipher will have advanced beyond the callers stream position.
142                 // We can't use the cipher anymore.
143                 cr.c = nil
144         }
145         if len(be) > len(cr.b) {
146                 cr.b = be
147         }
148         return
149 }
150
151 func newX() big.Int {
152         var X big.Int
153         X.SetBytes(func() []byte {
154                 var b [20]byte
155                 _, err := rand.Read(b[:])
156                 if err != nil {
157                         panic(err)
158                 }
159                 return b[:]
160         }())
161         return X
162 }
163
164 func paddedLeft(b []byte, _len int) []byte {
165         if len(b) == _len {
166                 return b
167         }
168         ret := make([]byte, _len)
169         if n := copy(ret[_len-len(b):], b); n != len(b) {
170                 panic(n)
171         }
172         return ret
173 }
174
175 // Calculate, and send Y, our public key.
176 func (h *handshake) postY(x *big.Int) error {
177         var y big.Int
178         y.Exp(&g, x, &p)
179         return h.postWrite(paddedLeft(y.Bytes(), 96))
180 }
181
182 func (h *handshake) establishS() error {
183         x := newX()
184         h.postY(&x)
185         var b [96]byte
186         _, err := io.ReadFull(h.conn, b[:])
187         if err != nil {
188                 return fmt.Errorf("error reading Y: %s", err)
189         }
190         var Y, S big.Int
191         Y.SetBytes(b[:])
192         S.Exp(&Y, &x, &p)
193         sBytes := S.Bytes()
194         copy(h.s[96-len(sBytes):96], sBytes)
195         return nil
196 }
197
198 func newPadLen() int64 {
199         i, err := rand.Int(rand.Reader, &newPadLenMax)
200         if err != nil {
201                 panic(err)
202         }
203         ret := i.Int64()
204         if ret < 0 || ret > maxPadLen {
205                 panic(ret)
206         }
207         return ret
208 }
209
210 // Manages state for both initiating and receiving handshakes.
211 type handshake struct {
212         conn   io.ReadWriter
213         s      [96]byte
214         initer bool          // Whether we're initiating or receiving.
215         skeys  SecretKeyIter // Skeys we'll accept if receiving.
216         skey   []byte        // Skey we're initiating with.
217         ia     []byte        // Initial payload. Only used by the initiator.
218         // Return the bit for the crypto method the receiver wants to use.
219         chooseMethod CryptoSelector
220         // Sent to the receiver.
221         cryptoProvides CryptoMethod
222
223         writeMu    sync.Mutex
224         writes     [][]byte
225         writeErr   error
226         writeCond  sync.Cond
227         writeClose bool
228
229         writerMu   sync.Mutex
230         writerCond sync.Cond
231         writerDone bool
232 }
233
234 func (h *handshake) finishWriting() {
235         h.writeMu.Lock()
236         h.writeClose = true
237         h.writeCond.Broadcast()
238         h.writeMu.Unlock()
239
240         h.writerMu.Lock()
241         for !h.writerDone {
242                 h.writerCond.Wait()
243         }
244         h.writerMu.Unlock()
245 }
246
247 func (h *handshake) writer() {
248         defer func() {
249                 h.writerMu.Lock()
250                 h.writerDone = true
251                 h.writerCond.Broadcast()
252                 h.writerMu.Unlock()
253         }()
254         for {
255                 h.writeMu.Lock()
256                 for {
257                         if len(h.writes) != 0 {
258                                 break
259                         }
260                         if h.writeClose {
261                                 h.writeMu.Unlock()
262                                 return
263                         }
264                         h.writeCond.Wait()
265                 }
266                 b := h.writes[0]
267                 h.writes = h.writes[1:]
268                 h.writeMu.Unlock()
269                 _, err := h.conn.Write(b)
270                 if err != nil {
271                         h.writeMu.Lock()
272                         h.writeErr = err
273                         h.writeMu.Unlock()
274                         return
275                 }
276         }
277 }
278
279 func (h *handshake) postWrite(b []byte) error {
280         h.writeMu.Lock()
281         defer h.writeMu.Unlock()
282         if h.writeErr != nil {
283                 return h.writeErr
284         }
285         h.writes = append(h.writes, b)
286         h.writeCond.Signal()
287         return nil
288 }
289
290 func xor(dst, src []byte) (ret []byte) {
291         max := len(dst)
292         if max > len(src) {
293                 max = len(src)
294         }
295         ret = make([]byte, 0, max)
296         for i := range iter.N(max) {
297                 ret = append(ret, dst[i]^src[i])
298         }
299         return
300 }
301
302 func marshal(w io.Writer, data ...interface{}) (err error) {
303         for _, data := range data {
304                 err = binary.Write(w, binary.BigEndian, data)
305                 if err != nil {
306                         break
307                 }
308         }
309         return
310 }
311
312 func unmarshal(r io.Reader, data ...interface{}) (err error) {
313         for _, data := range data {
314                 err = binary.Read(r, binary.BigEndian, data)
315                 if err != nil {
316                         break
317                 }
318         }
319         return
320 }
321
322 // Looking for b at the end of a.
323 func suffixMatchLen(a, b []byte) int {
324         if len(b) > len(a) {
325                 b = b[:len(a)]
326         }
327         // i is how much of b to try to match
328         for i := len(b); i > 0; i-- {
329                 // j is how many chars we've compared
330                 j := 0
331                 for ; j < i; j++ {
332                         if b[i-1-j] != a[len(a)-1-j] {
333                                 goto shorter
334                         }
335                 }
336                 return j
337         shorter:
338         }
339         return 0
340 }
341
342 // Reads from r until b has been seen. Keeps the minimum amount of data in
343 // memory.
344 func readUntil(r io.Reader, b []byte) error {
345         b1 := make([]byte, len(b))
346         i := 0
347         for {
348                 _, err := io.ReadFull(r, b1[i:])
349                 if err != nil {
350                         return err
351                 }
352                 i = suffixMatchLen(b1, b)
353                 if i == len(b) {
354                         break
355                 }
356                 if copy(b1, b1[len(b1)-i:]) != i {
357                         panic("wat")
358                 }
359         }
360         return nil
361 }
362
363 type readWriter struct {
364         io.Reader
365         io.Writer
366 }
367
368 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
369         return newEncrypt(initer, h.s[:], h.skey)
370 }
371
372 func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
373         h.postWrite(hash(req1, h.s[:]))
374         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
375         buf := &bytes.Buffer{}
376         padLen := uint16(newPadLen())
377         if len(h.ia) > math.MaxUint16 {
378                 err = errors.New("initial payload too large")
379                 return
380         }
381         err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
382         if err != nil {
383                 return
384         }
385         e := h.newEncrypt(true)
386         be := make([]byte, buf.Len())
387         e.XORKeyStream(be, buf.Bytes())
388         h.postWrite(be)
389         bC := h.newEncrypt(false)
390         var eVC [8]byte
391         bC.XORKeyStream(eVC[:], vc[:])
392         // Read until the all zero VC. At this point we've only read the 96 byte
393         // public key, Y. There is potentially 512 byte padding, between us and
394         // the 8 byte verification constant.
395         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
396         if err != nil {
397                 if err == io.EOF {
398                         err = errors.New("failed to synchronize on VC")
399                 } else {
400                         err = fmt.Errorf("error reading until VC: %s", err)
401                 }
402                 return
403         }
404         r := newCipherReader(bC, h.conn)
405         var method CryptoMethod
406         err = unmarshal(r, &method, &padLen)
407         if err != nil {
408                 return
409         }
410         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
411         if err != nil {
412                 return
413         }
414         selected = method & h.cryptoProvides
415         switch selected {
416         case CryptoMethodRC4:
417                 ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
418         case CryptoMethodPlaintext:
419                 ret = h.conn
420         default:
421                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
422         }
423         return
424 }
425
426 var ErrNoSecretKeyMatch = errors.New("no skey matched")
427
428 func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
429         // There is up to 512 bytes of padding, then the 20 byte hash.
430         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
431         if err != nil {
432                 if err == io.EOF {
433                         err = errors.New("failed to synchronize on S hash")
434                 }
435                 return
436         }
437         var b [20]byte
438         _, err = io.ReadFull(h.conn, b[:])
439         if err != nil {
440                 return
441         }
442         err = ErrNoSecretKeyMatch
443         h.skeys(func(skey []byte) bool {
444                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
445                         h.skey = skey
446                         err = nil
447                         return false
448                 }
449                 return true
450         })
451         if err != nil {
452                 return
453         }
454         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
455         var (
456                 vc       [8]byte
457                 provides CryptoMethod
458                 padLen   uint16
459         )
460
461         err = unmarshal(r, vc[:], &provides, &padLen)
462         if err != nil {
463                 return
464         }
465         cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
466         chosen = h.chooseMethod(provides)
467         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
468         if err != nil {
469                 return
470         }
471         var lenIA uint16
472         unmarshal(r, &lenIA)
473         if lenIA != 0 {
474                 h.ia = make([]byte, lenIA)
475                 unmarshal(r, h.ia)
476         }
477         buf := &bytes.Buffer{}
478         w := cipherWriter{h.newEncrypt(false), buf, nil}
479         padLen = uint16(newPadLen())
480         err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
481         if err != nil {
482                 return
483         }
484         err = h.postWrite(buf.Bytes())
485         if err != nil {
486                 return
487         }
488         switch chosen {
489         case CryptoMethodRC4:
490                 ret = readWriter{
491                         io.MultiReader(bytes.NewReader(h.ia), r),
492                         &cipherWriter{w.c, h.conn, nil},
493                 }
494         case CryptoMethodPlaintext:
495                 ret = readWriter{
496                         io.MultiReader(bytes.NewReader(h.ia), h.conn),
497                         h.conn,
498                 }
499         default:
500                 err = errors.New("chosen crypto method is not supported")
501         }
502         return
503 }
504
505 func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
506         h.writeCond.L = &h.writeMu
507         h.writerCond.L = &h.writerMu
508         go h.writer()
509         defer func() {
510                 h.finishWriting()
511                 if err == nil {
512                         err = h.writeErr
513                 }
514         }()
515         err = h.establishS()
516         if err != nil {
517                 err = fmt.Errorf("error while establishing secret: %s", err)
518                 return
519         }
520         pad := make([]byte, newPadLen())
521         io.ReadFull(rand.Reader, pad)
522         err = h.postWrite(pad)
523         if err != nil {
524                 return
525         }
526         if h.initer {
527                 ret, method, err = h.initerSteps()
528         } else {
529                 ret, method, err = h.receiverSteps()
530         }
531         return
532 }
533
534 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
535         h := handshake{
536                 conn:           rw,
537                 initer:         true,
538                 skey:           skey,
539                 ia:             initialPayload,
540                 cryptoProvides: cryptoProvides,
541         }
542         defer perf.ScopeTimerErr(&err)()
543         return h.Do()
544 }
545
546 func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
547         h := handshake{
548                 conn:         rw,
549                 initer:       false,
550                 skeys:        skeys,
551                 chooseMethod: selectCrypto,
552         }
553         return h.Do()
554 }
555
556 // A function that given a function, calls it with secret keys until it
557 // returns false or exhausted.
558 type SecretKeyIter func(callback func(skey []byte) (more bool))
559
560 func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
561         if provided&CryptoMethodPlaintext != 0 {
562                 return CryptoMethodPlaintext
563         }
564         return CryptoMethodRC4
565 }
566
567 type CryptoSelector func(CryptoMethod) CryptoMethod