From 066cdd520b85393ba94fbd03ebb999ae89ff4d2a Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Fri, 16 Feb 2018 10:36:29 +1100 Subject: [PATCH] Add mse.CryptoMethod type --- client.go | 2 +- connection.go | 2 +- handshake.go | 4 ++-- mse/mse.go | 24 +++++++++++++----------- mse/mse_test.go | 10 +++++----- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 55f89902..6eb1e5fa 100644 --- a/client.go +++ b/client.go @@ -726,7 +726,7 @@ func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err er }{c.r, c.w}, t.infoHash[:], nil, - func() uint32 { + func() mse.CryptoMethod { switch { case cl.config.ForceEncryption: return mse.CryptoMethodRC4 diff --git a/connection.go b/connection.go index 67b32bce..a12d8e65 100644 --- a/connection.go +++ b/connection.go @@ -46,7 +46,7 @@ type connection struct { r io.Reader // True if the connection is operating over MSE obfuscation. headerEncrypted bool - cryptoMethod uint32 + cryptoMethod mse.CryptoMethod Discovery peerSource uTP bool closed missinggo.Event diff --git a/handshake.go b/handshake.go index 260c5420..b950a9b1 100644 --- a/handshake.go +++ b/handshake.go @@ -165,7 +165,7 @@ func handleEncryption( ) ( ret io.ReadWriter, headerEncrypted bool, - cryptoMethod uint32, + cryptoMethod mse.CryptoMethod, err error, ) { if !policy.ForceEncryption { @@ -187,7 +187,7 @@ func handleEncryption( } } headerEncrypted = true - ret, err = mse.ReceiveHandshake(rw, skeys, func(provides uint32) uint32 { + ret, err = mse.ReceiveHandshake(rw, skeys, func(provides mse.CryptoMethod) mse.CryptoMethod { switch { case policy.ForceEncryption: return mse.CryptoMethodRC4 diff --git a/mse/mse.go b/mse/mse.go index 898cd3d7..d843d8f9 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -24,11 +24,13 @@ import ( const ( maxPadLen = 512 - CryptoMethodPlaintext = 1 - CryptoMethodRC4 = 2 - AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4 + CryptoMethodPlaintext CryptoMethod = 1 + CryptoMethodRC4 CryptoMethod = 2 + AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4 ) +type CryptoMethod uint32 + var ( // Prime P according to the spec, and G, the generator. p, g big.Int @@ -212,9 +214,9 @@ type handshake struct { skey []byte // Skey we're initiating with. ia []byte // Initial payload. Only used by the initiator. // Return the bit for the crypto method the receiver wants to use. - chooseMethod func(supported uint32) uint32 + chooseMethod CryptoSelector // Sent to the receiver. - cryptoProvides uint32 + cryptoProvides CryptoMethod writeMu sync.Mutex writes [][]byte @@ -398,7 +400,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { return } r := newCipherReader(bC, h.conn) - var method uint32 + var method CryptoMethod err = unmarshal(r, &method, &padLen) if err != nil { return @@ -449,7 +451,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn) var ( vc [8]byte - provides uint32 + provides CryptoMethod padLen uint16 ) @@ -526,7 +528,7 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) { return } -func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) { +func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: true, @@ -537,7 +539,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry return h.Do() } -func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) { +func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: false, @@ -551,11 +553,11 @@ func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(u // returns false or exhausted. type SecretKeyIter func(callback func(skey []byte) (more bool)) -func DefaultCryptoSelector(provided uint32) uint32 { +func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod { if provided&CryptoMethodPlaintext != 0 { return CryptoMethodPlaintext } return CryptoMethodRC4 } -type CryptoSelector func(uint32) uint32 +type CryptoSelector func(CryptoMethod) CryptoMethod diff --git a/mse/mse_test.go b/mse/mse_test.go index b754e2b4..51913504 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -58,7 +58,7 @@ func TestSuffixMatchLen(t *testing.T) { test("sup", "person", 1) } -func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) { +func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) { a, b := net.Pipe() wg := sync.WaitGroup{} wg.Add(2) @@ -100,7 +100,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides b.Close() } -func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) { +func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) { handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector) handshakeTest(t, nil, "hello world", "yo dawg", provides, selector) handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector) @@ -112,7 +112,7 @@ func TestHandshakeDefault(t *testing.T) { } func TestHandshakeSelectPlaintext(t *testing.T) { - allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return CryptoMethodPlaintext }) + allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext }) } func BenchmarkHandshakeDefault(b *testing.B) { @@ -165,7 +165,7 @@ func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error { return wErr } -func benchmarkStream(t *testing.B, crypto uint32) { +func benchmarkStream(t *testing.B, crypto CryptoMethod) { ia := make([]byte, 0x1000) a := make([]byte, 1<<20) b := make([]byte, 1<<20) @@ -189,7 +189,7 @@ func benchmarkStream(t *testing.B, crypto uint32) { }() func() { defer bc.Close() - rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(uint32) uint32 { return crypto }) + rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(CryptoMethod) CryptoMethod { return crypto }) require.NoError(t, err) require.NoError(t, readAndWrite(rw, br, b)) }() -- 2.44.0