]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
mse: Reuse writer write buffer
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "github.com/bradfitz/iter"
21 )
22
23 const (
24         maxPadLen = 512
25
26         cryptoMethodPlaintext = 1
27         cryptoMethodRC4       = 2
28 )
29
30 var (
31         // Prime P according to the spec, and G, the generator.
32         p, g big.Int
33         // The rand.Int max arg for use in newPadLen()
34         newPadLenMax big.Int
35         // For use in initer's hashes
36         req1 = []byte("req1")
37         req2 = []byte("req2")
38         req3 = []byte("req3")
39         // Verification constant "VC" which is all zeroes in the bittorrent
40         // implementation.
41         vc [8]byte
42         // Zero padding
43         zeroPad [512]byte
44         // Tracks counts of received crypto_provides
45         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
46 )
47
48 func init() {
49         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
50         g.SetInt64(2)
51         newPadLenMax.SetInt64(maxPadLen + 1)
52 }
53
54 func hash(parts ...[]byte) []byte {
55         h := sha1.New()
56         for _, p := range parts {
57                 n, err := h.Write(p)
58                 if err != nil {
59                         panic(err)
60                 }
61                 if n != len(p) {
62                         panic(n)
63                 }
64         }
65         return h.Sum(nil)
66 }
67
68 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
69         c, err := rc4.NewCipher(hash([]byte(func() string {
70                 if initer {
71                         return "keyA"
72                 } else {
73                         return "keyB"
74                 }
75         }()), s, skey))
76         if err != nil {
77                 panic(err)
78         }
79         var burnSrc, burnDst [1024]byte
80         c.XORKeyStream(burnDst[:], burnSrc[:])
81         return
82 }
83
84 type cipherReader struct {
85         c  *rc4.Cipher
86         r  io.Reader
87         mu sync.Mutex
88         be []byte
89 }
90
91 func (cr *cipherReader) Read(b []byte) (n int, err error) {
92         var be []byte
93         cr.mu.Lock()
94         if len(cr.be) >= len(b) {
95                 be = cr.be
96                 cr.be = nil
97                 cr.mu.Unlock()
98         } else {
99                 cr.mu.Unlock()
100                 be = make([]byte, len(b))
101         }
102         n, err = cr.r.Read(be[:len(b)])
103         cr.c.XORKeyStream(b[:n], be[:n])
104         cr.mu.Lock()
105         if len(be) > len(cr.be) {
106                 cr.be = be
107         }
108         cr.mu.Unlock()
109         return
110 }
111
112 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
113         return &cipherReader{c: c, r: r}
114 }
115
116 type cipherWriter struct {
117         c *rc4.Cipher
118         w io.Writer
119         b []byte
120 }
121
122 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
123         be := func() []byte {
124                 if len(cr.b) < len(b) {
125                         return make([]byte, len(b))
126                 } else {
127                         ret := cr.b
128                         cr.b = nil
129                         return ret
130                 }
131         }()
132         cr.c.XORKeyStream(be[:], b)
133         n, err = cr.w.Write(be[:len(b)])
134         if n != len(b) {
135                 // The cipher will have advanced beyond the callers stream position.
136                 // We can't use the cipher anymore.
137                 cr.c = nil
138         }
139         if len(be) > len(cr.b) {
140                 cr.b = be
141         }
142         return
143 }
144
145 func newX() big.Int {
146         var X big.Int
147         X.SetBytes(func() []byte {
148                 var b [20]byte
149                 _, err := rand.Read(b[:])
150                 if err != nil {
151                         panic(err)
152                 }
153                 return b[:]
154         }())
155         return X
156 }
157
158 func paddedLeft(b []byte, _len int) []byte {
159         if len(b) == _len {
160                 return b
161         }
162         ret := make([]byte, _len)
163         if n := copy(ret[_len-len(b):], b); n != len(b) {
164                 panic(n)
165         }
166         return ret
167 }
168
169 // Calculate, and send Y, our public key.
170 func (h *handshake) postY(x *big.Int) error {
171         var y big.Int
172         y.Exp(&g, x, &p)
173         return h.postWrite(paddedLeft(y.Bytes(), 96))
174 }
175
176 func (h *handshake) establishS() (err error) {
177         x := newX()
178         h.postY(&x)
179         var b [96]byte
180         _, err = io.ReadFull(h.conn, b[:])
181         if err != nil {
182                 return
183         }
184         var Y, S big.Int
185         Y.SetBytes(b[:])
186         S.Exp(&Y, &x, &p)
187         sBytes := S.Bytes()
188         copy(h.s[96-len(sBytes):96], sBytes)
189         return
190 }
191
192 func newPadLen() int64 {
193         i, err := rand.Int(rand.Reader, &newPadLenMax)
194         if err != nil {
195                 panic(err)
196         }
197         ret := i.Int64()
198         if ret < 0 || ret > maxPadLen {
199                 panic(ret)
200         }
201         return ret
202 }
203
204 // Manages state for both initiating and receiving handshakes.
205 type handshake struct {
206         conn   io.ReadWriter
207         s      [96]byte
208         initer bool          // Whether we're initiating or receiving.
209         skeys  SecretKeyIter // Skeys we'll accept if receiving.
210         skey   []byte        // Skey we're initiating with.
211         ia     []byte        // Initial payload. Only used by the initiator.
212
213         writeMu    sync.Mutex
214         writes     [][]byte
215         writeErr   error
216         writeCond  sync.Cond
217         writeClose bool
218
219         writerMu   sync.Mutex
220         writerCond sync.Cond
221         writerDone bool
222 }
223
224 func (h *handshake) finishWriting() {
225         h.writeMu.Lock()
226         h.writeClose = true
227         h.writeCond.Broadcast()
228         h.writeMu.Unlock()
229
230         h.writerMu.Lock()
231         for !h.writerDone {
232                 h.writerCond.Wait()
233         }
234         h.writerMu.Unlock()
235         return
236 }
237
238 func (h *handshake) writer() {
239         defer func() {
240                 h.writerMu.Lock()
241                 h.writerDone = true
242                 h.writerCond.Broadcast()
243                 h.writerMu.Unlock()
244         }()
245         for {
246                 h.writeMu.Lock()
247                 for {
248                         if len(h.writes) != 0 {
249                                 break
250                         }
251                         if h.writeClose {
252                                 h.writeMu.Unlock()
253                                 return
254                         }
255                         h.writeCond.Wait()
256                 }
257                 b := h.writes[0]
258                 h.writes = h.writes[1:]
259                 h.writeMu.Unlock()
260                 _, err := h.conn.Write(b)
261                 if err != nil {
262                         h.writeMu.Lock()
263                         h.writeErr = err
264                         h.writeMu.Unlock()
265                         return
266                 }
267         }
268 }
269
270 func (h *handshake) postWrite(b []byte) error {
271         h.writeMu.Lock()
272         defer h.writeMu.Unlock()
273         if h.writeErr != nil {
274                 return h.writeErr
275         }
276         h.writes = append(h.writes, b)
277         h.writeCond.Signal()
278         return nil
279 }
280
281 func xor(dst, src []byte) (ret []byte) {
282         max := len(dst)
283         if max > len(src) {
284                 max = len(src)
285         }
286         ret = make([]byte, 0, max)
287         for i := range iter.N(max) {
288                 ret = append(ret, dst[i]^src[i])
289         }
290         return
291 }
292
293 func marshal(w io.Writer, data ...interface{}) (err error) {
294         for _, data := range data {
295                 err = binary.Write(w, binary.BigEndian, data)
296                 if err != nil {
297                         break
298                 }
299         }
300         return
301 }
302
303 func unmarshal(r io.Reader, data ...interface{}) (err error) {
304         for _, data := range data {
305                 err = binary.Read(r, binary.BigEndian, data)
306                 if err != nil {
307                         break
308                 }
309         }
310         return
311 }
312
313 // Looking for b at the end of a.
314 func suffixMatchLen(a, b []byte) int {
315         if len(b) > len(a) {
316                 b = b[:len(a)]
317         }
318         // i is how much of b to try to match
319         for i := len(b); i > 0; i-- {
320                 // j is how many chars we've compared
321                 j := 0
322                 for ; j < i; j++ {
323                         if b[i-1-j] != a[len(a)-1-j] {
324                                 goto shorter
325                         }
326                 }
327                 return j
328         shorter:
329         }
330         return 0
331 }
332
333 // Reads from r until b has been seen. Keeps the minimum amount of data in
334 // memory.
335 func readUntil(r io.Reader, b []byte) error {
336         b1 := make([]byte, len(b))
337         i := 0
338         for {
339                 _, err := io.ReadFull(r, b1[i:])
340                 if err != nil {
341                         return err
342                 }
343                 i = suffixMatchLen(b1, b)
344                 if i == len(b) {
345                         break
346                 }
347                 if copy(b1, b1[len(b1)-i:]) != i {
348                         panic("wat")
349                 }
350         }
351         return nil
352 }
353
354 type readWriter struct {
355         io.Reader
356         io.Writer
357 }
358
359 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
360         return newEncrypt(initer, h.s[:], h.skey)
361 }
362
363 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
364         h.postWrite(hash(req1, h.s[:]))
365         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
366         buf := &bytes.Buffer{}
367         padLen := uint16(newPadLen())
368         err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
369         if err != nil {
370                 return
371         }
372         e := h.newEncrypt(true)
373         be := make([]byte, buf.Len())
374         e.XORKeyStream(be, buf.Bytes())
375         h.postWrite(be)
376         bC := h.newEncrypt(false)
377         var eVC [8]byte
378         bC.XORKeyStream(eVC[:], vc[:])
379         // Read until the all zero VC. At this point we've only read the 96 byte
380         // public key, Y. There is potentially 512 byte padding, between us and
381         // the 8 byte verification constant.
382         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
383         if err != nil {
384                 if err == io.EOF {
385                         err = errors.New("failed to synchronize on VC")
386                 } else {
387                         err = fmt.Errorf("error reading until VC: %s", err)
388                 }
389                 return
390         }
391         r := newCipherReader(bC, h.conn)
392         var method uint32
393         err = unmarshal(r, &method, &padLen)
394         if err != nil {
395                 return
396         }
397         if method != cryptoMethodRC4 {
398                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
399                 return
400         }
401         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
402         if err != nil {
403                 return
404         }
405         ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
406         return
407 }
408
409 var ErrNoSecretKeyMatch = errors.New("no skey matched")
410
411 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
412         // There is up to 512 bytes of padding, then the 20 byte hash.
413         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
414         if err != nil {
415                 if err == io.EOF {
416                         err = errors.New("failed to synchronize on S hash")
417                 }
418                 return
419         }
420         var b [20]byte
421         _, err = io.ReadFull(h.conn, b[:])
422         if err != nil {
423                 return
424         }
425         err = ErrNoSecretKeyMatch
426         h.skeys(func(skey []byte) bool {
427                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
428                         h.skey = skey
429                         err = nil
430                         return false
431                 }
432                 return true
433         })
434         if err != nil {
435                 return
436         }
437         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
438         var (
439                 vc     [8]byte
440                 method uint32
441                 padLen uint16
442         )
443
444         err = unmarshal(r, vc[:], &method, &padLen)
445         if err != nil {
446                 return
447         }
448         cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
449         if method&cryptoMethodRC4 == 0 {
450                 err = errors.New("no supported crypto methods were provided")
451                 return
452         }
453         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
454         if err != nil {
455                 return
456         }
457         var lenIA uint16
458         unmarshal(r, &lenIA)
459         if lenIA != 0 {
460                 h.ia = make([]byte, lenIA)
461                 unmarshal(r, h.ia)
462         }
463         buf := &bytes.Buffer{}
464         w := cipherWriter{h.newEncrypt(false), buf, nil}
465         padLen = uint16(newPadLen())
466         err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
467         if err != nil {
468                 return
469         }
470         err = h.postWrite(buf.Bytes())
471         if err != nil {
472                 return
473         }
474         ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}}
475         return
476 }
477
478 func (h *handshake) Do() (ret io.ReadWriter, err error) {
479         h.writeCond.L = &h.writeMu
480         h.writerCond.L = &h.writerMu
481         go h.writer()
482         defer func() {
483                 h.finishWriting()
484                 if err == nil {
485                         err = h.writeErr
486                 }
487         }()
488         err = h.establishS()
489         if err != nil {
490                 err = fmt.Errorf("error while establishing secret: %s", err)
491                 return
492         }
493         pad := make([]byte, newPadLen())
494         io.ReadFull(rand.Reader, pad)
495         err = h.postWrite(pad)
496         if err != nil {
497                 return
498         }
499         if h.initer {
500                 ret, err = h.initerSteps()
501         } else {
502                 ret, err = h.receiverSteps()
503         }
504         return
505 }
506
507 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
508         h := handshake{
509                 conn:   rw,
510                 initer: true,
511                 skey:   skey,
512                 ia:     initialPayload,
513         }
514         return h.Do()
515 }
516
517 func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
518         h := handshake{
519                 conn:   rw,
520                 initer: false,
521                 skeys:  sliceIter(skeys),
522         }
523         return h.Do()
524 }
525
526 func sliceIter(skeys [][]byte) SecretKeyIter {
527         return func(callback func([]byte) bool) {
528                 for _, sk := range skeys {
529                         if !callback(sk) {
530                                 break
531                         }
532                 }
533         }
534 }
535
536 // A function that given a function, calls it with secret keys until it
537 // returns false or exhausted.
538 type SecretKeyIter func(callback func(skey []byte) (more bool))
539
540 // Doesn't unpack the secret keys until it needs to, and through the passed
541 // function.
542 func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWriter, err error) {
543         h := handshake{
544                 conn:   rw,
545                 initer: false,
546                 skeys:  skeys,
547         }
548         return h.Do()
549 }