From: Matt Joiner Date: Wed, 14 Oct 2020 04:11:45 +0000 (+1100) Subject: Ensure PeerConn._close is called for incoming connections X-Git-Tag: v1.18.0~5 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=4a4cb5dc58006baa6bac05d2f0b09fadd4876de1;p=btrtrc.git Ensure PeerConn._close is called for incoming connections This fixes missing calls to PeerConnClosed callback. --- diff --git a/client.go b/client.go index 98563227..25f7a860 100644 --- a/client.go +++ b/client.go @@ -509,6 +509,7 @@ func (cl *Client) incomingConnection(nc net.Conn) { } c := cl.newConnection(nc, false, nc.RemoteAddr(), nc.RemoteAddr().Network(), regularNetConnPeerConnConnString(nc)) + defer c.close() c.Discovery = PeerSourceIncoming cl.runReceivedConn(c) } @@ -1348,6 +1349,7 @@ func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr net.Addr, connString: connString, conn: nc, writeBuffer: new(bytes.Buffer), + callbacks: &cl.config.Callbacks, } c.peerImpl = c c.logger = cl.logger.WithDefaultLevel(log.Warning).WithContextValue(c) diff --git a/peerconn.go b/peerconn.go index 8ff84102..c9a050a5 100644 --- a/peerconn.go +++ b/peerconn.go @@ -136,6 +136,8 @@ type PeerConn struct { writerCond sync.Cond pex pexConnState + + callbacks *Callbacks } func (cn *PeerConn) connStatusString() string { @@ -356,7 +358,7 @@ func (cn *PeerConn) _close() { if cn.conn != nil { cn.conn.Close() } - if cb := cn.t.cl.config.Callbacks.PeerConnClosed; cb != nil { + if cb := cn.callbacks.PeerConnClosed; cb != nil { cb(cn) } } @@ -1072,7 +1074,7 @@ func (c *PeerConn) mainReadLoop() (err error) { defer cl.lock() err = decoder.Decode(&msg) }() - if cb := cl.config.Callbacks.ReadMessage; cb != nil && err == nil { + if cb := c.callbacks.ReadMessage; cb != nil && err == nil { cb(c, &msg) } if t.closed.IsSet() || c.closed.IsSet() { @@ -1209,7 +1211,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err c.logger.Printf("error parsing extended handshake message %q: %s", payload, err) return errors.Wrap(err, "unmarshalling extended handshake payload") } - if cb := cl.config.Callbacks.ReadExtendedHandshake; cb != nil { + if cb := c.callbacks.ReadExtendedHandshake; cb != nil { cb(c, &d) } //c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d))