]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
mse: Tons of fixes and improvements
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "fmt"
13         "io"
14         "io/ioutil"
15         "log"
16         "math/big"
17         "sync"
18
19         "github.com/bradfitz/iter"
20 )
21
22 const (
23         maxPadLen = 512
24
25         cryptoMethodPlaintext = 1
26         cryptoMethodRC4       = 2
27 )
28
29 var (
30         // Prime P according to the spec, and G, the generator.
31         p, g big.Int
32         // The rand.Int max arg for use in newPadLen()
33         newPadLenMax big.Int
34         // For use in initer's hashes
35         req1 = []byte("req1")
36         req2 = []byte("req2")
37         req3 = []byte("req3")
38 )
39
40 func init() {
41         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
42         g.SetInt64(2)
43         newPadLenMax.SetInt64(maxPadLen + 1)
44 }
45
46 func hash(parts ...[]byte) []byte {
47         h := sha1.New()
48         for _, p := range parts {
49                 n, err := h.Write(p)
50                 if err != nil {
51                         panic(err)
52                 }
53                 if n != len(p) {
54                         panic(n)
55                 }
56         }
57         return h.Sum(nil)
58 }
59
60 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
61         c, err := rc4.NewCipher(hash([]byte(func() string {
62                 if initer {
63                         return "keyA"
64                 } else {
65                         return "keyB"
66                 }
67         }()), s, skey))
68         if err != nil {
69                 panic(err)
70         }
71         var burnSrc, burnDst [1024]byte
72         c.XORKeyStream(burnDst[:], burnSrc[:])
73         return
74 }
75
76 type cipherReader struct {
77         c *rc4.Cipher
78         r io.Reader
79 }
80
81 func (me *cipherReader) Read(b []byte) (n int, err error) {
82         be := make([]byte, len(b))
83         n, err = me.r.Read(be)
84         me.c.XORKeyStream(b[:n], be[:n])
85         return
86 }
87
88 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
89         return &cipherReader{c, r}
90 }
91
92 type cipherWriter struct {
93         c *rc4.Cipher
94         w io.Writer
95 }
96
97 func (me *cipherWriter) Write(b []byte) (n int, err error) {
98         be := make([]byte, len(b))
99         me.c.XORKeyStream(be, b)
100         n, err = me.w.Write(be)
101         if n != len(be) {
102                 // The cipher will have advanced beyond the callers stream position.
103                 // We can't use the cipher anymore.
104                 me.c = nil
105         }
106         return
107 }
108
109 func readY(r io.Reader) (y big.Int, err error) {
110         var b [96]byte
111         _, err = io.ReadFull(r, b[:])
112         if err != nil {
113                 return
114         }
115         y.SetBytes(b[:])
116         return
117 }
118
119 func newX() big.Int {
120         var X big.Int
121         X.SetBytes(func() []byte {
122                 var b [20]byte
123                 _, err := rand.Read(b[:])
124                 if err != nil {
125                         panic(err)
126                 }
127                 return b[:]
128         }())
129         return X
130 }
131
132 // Calculate, and send Y, our public key.
133 func (h *handshake) postY(x *big.Int) error {
134         var y big.Int
135         y.Exp(&g, x, &p)
136         b := y.Bytes()
137         if len(b) != 96 {
138                 b1 := make([]byte, 96)
139                 if n := copy(b1[96-len(b):], b); n != len(b) {
140                         panic(n)
141                 }
142                 b = b1
143         }
144         return h.postWrite(b)
145 }
146
147 func (h *handshake) establishS() (err error) {
148         x := newX()
149         h.postY(&x)
150         var b [96]byte
151         _, err = io.ReadFull(h.conn, b[:])
152         if err != nil {
153                 return
154         }
155         var Y big.Int
156         Y.SetBytes(b[:])
157         h.s.Exp(&Y, &x, &p)
158         return
159 }
160
161 func newPadLen() int64 {
162         i, err := rand.Int(rand.Reader, &newPadLenMax)
163         if err != nil {
164                 panic(err)
165         }
166         ret := i.Int64()
167         if ret < 0 || ret > maxPadLen {
168                 panic(ret)
169         }
170         return ret
171 }
172
173 type handshake struct {
174         conn   io.ReadWriteCloser
175         s      big.Int
176         initer bool
177         skeys  [][]byte
178         skey   []byte
179
180         writeMu    sync.Mutex
181         writes     [][]byte
182         writeErr   error
183         writeCond  sync.Cond
184         writeClose bool
185
186         writerMu   sync.Mutex
187         writerCond sync.Cond
188         writerDone bool
189 }
190
191 func (h *handshake) finishWriting() (err error) {
192         h.writeMu.Lock()
193         h.writeClose = true
194         h.writeCond.Broadcast()
195         err = h.writeErr
196         h.writeMu.Unlock()
197
198         h.writerMu.Lock()
199         for !h.writerDone {
200                 h.writerCond.Wait()
201         }
202         h.writerMu.Unlock()
203         return
204
205 }
206
207 func (h *handshake) writer() {
208         defer func() {
209                 h.writerMu.Lock()
210                 h.writerDone = true
211                 h.writerCond.Broadcast()
212                 h.writerMu.Unlock()
213         }()
214         for {
215                 h.writeMu.Lock()
216                 for {
217                         if len(h.writes) != 0 {
218                                 break
219                         }
220                         if h.writeClose {
221                                 h.writeMu.Unlock()
222                                 return
223                         }
224                         h.writeCond.Wait()
225                 }
226                 b := h.writes[0]
227                 h.writes = h.writes[1:]
228                 h.writeMu.Unlock()
229                 _, err := h.conn.Write(b)
230                 if err != nil {
231                         h.writeMu.Lock()
232                         h.writeErr = err
233                         h.writeMu.Unlock()
234                         return
235                 }
236         }
237 }
238
239 func (h *handshake) postWrite(b []byte) error {
240         h.writeMu.Lock()
241         defer h.writeMu.Unlock()
242         if h.writeErr != nil {
243                 return h.writeErr
244         }
245         h.writes = append(h.writes, b)
246         h.writeCond.Signal()
247         return nil
248 }
249
250 func xor(dst, src []byte) (ret []byte) {
251         max := len(dst)
252         if max > len(src) {
253                 max = len(src)
254         }
255         ret = make([]byte, 0, max)
256         for i := range iter.N(max) {
257                 ret = append(ret, dst[i]^src[i])
258         }
259         return
260 }
261
262 func marshal(w io.Writer, data ...interface{}) (err error) {
263         for _, data := range data {
264                 err = binary.Write(w, binary.BigEndian, data)
265                 if err != nil {
266                         break
267                 }
268         }
269         return
270 }
271
272 func unmarshal(r io.Reader, data ...interface{}) (err error) {
273         for _, data := range data {
274                 err = binary.Read(r, binary.BigEndian, data)
275                 if err != nil {
276                         break
277                 }
278         }
279         return
280 }
281
282 type cryptoNegotiation struct {
283         VC     [8]byte
284         Method uint32
285         PadLen uint16
286         IA     []byte
287 }
288
289 func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
290         err = binary.Read(r, binary.BigEndian, me.VC[:])
291         // _, err = io.ReadFull(r, me.VC[:])
292         if err != nil {
293                 return
294         }
295         err = binary.Read(r, binary.BigEndian, &me.Method)
296         if err != nil {
297                 return
298         }
299         err = binary.Read(r, binary.BigEndian, &me.PadLen)
300         if err != nil {
301                 return
302         }
303         log.Print(me.PadLen)
304         _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
305         return
306 }
307
308 func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
309         // _, err = w.Write(me.VC[:])
310         err = binary.Write(w, binary.BigEndian, me.VC[:])
311         if err != nil {
312                 return
313         }
314         err = binary.Write(w, binary.BigEndian, me.Method)
315         if err != nil {
316                 return
317         }
318         err = binary.Write(w, binary.BigEndian, me.PadLen)
319         if err != nil {
320                 return
321         }
322         _, err = w.Write(make([]byte, me.PadLen))
323         return
324 }
325
326 // Looking for b at the end of a.
327 func suffixMatchLen(a, b []byte) int {
328         if len(b) > len(a) {
329                 b = b[:len(a)]
330         }
331         // i is how much of b to try to match
332         for i := len(b); i > 0; i-- {
333                 // j is how many chars we've compared
334                 j := 0
335                 for ; j < i; j++ {
336                         if b[i-1-j] != a[len(a)-1-j] {
337                                 goto shorter
338                         }
339                 }
340                 return j
341         shorter:
342         }
343         return 0
344 }
345
346 func readUntil(r io.Reader, b []byte) error {
347         log.Println("read until", b)
348         b1 := make([]byte, len(b))
349         i := 0
350         for {
351                 _, err := io.ReadFull(r, b1[i:])
352                 if err != nil {
353                         return err
354                 }
355                 i = suffixMatchLen(b1, b)
356                 if i == len(b) {
357                         break
358                 }
359                 if copy(b1, b1[len(b1)-i:]) != i {
360                         panic("wat")
361                 }
362         }
363         return nil
364 }
365
366 type readWriter struct {
367         io.Reader
368         io.Writer
369 }
370
371 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
372         h.postWrite(hash(req1, h.s.Bytes()))
373         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
374         buf := &bytes.Buffer{}
375         err = (&cryptoNegotiation{
376                 Method: cryptoMethodRC4,
377                 PadLen: uint16(newPadLen()),
378         }).MarshalWriter(buf)
379         if err != nil {
380                 return
381         }
382         err = marshal(buf, uint16(0))
383         if err != nil {
384                 return
385         }
386         e := newEncrypt(true, h.s.Bytes(), h.skey)
387         be := make([]byte, buf.Len())
388         e.XORKeyStream(be, buf.Bytes())
389         h.postWrite(be)
390         bC := newEncrypt(false, h.s.Bytes(), h.skey)
391         var eVC [8]byte
392         bC.XORKeyStream(eVC[:], make([]byte, 8))
393         log.Print(eVC)
394         // Read until the all zero VC.
395         err = readUntil(h.conn, eVC[:])
396         if err != nil {
397                 err = fmt.Errorf("error reading until VC: %s", err)
398                 return
399         }
400         var cn cryptoNegotiation
401         r := &cipherReader{bC, h.conn}
402         err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
403         log.Printf("initer got %v", cn)
404         if err != nil {
405                 err = fmt.Errorf("error reading crypto negotiation: %s", err)
406                 return
407         }
408         ret = readWriter{r, &cipherWriter{e, h.conn}}
409         return
410 }
411
412 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
413         err = readUntil(h.conn, hash(req1, h.s.Bytes()))
414         if err != nil {
415                 return
416         }
417         var b [20]byte
418         _, err = io.ReadFull(h.conn, b[:])
419         if err != nil {
420                 return
421         }
422         err = errors.New("skey doesn't match")
423         for _, skey := range h.skeys {
424                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) {
425                         h.skey = skey
426                         err = nil
427                         break
428                 }
429         }
430         if err != nil {
431                 return
432         }
433         var cn cryptoNegotiation
434         r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
435         err = cn.UnmarshalReader(r)
436         if err != nil {
437                 return
438         }
439         log.Printf("receiver got %v", cn)
440         if cn.Method&cryptoMethodRC4 == 0 {
441                 err = errors.New("no supported crypto methods were provided")
442                 return
443         }
444         unmarshal(r, new(uint16))
445         buf := &bytes.Buffer{}
446         w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
447         err = (&cryptoNegotiation{
448                 Method: cryptoMethodRC4,
449                 PadLen: uint16(newPadLen()),
450         }).MarshalWriter(&w)
451         if err != nil {
452                 return
453         }
454         err = h.postWrite(buf.Bytes())
455         if err != nil {
456                 return
457         }
458         ret = readWriter{r, &cipherWriter{w.c, h.conn}}
459         return
460 }
461
462 func (h *handshake) Do() (ret io.ReadWriter, err error) {
463         err = h.establishS()
464         if err != nil {
465                 err = fmt.Errorf("error while establishing secret: %s", err)
466                 return
467         }
468         pad := make([]byte, newPadLen())
469         io.ReadFull(rand.Reader, pad)
470         err = h.postWrite(pad)
471         if err != nil {
472                 return
473         }
474         if h.initer {
475                 ret, err = h.initerSteps()
476         } else {
477                 ret, err = h.receiverSteps()
478         }
479         if err != nil {
480                 return
481         }
482         err = h.finishWriting()
483         if err != nil {
484                 return
485         }
486         log.Print("ermahgerd, finished MSE handshake")
487         return
488 }
489
490 func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
491         h := handshake{
492                 conn:   rw,
493                 initer: true,
494                 skey:   skey,
495         }
496         h.writeCond.L = &h.writeMu
497         h.writerCond.L = &h.writerMu
498         go h.writer()
499         return h.Do()
500 }
501 func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) {
502         h := handshake{
503                 conn:   rw,
504                 initer: false,
505                 skeys:  skeys,
506         }
507         h.writeCond.L = &h.writeMu
508         h.writerCond.L = &h.writerMu
509         go h.writer()
510         return h.Do()
511 }