"fmt"
"io"
"io/ioutil"
+ "math"
"math/big"
"strconv"
"sync"
cryptoMethodPlaintext = 1
cryptoMethodRC4 = 2
+ AllSupportedCrypto = cryptoMethodPlaintext | cryptoMethodRC4
)
var (
skeys SecretKeyIter // Skeys we'll accept if receiving.
skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator.
+ // Return the bit for the crypto method the receiver wants to use.
+ chooseMethod func(supported uint32) uint32
+ // Sent to the receiver.
+ cryptoProvides uint32
writeMu sync.Mutex
writes [][]byte
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{}
padLen := uint16(newPadLen())
- err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
if len(h.ia) > math.MaxUint16 {
err = errors.New("initial payload too large")
return
}
+ err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
if err != nil {
return
}
if err != nil {
return
}
- if method != cryptoMethodRC4 {
- err = fmt.Errorf("receiver chose unsupported method: %x", method)
- return
- }
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil {
return
}
- ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+ switch method & h.cryptoProvides {
+ case cryptoMethodRC4:
+ ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
+ case cryptoMethodPlaintext:
+ ret = h.conn
+ default:
+ err = fmt.Errorf("receiver chose unsupported method: %x", method)
+ }
return
}
}
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
var (
- vc [8]byte
- method uint32
- padLen uint16
+ vc [8]byte
+ provides uint32
+ padLen uint16
)
- err = unmarshal(r, vc[:], &method, &padLen)
+ err = unmarshal(r, vc[:], &provides, &padLen)
if err != nil {
return
}
- cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
- if method&cryptoMethodRC4 == 0 {
- err = errors.New("no supported crypto methods were provided")
- return
- }
+ cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
+ chosen := h.chooseMethod(provides)
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil {
return
buf := &bytes.Buffer{}
w := cipherWriter{h.newEncrypt(false), buf, nil}
padLen = uint16(newPadLen())
- err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
+ err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
if err != nil {
return
}
if err != nil {
return
}
- ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}}
+ switch chosen {
+ case cryptoMethodRC4:
+ ret = readWriter{
+ io.MultiReader(bytes.NewReader(h.ia), r),
+ &cipherWriter{w.c, h.conn, nil},
+ }
+ case cryptoMethodPlaintext:
+ ret = readWriter{
+ io.MultiReader(bytes.NewReader(h.ia), h.conn),
+ h.conn,
+ }
+ default:
+ err = errors.New("chosen crypto method is not supported")
+ }
return
}
return
}
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
h := handshake{
- conn: rw,
- initer: true,
- skey: skey,
- ia: initialPayload,
+ conn: rw,
+ initer: true,
+ skey: skey,
+ ia: initialPayload,
+ cryptoProvides: cryptoProvides,
}
return h.Do()
}
-func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
+func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
h := handshake{
- conn: rw,
- initer: false,
- skeys: sliceIter(skeys),
+ conn: rw,
+ initer: false,
+ skeys: sliceIter(skeys),
+ chooseMethod: selectCrypto,
}
return h.Do()
}
}
return h.Do()
}
+
+func DefaultCryptoSelector(provided uint32) uint32 {
+ if provided&cryptoMethodRC4 != 0 {
+ return cryptoMethodRC4
+ }
+ return cryptoMethodPlaintext
+}
+
+type CryptoSelector func(uint32) uint32
"sync"
"testing"
+ _ "github.com/anacrolix/envpprof"
+
"github.com/bradfitz/iter"
"github.com/stretchr/testify/require"
)
test("sup", "person", 1)
}
-func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
+func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
a, b := net.Pipe()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
- a, err := InitiateHandshake(a, []byte("yep"), ia)
+ a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
if err != nil {
t.Fatal(err)
return
}()
go func() {
defer wg.Done()
- b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
+ b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}, cryptoSelect)
if err != nil {
t.Fatal(err)
return
b.Close()
}
-func allHandshakeTests(t testing.TB) {
- handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
- handshakeTest(t, nil, "hello world", "yo dawg")
- handshakeTest(t, []byte{}, "hello world", "yo dawg")
+func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
+ handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
+ handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
+ handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
}
-func TestHandshake(t *testing.T) {
- allHandshakeTests(t)
+func TestHandshakeDefault(t *testing.T) {
+ allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
}
-func BenchmarkHandshake(b *testing.B) {
+func TestHandshakeSelectPlaintext(t *testing.T) {
+ allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return cryptoMethodPlaintext })
+}
+
+func BenchmarkHandshakeDefault(b *testing.B) {
for range iter.N(b.N) {
- allHandshakeTests(b)
+ allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
}
}
func TestReceiveRandomData(t *testing.T) {
tr := trackReader{rand.Reader, 0}
- _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil)
+ _, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
// No skey matches
require.Error(t, err)
// Establishing S, and then reading the maximum padding for giving up on
require.EqualValues(t, 96+532, tr.n)
}
-func BenchmarkPipe(t *testing.B) {
+func fillRand(t testing.TB, bs ...[]byte) {
+ for _, b := range bs {
+ _, err := rand.Read(b)
+ require.NoError(t, err)
+ }
+}
+
+func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ var wErr error
+ go func() {
+ defer wg.Done()
+ _, wErr = rw.Write(w)
+ }()
+ _, err := io.ReadFull(rw, r)
+ if err != nil {
+ return err
+ }
+ wg.Wait()
+ return wErr
+}
+
+func benchmarkStream(t *testing.B, crypto uint32) {
+ ia := make([]byte, 0x1000)
+ a := make([]byte, 1<<20)
+ b := make([]byte, 1<<20)
+ fillRand(t, ia, a, b)
+ t.StopTimer()
+ t.SetBytes(int64(len(ia) + len(a) + len(b)))
+ t.ResetTimer()
+ for range iter.N(t.N) {
+ ac, bc := net.Pipe()
+ ar := make([]byte, len(b))
+ br := make([]byte, len(ia)+len(a))
+ t.StartTimer()
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer ac.Close()
+ defer wg.Done()
+ rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
+ require.NoError(t, err)
+ require.NoError(t, readAndWrite(rw, ar, a))
+ }()
+ func() {
+ defer bc.Close()
+ rw, err := ReceiveHandshake(bc, [][]byte{[]byte("cats")}, func(uint32) uint32 { return crypto })
+ require.NoError(t, err)
+ require.NoError(t, readAndWrite(rw, br, b))
+ }()
+ t.StopTimer()
+ if !bytes.Equal(ar, b) {
+ t.Fatalf("A read the wrong bytes")
+ }
+ if !bytes.Equal(br[:len(ia)], ia) {
+ t.Fatalf("B read the wrong IA")
+ }
+ if !bytes.Equal(br[len(ia):], a) {
+ t.Fatalf("B read the wrong A")
+ }
+ // require.Equal(t, b, ar)
+ // require.Equal(t, ia, br[:len(ia)])
+ // require.Equal(t, a, br[len(ia):])
+ }
+}
+
+func BenchmarkStreamRC4(t *testing.B) {
+ benchmarkStream(t, cryptoMethodRC4)
+}
+
+func BenchmarkStreamPlaintext(t *testing.B) {
+ benchmarkStream(t, cryptoMethodPlaintext)
+}
+
+func BenchmarkPipeRC4(t *testing.B) {
key := make([]byte, 20)
n, _ := rand.Read(key)
require.Equal(t, len(key), n)