]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
msg: Return usable object after handshake
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "fmt"
13         "io"
14         "io/ioutil"
15         "log"
16         "math/big"
17         "sync"
18
19         "github.com/bradfitz/iter"
20 )
21
22 const (
23         maxPadLen = 512
24
25         cryptoMethodPlaintext = 1
26         cryptoMethodRC4       = 2
27 )
28
29 var (
30         // Prime P according to the spec, and G, the generator.
31         p, g big.Int
32         // The rand.Int max arg for use in newPadLen()
33         newPadLenMax big.Int
34         // For use in initer's hashes
35         req1 = []byte("req1")
36         req2 = []byte("req2")
37         req3 = []byte("req3")
38 )
39
40 func init() {
41         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
42         g.SetInt64(2)
43         newPadLenMax.SetInt64(maxPadLen + 1)
44 }
45
46 func hash(parts ...[]byte) []byte {
47         h := sha1.New()
48         for _, p := range parts {
49                 n, err := h.Write(p)
50                 if err != nil {
51                         panic(err)
52                 }
53                 if n != len(p) {
54                         panic(n)
55                 }
56         }
57         return h.Sum(nil)
58 }
59
60 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
61         c, err := rc4.NewCipher(hash([]byte(func() string {
62                 if initer {
63                         return "keyA"
64                 } else {
65                         return "keyB"
66                 }
67         }()), s, skey))
68         if err != nil {
69                 panic(err)
70         }
71         var burnSrc, burnDst [1024]byte
72         c.XORKeyStream(burnDst[:], burnSrc[:])
73         return
74 }
75
76 type cipherReader struct {
77         c *rc4.Cipher
78         r io.Reader
79 }
80
81 func (me *cipherReader) Read(b []byte) (n int, err error) {
82         be := make([]byte, len(b))
83         n, err = me.r.Read(be)
84         me.c.XORKeyStream(b[:n], be[:n])
85         return
86 }
87
88 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
89         return &cipherReader{c, r}
90 }
91
92 type cipherWriter struct {
93         c *rc4.Cipher
94         w io.Writer
95 }
96
97 func (me *cipherWriter) Write(b []byte) (n int, err error) {
98         be := make([]byte, len(b))
99         me.c.XORKeyStream(be, b)
100         n, err = me.w.Write(be)
101         if n != len(be) {
102                 // The cipher will have advanced beyond the callers stream position.
103                 // We can't use the cipher anymore.
104                 me.c = nil
105         }
106         return
107 }
108
109 func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
110         return &cipherWriter{c, w}
111 }
112
113 func readY(r io.Reader) (y big.Int, err error) {
114         var b [96]byte
115         _, err = io.ReadFull(r, b[:])
116         if err != nil {
117                 return
118         }
119         y.SetBytes(b[:])
120         return
121 }
122
123 func newX() big.Int {
124         var X big.Int
125         X.SetBytes(func() []byte {
126                 var b [20]byte
127                 _, err := rand.Read(b[:])
128                 if err != nil {
129                         panic(err)
130                 }
131                 return b[:]
132         }())
133         return X
134 }
135
136 func (h *handshake) postY(x *big.Int) error {
137         var y big.Int
138         y.Exp(&g, x, &p)
139         b := y.Bytes()
140         if len(b) != 96 {
141                 panic(len(b))
142         }
143         return h.postWrite(b)
144 }
145
146 func (h *handshake) establishS() (err error) {
147         x := newX()
148         h.postY(&x)
149         var b [96]byte
150         _, err = io.ReadFull(h.conn, b[:])
151         if err != nil {
152                 return
153         }
154         var Y big.Int
155         Y.SetBytes(b[:])
156         h.s.Exp(&Y, &x, &p)
157         return
158 }
159
160 func newPadLen() int64 {
161         i, err := rand.Int(rand.Reader, &newPadLenMax)
162         if err != nil {
163                 panic(err)
164         }
165         ret := i.Int64()
166         if ret < 0 || ret > maxPadLen {
167                 panic(ret)
168         }
169         return ret
170 }
171
172 type handshake struct {
173         conn   io.ReadWriteCloser
174         s      big.Int
175         initer bool
176         skey   []byte
177
178         writeMu    sync.Mutex
179         writes     [][]byte
180         writeErr   error
181         writeCond  sync.Cond
182         writeClose bool
183
184         writerMu   sync.Mutex
185         writerCond sync.Cond
186         writerDone bool
187 }
188
189 func (h *handshake) finishWriting() (err error) {
190         h.writeMu.Lock()
191         h.writeClose = true
192         h.writeCond.Broadcast()
193         err = h.writeErr
194         h.writeMu.Unlock()
195
196         h.writerMu.Lock()
197         for !h.writerDone {
198                 h.writerCond.Wait()
199         }
200         h.writerMu.Unlock()
201         return
202
203 }
204
205 func (h *handshake) writer() {
206         defer func() {
207                 h.writerMu.Lock()
208                 h.writerDone = true
209                 h.writerCond.Broadcast()
210                 h.writerMu.Unlock()
211         }()
212         for {
213                 h.writeMu.Lock()
214                 for {
215                         if len(h.writes) != 0 {
216                                 break
217                         }
218                         if h.writeClose {
219                                 h.writeMu.Unlock()
220                                 return
221                         }
222                         h.writeCond.Wait()
223                 }
224                 b := h.writes[0]
225                 h.writes = h.writes[1:]
226                 h.writeMu.Unlock()
227                 _, err := h.conn.Write(b)
228                 if err != nil {
229                         h.writeMu.Lock()
230                         h.writeErr = err
231                         h.writeMu.Unlock()
232                         return
233                 }
234         }
235 }
236
237 func (h *handshake) postWrite(b []byte) error {
238         h.writeMu.Lock()
239         defer h.writeMu.Unlock()
240         if h.writeErr != nil {
241                 return h.writeErr
242         }
243         h.writes = append(h.writes, b)
244         h.writeCond.Signal()
245         return nil
246 }
247
248 func xor(dst, src []byte) (ret []byte) {
249         max := len(dst)
250         if max > len(src) {
251                 max = len(src)
252         }
253         ret = make([]byte, 0, max)
254         for i := range iter.N(max) {
255                 ret = append(ret, dst[i]^src[i])
256         }
257         return
258 }
259
260 type cryptoNegotiation struct {
261         VC     [8]byte
262         Method uint32
263         PadLen uint16
264         IA     []byte
265 }
266
267 func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
268         _, err = io.ReadFull(r, me.VC[:])
269         if err != nil {
270                 return
271         }
272         err = binary.Read(r, binary.BigEndian, &me.Method)
273         if err != nil {
274                 return
275         }
276         err = binary.Read(r, binary.BigEndian, &me.PadLen)
277         if err != nil {
278                 return
279         }
280         log.Print(me.PadLen)
281         _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
282         return
283 }
284
285 func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
286         _, err = w.Write(me.VC[:])
287         if err != nil {
288                 return
289         }
290         err = binary.Write(w, binary.BigEndian, me.Method)
291         if err != nil {
292                 return
293         }
294         err = binary.Write(w, binary.BigEndian, me.PadLen)
295         if err != nil {
296                 return
297         }
298         _, err = w.Write(make([]byte, me.PadLen))
299         return
300 }
301
302 // Looking for b at the end of a.
303 func suffixMatchLen(a, b []byte) int {
304         if len(b) > len(a) {
305                 b = b[:len(a)]
306         }
307         // i is how much of b to try to match
308         for i := len(b); i > 0; i-- {
309                 // j is how many chars we've compared
310                 j := 0
311                 for ; j < i; j++ {
312                         if b[i-1-j] != a[len(a)-1-j] {
313                                 goto shorter
314                         }
315                 }
316                 return j
317         shorter:
318         }
319         return 0
320 }
321
322 func readUntil(r io.Reader, b []byte) error {
323         log.Println("read until", b)
324         b1 := make([]byte, len(b))
325         i := 0
326         for {
327                 _, err := io.ReadFull(r, b1[i:])
328                 if err != nil {
329                         return err
330                 }
331                 i = suffixMatchLen(b1, b)
332                 if i == len(b) {
333                         break
334                 }
335                 if copy(b1, b1[len(b1)-i:]) != i {
336                         panic("wat")
337                 }
338         }
339         return nil
340 }
341
342 type readWriter struct {
343         io.Reader
344         io.Writer
345 }
346
347 func (h *handshake) Do() (ret io.ReadWriter, err error) {
348         err = h.establishS()
349         if err != nil {
350                 return
351         }
352         pad := make([]byte, newPadLen())
353         io.ReadFull(rand.Reader, pad)
354         err = h.postWrite(pad)
355         if err != nil {
356                 return
357         }
358         if h.initer {
359                 h.postWrite(hash(req1, h.s.Bytes()))
360                 h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
361                 buf := &bytes.Buffer{}
362                 err = (&cryptoNegotiation{
363                         Method: cryptoMethodRC4,
364                         PadLen: uint16(newPadLen()),
365                 }).MarshalWriter(buf)
366                 if err != nil {
367                         return
368                 }
369                 e := newEncrypt(true, h.s.Bytes(), h.skey)
370                 be := make([]byte, buf.Len())
371                 e.XORKeyStream(be, buf.Bytes())
372                 h.postWrite(be)
373                 bC := newEncrypt(false, h.s.Bytes(), h.skey)
374                 var eVC [8]byte
375                 bC.XORKeyStream(eVC[:], make([]byte, 8))
376                 log.Print(eVC)
377                 // Read until the all zero VC.
378                 err = readUntil(h.conn, eVC[:])
379                 if err != nil {
380                         err = fmt.Errorf("error reading until VC: %s", err)
381                         return
382                 }
383                 var cn cryptoNegotiation
384                 r := &cipherReader{bC, h.conn}
385                 err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
386                 log.Printf("initer got %v", cn)
387                 if err != nil {
388                         err = fmt.Errorf("error reading crypto negotiation: %s", err)
389                         return
390                 }
391                 ret = readWriter{r, &cipherWriter{bC, h.conn}}
392         } else {
393                 err = readUntil(h.conn, hash(req1, h.s.Bytes()))
394                 if err != nil {
395                         return
396                 }
397                 var b [20]byte
398                 _, err = io.ReadFull(h.conn, b[:])
399                 if err != nil {
400                         return
401                 }
402                 if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
403                         err = errors.New("skey doesn't match")
404                         return
405                 }
406                 var cn cryptoNegotiation
407                 r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
408                 err = cn.UnmarshalReader(r)
409                 if err != nil {
410                         return
411                 }
412                 log.Printf("receiver got %v", cn)
413                 if cn.Method&cryptoMethodRC4 == 0 {
414                         err = errors.New("no supported crypto methods were provided")
415                         return
416                 }
417                 buf := &bytes.Buffer{}
418                 w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
419                 err = (&cryptoNegotiation{
420                         Method: cryptoMethodRC4,
421                         PadLen: uint16(newPadLen()),
422                 }).MarshalWriter(w)
423                 if err != nil {
424                         return
425                 }
426                 log.Println("encrypted VC", buf.Bytes()[:8])
427                 err = h.postWrite(buf.Bytes())
428                 if err != nil {
429                         return
430                 }
431                 ret = readWriter{r, w}
432         }
433         err = h.finishWriting()
434         if err != nil {
435                 return
436         }
437         ret = h.conn
438         return
439 }
440
441 func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) {
442         h := handshake{
443                 conn:   rw,
444                 initer: initer,
445                 skey:   skey,
446         }
447         h.writeCond.L = &h.writeMu
448         h.writerCond.L = &h.writerMu
449         go h.writer()
450         return h.Do()
451 }