]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
mse: Clean-up
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "bitbucket.org/anacrolix/go.torrent/util"
21
22         "github.com/bradfitz/iter"
23 )
24
25 const (
26         maxPadLen = 512
27
28         cryptoMethodPlaintext = 1
29         cryptoMethodRC4       = 2
30 )
31
32 var (
33         // Prime P according to the spec, and G, the generator.
34         p, g big.Int
35         // The rand.Int max arg for use in newPadLen()
36         newPadLenMax big.Int
37         // For use in initer's hashes
38         req1 = []byte("req1")
39         req2 = []byte("req2")
40         req3 = []byte("req3")
41         // Verification constant "VC" which is all zeroes in the bittorrent
42         // implementation.
43         vc [8]byte
44         // Zero padding
45         zeroPad [512]byte
46         // Tracks counts of received crypto_provides
47         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
48 )
49
50 func init() {
51         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
52         g.SetInt64(2)
53         newPadLenMax.SetInt64(maxPadLen + 1)
54 }
55
56 func hash(parts ...[]byte) []byte {
57         h := sha1.New()
58         for _, p := range parts {
59                 n, err := h.Write(p)
60                 if err != nil {
61                         panic(err)
62                 }
63                 if n != len(p) {
64                         panic(n)
65                 }
66         }
67         return h.Sum(nil)
68 }
69
70 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
71         c, err := rc4.NewCipher(hash([]byte(func() string {
72                 if initer {
73                         return "keyA"
74                 } else {
75                         return "keyB"
76                 }
77         }()), s, skey))
78         if err != nil {
79                 panic(err)
80         }
81         var burnSrc, burnDst [1024]byte
82         c.XORKeyStream(burnDst[:], burnSrc[:])
83         return
84 }
85
86 type cipherReader struct {
87         c *rc4.Cipher
88         r io.Reader
89 }
90
91 func (me *cipherReader) Read(b []byte) (n int, err error) {
92         be := make([]byte, len(b))
93         n, err = me.r.Read(be)
94         me.c.XORKeyStream(b[:n], be[:n])
95         return
96 }
97
98 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
99         return &cipherReader{c, r}
100 }
101
102 type cipherWriter struct {
103         c *rc4.Cipher
104         w io.Writer
105 }
106
107 func (me *cipherWriter) Write(b []byte) (n int, err error) {
108         be := make([]byte, len(b))
109         me.c.XORKeyStream(be, b)
110         n, err = me.w.Write(be)
111         if n != len(be) {
112                 // The cipher will have advanced beyond the callers stream position.
113                 // We can't use the cipher anymore.
114                 me.c = nil
115         }
116         return
117 }
118
119 func readY(r io.Reader) (y big.Int, err error) {
120         var b [96]byte
121         _, err = io.ReadFull(r, b[:])
122         if err != nil {
123                 return
124         }
125         y.SetBytes(b[:])
126         return
127 }
128
129 func newX() big.Int {
130         var X big.Int
131         X.SetBytes(func() []byte {
132                 var b [20]byte
133                 _, err := rand.Read(b[:])
134                 if err != nil {
135                         panic(err)
136                 }
137                 return b[:]
138         }())
139         return X
140 }
141
142 func paddedLeft(b []byte, _len int) []byte {
143         if len(b) == _len {
144                 return b
145         }
146         ret := make([]byte, _len)
147         if n := copy(ret[_len-len(b):], b); n != len(b) {
148                 panic(n)
149         }
150         return ret
151 }
152
153 // Calculate, and send Y, our public key.
154 func (h *handshake) postY(x *big.Int) error {
155         var y big.Int
156         y.Exp(&g, x, &p)
157         return h.postWrite(paddedLeft(y.Bytes(), 96))
158 }
159
160 func (h *handshake) establishS() (err error) {
161         x := newX()
162         h.postY(&x)
163         var b [96]byte
164         _, err = io.ReadFull(h.conn, b[:])
165         if err != nil {
166                 return
167         }
168         var Y, S big.Int
169         Y.SetBytes(b[:])
170         S.Exp(&Y, &x, &p)
171         util.CopyExact(&h.s, paddedLeft(S.Bytes(), 96))
172         return
173 }
174
175 func newPadLen() int64 {
176         i, err := rand.Int(rand.Reader, &newPadLenMax)
177         if err != nil {
178                 panic(err)
179         }
180         ret := i.Int64()
181         if ret < 0 || ret > maxPadLen {
182                 panic(ret)
183         }
184         return ret
185 }
186
187 type handshake struct {
188         conn   io.ReadWriter
189         s      [96]byte
190         initer bool
191         skeys  [][]byte
192         skey   []byte
193         ia     []byte // Initial payload. Only used by the initiator.
194
195         writeMu    sync.Mutex
196         writes     [][]byte
197         writeErr   error
198         writeCond  sync.Cond
199         writeClose bool
200
201         writerMu   sync.Mutex
202         writerCond sync.Cond
203         writerDone bool
204 }
205
206 func (h *handshake) finishWriting() {
207         h.writeMu.Lock()
208         h.writeClose = true
209         h.writeCond.Broadcast()
210         h.writeMu.Unlock()
211
212         h.writerMu.Lock()
213         for !h.writerDone {
214                 h.writerCond.Wait()
215         }
216         h.writerMu.Unlock()
217         return
218 }
219
220 func (h *handshake) writer() {
221         defer func() {
222                 h.writerMu.Lock()
223                 h.writerDone = true
224                 h.writerCond.Broadcast()
225                 h.writerMu.Unlock()
226         }()
227         for {
228                 h.writeMu.Lock()
229                 for {
230                         if len(h.writes) != 0 {
231                                 break
232                         }
233                         if h.writeClose {
234                                 h.writeMu.Unlock()
235                                 return
236                         }
237                         h.writeCond.Wait()
238                 }
239                 b := h.writes[0]
240                 h.writes = h.writes[1:]
241                 h.writeMu.Unlock()
242                 _, err := h.conn.Write(b)
243                 if err != nil {
244                         h.writeMu.Lock()
245                         h.writeErr = err
246                         h.writeMu.Unlock()
247                         return
248                 }
249         }
250 }
251
252 func (h *handshake) postWrite(b []byte) error {
253         h.writeMu.Lock()
254         defer h.writeMu.Unlock()
255         if h.writeErr != nil {
256                 return h.writeErr
257         }
258         h.writes = append(h.writes, b)
259         h.writeCond.Signal()
260         return nil
261 }
262
263 func xor(dst, src []byte) (ret []byte) {
264         max := len(dst)
265         if max > len(src) {
266                 max = len(src)
267         }
268         ret = make([]byte, 0, max)
269         for i := range iter.N(max) {
270                 ret = append(ret, dst[i]^src[i])
271         }
272         return
273 }
274
275 func marshal(w io.Writer, data ...interface{}) (err error) {
276         for _, data := range data {
277                 err = binary.Write(w, binary.BigEndian, data)
278                 if err != nil {
279                         break
280                 }
281         }
282         return
283 }
284
285 func unmarshal(r io.Reader, data ...interface{}) (err error) {
286         for _, data := range data {
287                 err = binary.Read(r, binary.BigEndian, data)
288                 if err != nil {
289                         break
290                 }
291         }
292         return
293 }
294
295 // Looking for b at the end of a.
296 func suffixMatchLen(a, b []byte) int {
297         if len(b) > len(a) {
298                 b = b[:len(a)]
299         }
300         // i is how much of b to try to match
301         for i := len(b); i > 0; i-- {
302                 // j is how many chars we've compared
303                 j := 0
304                 for ; j < i; j++ {
305                         if b[i-1-j] != a[len(a)-1-j] {
306                                 goto shorter
307                         }
308                 }
309                 return j
310         shorter:
311         }
312         return 0
313 }
314
315 func readUntil(r io.Reader, b []byte) error {
316         b1 := make([]byte, len(b))
317         i := 0
318         for {
319                 _, err := io.ReadFull(r, b1[i:])
320                 if err != nil {
321                         return err
322                 }
323                 i = suffixMatchLen(b1, b)
324                 if i == len(b) {
325                         break
326                 }
327                 if copy(b1, b1[len(b1)-i:]) != i {
328                         panic("wat")
329                 }
330         }
331         return nil
332 }
333
334 type readWriter struct {
335         io.Reader
336         io.Writer
337 }
338
339 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
340         return newEncrypt(initer, h.s[:], h.skey)
341 }
342
343 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
344         h.postWrite(hash(req1, h.s[:]))
345         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
346         buf := &bytes.Buffer{}
347         padLen := uint16(newPadLen())
348         err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
349         if err != nil {
350                 return
351         }
352         e := h.newEncrypt(true)
353         be := make([]byte, buf.Len())
354         e.XORKeyStream(be, buf.Bytes())
355         h.postWrite(be)
356         bC := h.newEncrypt(false)
357         var eVC [8]byte
358         bC.XORKeyStream(eVC[:], vc[:])
359         // Read until the all zero VC. At this point we've only read the 96 byte
360         // public key, Y. There is potentially 512 byte padding, between us and
361         // the 8 byte verification constant.
362         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
363         if err != nil {
364                 if err == io.EOF {
365                         err = errors.New("failed to synchronize on VC")
366                 } else {
367                         err = fmt.Errorf("error reading until VC: %s", err)
368                 }
369                 return
370         }
371         r := &cipherReader{bC, h.conn}
372         var method uint32
373         err = unmarshal(r, &method, &padLen)
374         if err != nil {
375                 return
376         }
377         if method != cryptoMethodRC4 {
378                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
379                 return
380         }
381         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
382         if err != nil {
383                 return
384         }
385         ret = readWriter{r, &cipherWriter{e, h.conn}}
386         return
387 }
388
389 var ErrNoSecretKeyMatch = errors.New("no skey matched")
390
391 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
392         // There is up to 512 bytes of padding, then the 20 byte hash.
393         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
394         if err != nil {
395                 if err == io.EOF {
396                         err = errors.New("failed to synchronize on S hash")
397                 }
398                 return
399         }
400         var b [20]byte
401         _, err = io.ReadFull(h.conn, b[:])
402         if err != nil {
403                 return
404         }
405         err = ErrNoSecretKeyMatch
406         for _, skey := range h.skeys {
407                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
408                         h.skey = skey
409                         err = nil
410                         break
411                 }
412         }
413         if err != nil {
414                 return
415         }
416         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
417         var (
418                 vc     [8]byte
419                 method uint32
420                 padLen uint16
421         )
422
423         err = unmarshal(r, vc[:], &method, &padLen)
424         if err != nil {
425                 return
426         }
427         cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
428         if method&cryptoMethodRC4 == 0 {
429                 err = errors.New("no supported crypto methods were provided")
430                 return
431         }
432         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
433         if err != nil {
434                 return
435         }
436         var lenIA uint16
437         unmarshal(r, &lenIA)
438         if lenIA != 0 {
439                 h.ia = make([]byte, lenIA)
440                 unmarshal(r, h.ia)
441         }
442         buf := &bytes.Buffer{}
443         w := cipherWriter{h.newEncrypt(false), buf}
444         padLen = uint16(newPadLen())
445         err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
446         if err != nil {
447                 return
448         }
449         err = h.postWrite(buf.Bytes())
450         if err != nil {
451                 return
452         }
453         ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
454         return
455 }
456
457 func (h *handshake) Do() (ret io.ReadWriter, err error) {
458         h.writeCond.L = &h.writeMu
459         h.writerCond.L = &h.writerMu
460         go h.writer()
461         defer func() {
462                 h.finishWriting()
463                 if err == nil {
464                         err = h.writeErr
465                 }
466         }()
467         err = h.establishS()
468         if err != nil {
469                 err = fmt.Errorf("error while establishing secret: %s", err)
470                 return
471         }
472         pad := make([]byte, newPadLen())
473         io.ReadFull(rand.Reader, pad)
474         err = h.postWrite(pad)
475         if err != nil {
476                 return
477         }
478         if h.initer {
479                 ret, err = h.initerSteps()
480         } else {
481                 ret, err = h.receiverSteps()
482         }
483         return
484 }
485
486 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
487         h := handshake{
488                 conn:   rw,
489                 initer: true,
490                 skey:   skey,
491                 ia:     initialPayload,
492         }
493         return h.Do()
494 }
495 func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
496         h := handshake{
497                 conn:   rw,
498                 initer: false,
499                 skeys:  skeys,
500         }
501         return h.Do()
502 }