From: Matt Joiner Date: Thu, 12 Mar 2015 09:03:29 +0000 (+1100) Subject: mse: Got basic message stream encryption working with some tests X-Git-Tag: v1.0.0~1286 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=954e03952af87533df3f5050fb24ecefadca8de3;p=btrtrc.git mse: Got basic message stream encryption working with some tests Not complete yet. --- diff --git a/mse/mse.go b/mse/mse.go new file mode 100644 index 00000000..1d53ad21 --- /dev/null +++ b/mse/mse.go @@ -0,0 +1,444 @@ +// https://wiki.vuze.com/w/Message_Stream_Encryption + +package mse + +import ( + "bytes" + "crypto/rand" + "crypto/rc4" + "crypto/sha1" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math/big" + "sync" + + "github.com/bradfitz/iter" +) + +const ( + maxPadLen = 512 + + cryptoMethodPlaintext = 1 + cryptoMethodRC4 = 2 +) + +var ( + // Prime P according to the spec, and G, the generator. + p, g big.Int + // The rand.Int max arg for use in newPadLen() + newPadLenMax big.Int + // For use in initer's hashes + req1 = []byte("req1") + req2 = []byte("req2") + req3 = []byte("req3") +) + +func init() { + p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0) + g.SetInt64(2) + newPadLenMax.SetInt64(maxPadLen + 1) +} + +func hash(parts ...[]byte) []byte { + h := sha1.New() + for _, p := range parts { + n, err := h.Write(p) + if err != nil { + panic(err) + } + if n != len(p) { + panic(n) + } + } + return h.Sum(nil) +} + +func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) { + c, err := rc4.NewCipher(hash([]byte(func() string { + if initer { + return "keyA" + } else { + return "keyB" + } + }()), s, skey)) + if err != nil { + panic(err) + } + var burnSrc, burnDst [1024]byte + c.XORKeyStream(burnDst[:], burnSrc[:]) + return +} + +type cipherReader struct { + c *rc4.Cipher + r io.Reader +} + +func (me *cipherReader) Read(b []byte) (n int, err error) { + be := make([]byte, len(b)) + n, err = me.r.Read(be) + me.c.XORKeyStream(b[:n], be[:n]) + return +} + +func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader { + return &cipherReader{c, r} +} + +type cipherWriter struct { + c *rc4.Cipher + w io.Writer +} + +func (me *cipherWriter) Write(b []byte) (n int, err error) { + be := make([]byte, len(b)) + me.c.XORKeyStream(be, b) + n, err = me.w.Write(be) + if n != len(be) { + // The cipher will have advanced beyond the callers stream position. + // We can't use the cipher anymore. + me.c = nil + } + return +} + +func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer { + return &cipherWriter{c, w} +} + +func readY(r io.Reader) (y big.Int, err error) { + var b [96]byte + _, err = io.ReadFull(r, b[:]) + if err != nil { + return + } + y.SetBytes(b[:]) + return +} + +func newX() big.Int { + var X big.Int + X.SetBytes(func() []byte { + var b [20]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return b[:] + }()) + return X +} + +func (h *handshake) postY(x *big.Int) error { + var y big.Int + y.Exp(&g, x, &p) + b := y.Bytes() + if len(b) != 96 { + panic(len(b)) + } + return h.postWrite(b) +} + +func (h *handshake) establishS() (err error) { + x := newX() + h.postY(&x) + var b [96]byte + _, err = io.ReadFull(h.conn, b[:]) + if err != nil { + return + } + var Y big.Int + Y.SetBytes(b[:]) + h.s.Exp(&Y, &x, &p) + return +} + +func newPadLen() int64 { + i, err := rand.Int(rand.Reader, &newPadLenMax) + if err != nil { + panic(err) + } + ret := i.Int64() + if ret < 0 || ret > maxPadLen { + panic(ret) + } + return ret +} + +type handshake struct { + conn io.ReadWriteCloser + s big.Int + initer bool + skey []byte + + writeMu sync.Mutex + writes [][]byte + writeErr error + writeCond sync.Cond + writeClose bool + + writerMu sync.Mutex + writerCond sync.Cond + writerDone bool +} + +func (h *handshake) finishWriting() (err error) { + h.writeMu.Lock() + h.writeClose = true + h.writeCond.Broadcast() + err = h.writeErr + h.writeMu.Unlock() + + h.writerMu.Lock() + for !h.writerDone { + h.writerCond.Wait() + } + h.writerMu.Unlock() + return + +} + +func (h *handshake) writer() { + defer func() { + h.writerMu.Lock() + h.writerDone = true + h.writerCond.Broadcast() + h.writerMu.Unlock() + }() + for { + h.writeMu.Lock() + for { + if len(h.writes) != 0 { + break + } + if h.writeClose { + h.writeMu.Unlock() + return + } + h.writeCond.Wait() + } + b := h.writes[0] + h.writes = h.writes[1:] + h.writeMu.Unlock() + _, err := h.conn.Write(b) + if err != nil { + h.writeMu.Lock() + h.writeErr = err + h.writeMu.Unlock() + return + } + } +} + +func (h *handshake) postWrite(b []byte) error { + h.writeMu.Lock() + defer h.writeMu.Unlock() + if h.writeErr != nil { + return h.writeErr + } + h.writes = append(h.writes, b) + h.writeCond.Signal() + return nil +} + +func xor(dst, src []byte) (ret []byte) { + max := len(dst) + if max > len(src) { + max = len(src) + } + ret = make([]byte, 0, max) + for i := range iter.N(max) { + ret = append(ret, dst[i]^src[i]) + } + return +} + +type cryptoNegotiation struct { + VC [8]byte + Method uint32 + PadLen uint16 + IA []byte +} + +func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) { + _, err = io.ReadFull(r, me.VC[:]) + if err != nil { + return + } + err = binary.Read(r, binary.BigEndian, &me.Method) + if err != nil { + return + } + err = binary.Read(r, binary.BigEndian, &me.PadLen) + if err != nil { + return + } + log.Print(me.PadLen) + _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen)) + return +} + +func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) { + _, err = w.Write(me.VC[:]) + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, me.Method) + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, me.PadLen) + if err != nil { + return + } + _, err = w.Write(make([]byte, me.PadLen)) + return +} + +// Looking for b at the end of a. +func suffixMatchLen(a, b []byte) int { + if len(b) > len(a) { + b = b[:len(a)] + } + // i is how much of b to try to match + for i := len(b); i > 0; i-- { + // j is how many chars we've compared + j := 0 + for ; j < i; j++ { + if b[i-1-j] != a[len(a)-1-j] { + goto shorter + } + } + return j + shorter: + } + return 0 +} + +func readUntil(r io.Reader, b []byte) error { + log.Println("read until", b) + b1 := make([]byte, len(b)) + i := 0 + for { + _, err := io.ReadFull(r, b1[i:]) + if err != nil { + return err + } + i = suffixMatchLen(b1, b) + if i == len(b) { + break + } + if copy(b1, b1[len(b1)-i:]) != i { + panic("wat") + } + } + return nil +} + +func (h *handshake) Do() (ret io.ReadWriteCloser, err error) { + err = h.establishS() + if err != nil { + return + } + pad := make([]byte, newPadLen()) + io.ReadFull(rand.Reader, pad) + err = h.postWrite(pad) + if err != nil { + return + } + if h.initer { + h.postWrite(hash(req1, h.s.Bytes())) + h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes()))) + buf := &bytes.Buffer{} + err = (&cryptoNegotiation{ + Method: cryptoMethodRC4, + PadLen: uint16(newPadLen()), + }).MarshalWriter(buf) + if err != nil { + return + } + e := newEncrypt(true, h.s.Bytes(), h.skey) + be := make([]byte, buf.Len()) + e.XORKeyStream(be, buf.Bytes()) + h.postWrite(be) + bC := newEncrypt(false, h.s.Bytes(), h.skey) + var eVC [8]byte + bC.XORKeyStream(eVC[:], make([]byte, 8)) + log.Print(eVC) + // Read until the all zero VC. + err = readUntil(h.conn, eVC[:]) + if err != nil { + err = fmt.Errorf("error reading until VC: %s", err) + return + } + var cn cryptoNegotiation + r := &cipherReader{bC, h.conn} + err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r)) + log.Printf("initer got %v", cn) + if err != nil { + err = fmt.Errorf("error reading crypto negotiation: %s", err) + return + } + } else { + err = readUntil(h.conn, hash(req1, h.s.Bytes())) + if err != nil { + return + } + var b [20]byte + _, err = io.ReadFull(h.conn, b[:]) + if err != nil { + return + } + if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) { + err = errors.New("skey doesn't match") + return + } + var cn cryptoNegotiation + r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn) + err = cn.UnmarshalReader(r) + if err != nil { + return + } + log.Printf("receiver got %v", cn) + if cn.Method&cryptoMethodRC4 == 0 { + err = errors.New("no supported crypto methods were provided") + return + } + buf := &bytes.Buffer{} + w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf) + err = (&cryptoNegotiation{ + Method: cryptoMethodRC4, + PadLen: uint16(newPadLen()), + }).MarshalWriter(w) + if err != nil { + return + } + log.Println("encrypted VC", buf.Bytes()[:8]) + err = h.postWrite(buf.Bytes()) + if err != nil { + return + } + } + err = h.finishWriting() + if err != nil { + return + } + ret = h.conn + return +} + +func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriteCloser, err error) { + h := handshake{ + conn: rw, + initer: initer, + skey: skey, + } + h.writeCond.L = &h.writeMu + h.writerCond.L = &h.writerMu + go h.writer() + return h.Do() +} diff --git a/mse/mse_test.go b/mse/mse_test.go new file mode 100644 index 00000000..6c20fdf7 --- /dev/null +++ b/mse/mse_test.go @@ -0,0 +1,68 @@ +package mse + +import ( + "bytes" + "io" + "net" + "sync" + + "testing" +) + +func TestReadUntil(t *testing.T) { + test := func(data, until string, leftover int, expectedErr error) { + r := bytes.NewReader([]byte(data)) + err := readUntil(r, []byte(until)) + if err != expectedErr { + t.Fatal(err) + } + if r.Len() != leftover { + t.Fatal(r.Len()) + } + } + test("feakjfeafeafegbaabc00", "abc", 2, nil) + test("feakjfeafeafegbaadc00", "abc", 0, io.EOF) +} + +func TestSuffixMatchLen(t *testing.T) { + test := func(a, b string, expected int) { + actual := suffixMatchLen([]byte(a), []byte(b)) + if actual != expected { + t.Fatalf("expected %d, got %d for %q and %q", expected, actual, a, b) + } + } + test("hello", "world", 0) + test("hello", "lo", 2) + test("hello", "llo", 3) + test("hello", "hell", 0) + test("hello", "helloooo!", 5) + test("hello", "lol!", 2) + test("hello", "mondo", 0) + test("mongo", "webscale", 0) + test("sup", "person", 1) +} + +func TestHandshake(t *testing.T) { + a, b := net.Pipe() + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + a, err := Handshake(a, true, []byte("yep")) + if err != nil { + t.Fatal(err) + return + } + a.Close() + }() + go func() { + defer wg.Done() + b, err := Handshake(b, false, []byte("yep")) + if err != nil { + t.Fatal(err) + return + } + b.Close() + }() + wg.Wait() +}