]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
mse: Got basic message stream encryption working with some tests
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "fmt"
13         "io"
14         "io/ioutil"
15         "log"
16         "math/big"
17         "sync"
18
19         "github.com/bradfitz/iter"
20 )
21
22 const (
23         maxPadLen = 512
24
25         cryptoMethodPlaintext = 1
26         cryptoMethodRC4       = 2
27 )
28
29 var (
30         // Prime P according to the spec, and G, the generator.
31         p, g big.Int
32         // The rand.Int max arg for use in newPadLen()
33         newPadLenMax big.Int
34         // For use in initer's hashes
35         req1 = []byte("req1")
36         req2 = []byte("req2")
37         req3 = []byte("req3")
38 )
39
40 func init() {
41         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
42         g.SetInt64(2)
43         newPadLenMax.SetInt64(maxPadLen + 1)
44 }
45
46 func hash(parts ...[]byte) []byte {
47         h := sha1.New()
48         for _, p := range parts {
49                 n, err := h.Write(p)
50                 if err != nil {
51                         panic(err)
52                 }
53                 if n != len(p) {
54                         panic(n)
55                 }
56         }
57         return h.Sum(nil)
58 }
59
60 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
61         c, err := rc4.NewCipher(hash([]byte(func() string {
62                 if initer {
63                         return "keyA"
64                 } else {
65                         return "keyB"
66                 }
67         }()), s, skey))
68         if err != nil {
69                 panic(err)
70         }
71         var burnSrc, burnDst [1024]byte
72         c.XORKeyStream(burnDst[:], burnSrc[:])
73         return
74 }
75
76 type cipherReader struct {
77         c *rc4.Cipher
78         r io.Reader
79 }
80
81 func (me *cipherReader) Read(b []byte) (n int, err error) {
82         be := make([]byte, len(b))
83         n, err = me.r.Read(be)
84         me.c.XORKeyStream(b[:n], be[:n])
85         return
86 }
87
88 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
89         return &cipherReader{c, r}
90 }
91
92 type cipherWriter struct {
93         c *rc4.Cipher
94         w io.Writer
95 }
96
97 func (me *cipherWriter) Write(b []byte) (n int, err error) {
98         be := make([]byte, len(b))
99         me.c.XORKeyStream(be, b)
100         n, err = me.w.Write(be)
101         if n != len(be) {
102                 // The cipher will have advanced beyond the callers stream position.
103                 // We can't use the cipher anymore.
104                 me.c = nil
105         }
106         return
107 }
108
109 func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
110         return &cipherWriter{c, w}
111 }
112
113 func readY(r io.Reader) (y big.Int, err error) {
114         var b [96]byte
115         _, err = io.ReadFull(r, b[:])
116         if err != nil {
117                 return
118         }
119         y.SetBytes(b[:])
120         return
121 }
122
123 func newX() big.Int {
124         var X big.Int
125         X.SetBytes(func() []byte {
126                 var b [20]byte
127                 _, err := rand.Read(b[:])
128                 if err != nil {
129                         panic(err)
130                 }
131                 return b[:]
132         }())
133         return X
134 }
135
136 func (h *handshake) postY(x *big.Int) error {
137         var y big.Int
138         y.Exp(&g, x, &p)
139         b := y.Bytes()
140         if len(b) != 96 {
141                 panic(len(b))
142         }
143         return h.postWrite(b)
144 }
145
146 func (h *handshake) establishS() (err error) {
147         x := newX()
148         h.postY(&x)
149         var b [96]byte
150         _, err = io.ReadFull(h.conn, b[:])
151         if err != nil {
152                 return
153         }
154         var Y big.Int
155         Y.SetBytes(b[:])
156         h.s.Exp(&Y, &x, &p)
157         return
158 }
159
160 func newPadLen() int64 {
161         i, err := rand.Int(rand.Reader, &newPadLenMax)
162         if err != nil {
163                 panic(err)
164         }
165         ret := i.Int64()
166         if ret < 0 || ret > maxPadLen {
167                 panic(ret)
168         }
169         return ret
170 }
171
172 type handshake struct {
173         conn   io.ReadWriteCloser
174         s      big.Int
175         initer bool
176         skey   []byte
177
178         writeMu    sync.Mutex
179         writes     [][]byte
180         writeErr   error
181         writeCond  sync.Cond
182         writeClose bool
183
184         writerMu   sync.Mutex
185         writerCond sync.Cond
186         writerDone bool
187 }
188
189 func (h *handshake) finishWriting() (err error) {
190         h.writeMu.Lock()
191         h.writeClose = true
192         h.writeCond.Broadcast()
193         err = h.writeErr
194         h.writeMu.Unlock()
195
196         h.writerMu.Lock()
197         for !h.writerDone {
198                 h.writerCond.Wait()
199         }
200         h.writerMu.Unlock()
201         return
202
203 }
204
205 func (h *handshake) writer() {
206         defer func() {
207                 h.writerMu.Lock()
208                 h.writerDone = true
209                 h.writerCond.Broadcast()
210                 h.writerMu.Unlock()
211         }()
212         for {
213                 h.writeMu.Lock()
214                 for {
215                         if len(h.writes) != 0 {
216                                 break
217                         }
218                         if h.writeClose {
219                                 h.writeMu.Unlock()
220                                 return
221                         }
222                         h.writeCond.Wait()
223                 }
224                 b := h.writes[0]
225                 h.writes = h.writes[1:]
226                 h.writeMu.Unlock()
227                 _, err := h.conn.Write(b)
228                 if err != nil {
229                         h.writeMu.Lock()
230                         h.writeErr = err
231                         h.writeMu.Unlock()
232                         return
233                 }
234         }
235 }
236
237 func (h *handshake) postWrite(b []byte) error {
238         h.writeMu.Lock()
239         defer h.writeMu.Unlock()
240         if h.writeErr != nil {
241                 return h.writeErr
242         }
243         h.writes = append(h.writes, b)
244         h.writeCond.Signal()
245         return nil
246 }
247
248 func xor(dst, src []byte) (ret []byte) {
249         max := len(dst)
250         if max > len(src) {
251                 max = len(src)
252         }
253         ret = make([]byte, 0, max)
254         for i := range iter.N(max) {
255                 ret = append(ret, dst[i]^src[i])
256         }
257         return
258 }
259
260 type cryptoNegotiation struct {
261         VC     [8]byte
262         Method uint32
263         PadLen uint16
264         IA     []byte
265 }
266
267 func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
268         _, err = io.ReadFull(r, me.VC[:])
269         if err != nil {
270                 return
271         }
272         err = binary.Read(r, binary.BigEndian, &me.Method)
273         if err != nil {
274                 return
275         }
276         err = binary.Read(r, binary.BigEndian, &me.PadLen)
277         if err != nil {
278                 return
279         }
280         log.Print(me.PadLen)
281         _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
282         return
283 }
284
285 func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
286         _, err = w.Write(me.VC[:])
287         if err != nil {
288                 return
289         }
290         err = binary.Write(w, binary.BigEndian, me.Method)
291         if err != nil {
292                 return
293         }
294         err = binary.Write(w, binary.BigEndian, me.PadLen)
295         if err != nil {
296                 return
297         }
298         _, err = w.Write(make([]byte, me.PadLen))
299         return
300 }
301
302 // Looking for b at the end of a.
303 func suffixMatchLen(a, b []byte) int {
304         if len(b) > len(a) {
305                 b = b[:len(a)]
306         }
307         // i is how much of b to try to match
308         for i := len(b); i > 0; i-- {
309                 // j is how many chars we've compared
310                 j := 0
311                 for ; j < i; j++ {
312                         if b[i-1-j] != a[len(a)-1-j] {
313                                 goto shorter
314                         }
315                 }
316                 return j
317         shorter:
318         }
319         return 0
320 }
321
322 func readUntil(r io.Reader, b []byte) error {
323         log.Println("read until", b)
324         b1 := make([]byte, len(b))
325         i := 0
326         for {
327                 _, err := io.ReadFull(r, b1[i:])
328                 if err != nil {
329                         return err
330                 }
331                 i = suffixMatchLen(b1, b)
332                 if i == len(b) {
333                         break
334                 }
335                 if copy(b1, b1[len(b1)-i:]) != i {
336                         panic("wat")
337                 }
338         }
339         return nil
340 }
341
342 func (h *handshake) Do() (ret io.ReadWriteCloser, err error) {
343         err = h.establishS()
344         if err != nil {
345                 return
346         }
347         pad := make([]byte, newPadLen())
348         io.ReadFull(rand.Reader, pad)
349         err = h.postWrite(pad)
350         if err != nil {
351                 return
352         }
353         if h.initer {
354                 h.postWrite(hash(req1, h.s.Bytes()))
355                 h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
356                 buf := &bytes.Buffer{}
357                 err = (&cryptoNegotiation{
358                         Method: cryptoMethodRC4,
359                         PadLen: uint16(newPadLen()),
360                 }).MarshalWriter(buf)
361                 if err != nil {
362                         return
363                 }
364                 e := newEncrypt(true, h.s.Bytes(), h.skey)
365                 be := make([]byte, buf.Len())
366                 e.XORKeyStream(be, buf.Bytes())
367                 h.postWrite(be)
368                 bC := newEncrypt(false, h.s.Bytes(), h.skey)
369                 var eVC [8]byte
370                 bC.XORKeyStream(eVC[:], make([]byte, 8))
371                 log.Print(eVC)
372                 // Read until the all zero VC.
373                 err = readUntil(h.conn, eVC[:])
374                 if err != nil {
375                         err = fmt.Errorf("error reading until VC: %s", err)
376                         return
377                 }
378                 var cn cryptoNegotiation
379                 r := &cipherReader{bC, h.conn}
380                 err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
381                 log.Printf("initer got %v", cn)
382                 if err != nil {
383                         err = fmt.Errorf("error reading crypto negotiation: %s", err)
384                         return
385                 }
386         } else {
387                 err = readUntil(h.conn, hash(req1, h.s.Bytes()))
388                 if err != nil {
389                         return
390                 }
391                 var b [20]byte
392                 _, err = io.ReadFull(h.conn, b[:])
393                 if err != nil {
394                         return
395                 }
396                 if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
397                         err = errors.New("skey doesn't match")
398                         return
399                 }
400                 var cn cryptoNegotiation
401                 r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
402                 err = cn.UnmarshalReader(r)
403                 if err != nil {
404                         return
405                 }
406                 log.Printf("receiver got %v", cn)
407                 if cn.Method&cryptoMethodRC4 == 0 {
408                         err = errors.New("no supported crypto methods were provided")
409                         return
410                 }
411                 buf := &bytes.Buffer{}
412                 w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
413                 err = (&cryptoNegotiation{
414                         Method: cryptoMethodRC4,
415                         PadLen: uint16(newPadLen()),
416                 }).MarshalWriter(w)
417                 if err != nil {
418                         return
419                 }
420                 log.Println("encrypted VC", buf.Bytes()[:8])
421                 err = h.postWrite(buf.Bytes())
422                 if err != nil {
423                         return
424                 }
425         }
426         err = h.finishWriting()
427         if err != nil {
428                 return
429         }
430         ret = h.conn
431         return
432 }
433
434 func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriteCloser, err error) {
435         h := handshake{
436                 conn:   rw,
437                 initer: initer,
438                 skey:   skey,
439         }
440         h.writeCond.L = &h.writeMu
441         h.writerCond.L = &h.writerMu
442         go h.writer()
443         return h.Do()
444 }