]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Merge branch 'fs-handle-reader'
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "github.com/bradfitz/iter"
21 )
22
23 const (
24         maxPadLen = 512
25
26         cryptoMethodPlaintext = 1
27         cryptoMethodRC4       = 2
28 )
29
30 var (
31         // Prime P according to the spec, and G, the generator.
32         p, g big.Int
33         // The rand.Int max arg for use in newPadLen()
34         newPadLenMax big.Int
35         // For use in initer's hashes
36         req1 = []byte("req1")
37         req2 = []byte("req2")
38         req3 = []byte("req3")
39         // Verification constant "VC" which is all zeroes in the bittorrent
40         // implementation.
41         vc [8]byte
42         // Zero padding
43         zeroPad [512]byte
44         // Tracks counts of received crypto_provides
45         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
46 )
47
48 func init() {
49         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
50         g.SetInt64(2)
51         newPadLenMax.SetInt64(maxPadLen + 1)
52 }
53
54 func hash(parts ...[]byte) []byte {
55         h := sha1.New()
56         for _, p := range parts {
57                 n, err := h.Write(p)
58                 if err != nil {
59                         panic(err)
60                 }
61                 if n != len(p) {
62                         panic(n)
63                 }
64         }
65         return h.Sum(nil)
66 }
67
68 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
69         c, err := rc4.NewCipher(hash([]byte(func() string {
70                 if initer {
71                         return "keyA"
72                 } else {
73                         return "keyB"
74                 }
75         }()), s, skey))
76         if err != nil {
77                 panic(err)
78         }
79         var burnSrc, burnDst [1024]byte
80         c.XORKeyStream(burnDst[:], burnSrc[:])
81         return
82 }
83
84 type cipherReader struct {
85         c  *rc4.Cipher
86         r  io.Reader
87         mu sync.Mutex
88         be []byte
89 }
90
91 func (cr *cipherReader) Read(b []byte) (n int, err error) {
92         var be []byte
93         cr.mu.Lock()
94         if len(cr.be) >= len(b) {
95                 be = cr.be
96                 cr.be = nil
97                 cr.mu.Unlock()
98         } else {
99                 cr.mu.Unlock()
100                 be = make([]byte, len(b))
101         }
102         n, err = cr.r.Read(be[:len(b)])
103         cr.c.XORKeyStream(b[:n], be[:n])
104         cr.mu.Lock()
105         if len(be) > len(cr.be) {
106                 cr.be = be
107         }
108         cr.mu.Unlock()
109         return
110 }
111
112 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
113         return &cipherReader{c: c, r: r}
114 }
115
116 type cipherWriter struct {
117         c *rc4.Cipher
118         w io.Writer
119 }
120
121 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
122         be := make([]byte, len(b))
123         cr.c.XORKeyStream(be, b)
124         n, err = cr.w.Write(be)
125         if n != len(be) {
126                 // The cipher will have advanced beyond the callers stream position.
127                 // We can't use the cipher anymore.
128                 cr.c = nil
129         }
130         return
131 }
132
133 func readY(r io.Reader) (y big.Int, err error) {
134         var b [96]byte
135         _, err = io.ReadFull(r, b[:])
136         if err != nil {
137                 return
138         }
139         y.SetBytes(b[:])
140         return
141 }
142
143 func newX() big.Int {
144         var X big.Int
145         X.SetBytes(func() []byte {
146                 var b [20]byte
147                 _, err := rand.Read(b[:])
148                 if err != nil {
149                         panic(err)
150                 }
151                 return b[:]
152         }())
153         return X
154 }
155
156 func paddedLeft(b []byte, _len int) []byte {
157         if len(b) == _len {
158                 return b
159         }
160         ret := make([]byte, _len)
161         if n := copy(ret[_len-len(b):], b); n != len(b) {
162                 panic(n)
163         }
164         return ret
165 }
166
167 // Calculate, and send Y, our public key.
168 func (h *handshake) postY(x *big.Int) error {
169         var y big.Int
170         y.Exp(&g, x, &p)
171         return h.postWrite(paddedLeft(y.Bytes(), 96))
172 }
173
174 func (h *handshake) establishS() (err error) {
175         x := newX()
176         h.postY(&x)
177         var b [96]byte
178         _, err = io.ReadFull(h.conn, b[:])
179         if err != nil {
180                 return
181         }
182         var Y, S big.Int
183         Y.SetBytes(b[:])
184         S.Exp(&Y, &x, &p)
185         sBytes := S.Bytes()
186         copy(h.s[96-len(sBytes):96], sBytes)
187         return
188 }
189
190 func newPadLen() int64 {
191         i, err := rand.Int(rand.Reader, &newPadLenMax)
192         if err != nil {
193                 panic(err)
194         }
195         ret := i.Int64()
196         if ret < 0 || ret > maxPadLen {
197                 panic(ret)
198         }
199         return ret
200 }
201
202 // Manages state for both initiating and receiving handshakes.
203 type handshake struct {
204         conn   io.ReadWriter
205         s      [96]byte
206         initer bool          // Whether we're initiating or receiving.
207         skeys  SecretKeyIter // Skeys we'll accept if receiving.
208         skey   []byte        // Skey we're initiating with.
209         ia     []byte        // Initial payload. Only used by the initiator.
210
211         writeMu    sync.Mutex
212         writes     [][]byte
213         writeErr   error
214         writeCond  sync.Cond
215         writeClose bool
216
217         writerMu   sync.Mutex
218         writerCond sync.Cond
219         writerDone bool
220 }
221
222 func (h *handshake) finishWriting() {
223         h.writeMu.Lock()
224         h.writeClose = true
225         h.writeCond.Broadcast()
226         h.writeMu.Unlock()
227
228         h.writerMu.Lock()
229         for !h.writerDone {
230                 h.writerCond.Wait()
231         }
232         h.writerMu.Unlock()
233         return
234 }
235
236 func (h *handshake) writer() {
237         defer func() {
238                 h.writerMu.Lock()
239                 h.writerDone = true
240                 h.writerCond.Broadcast()
241                 h.writerMu.Unlock()
242         }()
243         for {
244                 h.writeMu.Lock()
245                 for {
246                         if len(h.writes) != 0 {
247                                 break
248                         }
249                         if h.writeClose {
250                                 h.writeMu.Unlock()
251                                 return
252                         }
253                         h.writeCond.Wait()
254                 }
255                 b := h.writes[0]
256                 h.writes = h.writes[1:]
257                 h.writeMu.Unlock()
258                 _, err := h.conn.Write(b)
259                 if err != nil {
260                         h.writeMu.Lock()
261                         h.writeErr = err
262                         h.writeMu.Unlock()
263                         return
264                 }
265         }
266 }
267
268 func (h *handshake) postWrite(b []byte) error {
269         h.writeMu.Lock()
270         defer h.writeMu.Unlock()
271         if h.writeErr != nil {
272                 return h.writeErr
273         }
274         h.writes = append(h.writes, b)
275         h.writeCond.Signal()
276         return nil
277 }
278
279 func xor(dst, src []byte) (ret []byte) {
280         max := len(dst)
281         if max > len(src) {
282                 max = len(src)
283         }
284         ret = make([]byte, 0, max)
285         for i := range iter.N(max) {
286                 ret = append(ret, dst[i]^src[i])
287         }
288         return
289 }
290
291 func marshal(w io.Writer, data ...interface{}) (err error) {
292         for _, data := range data {
293                 err = binary.Write(w, binary.BigEndian, data)
294                 if err != nil {
295                         break
296                 }
297         }
298         return
299 }
300
301 func unmarshal(r io.Reader, data ...interface{}) (err error) {
302         for _, data := range data {
303                 err = binary.Read(r, binary.BigEndian, data)
304                 if err != nil {
305                         break
306                 }
307         }
308         return
309 }
310
311 // Looking for b at the end of a.
312 func suffixMatchLen(a, b []byte) int {
313         if len(b) > len(a) {
314                 b = b[:len(a)]
315         }
316         // i is how much of b to try to match
317         for i := len(b); i > 0; i-- {
318                 // j is how many chars we've compared
319                 j := 0
320                 for ; j < i; j++ {
321                         if b[i-1-j] != a[len(a)-1-j] {
322                                 goto shorter
323                         }
324                 }
325                 return j
326         shorter:
327         }
328         return 0
329 }
330
331 // Reads from r until b has been seen. Keeps the minimum amount of data in
332 // memory.
333 func readUntil(r io.Reader, b []byte) error {
334         b1 := make([]byte, len(b))
335         i := 0
336         for {
337                 _, err := io.ReadFull(r, b1[i:])
338                 if err != nil {
339                         return err
340                 }
341                 i = suffixMatchLen(b1, b)
342                 if i == len(b) {
343                         break
344                 }
345                 if copy(b1, b1[len(b1)-i:]) != i {
346                         panic("wat")
347                 }
348         }
349         return nil
350 }
351
352 type readWriter struct {
353         io.Reader
354         io.Writer
355 }
356
357 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
358         return newEncrypt(initer, h.s[:], h.skey)
359 }
360
361 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
362         h.postWrite(hash(req1, h.s[:]))
363         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
364         buf := &bytes.Buffer{}
365         padLen := uint16(newPadLen())
366         err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
367         if err != nil {
368                 return
369         }
370         e := h.newEncrypt(true)
371         be := make([]byte, buf.Len())
372         e.XORKeyStream(be, buf.Bytes())
373         h.postWrite(be)
374         bC := h.newEncrypt(false)
375         var eVC [8]byte
376         bC.XORKeyStream(eVC[:], vc[:])
377         // Read until the all zero VC. At this point we've only read the 96 byte
378         // public key, Y. There is potentially 512 byte padding, between us and
379         // the 8 byte verification constant.
380         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
381         if err != nil {
382                 if err == io.EOF {
383                         err = errors.New("failed to synchronize on VC")
384                 } else {
385                         err = fmt.Errorf("error reading until VC: %s", err)
386                 }
387                 return
388         }
389         r := newCipherReader(bC, h.conn)
390         var method uint32
391         err = unmarshal(r, &method, &padLen)
392         if err != nil {
393                 return
394         }
395         if method != cryptoMethodRC4 {
396                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
397                 return
398         }
399         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
400         if err != nil {
401                 return
402         }
403         ret = readWriter{r, &cipherWriter{e, h.conn}}
404         return
405 }
406
407 var ErrNoSecretKeyMatch = errors.New("no skey matched")
408
409 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
410         // There is up to 512 bytes of padding, then the 20 byte hash.
411         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
412         if err != nil {
413                 if err == io.EOF {
414                         err = errors.New("failed to synchronize on S hash")
415                 }
416                 return
417         }
418         var b [20]byte
419         _, err = io.ReadFull(h.conn, b[:])
420         if err != nil {
421                 return
422         }
423         err = ErrNoSecretKeyMatch
424         h.skeys(func(skey []byte) bool {
425                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
426                         h.skey = skey
427                         err = nil
428                         return false
429                 }
430                 return true
431         })
432         if err != nil {
433                 return
434         }
435         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
436         var (
437                 vc     [8]byte
438                 method uint32
439                 padLen uint16
440         )
441
442         err = unmarshal(r, vc[:], &method, &padLen)
443         if err != nil {
444                 return
445         }
446         cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
447         if method&cryptoMethodRC4 == 0 {
448                 err = errors.New("no supported crypto methods were provided")
449                 return
450         }
451         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
452         if err != nil {
453                 return
454         }
455         var lenIA uint16
456         unmarshal(r, &lenIA)
457         if lenIA != 0 {
458                 h.ia = make([]byte, lenIA)
459                 unmarshal(r, h.ia)
460         }
461         buf := &bytes.Buffer{}
462         w := cipherWriter{h.newEncrypt(false), buf}
463         padLen = uint16(newPadLen())
464         err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
465         if err != nil {
466                 return
467         }
468         err = h.postWrite(buf.Bytes())
469         if err != nil {
470                 return
471         }
472         ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
473         return
474 }
475
476 func (h *handshake) Do() (ret io.ReadWriter, err error) {
477         h.writeCond.L = &h.writeMu
478         h.writerCond.L = &h.writerMu
479         go h.writer()
480         defer func() {
481                 h.finishWriting()
482                 if err == nil {
483                         err = h.writeErr
484                 }
485         }()
486         err = h.establishS()
487         if err != nil {
488                 err = fmt.Errorf("error while establishing secret: %s", err)
489                 return
490         }
491         pad := make([]byte, newPadLen())
492         io.ReadFull(rand.Reader, pad)
493         err = h.postWrite(pad)
494         if err != nil {
495                 return
496         }
497         if h.initer {
498                 ret, err = h.initerSteps()
499         } else {
500                 ret, err = h.receiverSteps()
501         }
502         return
503 }
504
505 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
506         h := handshake{
507                 conn:   rw,
508                 initer: true,
509                 skey:   skey,
510                 ia:     initialPayload,
511         }
512         return h.Do()
513 }
514
515 func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
516         h := handshake{
517                 conn:   rw,
518                 initer: false,
519                 skeys:  sliceIter(skeys),
520         }
521         return h.Do()
522 }
523
524 func sliceIter(skeys [][]byte) SecretKeyIter {
525         return func(callback func([]byte) bool) {
526                 for _, sk := range skeys {
527                         if !callback(sk) {
528                                 break
529                         }
530                 }
531         }
532 }
533
534 // A function that given a function, calls it with secret keys until it
535 // returns false or exhausted.
536 type SecretKeyIter func(callback func(skey []byte) (more bool))
537
538 // Doesn't unpack the secret keys until it needs to, and through the passed
539 // function.
540 func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWriter, err error) {
541         h := handshake{
542                 conn:   rw,
543                 initer: false,
544                 skeys:  skeys,
545         }
546         return h.Do()
547 }