]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Support initial payload, and improve tests
authorMatt Joiner <anacrolix@gmail.com>
Fri, 13 Mar 2015 03:30:48 +0000 (14:30 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 13 Mar 2015 03:30:48 +0000 (14:30 +1100)
mse/mse.go
mse/mse_test.go

index e9ea9a6772d4aad213c415f5e3b880288f460146..f675b40be6a2b262047bb2e8214fca415dcd7269 100644 (file)
@@ -9,11 +9,12 @@ import (
        "crypto/sha1"
        "encoding/binary"
        "errors"
+       "expvar"
        "fmt"
        "io"
        "io/ioutil"
-       "log"
        "math/big"
+       "strconv"
        "sync"
 
        "github.com/bradfitz/iter"
@@ -35,6 +36,8 @@ var (
        req1 = []byte("req1")
        req2 = []byte("req2")
        req3 = []byte("req3")
+
+       cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
 )
 
 func init() {
@@ -176,6 +179,7 @@ type handshake struct {
        initer bool
        skeys  [][]byte
        skey   []byte
+       ia     []byte // Initial payload. Only used by the initiator.
 
        writeMu    sync.Mutex
        writes     [][]byte
@@ -288,7 +292,6 @@ type cryptoNegotiation struct {
 
 func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
        err = binary.Read(r, binary.BigEndian, me.VC[:])
-       // _, err = io.ReadFull(r, me.VC[:])
        if err != nil {
                return
        }
@@ -300,7 +303,6 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
        if err != nil {
                return
        }
-       log.Print(me.PadLen)
        _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
        return
 }
@@ -344,7 +346,6 @@ func suffixMatchLen(a, b []byte) int {
 }
 
 func readUntil(r io.Reader, b []byte) error {
-       log.Println("read until", b)
        b1 := make([]byte, len(b))
        i := 0
        for {
@@ -379,7 +380,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       err = marshal(buf, uint16(0))
+       err = marshal(buf, uint16(len(h.ia)), h.ia)
        if err != nil {
                return
        }
@@ -390,7 +391,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
        bC := newEncrypt(false, h.s.Bytes(), h.skey)
        var eVC [8]byte
        bC.XORKeyStream(eVC[:], make([]byte, 8))
-       log.Print(eVC)
        // Read until the all zero VC.
        err = readUntil(h.conn, eVC[:])
        if err != nil {
@@ -400,7 +400,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
        var cn cryptoNegotiation
        r := &cipherReader{bC, h.conn}
        err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
-       log.Printf("initer got %v", cn)
        if err != nil {
                err = fmt.Errorf("error reading crypto negotiation: %s", err)
                return
@@ -436,12 +435,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       log.Printf("receiver got %v", cn)
+       cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1)
        if cn.Method&cryptoMethodRC4 == 0 {
                err = errors.New("no supported crypto methods were provided")
                return
        }
-       unmarshal(r, new(uint16))
+       var lenIA uint16
+       unmarshal(r, &lenIA)
+       if lenIA != 0 {
+               h.ia = make([]byte, lenIA)
+               unmarshal(r, h.ia)
+       }
        buf := &bytes.Buffer{}
        w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
        err = (&cryptoNegotiation{
@@ -455,7 +459,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       ret = readWriter{r, &cipherWriter{w.c, h.conn}}
+       ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
        return
 }
 
@@ -483,15 +487,15 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       log.Print("ermahgerd, finished MSE handshake")
        return
 }
 
-func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
        h := handshake{
                conn:   rw,
                initer: true,
                skey:   skey,
+               ia:     initialPayload,
        }
        h.writeCond.L = &h.writeMu
        h.writerCond.L = &h.writerMu
index 644a4cdfffab7efcc1859d7b6d5e3264eb550cf6..6b5f82c39fba997bdfb68c536c2c5194b5ec2649 100644 (file)
@@ -3,10 +3,11 @@ package mse
 import (
        "bytes"
        "io"
-       "log"
        "net"
        "sync"
 
+       "github.com/bradfitz/iter"
+
        "testing"
 )
 
@@ -43,21 +44,25 @@ func TestSuffixMatchLen(t *testing.T) {
        test("sup", "person", 1)
 }
 
-func TestHandshake(t *testing.T) {
+func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
        a, b := net.Pipe()
        wg := sync.WaitGroup{}
        wg.Add(2)
        go func() {
                defer wg.Done()
-               a, err := InitiateHandshake(a, []byte("yep"))
+               a, err := InitiateHandshake(a, []byte("yep"), ia)
                if err != nil {
                        t.Fatal(err)
                        return
                }
-               a.Write([]byte("hello world"))
+               go a.Write([]byte(aData))
+
                var msg [20]byte
                n, _ := a.Read(msg[:])
-               log.Print(string(msg[:n]))
+               if n != len(bData) {
+                       t.FailNow()
+               }
+               // t.Log(string(msg[:n]))
        }()
        go func() {
                defer wg.Done()
@@ -66,10 +71,34 @@ func TestHandshake(t *testing.T) {
                        t.Fatal(err)
                        return
                }
-               var msg [20]byte
-               n, _ := b.Read(msg[:])
-               log.Print(string(msg[:n]))
-               b.Write([]byte("yo dawg"))
+               go b.Write([]byte(bData))
+               // Need to be exact here, as there are several reads, and net.Pipe is
+               // most synchronous.
+               msg := make([]byte, len(ia)+len(aData))
+               n, _ := io.ReadFull(b, msg[:])
+               if n != len(msg) {
+                       t.FailNow()
+               }
+               // t.Log(string(msg[:n]))
        }()
        wg.Wait()
+       a.Close()
+       b.Close()
+}
+
+func allHandshakeTests(t testing.TB) {
+       handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
+       handshakeTest(t, nil, "hello world", "yo dawg")
+       handshakeTest(t, []byte{}, "hello world", "yo dawg")
+}
+
+func TestHandshake(t *testing.T) {
+       allHandshakeTests(t)
+       t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
+}
+
+func BenchmarkHandshake(b *testing.B) {
+       for range iter.N(b.N) {
+               allHandshakeTests(b)
+       }
 }