]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
More megacheck
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math"
17         "math/big"
18         "strconv"
19         "sync"
20
21         "github.com/bradfitz/iter"
22 )
23
24 const (
25         maxPadLen = 512
26
27         CryptoMethodPlaintext = 1
28         CryptoMethodRC4       = 2
29         AllSupportedCrypto    = CryptoMethodPlaintext | CryptoMethodRC4
30 )
31
32 var (
33         // Prime P according to the spec, and G, the generator.
34         p, g big.Int
35         // The rand.Int max arg for use in newPadLen()
36         newPadLenMax big.Int
37         // For use in initer's hashes
38         req1 = []byte("req1")
39         req2 = []byte("req2")
40         req3 = []byte("req3")
41         // Verification constant "VC" which is all zeroes in the bittorrent
42         // implementation.
43         vc [8]byte
44         // Zero padding
45         zeroPad [512]byte
46         // Tracks counts of received crypto_provides
47         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
48 )
49
50 func init() {
51         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
52         g.SetInt64(2)
53         newPadLenMax.SetInt64(maxPadLen + 1)
54 }
55
56 func hash(parts ...[]byte) []byte {
57         h := sha1.New()
58         for _, p := range parts {
59                 n, err := h.Write(p)
60                 if err != nil {
61                         panic(err)
62                 }
63                 if n != len(p) {
64                         panic(n)
65                 }
66         }
67         return h.Sum(nil)
68 }
69
70 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
71         c, err := rc4.NewCipher(hash([]byte(func() string {
72                 if initer {
73                         return "keyA"
74                 } else {
75                         return "keyB"
76                 }
77         }()), s, skey))
78         if err != nil {
79                 panic(err)
80         }
81         var burnSrc, burnDst [1024]byte
82         c.XORKeyStream(burnDst[:], burnSrc[:])
83         return
84 }
85
86 type cipherReader struct {
87         c  *rc4.Cipher
88         r  io.Reader
89         mu sync.Mutex
90         be []byte
91 }
92
93 func (cr *cipherReader) Read(b []byte) (n int, err error) {
94         var be []byte
95         cr.mu.Lock()
96         if len(cr.be) >= len(b) {
97                 be = cr.be
98                 cr.be = nil
99                 cr.mu.Unlock()
100         } else {
101                 cr.mu.Unlock()
102                 be = make([]byte, len(b))
103         }
104         n, err = cr.r.Read(be[:len(b)])
105         cr.c.XORKeyStream(b[:n], be[:n])
106         cr.mu.Lock()
107         if len(be) > len(cr.be) {
108                 cr.be = be
109         }
110         cr.mu.Unlock()
111         return
112 }
113
114 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
115         return &cipherReader{c: c, r: r}
116 }
117
118 type cipherWriter struct {
119         c *rc4.Cipher
120         w io.Writer
121         b []byte
122 }
123
124 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
125         be := func() []byte {
126                 if len(cr.b) < len(b) {
127                         return make([]byte, len(b))
128                 } else {
129                         ret := cr.b
130                         cr.b = nil
131                         return ret
132                 }
133         }()
134         cr.c.XORKeyStream(be[:], b)
135         n, err = cr.w.Write(be[:len(b)])
136         if n != len(b) {
137                 // The cipher will have advanced beyond the callers stream position.
138                 // We can't use the cipher anymore.
139                 cr.c = nil
140         }
141         if len(be) > len(cr.b) {
142                 cr.b = be
143         }
144         return
145 }
146
147 func newX() big.Int {
148         var X big.Int
149         X.SetBytes(func() []byte {
150                 var b [20]byte
151                 _, err := rand.Read(b[:])
152                 if err != nil {
153                         panic(err)
154                 }
155                 return b[:]
156         }())
157         return X
158 }
159
160 func paddedLeft(b []byte, _len int) []byte {
161         if len(b) == _len {
162                 return b
163         }
164         ret := make([]byte, _len)
165         if n := copy(ret[_len-len(b):], b); n != len(b) {
166                 panic(n)
167         }
168         return ret
169 }
170
171 // Calculate, and send Y, our public key.
172 func (h *handshake) postY(x *big.Int) error {
173         var y big.Int
174         y.Exp(&g, x, &p)
175         return h.postWrite(paddedLeft(y.Bytes(), 96))
176 }
177
178 func (h *handshake) establishS() (err error) {
179         x := newX()
180         h.postY(&x)
181         var b [96]byte
182         _, err = io.ReadFull(h.conn, b[:])
183         if err != nil {
184                 return
185         }
186         var Y, S big.Int
187         Y.SetBytes(b[:])
188         S.Exp(&Y, &x, &p)
189         sBytes := S.Bytes()
190         copy(h.s[96-len(sBytes):96], sBytes)
191         return
192 }
193
194 func newPadLen() int64 {
195         i, err := rand.Int(rand.Reader, &newPadLenMax)
196         if err != nil {
197                 panic(err)
198         }
199         ret := i.Int64()
200         if ret < 0 || ret > maxPadLen {
201                 panic(ret)
202         }
203         return ret
204 }
205
206 // Manages state for both initiating and receiving handshakes.
207 type handshake struct {
208         conn   io.ReadWriter
209         s      [96]byte
210         initer bool          // Whether we're initiating or receiving.
211         skeys  SecretKeyIter // Skeys we'll accept if receiving.
212         skey   []byte        // Skey we're initiating with.
213         ia     []byte        // Initial payload. Only used by the initiator.
214         // Return the bit for the crypto method the receiver wants to use.
215         chooseMethod func(supported uint32) uint32
216         // Sent to the receiver.
217         cryptoProvides uint32
218
219         writeMu    sync.Mutex
220         writes     [][]byte
221         writeErr   error
222         writeCond  sync.Cond
223         writeClose bool
224
225         writerMu   sync.Mutex
226         writerCond sync.Cond
227         writerDone bool
228 }
229
230 func (h *handshake) finishWriting() {
231         h.writeMu.Lock()
232         h.writeClose = true
233         h.writeCond.Broadcast()
234         h.writeMu.Unlock()
235
236         h.writerMu.Lock()
237         for !h.writerDone {
238                 h.writerCond.Wait()
239         }
240         h.writerMu.Unlock()
241 }
242
243 func (h *handshake) writer() {
244         defer func() {
245                 h.writerMu.Lock()
246                 h.writerDone = true
247                 h.writerCond.Broadcast()
248                 h.writerMu.Unlock()
249         }()
250         for {
251                 h.writeMu.Lock()
252                 for {
253                         if len(h.writes) != 0 {
254                                 break
255                         }
256                         if h.writeClose {
257                                 h.writeMu.Unlock()
258                                 return
259                         }
260                         h.writeCond.Wait()
261                 }
262                 b := h.writes[0]
263                 h.writes = h.writes[1:]
264                 h.writeMu.Unlock()
265                 _, err := h.conn.Write(b)
266                 if err != nil {
267                         h.writeMu.Lock()
268                         h.writeErr = err
269                         h.writeMu.Unlock()
270                         return
271                 }
272         }
273 }
274
275 func (h *handshake) postWrite(b []byte) error {
276         h.writeMu.Lock()
277         defer h.writeMu.Unlock()
278         if h.writeErr != nil {
279                 return h.writeErr
280         }
281         h.writes = append(h.writes, b)
282         h.writeCond.Signal()
283         return nil
284 }
285
286 func xor(dst, src []byte) (ret []byte) {
287         max := len(dst)
288         if max > len(src) {
289                 max = len(src)
290         }
291         ret = make([]byte, 0, max)
292         for i := range iter.N(max) {
293                 ret = append(ret, dst[i]^src[i])
294         }
295         return
296 }
297
298 func marshal(w io.Writer, data ...interface{}) (err error) {
299         for _, data := range data {
300                 err = binary.Write(w, binary.BigEndian, data)
301                 if err != nil {
302                         break
303                 }
304         }
305         return
306 }
307
308 func unmarshal(r io.Reader, data ...interface{}) (err error) {
309         for _, data := range data {
310                 err = binary.Read(r, binary.BigEndian, data)
311                 if err != nil {
312                         break
313                 }
314         }
315         return
316 }
317
318 // Looking for b at the end of a.
319 func suffixMatchLen(a, b []byte) int {
320         if len(b) > len(a) {
321                 b = b[:len(a)]
322         }
323         // i is how much of b to try to match
324         for i := len(b); i > 0; i-- {
325                 // j is how many chars we've compared
326                 j := 0
327                 for ; j < i; j++ {
328                         if b[i-1-j] != a[len(a)-1-j] {
329                                 goto shorter
330                         }
331                 }
332                 return j
333         shorter:
334         }
335         return 0
336 }
337
338 // Reads from r until b has been seen. Keeps the minimum amount of data in
339 // memory.
340 func readUntil(r io.Reader, b []byte) error {
341         b1 := make([]byte, len(b))
342         i := 0
343         for {
344                 _, err := io.ReadFull(r, b1[i:])
345                 if err != nil {
346                         return err
347                 }
348                 i = suffixMatchLen(b1, b)
349                 if i == len(b) {
350                         break
351                 }
352                 if copy(b1, b1[len(b1)-i:]) != i {
353                         panic("wat")
354                 }
355         }
356         return nil
357 }
358
359 type readWriter struct {
360         io.Reader
361         io.Writer
362 }
363
364 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
365         return newEncrypt(initer, h.s[:], h.skey)
366 }
367
368 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
369         h.postWrite(hash(req1, h.s[:]))
370         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
371         buf := &bytes.Buffer{}
372         padLen := uint16(newPadLen())
373         if len(h.ia) > math.MaxUint16 {
374                 err = errors.New("initial payload too large")
375                 return
376         }
377         err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
378         if err != nil {
379                 return
380         }
381         e := h.newEncrypt(true)
382         be := make([]byte, buf.Len())
383         e.XORKeyStream(be, buf.Bytes())
384         h.postWrite(be)
385         bC := h.newEncrypt(false)
386         var eVC [8]byte
387         bC.XORKeyStream(eVC[:], vc[:])
388         // Read until the all zero VC. At this point we've only read the 96 byte
389         // public key, Y. There is potentially 512 byte padding, between us and
390         // the 8 byte verification constant.
391         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
392         if err != nil {
393                 if err == io.EOF {
394                         err = errors.New("failed to synchronize on VC")
395                 } else {
396                         err = fmt.Errorf("error reading until VC: %s", err)
397                 }
398                 return
399         }
400         r := newCipherReader(bC, h.conn)
401         var method uint32
402         err = unmarshal(r, &method, &padLen)
403         if err != nil {
404                 return
405         }
406         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
407         if err != nil {
408                 return
409         }
410         switch method & h.cryptoProvides {
411         case CryptoMethodRC4:
412                 ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
413         case CryptoMethodPlaintext:
414                 ret = h.conn
415         default:
416                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
417         }
418         return
419 }
420
421 var ErrNoSecretKeyMatch = errors.New("no skey matched")
422
423 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
424         // There is up to 512 bytes of padding, then the 20 byte hash.
425         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
426         if err != nil {
427                 if err == io.EOF {
428                         err = errors.New("failed to synchronize on S hash")
429                 }
430                 return
431         }
432         var b [20]byte
433         _, err = io.ReadFull(h.conn, b[:])
434         if err != nil {
435                 return
436         }
437         err = ErrNoSecretKeyMatch
438         h.skeys(func(skey []byte) bool {
439                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
440                         h.skey = skey
441                         err = nil
442                         return false
443                 }
444                 return true
445         })
446         if err != nil {
447                 return
448         }
449         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
450         var (
451                 vc       [8]byte
452                 provides uint32
453                 padLen   uint16
454         )
455
456         err = unmarshal(r, vc[:], &provides, &padLen)
457         if err != nil {
458                 return
459         }
460         cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
461         chosen := h.chooseMethod(provides)
462         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
463         if err != nil {
464                 return
465         }
466         var lenIA uint16
467         unmarshal(r, &lenIA)
468         if lenIA != 0 {
469                 h.ia = make([]byte, lenIA)
470                 unmarshal(r, h.ia)
471         }
472         buf := &bytes.Buffer{}
473         w := cipherWriter{h.newEncrypt(false), buf, nil}
474         padLen = uint16(newPadLen())
475         err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
476         if err != nil {
477                 return
478         }
479         err = h.postWrite(buf.Bytes())
480         if err != nil {
481                 return
482         }
483         switch chosen {
484         case CryptoMethodRC4:
485                 ret = readWriter{
486                         io.MultiReader(bytes.NewReader(h.ia), r),
487                         &cipherWriter{w.c, h.conn, nil},
488                 }
489         case CryptoMethodPlaintext:
490                 ret = readWriter{
491                         io.MultiReader(bytes.NewReader(h.ia), h.conn),
492                         h.conn,
493                 }
494         default:
495                 err = errors.New("chosen crypto method is not supported")
496         }
497         return
498 }
499
500 func (h *handshake) Do() (ret io.ReadWriter, err error) {
501         h.writeCond.L = &h.writeMu
502         h.writerCond.L = &h.writerMu
503         go h.writer()
504         defer func() {
505                 h.finishWriting()
506                 if err == nil {
507                         err = h.writeErr
508                 }
509         }()
510         err = h.establishS()
511         if err != nil {
512                 err = fmt.Errorf("error while establishing secret: %s", err)
513                 return
514         }
515         pad := make([]byte, newPadLen())
516         io.ReadFull(rand.Reader, pad)
517         err = h.postWrite(pad)
518         if err != nil {
519                 return
520         }
521         if h.initer {
522                 ret, err = h.initerSteps()
523         } else {
524                 ret, err = h.receiverSteps()
525         }
526         return
527 }
528
529 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
530         h := handshake{
531                 conn:           rw,
532                 initer:         true,
533                 skey:           skey,
534                 ia:             initialPayload,
535                 cryptoProvides: cryptoProvides,
536         }
537         return h.Do()
538 }
539
540 func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
541         h := handshake{
542                 conn:         rw,
543                 initer:       false,
544                 skeys:        skeys,
545                 chooseMethod: selectCrypto,
546         }
547         return h.Do()
548 }
549
550 // A function that given a function, calls it with secret keys until it
551 // returns false or exhausted.
552 type SecretKeyIter func(callback func(skey []byte) (more bool))
553
554 func DefaultCryptoSelector(provided uint32) uint32 {
555         if provided&CryptoMethodPlaintext != 0 {
556                 return CryptoMethodPlaintext
557         }
558         return CryptoMethodRC4
559 }
560
561 type CryptoSelector func(uint32) uint32