]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Some utils moved to missinggo
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "github.com/anacrolix/missinggo"
21         "github.com/bradfitz/iter"
22 )
23
24 const (
25         maxPadLen = 512
26
27         cryptoMethodPlaintext = 1
28         cryptoMethodRC4       = 2
29 )
30
31 var (
32         // Prime P according to the spec, and G, the generator.
33         p, g big.Int
34         // The rand.Int max arg for use in newPadLen()
35         newPadLenMax big.Int
36         // For use in initer's hashes
37         req1 = []byte("req1")
38         req2 = []byte("req2")
39         req3 = []byte("req3")
40         // Verification constant "VC" which is all zeroes in the bittorrent
41         // implementation.
42         vc [8]byte
43         // Zero padding
44         zeroPad [512]byte
45         // Tracks counts of received crypto_provides
46         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
47 )
48
49 func init() {
50         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
51         g.SetInt64(2)
52         newPadLenMax.SetInt64(maxPadLen + 1)
53 }
54
55 func hash(parts ...[]byte) []byte {
56         h := sha1.New()
57         for _, p := range parts {
58                 n, err := h.Write(p)
59                 if err != nil {
60                         panic(err)
61                 }
62                 if n != len(p) {
63                         panic(n)
64                 }
65         }
66         return h.Sum(nil)
67 }
68
69 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
70         c, err := rc4.NewCipher(hash([]byte(func() string {
71                 if initer {
72                         return "keyA"
73                 } else {
74                         return "keyB"
75                 }
76         }()), s, skey))
77         if err != nil {
78                 panic(err)
79         }
80         var burnSrc, burnDst [1024]byte
81         c.XORKeyStream(burnDst[:], burnSrc[:])
82         return
83 }
84
85 type cipherReader struct {
86         c *rc4.Cipher
87         r io.Reader
88 }
89
90 func (me *cipherReader) Read(b []byte) (n int, err error) {
91         be := make([]byte, len(b))
92         n, err = me.r.Read(be)
93         me.c.XORKeyStream(b[:n], be[:n])
94         return
95 }
96
97 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
98         return &cipherReader{c, r}
99 }
100
101 type cipherWriter struct {
102         c *rc4.Cipher
103         w io.Writer
104 }
105
106 func (me *cipherWriter) Write(b []byte) (n int, err error) {
107         be := make([]byte, len(b))
108         me.c.XORKeyStream(be, b)
109         n, err = me.w.Write(be)
110         if n != len(be) {
111                 // The cipher will have advanced beyond the callers stream position.
112                 // We can't use the cipher anymore.
113                 me.c = nil
114         }
115         return
116 }
117
118 func readY(r io.Reader) (y big.Int, err error) {
119         var b [96]byte
120         _, err = io.ReadFull(r, b[:])
121         if err != nil {
122                 return
123         }
124         y.SetBytes(b[:])
125         return
126 }
127
128 func newX() big.Int {
129         var X big.Int
130         X.SetBytes(func() []byte {
131                 var b [20]byte
132                 _, err := rand.Read(b[:])
133                 if err != nil {
134                         panic(err)
135                 }
136                 return b[:]
137         }())
138         return X
139 }
140
141 func paddedLeft(b []byte, _len int) []byte {
142         if len(b) == _len {
143                 return b
144         }
145         ret := make([]byte, _len)
146         if n := copy(ret[_len-len(b):], b); n != len(b) {
147                 panic(n)
148         }
149         return ret
150 }
151
152 // Calculate, and send Y, our public key.
153 func (h *handshake) postY(x *big.Int) error {
154         var y big.Int
155         y.Exp(&g, x, &p)
156         return h.postWrite(paddedLeft(y.Bytes(), 96))
157 }
158
159 func (h *handshake) establishS() (err error) {
160         x := newX()
161         h.postY(&x)
162         var b [96]byte
163         _, err = io.ReadFull(h.conn, b[:])
164         if err != nil {
165                 return
166         }
167         var Y, S big.Int
168         Y.SetBytes(b[:])
169         S.Exp(&Y, &x, &p)
170         missinggo.CopyExact(&h.s, paddedLeft(S.Bytes(), 96))
171         return
172 }
173
174 func newPadLen() int64 {
175         i, err := rand.Int(rand.Reader, &newPadLenMax)
176         if err != nil {
177                 panic(err)
178         }
179         ret := i.Int64()
180         if ret < 0 || ret > maxPadLen {
181                 panic(ret)
182         }
183         return ret
184 }
185
186 type handshake struct {
187         conn   io.ReadWriter
188         s      [96]byte
189         initer bool
190         skeys  [][]byte
191         skey   []byte
192         ia     []byte // Initial payload. Only used by the initiator.
193
194         writeMu    sync.Mutex
195         writes     [][]byte
196         writeErr   error
197         writeCond  sync.Cond
198         writeClose bool
199
200         writerMu   sync.Mutex
201         writerCond sync.Cond
202         writerDone bool
203 }
204
205 func (h *handshake) finishWriting() {
206         h.writeMu.Lock()
207         h.writeClose = true
208         h.writeCond.Broadcast()
209         h.writeMu.Unlock()
210
211         h.writerMu.Lock()
212         for !h.writerDone {
213                 h.writerCond.Wait()
214         }
215         h.writerMu.Unlock()
216         return
217 }
218
219 func (h *handshake) writer() {
220         defer func() {
221                 h.writerMu.Lock()
222                 h.writerDone = true
223                 h.writerCond.Broadcast()
224                 h.writerMu.Unlock()
225         }()
226         for {
227                 h.writeMu.Lock()
228                 for {
229                         if len(h.writes) != 0 {
230                                 break
231                         }
232                         if h.writeClose {
233                                 h.writeMu.Unlock()
234                                 return
235                         }
236                         h.writeCond.Wait()
237                 }
238                 b := h.writes[0]
239                 h.writes = h.writes[1:]
240                 h.writeMu.Unlock()
241                 _, err := h.conn.Write(b)
242                 if err != nil {
243                         h.writeMu.Lock()
244                         h.writeErr = err
245                         h.writeMu.Unlock()
246                         return
247                 }
248         }
249 }
250
251 func (h *handshake) postWrite(b []byte) error {
252         h.writeMu.Lock()
253         defer h.writeMu.Unlock()
254         if h.writeErr != nil {
255                 return h.writeErr
256         }
257         h.writes = append(h.writes, b)
258         h.writeCond.Signal()
259         return nil
260 }
261
262 func xor(dst, src []byte) (ret []byte) {
263         max := len(dst)
264         if max > len(src) {
265                 max = len(src)
266         }
267         ret = make([]byte, 0, max)
268         for i := range iter.N(max) {
269                 ret = append(ret, dst[i]^src[i])
270         }
271         return
272 }
273
274 func marshal(w io.Writer, data ...interface{}) (err error) {
275         for _, data := range data {
276                 err = binary.Write(w, binary.BigEndian, data)
277                 if err != nil {
278                         break
279                 }
280         }
281         return
282 }
283
284 func unmarshal(r io.Reader, data ...interface{}) (err error) {
285         for _, data := range data {
286                 err = binary.Read(r, binary.BigEndian, data)
287                 if err != nil {
288                         break
289                 }
290         }
291         return
292 }
293
294 // Looking for b at the end of a.
295 func suffixMatchLen(a, b []byte) int {
296         if len(b) > len(a) {
297                 b = b[:len(a)]
298         }
299         // i is how much of b to try to match
300         for i := len(b); i > 0; i-- {
301                 // j is how many chars we've compared
302                 j := 0
303                 for ; j < i; j++ {
304                         if b[i-1-j] != a[len(a)-1-j] {
305                                 goto shorter
306                         }
307                 }
308                 return j
309         shorter:
310         }
311         return 0
312 }
313
314 func readUntil(r io.Reader, b []byte) error {
315         b1 := make([]byte, len(b))
316         i := 0
317         for {
318                 _, err := io.ReadFull(r, b1[i:])
319                 if err != nil {
320                         return err
321                 }
322                 i = suffixMatchLen(b1, b)
323                 if i == len(b) {
324                         break
325                 }
326                 if copy(b1, b1[len(b1)-i:]) != i {
327                         panic("wat")
328                 }
329         }
330         return nil
331 }
332
333 type readWriter struct {
334         io.Reader
335         io.Writer
336 }
337
338 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
339         return newEncrypt(initer, h.s[:], h.skey)
340 }
341
342 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
343         h.postWrite(hash(req1, h.s[:]))
344         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
345         buf := &bytes.Buffer{}
346         padLen := uint16(newPadLen())
347         err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
348         if err != nil {
349                 return
350         }
351         e := h.newEncrypt(true)
352         be := make([]byte, buf.Len())
353         e.XORKeyStream(be, buf.Bytes())
354         h.postWrite(be)
355         bC := h.newEncrypt(false)
356         var eVC [8]byte
357         bC.XORKeyStream(eVC[:], vc[:])
358         // Read until the all zero VC. At this point we've only read the 96 byte
359         // public key, Y. There is potentially 512 byte padding, between us and
360         // the 8 byte verification constant.
361         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
362         if err != nil {
363                 if err == io.EOF {
364                         err = errors.New("failed to synchronize on VC")
365                 } else {
366                         err = fmt.Errorf("error reading until VC: %s", err)
367                 }
368                 return
369         }
370         r := &cipherReader{bC, h.conn}
371         var method uint32
372         err = unmarshal(r, &method, &padLen)
373         if err != nil {
374                 return
375         }
376         if method != cryptoMethodRC4 {
377                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
378                 return
379         }
380         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
381         if err != nil {
382                 return
383         }
384         ret = readWriter{r, &cipherWriter{e, h.conn}}
385         return
386 }
387
388 var ErrNoSecretKeyMatch = errors.New("no skey matched")
389
390 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
391         // There is up to 512 bytes of padding, then the 20 byte hash.
392         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
393         if err != nil {
394                 if err == io.EOF {
395                         err = errors.New("failed to synchronize on S hash")
396                 }
397                 return
398         }
399         var b [20]byte
400         _, err = io.ReadFull(h.conn, b[:])
401         if err != nil {
402                 return
403         }
404         err = ErrNoSecretKeyMatch
405         for _, skey := range h.skeys {
406                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
407                         h.skey = skey
408                         err = nil
409                         break
410                 }
411         }
412         if err != nil {
413                 return
414         }
415         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
416         var (
417                 vc     [8]byte
418                 method uint32
419                 padLen uint16
420         )
421
422         err = unmarshal(r, vc[:], &method, &padLen)
423         if err != nil {
424                 return
425         }
426         cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
427         if method&cryptoMethodRC4 == 0 {
428                 err = errors.New("no supported crypto methods were provided")
429                 return
430         }
431         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
432         if err != nil {
433                 return
434         }
435         var lenIA uint16
436         unmarshal(r, &lenIA)
437         if lenIA != 0 {
438                 h.ia = make([]byte, lenIA)
439                 unmarshal(r, h.ia)
440         }
441         buf := &bytes.Buffer{}
442         w := cipherWriter{h.newEncrypt(false), buf}
443         padLen = uint16(newPadLen())
444         err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
445         if err != nil {
446                 return
447         }
448         err = h.postWrite(buf.Bytes())
449         if err != nil {
450                 return
451         }
452         ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
453         return
454 }
455
456 func (h *handshake) Do() (ret io.ReadWriter, err error) {
457         h.writeCond.L = &h.writeMu
458         h.writerCond.L = &h.writerMu
459         go h.writer()
460         defer func() {
461                 h.finishWriting()
462                 if err == nil {
463                         err = h.writeErr
464                 }
465         }()
466         err = h.establishS()
467         if err != nil {
468                 err = fmt.Errorf("error while establishing secret: %s", err)
469                 return
470         }
471         pad := make([]byte, newPadLen())
472         io.ReadFull(rand.Reader, pad)
473         err = h.postWrite(pad)
474         if err != nil {
475                 return
476         }
477         if h.initer {
478                 ret, err = h.initerSteps()
479         } else {
480                 ret, err = h.receiverSteps()
481         }
482         return
483 }
484
485 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
486         h := handshake{
487                 conn:   rw,
488                 initer: true,
489                 skey:   skey,
490                 ia:     initialPayload,
491         }
492         return h.Do()
493 }
494 func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
495         h := handshake{
496                 conn:   rw,
497                 initer: false,
498                 skeys:  skeys,
499         }
500         return h.Do()
501 }