]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
mse: Move sliceIter into test file
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math"
17         "math/big"
18         "strconv"
19         "sync"
20
21         "github.com/bradfitz/iter"
22 )
23
24 const (
25         maxPadLen = 512
26
27         CryptoMethodPlaintext = 1
28         CryptoMethodRC4       = 2
29         AllSupportedCrypto    = CryptoMethodPlaintext | CryptoMethodRC4
30 )
31
32 var (
33         // Prime P according to the spec, and G, the generator.
34         p, g big.Int
35         // The rand.Int max arg for use in newPadLen()
36         newPadLenMax big.Int
37         // For use in initer's hashes
38         req1 = []byte("req1")
39         req2 = []byte("req2")
40         req3 = []byte("req3")
41         // Verification constant "VC" which is all zeroes in the bittorrent
42         // implementation.
43         vc [8]byte
44         // Zero padding
45         zeroPad [512]byte
46         // Tracks counts of received crypto_provides
47         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
48 )
49
50 func init() {
51         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
52         g.SetInt64(2)
53         newPadLenMax.SetInt64(maxPadLen + 1)
54 }
55
56 func hash(parts ...[]byte) []byte {
57         h := sha1.New()
58         for _, p := range parts {
59                 n, err := h.Write(p)
60                 if err != nil {
61                         panic(err)
62                 }
63                 if n != len(p) {
64                         panic(n)
65                 }
66         }
67         return h.Sum(nil)
68 }
69
70 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
71         c, err := rc4.NewCipher(hash([]byte(func() string {
72                 if initer {
73                         return "keyA"
74                 } else {
75                         return "keyB"
76                 }
77         }()), s, skey))
78         if err != nil {
79                 panic(err)
80         }
81         var burnSrc, burnDst [1024]byte
82         c.XORKeyStream(burnDst[:], burnSrc[:])
83         return
84 }
85
86 type cipherReader struct {
87         c  *rc4.Cipher
88         r  io.Reader
89         mu sync.Mutex
90         be []byte
91 }
92
93 func (cr *cipherReader) Read(b []byte) (n int, err error) {
94         var be []byte
95         cr.mu.Lock()
96         if len(cr.be) >= len(b) {
97                 be = cr.be
98                 cr.be = nil
99                 cr.mu.Unlock()
100         } else {
101                 cr.mu.Unlock()
102                 be = make([]byte, len(b))
103         }
104         n, err = cr.r.Read(be[:len(b)])
105         cr.c.XORKeyStream(b[:n], be[:n])
106         cr.mu.Lock()
107         if len(be) > len(cr.be) {
108                 cr.be = be
109         }
110         cr.mu.Unlock()
111         return
112 }
113
114 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
115         return &cipherReader{c: c, r: r}
116 }
117
118 type cipherWriter struct {
119         c *rc4.Cipher
120         w io.Writer
121         b []byte
122 }
123
124 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
125         be := func() []byte {
126                 if len(cr.b) < len(b) {
127                         return make([]byte, len(b))
128                 } else {
129                         ret := cr.b
130                         cr.b = nil
131                         return ret
132                 }
133         }()
134         cr.c.XORKeyStream(be[:], b)
135         n, err = cr.w.Write(be[:len(b)])
136         if n != len(b) {
137                 // The cipher will have advanced beyond the callers stream position.
138                 // We can't use the cipher anymore.
139                 cr.c = nil
140         }
141         if len(be) > len(cr.b) {
142                 cr.b = be
143         }
144         return
145 }
146
147 func newX() big.Int {
148         var X big.Int
149         X.SetBytes(func() []byte {
150                 var b [20]byte
151                 _, err := rand.Read(b[:])
152                 if err != nil {
153                         panic(err)
154                 }
155                 return b[:]
156         }())
157         return X
158 }
159
160 func paddedLeft(b []byte, _len int) []byte {
161         if len(b) == _len {
162                 return b
163         }
164         ret := make([]byte, _len)
165         if n := copy(ret[_len-len(b):], b); n != len(b) {
166                 panic(n)
167         }
168         return ret
169 }
170
171 // Calculate, and send Y, our public key.
172 func (h *handshake) postY(x *big.Int) error {
173         var y big.Int
174         y.Exp(&g, x, &p)
175         return h.postWrite(paddedLeft(y.Bytes(), 96))
176 }
177
178 func (h *handshake) establishS() (err error) {
179         x := newX()
180         h.postY(&x)
181         var b [96]byte
182         _, err = io.ReadFull(h.conn, b[:])
183         if err != nil {
184                 return
185         }
186         var Y, S big.Int
187         Y.SetBytes(b[:])
188         S.Exp(&Y, &x, &p)
189         sBytes := S.Bytes()
190         copy(h.s[96-len(sBytes):96], sBytes)
191         return
192 }
193
194 func newPadLen() int64 {
195         i, err := rand.Int(rand.Reader, &newPadLenMax)
196         if err != nil {
197                 panic(err)
198         }
199         ret := i.Int64()
200         if ret < 0 || ret > maxPadLen {
201                 panic(ret)
202         }
203         return ret
204 }
205
206 // Manages state for both initiating and receiving handshakes.
207 type handshake struct {
208         conn   io.ReadWriter
209         s      [96]byte
210         initer bool          // Whether we're initiating or receiving.
211         skeys  SecretKeyIter // Skeys we'll accept if receiving.
212         skey   []byte        // Skey we're initiating with.
213         ia     []byte        // Initial payload. Only used by the initiator.
214         // Return the bit for the crypto method the receiver wants to use.
215         chooseMethod func(supported uint32) uint32
216         // Sent to the receiver.
217         cryptoProvides uint32
218
219         writeMu    sync.Mutex
220         writes     [][]byte
221         writeErr   error
222         writeCond  sync.Cond
223         writeClose bool
224
225         writerMu   sync.Mutex
226         writerCond sync.Cond
227         writerDone bool
228 }
229
230 func (h *handshake) finishWriting() {
231         h.writeMu.Lock()
232         h.writeClose = true
233         h.writeCond.Broadcast()
234         h.writeMu.Unlock()
235
236         h.writerMu.Lock()
237         for !h.writerDone {
238                 h.writerCond.Wait()
239         }
240         h.writerMu.Unlock()
241         return
242 }
243
244 func (h *handshake) writer() {
245         defer func() {
246                 h.writerMu.Lock()
247                 h.writerDone = true
248                 h.writerCond.Broadcast()
249                 h.writerMu.Unlock()
250         }()
251         for {
252                 h.writeMu.Lock()
253                 for {
254                         if len(h.writes) != 0 {
255                                 break
256                         }
257                         if h.writeClose {
258                                 h.writeMu.Unlock()
259                                 return
260                         }
261                         h.writeCond.Wait()
262                 }
263                 b := h.writes[0]
264                 h.writes = h.writes[1:]
265                 h.writeMu.Unlock()
266                 _, err := h.conn.Write(b)
267                 if err != nil {
268                         h.writeMu.Lock()
269                         h.writeErr = err
270                         h.writeMu.Unlock()
271                         return
272                 }
273         }
274 }
275
276 func (h *handshake) postWrite(b []byte) error {
277         h.writeMu.Lock()
278         defer h.writeMu.Unlock()
279         if h.writeErr != nil {
280                 return h.writeErr
281         }
282         h.writes = append(h.writes, b)
283         h.writeCond.Signal()
284         return nil
285 }
286
287 func xor(dst, src []byte) (ret []byte) {
288         max := len(dst)
289         if max > len(src) {
290                 max = len(src)
291         }
292         ret = make([]byte, 0, max)
293         for i := range iter.N(max) {
294                 ret = append(ret, dst[i]^src[i])
295         }
296         return
297 }
298
299 func marshal(w io.Writer, data ...interface{}) (err error) {
300         for _, data := range data {
301                 err = binary.Write(w, binary.BigEndian, data)
302                 if err != nil {
303                         break
304                 }
305         }
306         return
307 }
308
309 func unmarshal(r io.Reader, data ...interface{}) (err error) {
310         for _, data := range data {
311                 err = binary.Read(r, binary.BigEndian, data)
312                 if err != nil {
313                         break
314                 }
315         }
316         return
317 }
318
319 // Looking for b at the end of a.
320 func suffixMatchLen(a, b []byte) int {
321         if len(b) > len(a) {
322                 b = b[:len(a)]
323         }
324         // i is how much of b to try to match
325         for i := len(b); i > 0; i-- {
326                 // j is how many chars we've compared
327                 j := 0
328                 for ; j < i; j++ {
329                         if b[i-1-j] != a[len(a)-1-j] {
330                                 goto shorter
331                         }
332                 }
333                 return j
334         shorter:
335         }
336         return 0
337 }
338
339 // Reads from r until b has been seen. Keeps the minimum amount of data in
340 // memory.
341 func readUntil(r io.Reader, b []byte) error {
342         b1 := make([]byte, len(b))
343         i := 0
344         for {
345                 _, err := io.ReadFull(r, b1[i:])
346                 if err != nil {
347                         return err
348                 }
349                 i = suffixMatchLen(b1, b)
350                 if i == len(b) {
351                         break
352                 }
353                 if copy(b1, b1[len(b1)-i:]) != i {
354                         panic("wat")
355                 }
356         }
357         return nil
358 }
359
360 type readWriter struct {
361         io.Reader
362         io.Writer
363 }
364
365 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
366         return newEncrypt(initer, h.s[:], h.skey)
367 }
368
369 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
370         h.postWrite(hash(req1, h.s[:]))
371         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
372         buf := &bytes.Buffer{}
373         padLen := uint16(newPadLen())
374         if len(h.ia) > math.MaxUint16 {
375                 err = errors.New("initial payload too large")
376                 return
377         }
378         err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
379         if err != nil {
380                 return
381         }
382         e := h.newEncrypt(true)
383         be := make([]byte, buf.Len())
384         e.XORKeyStream(be, buf.Bytes())
385         h.postWrite(be)
386         bC := h.newEncrypt(false)
387         var eVC [8]byte
388         bC.XORKeyStream(eVC[:], vc[:])
389         // Read until the all zero VC. At this point we've only read the 96 byte
390         // public key, Y. There is potentially 512 byte padding, between us and
391         // the 8 byte verification constant.
392         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
393         if err != nil {
394                 if err == io.EOF {
395                         err = errors.New("failed to synchronize on VC")
396                 } else {
397                         err = fmt.Errorf("error reading until VC: %s", err)
398                 }
399                 return
400         }
401         r := newCipherReader(bC, h.conn)
402         var method uint32
403         err = unmarshal(r, &method, &padLen)
404         if err != nil {
405                 return
406         }
407         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
408         if err != nil {
409                 return
410         }
411         switch method & h.cryptoProvides {
412         case CryptoMethodRC4:
413                 ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
414         case CryptoMethodPlaintext:
415                 ret = h.conn
416         default:
417                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
418         }
419         return
420 }
421
422 var ErrNoSecretKeyMatch = errors.New("no skey matched")
423
424 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
425         // There is up to 512 bytes of padding, then the 20 byte hash.
426         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
427         if err != nil {
428                 if err == io.EOF {
429                         err = errors.New("failed to synchronize on S hash")
430                 }
431                 return
432         }
433         var b [20]byte
434         _, err = io.ReadFull(h.conn, b[:])
435         if err != nil {
436                 return
437         }
438         err = ErrNoSecretKeyMatch
439         h.skeys(func(skey []byte) bool {
440                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
441                         h.skey = skey
442                         err = nil
443                         return false
444                 }
445                 return true
446         })
447         if err != nil {
448                 return
449         }
450         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
451         var (
452                 vc       [8]byte
453                 provides uint32
454                 padLen   uint16
455         )
456
457         err = unmarshal(r, vc[:], &provides, &padLen)
458         if err != nil {
459                 return
460         }
461         cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
462         chosen := h.chooseMethod(provides)
463         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
464         if err != nil {
465                 return
466         }
467         var lenIA uint16
468         unmarshal(r, &lenIA)
469         if lenIA != 0 {
470                 h.ia = make([]byte, lenIA)
471                 unmarshal(r, h.ia)
472         }
473         buf := &bytes.Buffer{}
474         w := cipherWriter{h.newEncrypt(false), buf, nil}
475         padLen = uint16(newPadLen())
476         err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
477         if err != nil {
478                 return
479         }
480         err = h.postWrite(buf.Bytes())
481         if err != nil {
482                 return
483         }
484         switch chosen {
485         case CryptoMethodRC4:
486                 ret = readWriter{
487                         io.MultiReader(bytes.NewReader(h.ia), r),
488                         &cipherWriter{w.c, h.conn, nil},
489                 }
490         case CryptoMethodPlaintext:
491                 ret = readWriter{
492                         io.MultiReader(bytes.NewReader(h.ia), h.conn),
493                         h.conn,
494                 }
495         default:
496                 err = errors.New("chosen crypto method is not supported")
497         }
498         return
499 }
500
501 func (h *handshake) Do() (ret io.ReadWriter, err error) {
502         h.writeCond.L = &h.writeMu
503         h.writerCond.L = &h.writerMu
504         go h.writer()
505         defer func() {
506                 h.finishWriting()
507                 if err == nil {
508                         err = h.writeErr
509                 }
510         }()
511         err = h.establishS()
512         if err != nil {
513                 err = fmt.Errorf("error while establishing secret: %s", err)
514                 return
515         }
516         pad := make([]byte, newPadLen())
517         io.ReadFull(rand.Reader, pad)
518         err = h.postWrite(pad)
519         if err != nil {
520                 return
521         }
522         if h.initer {
523                 ret, err = h.initerSteps()
524         } else {
525                 ret, err = h.receiverSteps()
526         }
527         return
528 }
529
530 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
531         h := handshake{
532                 conn:           rw,
533                 initer:         true,
534                 skey:           skey,
535                 ia:             initialPayload,
536                 cryptoProvides: cryptoProvides,
537         }
538         return h.Do()
539 }
540
541 func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
542         h := handshake{
543                 conn:         rw,
544                 initer:       false,
545                 skeys:        skeys,
546                 chooseMethod: selectCrypto,
547         }
548         return h.Do()
549 }
550
551 // A function that given a function, calls it with secret keys until it
552 // returns false or exhausted.
553 type SecretKeyIter func(callback func(skey []byte) (more bool))
554
555 func DefaultCryptoSelector(provided uint32) uint32 {
556         if provided&CryptoMethodPlaintext != 0 {
557                 return CryptoMethodPlaintext
558         }
559         return CryptoMethodRC4
560 }
561
562 type CryptoSelector func(uint32) uint32