// Performs initiator handshakes and returns a connection. Returns nil
// *connection if no connection for valid reasons.
-func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encrypted, utp bool) (c *connection, err error) {
+func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader, utp bool) (c *connection, err error) {
c = cl.newConnection(nc)
- c.encrypted = encrypted
+ c.headerEncrypted = encryptHeader
c.uTP = utp
ctx, cancel := context.WithTimeout(ctx, handshakesTimeout)
defer cancel()
if nc == nil {
return
}
- encryptFirst := !cl.config.DisableEncryption && !cl.config.PreferNoEncryption
- c, err = cl.handshakesConnection(ctx, nc, t, encryptFirst, utp)
+ obfuscatedHeaderFirst := !cl.config.DisableEncryption && !cl.config.PreferNoEncryption
+ c, err = cl.handshakesConnection(ctx, nc, t, obfuscatedHeaderFirst, utp)
if err != nil {
nc.Close()
return
return
}
nc.Close()
- if cl.config.DisableEncryption || cl.config.ForceEncryption {
- // There's no alternate encryption case to try.
+ if cl.config.ForceEncryption {
+ // We should have just tried with an obfuscated header. A plaintext
+ // header can't result in an encrypted connection, so we're done.
+ if !obfuscatedHeaderFirst {
+ panic(cl.config.EncryptionPolicy)
+ }
return
}
// Try again with encryption if we didn't earlier, or without if we did,
err = fmt.Errorf("error dialing for unencrypted connection: %s", err)
return
}
- c, err = cl.handshakesConnection(ctx, nc, t, !encryptFirst, utp)
+ c, err = cl.handshakesConnection(ctx, nc, t, !obfuscatedHeaderFirst, utp)
if err != nil || c == nil {
nc.Close()
}
return
}
-func maybeReceiveEncryptedHandshake(rw io.ReadWriter, skeys mse.SecretKeyIter) (ret io.ReadWriter, encrypted bool, err error) {
- var protocol [len(pp.Protocol)]byte
- _, err = io.ReadFull(rw, protocol[:])
- if err != nil {
- return
- }
- ret = struct {
- io.Reader
- io.Writer
- }{
- io.MultiReader(bytes.NewReader(protocol[:]), rw),
- rw,
- }
- if string(protocol[:]) == pp.Protocol {
- return
+func handleEncryption(
+ rw io.ReadWriter,
+ skeys mse.SecretKeyIter,
+ policy EncryptionPolicy,
+) (
+ ret io.ReadWriter,
+ headerEncrypted bool,
+ cryptoMethod uint32,
+ err error,
+) {
+ if !policy.ForceEncryption {
+ var protocol [len(pp.Protocol)]byte
+ _, err = io.ReadFull(rw, protocol[:])
+ if err != nil {
+ return
+ }
+ rw = struct {
+ io.Reader
+ io.Writer
+ }{
+ io.MultiReader(bytes.NewReader(protocol[:]), rw),
+ rw,
+ }
+ if string(protocol[:]) == pp.Protocol {
+ ret = rw
+ return
+ }
}
- encrypted = true
- ret, err = mse.ReceiveHandshakeLazy(ret, skeys)
+ headerEncrypted = true
+ ret, err = mse.ReceiveHandshake(rw, skeys, func(provides uint32) uint32 {
+ cryptoMethod = func() uint32 {
+ switch {
+ case policy.ForceEncryption:
+ return mse.CryptoMethodRC4
+ case policy.DisableEncryption:
+ return mse.CryptoMethodPlaintext
+ case policy.PreferNoEncryption && provides&mse.CryptoMethodPlaintext != 0:
+ return mse.CryptoMethodPlaintext
+ default:
+ return mse.DefaultCryptoSelector(provides)
+ }
+ }()
+ return cryptoMethod
+ })
return
}
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
- if c.encrypted {
+ if c.headerEncrypted {
var rw io.ReadWriter
- rw, err = mse.InitiateHandshake(struct {
- io.Reader
- io.Writer
- }{c.r, c.w}, t.infoHash[:], nil)
+ rw, err = mse.InitiateHandshake(
+ struct {
+ io.Reader
+ io.Writer
+ }{c.r, c.w},
+ t.infoHash[:],
+ nil,
+ func() uint32 {
+ switch {
+ case cl.config.ForceEncryption:
+ return mse.CryptoMethodRC4
+ case cl.config.DisableEncryption:
+ return mse.CryptoMethodPlaintext
+ default:
+ return mse.AllSupportedCrypto
+ }
+ }(),
+ )
c.setRW(rw)
if err != nil {
return
// Do encryption and bittorrent handshakes as receiver.
func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
- if !cl.config.DisableEncryption {
- var rw io.ReadWriter
- rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw(), cl.forSkeys)
- c.setRW(rw)
- if err != nil {
- if err == mse.ErrNoSecretKeyMatch {
- err = nil
- }
- return
+ var rw io.ReadWriter
+ rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.forSkeys, cl.config.EncryptionPolicy)
+ c.setRW(rw)
+ if err != nil {
+ if err == mse.ErrNoSecretKeyMatch {
+ err = nil
}
+ return
}
- if cl.config.ForceEncryption && !c.encrypted {
+ if cl.config.ForceEncryption && !c.headerEncrypted {
err = errors.New("connection not encrypted")
return
}
const (
maxPadLen = 512
- cryptoMethodPlaintext = 1
- cryptoMethodRC4 = 2
- AllSupportedCrypto = cryptoMethodPlaintext | cryptoMethodRC4
+ CryptoMethodPlaintext = 1
+ CryptoMethodRC4 = 2
+ AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4
)
var (
return
}
switch method & h.cryptoProvides {
- case cryptoMethodRC4:
+ case CryptoMethodRC4:
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
- case cryptoMethodPlaintext:
+ case CryptoMethodPlaintext:
ret = h.conn
default:
err = fmt.Errorf("receiver chose unsupported method: %x", method)
return
}
switch chosen {
- case cryptoMethodRC4:
+ case CryptoMethodRC4:
ret = readWriter{
io.MultiReader(bytes.NewReader(h.ia), r),
&cipherWriter{w.c, h.conn, nil},
}
- case cryptoMethodPlaintext:
+ case CryptoMethodPlaintext:
ret = readWriter{
io.MultiReader(bytes.NewReader(h.ia), h.conn),
h.conn,
return h.Do()
}
-func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: false,
- skeys: sliceIter(skeys),
+ skeys: skeys,
chooseMethod: selectCrypto,
}
return h.Do()
// returns false or exhausted.
type SecretKeyIter func(callback func(skey []byte) (more bool))
-// Doesn't unpack the secret keys until it needs to, and through the passed
-// function.
-func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWriter, err error) {
- h := handshake{
- conn: rw,
- initer: false,
- skeys: skeys,
- }
- return h.Do()
-}
-
func DefaultCryptoSelector(provided uint32) uint32 {
- if provided&cryptoMethodRC4 != 0 {
- return cryptoMethodRC4
+ if provided&CryptoMethodRC4 != 0 {
+ return CryptoMethodRC4
}
- return cryptoMethodPlaintext
+ return CryptoMethodPlaintext
}
type CryptoSelector func(uint32) uint32
}()
go func() {
defer wg.Done()
- b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}, cryptoSelect)
+ b, err := ReceiveHandshake(b, sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), cryptoSelect)
if err != nil {
t.Fatal(err)
return
}
func TestHandshakeSelectPlaintext(t *testing.T) {
- allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return cryptoMethodPlaintext })
+ allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return CryptoMethodPlaintext })
}
func BenchmarkHandshakeDefault(b *testing.B) {
}()
func() {
defer bc.Close()
- rw, err := ReceiveHandshake(bc, [][]byte{[]byte("cats")}, func(uint32) uint32 { return crypto })
+ rw, err := ReceiveHandshake(bc, sliceIter([][]byte{[]byte("cats")}), func(uint32) uint32 { return crypto })
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b))
}()
}
func BenchmarkStreamRC4(t *testing.B) {
- benchmarkStream(t, cryptoMethodRC4)
+ benchmarkStream(t, CryptoMethodRC4)
}
func BenchmarkStreamPlaintext(t *testing.B) {
- benchmarkStream(t, cryptoMethodPlaintext)
+ benchmarkStream(t, CryptoMethodPlaintext)
}
func BenchmarkPipeRC4(t *testing.B) {