From 4a4cb5dc58006baa6bac05d2f0b09fadd4876de1 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Wed, 14 Oct 2020 15:11:45 +1100
Subject: [PATCH] Ensure PeerConn._close is called for incoming connections

This fixes missing calls to PeerConnClosed callback.
---
 client.go   | 2 ++
 peerconn.go | 8 +++++---
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/client.go b/client.go
index 98563227..25f7a860 100644
--- a/client.go
+++ b/client.go
@@ -509,6 +509,7 @@ func (cl *Client) incomingConnection(nc net.Conn) {
 	}
 	c := cl.newConnection(nc, false, nc.RemoteAddr(), nc.RemoteAddr().Network(),
 		regularNetConnPeerConnConnString(nc))
+	defer c.close()
 	c.Discovery = PeerSourceIncoming
 	cl.runReceivedConn(c)
 }
@@ -1348,6 +1349,7 @@ func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr net.Addr,
 		connString:  connString,
 		conn:        nc,
 		writeBuffer: new(bytes.Buffer),
+		callbacks:   &cl.config.Callbacks,
 	}
 	c.peerImpl = c
 	c.logger = cl.logger.WithDefaultLevel(log.Warning).WithContextValue(c)
diff --git a/peerconn.go b/peerconn.go
index 8ff84102..c9a050a5 100644
--- a/peerconn.go
+++ b/peerconn.go
@@ -136,6 +136,8 @@ type PeerConn struct {
 	writerCond  sync.Cond
 
 	pex pexConnState
+
+	callbacks *Callbacks
 }
 
 func (cn *PeerConn) connStatusString() string {
@@ -356,7 +358,7 @@ func (cn *PeerConn) _close() {
 	if cn.conn != nil {
 		cn.conn.Close()
 	}
-	if cb := cn.t.cl.config.Callbacks.PeerConnClosed; cb != nil {
+	if cb := cn.callbacks.PeerConnClosed; cb != nil {
 		cb(cn)
 	}
 }
@@ -1072,7 +1074,7 @@ func (c *PeerConn) mainReadLoop() (err error) {
 			defer cl.lock()
 			err = decoder.Decode(&msg)
 		}()
-		if cb := cl.config.Callbacks.ReadMessage; cb != nil && err == nil {
+		if cb := c.callbacks.ReadMessage; cb != nil && err == nil {
 			cb(c, &msg)
 		}
 		if t.closed.IsSet() || c.closed.IsSet() {
@@ -1209,7 +1211,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
 			c.logger.Printf("error parsing extended handshake message %q: %s", payload, err)
 			return errors.Wrap(err, "unmarshalling extended handshake payload")
 		}
-		if cb := cl.config.Callbacks.ReadExtendedHandshake; cb != nil {
+		if cb := c.callbacks.ReadExtendedHandshake; cb != nil {
 			cb(c, &d)
 		}
 		//c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d))
-- 
2.51.0