client.go | 28 +++++++++++++++------------- connection.go | 31 +++++++++++++++++++++++++------ connection_test.go | 18 ++++-------------- diff --git a/client.go b/client.go index 712e557d6740d3f08b4013404ebbad3e42a76798..12f1997326880e09c9bb78d0cbf08fd4ca6e5626 100644 --- a/client.go +++ b/client.go @@ -809,18 +809,16 @@ // } return } -type readWriter struct { - io.Reader - io.Writer -} - func maybeReceiveEncryptedHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, encrypted bool, err error) { var protocol [len(pp.Protocol)]byte _, err = io.ReadFull(rw, protocol[:]) if err != nil { return } - ret = readWriter{ + ret = struct { + io.Reader + io.Writer + }{ io.MultiReader(bytes.NewReader(protocol[:]), rw), rw, } @@ -841,7 +839,12 @@ } func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) { if c.encrypted { - c.rw, err = mse.InitiateHandshake(c.rw, t.infoHash[:], nil) + var rw io.ReadWriter + rw, err = mse.InitiateHandshake(struct { + io.Reader + io.Writer + }{c.r, c.w}, t.infoHash[:], nil) + c.setRW(rw) if err != nil { return } @@ -859,7 +862,9 @@ cl.mu.Lock() skeys := cl.receiveSkeys() cl.mu.Unlock() if !cl.config.DisableEncryption { - c.rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw, skeys) + var rw io.ReadWriter + rw, c.encrypted, err = maybeReceiveEncryptedHandshake(c.rw(), skeys) + c.setRW(rw) if err != nil { if err == mse.ErrNoSecretKeyMatch { err = nil @@ -887,7 +892,7 @@ } // Returns !ok if handshake failed for valid reasons. func (cl *Client) connBTHandshake(c *connection, ih *metainfo.Hash) (ret metainfo.Hash, ok bool, err error) { - res, ok, err := handshake(c.rw, ih, cl.peerID, cl.extensionBytes) + res, ok, err := handshake(c.rw(), ih, cl.peerID, cl.extensionBytes) if err != nil || !ok { return } @@ -937,10 +942,7 @@ } func (cl *Client) runHandshookConn(c *connection, t *Torrent) { c.conn.SetWriteDeadline(time.Time{}) - c.rw = readWriter{ - deadlineReader{c.conn, c.rw}, - c.rw, - } + c.r = deadlineReader{c.conn, c.r} completedHandshakeConnectionFlags.Add(c.connectionFlags(), 1) if !t.addConnection(c) { return diff --git a/connection.go b/connection.go index 742b90af795d13de38fa314d6f8e5173e3c2ef3f..a0b20f44b4cc1188aa9352f399161558bfc5a6ee 100644 --- a/connection.go +++ b/connection.go @@ -38,9 +38,14 @@ ) // Maintains the state of a connection with a peer. type connection struct { - t *Torrent - conn net.Conn - rw io.ReadWriter // The real slim shady + t *Torrent + // The actual Conn, used for closing, and setting socket options. + conn net.Conn + // The Reader and Writer for this Conn, with hooks installed for stats, + // limiting, deadlines etc. + w io.Writer + r io.Reader + // True if the connection is operating over MSE obfuscation. encrypted bool Discovery peerSource uTP bool @@ -109,7 +114,7 @@ Choked: true, PeerChoked: true, PeerMaxRequests: 250, } - c.rw = connStatsReadWriter{nc, l, c} + c.setRW(connStatsReadWriter{nc, l, c}) return } @@ -407,7 +412,7 @@ defer cn.mu().Unlock() cn.Close() }() // Reduce write syscalls. - buf := bufio.NewWriter(cn.rw) + buf := bufio.NewWriter(cn.w) keepAliveTimer := time.NewTimer(keepAliveTimeout) for { cn.mu().Lock() @@ -700,7 +705,7 @@ t := c.t cl := t.cl decoder := pp.Decoder{ - R: bufio.NewReader(c.rw), + R: bufio.NewReader(c.r), MaxLength: 256 * 1024, Pool: t.chunkPool, } @@ -907,3 +912,17 @@ return err } } } + +// Set both the Reader and Writer for the connection from a single ReadWriter. +func (cn *connection) setRW(rw io.ReadWriter) { + cn.r = rw + cn.w = rw +} + +// Returns the Reader and Writer as a combined ReadWriter. +func (cn *connection) rw() io.ReadWriter { + return struct { + io.Reader + io.Writer + }{cn.r, cn.w} +} diff --git a/connection_test.go b/connection_test.go index 52f269b325a33d3d3c5a780c5f6f40283fdacb5e..31ddbe141822f8bb7407e83530e6f43e1c72ead9 100644 --- a/connection_test.go +++ b/connection_test.go @@ -29,12 +29,7 @@ var bm bitmap.Bitmap bm.Set(1, true) return bm }(), - rw: struct { - io.Reader - io.Writer - }{ - Writer: w, - }, + w: w, conn: new(net.TCPConn), // For the locks t: &Torrent{cl: &Client{}}, @@ -74,10 +69,8 @@ c := &connection{ t: &Torrent{ cl: &Client{}, }, - rw: struct { - io.Reader - io.Writer - }{r, w}, + r: r, + w: w, outgoingUnbufferedMessages: list.New(), } go c.writer(time.Minute) @@ -153,10 +146,7 @@ t.pendingPieces.Add(0) r, w := io.Pipe() cn := &connection{ t: t, - rw: struct { - io.Reader - io.Writer - }{r, nil}, + r: r, } mrlErr := make(chan error) cl.mu.Lock()