]> Sergey Matveev's repositories - btrtrc.git/blob - mse/mse.go
Speed up mse.handshake.establishS
[btrtrc.git] / mse / mse.go
1 // https://wiki.vuze.com/w/Message_Stream_Encryption
2
3 package mse
4
5 import (
6         "bytes"
7         "crypto/rand"
8         "crypto/rc4"
9         "crypto/sha1"
10         "encoding/binary"
11         "errors"
12         "expvar"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/big"
17         "strconv"
18         "sync"
19
20         "github.com/bradfitz/iter"
21 )
22
23 const (
24         maxPadLen = 512
25
26         cryptoMethodPlaintext = 1
27         cryptoMethodRC4       = 2
28 )
29
30 var (
31         // Prime P according to the spec, and G, the generator.
32         p, g big.Int
33         // The rand.Int max arg for use in newPadLen()
34         newPadLenMax big.Int
35         // For use in initer's hashes
36         req1 = []byte("req1")
37         req2 = []byte("req2")
38         req3 = []byte("req3")
39         // Verification constant "VC" which is all zeroes in the bittorrent
40         // implementation.
41         vc [8]byte
42         // Zero padding
43         zeroPad [512]byte
44         // Tracks counts of received crypto_provides
45         cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
46 )
47
48 func init() {
49         p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
50         g.SetInt64(2)
51         newPadLenMax.SetInt64(maxPadLen + 1)
52 }
53
54 func hash(parts ...[]byte) []byte {
55         h := sha1.New()
56         for _, p := range parts {
57                 n, err := h.Write(p)
58                 if err != nil {
59                         panic(err)
60                 }
61                 if n != len(p) {
62                         panic(n)
63                 }
64         }
65         return h.Sum(nil)
66 }
67
68 func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) {
69         c, err := rc4.NewCipher(hash([]byte(func() string {
70                 if initer {
71                         return "keyA"
72                 } else {
73                         return "keyB"
74                 }
75         }()), s, skey))
76         if err != nil {
77                 panic(err)
78         }
79         var burnSrc, burnDst [1024]byte
80         c.XORKeyStream(burnDst[:], burnSrc[:])
81         return
82 }
83
84 type cipherReader struct {
85         c *rc4.Cipher
86         r io.Reader
87 }
88
89 func (cr *cipherReader) Read(b []byte) (n int, err error) {
90         be := make([]byte, len(b))
91         n, err = cr.r.Read(be)
92         cr.c.XORKeyStream(b[:n], be[:n])
93         return
94 }
95
96 func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
97         return &cipherReader{c, r}
98 }
99
100 type cipherWriter struct {
101         c *rc4.Cipher
102         w io.Writer
103 }
104
105 func (cr *cipherWriter) Write(b []byte) (n int, err error) {
106         be := make([]byte, len(b))
107         cr.c.XORKeyStream(be, b)
108         n, err = cr.w.Write(be)
109         if n != len(be) {
110                 // The cipher will have advanced beyond the callers stream position.
111                 // We can't use the cipher anymore.
112                 cr.c = nil
113         }
114         return
115 }
116
117 func readY(r io.Reader) (y big.Int, err error) {
118         var b [96]byte
119         _, err = io.ReadFull(r, b[:])
120         if err != nil {
121                 return
122         }
123         y.SetBytes(b[:])
124         return
125 }
126
127 func newX() big.Int {
128         var X big.Int
129         X.SetBytes(func() []byte {
130                 var b [20]byte
131                 _, err := rand.Read(b[:])
132                 if err != nil {
133                         panic(err)
134                 }
135                 return b[:]
136         }())
137         return X
138 }
139
140 func paddedLeft(b []byte, _len int) []byte {
141         if len(b) == _len {
142                 return b
143         }
144         ret := make([]byte, _len)
145         if n := copy(ret[_len-len(b):], b); n != len(b) {
146                 panic(n)
147         }
148         return ret
149 }
150
151 // Calculate, and send Y, our public key.
152 func (h *handshake) postY(x *big.Int) error {
153         var y big.Int
154         y.Exp(&g, x, &p)
155         return h.postWrite(paddedLeft(y.Bytes(), 96))
156 }
157
158 func (h *handshake) establishS() (err error) {
159         x := newX()
160         h.postY(&x)
161         var b [96]byte
162         _, err = io.ReadFull(h.conn, b[:])
163         if err != nil {
164                 return
165         }
166         var Y, S big.Int
167         Y.SetBytes(b[:])
168         S.Exp(&Y, &x, &p)
169         sBytes := S.Bytes()
170         copy(h.s[96-len(sBytes):96], sBytes)
171         return
172 }
173
174 func newPadLen() int64 {
175         i, err := rand.Int(rand.Reader, &newPadLenMax)
176         if err != nil {
177                 panic(err)
178         }
179         ret := i.Int64()
180         if ret < 0 || ret > maxPadLen {
181                 panic(ret)
182         }
183         return ret
184 }
185
186 // Manages state for both initiating and receiving handshakes.
187 type handshake struct {
188         conn   io.ReadWriter
189         s      [96]byte
190         initer bool     // Whether we're initiating or receiving.
191         skeys  [][]byte // Skeys we'll accept if receiving.
192         skey   []byte   // Skey we're initiating with.
193         ia     []byte   // Initial payload. Only used by the initiator.
194
195         writeMu    sync.Mutex
196         writes     [][]byte
197         writeErr   error
198         writeCond  sync.Cond
199         writeClose bool
200
201         writerMu   sync.Mutex
202         writerCond sync.Cond
203         writerDone bool
204 }
205
206 func (h *handshake) finishWriting() {
207         h.writeMu.Lock()
208         h.writeClose = true
209         h.writeCond.Broadcast()
210         h.writeMu.Unlock()
211
212         h.writerMu.Lock()
213         for !h.writerDone {
214                 h.writerCond.Wait()
215         }
216         h.writerMu.Unlock()
217         return
218 }
219
220 func (h *handshake) writer() {
221         defer func() {
222                 h.writerMu.Lock()
223                 h.writerDone = true
224                 h.writerCond.Broadcast()
225                 h.writerMu.Unlock()
226         }()
227         for {
228                 h.writeMu.Lock()
229                 for {
230                         if len(h.writes) != 0 {
231                                 break
232                         }
233                         if h.writeClose {
234                                 h.writeMu.Unlock()
235                                 return
236                         }
237                         h.writeCond.Wait()
238                 }
239                 b := h.writes[0]
240                 h.writes = h.writes[1:]
241                 h.writeMu.Unlock()
242                 _, err := h.conn.Write(b)
243                 if err != nil {
244                         h.writeMu.Lock()
245                         h.writeErr = err
246                         h.writeMu.Unlock()
247                         return
248                 }
249         }
250 }
251
252 func (h *handshake) postWrite(b []byte) error {
253         h.writeMu.Lock()
254         defer h.writeMu.Unlock()
255         if h.writeErr != nil {
256                 return h.writeErr
257         }
258         h.writes = append(h.writes, b)
259         h.writeCond.Signal()
260         return nil
261 }
262
263 func xor(dst, src []byte) (ret []byte) {
264         max := len(dst)
265         if max > len(src) {
266                 max = len(src)
267         }
268         ret = make([]byte, 0, max)
269         for i := range iter.N(max) {
270                 ret = append(ret, dst[i]^src[i])
271         }
272         return
273 }
274
275 func marshal(w io.Writer, data ...interface{}) (err error) {
276         for _, data := range data {
277                 err = binary.Write(w, binary.BigEndian, data)
278                 if err != nil {
279                         break
280                 }
281         }
282         return
283 }
284
285 func unmarshal(r io.Reader, data ...interface{}) (err error) {
286         for _, data := range data {
287                 err = binary.Read(r, binary.BigEndian, data)
288                 if err != nil {
289                         break
290                 }
291         }
292         return
293 }
294
295 // Looking for b at the end of a.
296 func suffixMatchLen(a, b []byte) int {
297         if len(b) > len(a) {
298                 b = b[:len(a)]
299         }
300         // i is how much of b to try to match
301         for i := len(b); i > 0; i-- {
302                 // j is how many chars we've compared
303                 j := 0
304                 for ; j < i; j++ {
305                         if b[i-1-j] != a[len(a)-1-j] {
306                                 goto shorter
307                         }
308                 }
309                 return j
310         shorter:
311         }
312         return 0
313 }
314
315 // Reads from r until b has been seen. Keeps the minimum amount of data in
316 // memory.
317 func readUntil(r io.Reader, b []byte) error {
318         b1 := make([]byte, len(b))
319         i := 0
320         for {
321                 _, err := io.ReadFull(r, b1[i:])
322                 if err != nil {
323                         return err
324                 }
325                 i = suffixMatchLen(b1, b)
326                 if i == len(b) {
327                         break
328                 }
329                 if copy(b1, b1[len(b1)-i:]) != i {
330                         panic("wat")
331                 }
332         }
333         return nil
334 }
335
336 type readWriter struct {
337         io.Reader
338         io.Writer
339 }
340
341 func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
342         return newEncrypt(initer, h.s[:], h.skey)
343 }
344
345 func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
346         h.postWrite(hash(req1, h.s[:]))
347         h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
348         buf := &bytes.Buffer{}
349         padLen := uint16(newPadLen())
350         err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
351         if err != nil {
352                 return
353         }
354         e := h.newEncrypt(true)
355         be := make([]byte, buf.Len())
356         e.XORKeyStream(be, buf.Bytes())
357         h.postWrite(be)
358         bC := h.newEncrypt(false)
359         var eVC [8]byte
360         bC.XORKeyStream(eVC[:], vc[:])
361         // Read until the all zero VC. At this point we've only read the 96 byte
362         // public key, Y. There is potentially 512 byte padding, between us and
363         // the 8 byte verification constant.
364         err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
365         if err != nil {
366                 if err == io.EOF {
367                         err = errors.New("failed to synchronize on VC")
368                 } else {
369                         err = fmt.Errorf("error reading until VC: %s", err)
370                 }
371                 return
372         }
373         r := &cipherReader{bC, h.conn}
374         var method uint32
375         err = unmarshal(r, &method, &padLen)
376         if err != nil {
377                 return
378         }
379         if method != cryptoMethodRC4 {
380                 err = fmt.Errorf("receiver chose unsupported method: %x", method)
381                 return
382         }
383         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
384         if err != nil {
385                 return
386         }
387         ret = readWriter{r, &cipherWriter{e, h.conn}}
388         return
389 }
390
391 var ErrNoSecretKeyMatch = errors.New("no skey matched")
392
393 func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
394         // There is up to 512 bytes of padding, then the 20 byte hash.
395         err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
396         if err != nil {
397                 if err == io.EOF {
398                         err = errors.New("failed to synchronize on S hash")
399                 }
400                 return
401         }
402         var b [20]byte
403         _, err = io.ReadFull(h.conn, b[:])
404         if err != nil {
405                 return
406         }
407         err = ErrNoSecretKeyMatch
408         for _, skey := range h.skeys {
409                 if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) {
410                         h.skey = skey
411                         err = nil
412                         break
413                 }
414         }
415         if err != nil {
416                 return
417         }
418         r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
419         var (
420                 vc     [8]byte
421                 method uint32
422                 padLen uint16
423         )
424
425         err = unmarshal(r, vc[:], &method, &padLen)
426         if err != nil {
427                 return
428         }
429         cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
430         if method&cryptoMethodRC4 == 0 {
431                 err = errors.New("no supported crypto methods were provided")
432                 return
433         }
434         _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
435         if err != nil {
436                 return
437         }
438         var lenIA uint16
439         unmarshal(r, &lenIA)
440         if lenIA != 0 {
441                 h.ia = make([]byte, lenIA)
442                 unmarshal(r, h.ia)
443         }
444         buf := &bytes.Buffer{}
445         w := cipherWriter{h.newEncrypt(false), buf}
446         padLen = uint16(newPadLen())
447         err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
448         if err != nil {
449                 return
450         }
451         err = h.postWrite(buf.Bytes())
452         if err != nil {
453                 return
454         }
455         ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
456         return
457 }
458
459 func (h *handshake) Do() (ret io.ReadWriter, err error) {
460         h.writeCond.L = &h.writeMu
461         h.writerCond.L = &h.writerMu
462         go h.writer()
463         defer func() {
464                 h.finishWriting()
465                 if err == nil {
466                         err = h.writeErr
467                 }
468         }()
469         err = h.establishS()
470         if err != nil {
471                 err = fmt.Errorf("error while establishing secret: %s", err)
472                 return
473         }
474         pad := make([]byte, newPadLen())
475         io.ReadFull(rand.Reader, pad)
476         err = h.postWrite(pad)
477         if err != nil {
478                 return
479         }
480         if h.initer {
481                 ret, err = h.initerSteps()
482         } else {
483                 ret, err = h.receiverSteps()
484         }
485         return
486 }
487
488 func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
489         h := handshake{
490                 conn:   rw,
491                 initer: true,
492                 skey:   skey,
493                 ia:     initialPayload,
494         }
495         return h.Do()
496 }
497 func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
498         h := handshake{
499                 conn:   rw,
500                 initer: false,
501                 skeys:  skeys,
502         }
503         return h.Do()
504 }