"crypto/sha1"
"encoding/binary"
"errors"
+ "expvar"
"fmt"
"io"
"io/ioutil"
- "log"
"math/big"
+ "strconv"
"sync"
"github.com/bradfitz/iter"
req1 = []byte("req1")
req2 = []byte("req2")
req3 = []byte("req3")
+
+ cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
)
func init() {
initer bool
skeys [][]byte
skey []byte
+ ia []byte // Initial payload. Only used by the initiator.
writeMu sync.Mutex
writes [][]byte
func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
err = binary.Read(r, binary.BigEndian, me.VC[:])
- // _, err = io.ReadFull(r, me.VC[:])
if err != nil {
return
}
if err != nil {
return
}
- log.Print(me.PadLen)
_, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen))
return
}
}
func readUntil(r io.Reader, b []byte) error {
- log.Println("read until", b)
b1 := make([]byte, len(b))
i := 0
for {
if err != nil {
return
}
- err = marshal(buf, uint16(0))
+ err = marshal(buf, uint16(len(h.ia)), h.ia)
if err != nil {
return
}
bC := newEncrypt(false, h.s.Bytes(), h.skey)
var eVC [8]byte
bC.XORKeyStream(eVC[:], make([]byte, 8))
- log.Print(eVC)
// Read until the all zero VC.
err = readUntil(h.conn, eVC[:])
if err != nil {
var cn cryptoNegotiation
r := &cipherReader{bC, h.conn}
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
- log.Printf("initer got %v", cn)
if err != nil {
err = fmt.Errorf("error reading crypto negotiation: %s", err)
return
if err != nil {
return
}
- log.Printf("receiver got %v", cn)
+ cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1)
if cn.Method&cryptoMethodRC4 == 0 {
err = errors.New("no supported crypto methods were provided")
return
}
- unmarshal(r, new(uint16))
+ var lenIA uint16
+ unmarshal(r, &lenIA)
+ if lenIA != 0 {
+ h.ia = make([]byte, lenIA)
+ unmarshal(r, h.ia)
+ }
buf := &bytes.Buffer{}
w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
err = (&cryptoNegotiation{
if err != nil {
return
}
- ret = readWriter{r, &cipherWriter{w.c, h.conn}}
+ ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}}
return
}
if err != nil {
return
}
- log.Print("ermahgerd, finished MSE handshake")
return
}
-func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: true,
skey: skey,
+ ia: initialPayload,
}
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
import (
"bytes"
"io"
- "log"
"net"
"sync"
+ "github.com/bradfitz/iter"
+
"testing"
)
test("sup", "person", 1)
}
-func TestHandshake(t *testing.T) {
+func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
a, b := net.Pipe()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
- a, err := InitiateHandshake(a, []byte("yep"))
+ a, err := InitiateHandshake(a, []byte("yep"), ia)
if err != nil {
t.Fatal(err)
return
}
- a.Write([]byte("hello world"))
+ go a.Write([]byte(aData))
+
var msg [20]byte
n, _ := a.Read(msg[:])
- log.Print(string(msg[:n]))
+ if n != len(bData) {
+ t.FailNow()
+ }
+ // t.Log(string(msg[:n]))
}()
go func() {
defer wg.Done()
t.Fatal(err)
return
}
- var msg [20]byte
- n, _ := b.Read(msg[:])
- log.Print(string(msg[:n]))
- b.Write([]byte("yo dawg"))
+ go b.Write([]byte(bData))
+ // Need to be exact here, as there are several reads, and net.Pipe is
+ // most synchronous.
+ msg := make([]byte, len(ia)+len(aData))
+ n, _ := io.ReadFull(b, msg[:])
+ if n != len(msg) {
+ t.FailNow()
+ }
+ // t.Log(string(msg[:n]))
}()
wg.Wait()
+ a.Close()
+ b.Close()
+}
+
+func allHandshakeTests(t testing.TB) {
+ handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
+ handshakeTest(t, nil, "hello world", "yo dawg")
+ handshakeTest(t, []byte{}, "hello world", "yo dawg")
+}
+
+func TestHandshake(t *testing.T) {
+ allHandshakeTests(t)
+ t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
+}
+
+func BenchmarkHandshake(b *testing.B) {
+ for range iter.N(b.N) {
+ allHandshakeTests(b)
+ }
}