]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Support initial payload, and improve tests
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "github.com/bradfitz/iter"
21 )
22
23 const (
24         maxPadLen = 512
25
26         cryptoMethodPlaintext = 1
27         cryptoMethodRC4       = 2
28 )
29
30 var (
31         // Prime P according to the spec, and G, the generator.
32         p, g big.Int
33         // The rand.Int max arg for use in newPadLen()
34         newPadLenMax big.Int
35         // For use in initer's hashes
36         req1 = []byte("req1")
37         req2 = []byte("req2")
38         req3 = []byte("req3")
39
40         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
41 )
42
43 func init() {
44         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
45         g.SetInt64(2)
46         newPadLenMax.SetInt64(maxPadLen + 1)
47 }
48
49 func hash(parts ...[]byte) []byte {
50         h := sha1.New()
51         for _, p := range parts {
52                 n, err := h.Write(p)
53                 if err != nil {
54                         panic(err)
55                 }
56                 if n != len(p) {
57                         panic(n)
58                 }
59         }
60         return h.Sum(nil)
61 }
62
63 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
64         c, err := rc4.NewCipher(hash([]byte(func() string {
65                 if initer {
66                         return "keyA"
67                 } else {
68                         return "keyB"
69                 }
70         }()), s, skey))
71         if err != nil {
72                 panic(err)
73         }
74         var burnSrc, burnDst [1024]byte
75         c.XORKeyStream(burnDst[:], burnSrc[:])
76         return
77 }
78
79 type cipherReader struct {
80         c *rc4.Cipher
81         r io.Reader
82 }
83
84 func (me *cipherReader) Read(b []byte) (n int, err error) {
85         be := make([]byte, len(b))
86         n, err = me.r.Read(be)
87         me.c.XORKeyStream(b[:n], be[:n])
88         return
89 }
90
91 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
92         return &cipherReader{c, r}
93 }
94
95 type cipherWriter struct {
96         c *rc4.Cipher
97         w io.Writer
98 }
99
100 func (me *cipherWriter) Write(b []byte) (n int, err error) {
101         be := make([]byte, len(b))
102         me.c.XORKeyStream(be, b)
103         n, err = me.w.Write(be)
104         if n != len(be) {
105                 // The cipher will have advanced beyond the callers stream position.
106                 // We can't use the cipher anymore.
107                 me.c = nil
108         }
109         return
110 }
111
112 func readY(r io.Reader) (y big.Int, err error) {
113         var b [96]byte
114         _, err = io.ReadFull(r, b[:])
115         if err != nil {
116                 return
117         }
118         y.SetBytes(b[:])
119         return
120 }
121
122 func newX() big.Int {
123         var X big.Int
124         X.SetBytes(func() []byte {
125                 var b [20]byte
126                 _, err := rand.Read(b[:])
127                 if err != nil {
128                         panic(err)
129                 }
130                 return b[:]
131         }())
132         return X
133 }
134
135 // Calculate, and send Y, our public key.
136 func (h *handshake) postY(x *big.Int) error {
137         var y big.Int
138         y.Exp(&g, x, &p)
139         b := y.Bytes()
140         if len(b) != 96 {
141                 b1 := make([]byte, 96)
142                 if n := copy(b1[96-len(b):], b); n != len(b) {
143                         panic(n)
144                 }
145                 b = b1
146         }
147         return h.postWrite(b)
148 }
149
150 func (h *handshake) establishS() (err error) {
151         x := newX()
152         h.postY(&x)
153         var b [96]byte
154         _, err = io.ReadFull(h.conn, b[:])
155         if err != nil {
156                 return
157         }
158         var Y big.Int
159         Y.SetBytes(b[:])
160         h.s.Exp(&Y, &x, &p)
161         return
162 }
163
164 func newPadLen() int64 {
165         i, err := rand.Int(rand.Reader, &newPadLenMax)
166         if err != nil {
167                 panic(err)
168         }
169         ret := i.Int64()
170         if ret < 0 || ret > maxPadLen {
171                 panic(ret)
172         }
173         return ret
174 }
175
176 type handshake struct {
177         conn   io.ReadWriteCloser
178         s      big.Int
179         initer bool
180         skeys  [][]byte
181         skey   []byte
182         ia     []byte // Initial payload. Only used by the initiator.
183
184         writeMu    sync.Mutex
185         writes     [][]byte
186         writeErr   error
187         writeCond  sync.Cond
188         writeClose bool
189
190         writerMu   sync.Mutex
191         writerCond sync.Cond
192         writerDone bool
193 }
194
195 func (h *handshake) finishWriting() (err error) {
196         h.writeMu.Lock()
197         h.writeClose = true
198         h.writeCond.Broadcast()
199         err = h.writeErr
200         h.writeMu.Unlock()
201
202         h.writerMu.Lock()
203         for !h.writerDone {
204                 h.writerCond.Wait()
205         }
206         h.writerMu.Unlock()
207         return
208
209 }
210
211 func (h *handshake) writer() {
212         defer func() {
213                 h.writerMu.Lock()
214                 h.writerDone = true
215                 h.writerCond.Broadcast()
216                 h.writerMu.Unlock()
217         }()
218         for {
219                 h.writeMu.Lock()
220                 for {
221                         if len(h.writes) != 0 {
222                                 break
223                         }
224                         if h.writeClose {
225                                 h.writeMu.Unlock()
226                                 return
227                         }
228                         h.writeCond.Wait()
229                 }
230                 b := h.writes[0]
231                 h.writes = h.writes[1:]
232                 h.writeMu.Unlock()
233                 _, err := h.conn.Write(b)
234                 if err != nil {
235                         h.writeMu.Lock()
236                         h.writeErr = err
237                         h.writeMu.Unlock()
238                         return
239                 }
240         }
241 }
242
243 func (h *handshake) postWrite(b []byte) error {
244         h.writeMu.Lock()
245         defer h.writeMu.Unlock()
246         if h.writeErr != nil {
247                 return h.writeErr
248         }
249         h.writes = append(h.writes, b)
250         h.writeCond.Signal()
251         return nil
252 }
253
254 func xor(dst, src []byte) (ret []byte) {
255         max := len(dst)
256         if max > len(src) {
257                 max = len(src)
258         }
259         ret = make([]byte, 0, max)
260         for i := range iter.N(max) {
261                 ret = append(ret, dst[i]^src[i])
262         }
263         return
264 }
265
266 func marshal(w io.Writer, data ...interface{}) (err error) {
267         for _, data := range data {
268                 err = binary.Write(w, binary.BigEndian, data)
269                 if err != nil {
270                         break
271                 }
272         }
273         return
274 }
275
276 func unmarshal(r io.Reader, data ...interface{}) (err error) {
277         for _, data := range data {
278                 err = binary.Read(r, binary.BigEndian, data)
279                 if err != nil {
280                         break
281                 }
282         }
283         return
284 }
285
286 type cryptoNegotiation struct {
287         VC     [8]byte
288         Method uint32
289         PadLen uint16
290         IA     []byte
291 }
292
293 func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
294         err = binary.Read(r, binary.BigEndian, me.VC[:])
295         if err != nil {
296                 return
297         }
298         err = binary.Read(r, binary.BigEndian, &me.Method)
299         if err != nil {
300                 return
301         }
302         err = binary.Read(r, binary.BigEndian, &me.PadLen)
303         if err != nil {
304                 return
305         }
306         _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
307         return
308 }
309
310 func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
311         // _, err = w.Write(me.VC[:])
312         err = binary.Write(w, binary.BigEndian, me.VC[:])
313         if err != nil {
314                 return
315         }
316         err = binary.Write(w, binary.BigEndian, me.Method)
317         if err != nil {
318                 return
319         }
320         err = binary.Write(w, binary.BigEndian, me.PadLen)
321         if err != nil {
322                 return
323         }
324         _, err = w.Write(make([]byte, me.PadLen))
325         return
326 }
327
328 // Looking for b at the end of a.
329 func suffixMatchLen(a, b []byte) int {
330         if len(b) > len(a) {
331                 b = b[:len(a)]
332         }
333         // i is how much of b to try to match
334         for i := len(b); i > 0; i-- {
335                 // j is how many chars we've compared
336                 j := 0
337                 for ; j < i; j++ {
338                         if b[i-1-j] != a[len(a)-1-j] {
339                                 goto shorter
340                         }
341                 }
342                 return j
343         shorter:
344         }
345         return 0
346 }
347
348 func readUntil(r io.Reader, b []byte) error {
349         b1 := make([]byte, len(b))
350         i := 0
351         for {
352                 _, err := io.ReadFull(r, b1[i:])
353                 if err != nil {
354                         return err
355                 }
356                 i = suffixMatchLen(b1, b)
357                 if i == len(b) {
358                         break
359                 }
360                 if copy(b1, b1[len(b1)-i:]) != i {
361                         panic("wat")
362                 }
363         }
364         return nil
365 }
366
367 type readWriter struct {
368         io.Reader
369         io.Writer
370 }
371
372 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
373         h.postWrite(hash(req1, h.s.Bytes()))
374         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
375         buf := &bytes.Buffer{}
376         err = (&cryptoNegotiation{
377                 Method: cryptoMethodRC4,
378                 PadLen: uint16(newPadLen()),
379         }).MarshalWriter(buf)
380         if err != nil {
381                 return
382         }
383         err = marshal(buf, uint16(len(h.ia)), h.ia)
384         if err != nil {
385                 return
386         }
387         e := newEncrypt(true, h.s.Bytes(), h.skey)
388         be := make([]byte, buf.Len())
389         e.XORKeyStream(be, buf.Bytes())
390         h.postWrite(be)
391         bC := newEncrypt(false, h.s.Bytes(), h.skey)
392         var eVC [8]byte
393         bC.XORKeyStream(eVC[:], make([]byte, 8))
394         // Read until the all zero VC.
395         err = readUntil(h.conn, eVC[:])
396         if err != nil {
397                 err = fmt.Errorf("error reading until VC: %s", err)
398                 return
399         }
400         var cn cryptoNegotiation
401         r := &cipherReader{bC, h.conn}
402         err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
403         if err != nil {
404                 err = fmt.Errorf("error reading crypto negotiation: %s", err)
405                 return
406         }
407         ret = readWriter{r, &cipherWriter{e, h.conn}}
408         return
409 }
410
411 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
412         err = readUntil(h.conn, hash(req1, h.s.Bytes()))
413         if err != nil {
414                 return
415         }
416         var b [20]byte
417         _, err = io.ReadFull(h.conn, b[:])
418         if err != nil {
419                 return
420         }
421         err = errors.New("skey doesn't match")
422         for _, skey := range h.skeys {
423                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) {
424                         h.skey = skey
425                         err = nil
426                         break
427                 }
428         }
429         if err != nil {
430                 return
431         }
432         var cn cryptoNegotiation
433         r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
434         err = cn.UnmarshalReader(r)
435         if err != nil {
436                 return
437         }
438         cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1)
439         if cn.Method&cryptoMethodRC4 == 0 {
440                 err = errors.New("no supported crypto methods were provided")
441                 return
442         }
443         var lenIA uint16
444         unmarshal(r, &lenIA)
445         if lenIA != 0 {
446                 h.ia = make([]byte, lenIA)
447                 unmarshal(r, h.ia)
448         }
449         buf := &bytes.Buffer{}
450         w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
451         err = (&cryptoNegotiation{
452                 Method: cryptoMethodRC4,
453                 PadLen: uint16(newPadLen()),
454         }).MarshalWriter(&w)
455         if err != nil {
456                 return
457         }
458         err = h.postWrite(buf.Bytes())
459         if err != nil {
460                 return
461         }
462         ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
463         return
464 }
465
466 func (h *handshake) Do() (ret io.ReadWriter, err error) {
467         err = h.establishS()
468         if err != nil {
469                 err = fmt.Errorf("error while establishing secret: %s", err)
470                 return
471         }
472         pad := make([]byte, newPadLen())
473         io.ReadFull(rand.Reader, pad)
474         err = h.postWrite(pad)
475         if err != nil {
476                 return
477         }
478         if h.initer {
479                 ret, err = h.initerSteps()
480         } else {
481                 ret, err = h.receiverSteps()
482         }
483         if err != nil {
484                 return
485         }
486         err = h.finishWriting()
487         if err != nil {
488                 return
489         }
490         return
491 }
492
493 func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
494         h := handshake{
495                 conn:   rw,
496                 initer: true,
497                 skey:   skey,
498                 ia:     initialPayload,
499         }
500         h.writeCond.L = &h.writeMu
501         h.writerCond.L = &h.writerMu
502         go h.writer()
503         return h.Do()
504 }
505 func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) {
506         h := handshake{
507                 conn:   rw,
508                 initer: false,
509                 skeys:  skeys,
510         }
511         h.writeCond.L = &h.writeMu
512         h.writerCond.L = &h.writerMu
513         go h.writer()
514         return h.Do()
515 }