From 71b971834723ccb440c65f96cb2e40fa2542339c Mon Sep 17 00:00:00 2001
From: Yaroslav Kolomiiets <yarikos@gmail.com>
Date: Thu, 12 Nov 2020 21:24:33 +0000
Subject: [PATCH] optimise PEX by avoiding intermediate storage while preparing
 PEX messages

---
 pex.go | 124 +++++++++++++++++++++++++++++++++++++--------------------
 1 file changed, 81 insertions(+), 43 deletions(-)

diff --git a/pex.go b/pex.go
index 7d2ba208..4771e700 100644
--- a/pex.go
+++ b/pex.go
@@ -31,8 +31,9 @@ type pexEvent struct {
 
 // facilitates efficient de-duplication while generating PEX messages
 type pexMsgFactory struct {
-	added   map[addrKey]pexEvent
-	dropped map[addrKey]pexEvent
+	msg     pp.PexMsg
+	added   map[addrKey]struct{}
+	dropped map[addrKey]struct{}
 }
 
 func (me *pexMsgFactory) DeltaLen() int {
@@ -48,32 +49,99 @@ func (me *pexMsgFactory) addrKey(addr net.Addr) addrKey {
 	return addrKey(addr.String())
 }
 
+func addrEqual(a, b *krpc.NodeAddr) bool {
+	return a.IP.Equal(b.IP) && a.Port == b.Port
+}
+
+func addrIndex(v []krpc.NodeAddr, a *krpc.NodeAddr) int {
+	for i := range v {
+		if addrEqual(&v[i], a) {
+			return i
+		}
+	}
+	return -1
+}
+
 // Returns whether the entry was added (we can check if we're cancelling out another entry and so
 // won't hit the limit consuming this event).
 func (me *pexMsgFactory) add(e pexEvent) {
 	key := me.addrKey(e.addr)
-	if _, ok := me.dropped[key]; ok {
-		delete(me.dropped, key)
+	if _, ok := me.added[key]; ok {
 		return
 	}
 	if me.added == nil {
-		me.added = make(map[addrKey]pexEvent, pexMaxDelta)
+		me.added = make(map[addrKey]struct{}, pexMaxDelta)
 	}
-	me.added[key] = e
+	addr, ok := nodeAddr(e.addr)
+	if !ok {
+		return
+	}
+	m := &me.msg
+	switch {
+	case addr.IP.To4() != nil:
+		if _, ok := me.dropped[key]; ok {
+			if i := addrIndex(m.Dropped.NodeAddrs(), &addr); i >= 0 {
+				m.Dropped = append(m.Dropped[:i], m.Dropped[i+1:]...)
+			}
+			delete(me.dropped, key)
+			return
+		}
+		m.Added = append(m.Added, addr)
+		m.AddedFlags = append(m.AddedFlags, e.f)
+	case len(addr.IP) == net.IPv6len:
+		if _, ok := me.dropped[key]; ok {
+			if i := addrIndex(m.Dropped6.NodeAddrs(), &addr); i >= 0 {
+				m.Dropped6 = append(m.Dropped6[:i], m.Dropped6[i+1:]...)
+			}
+			delete(me.dropped, key)
+			return
+		}
+		m.Added6 = append(m.Added6, addr)
+		m.Added6Flags = append(m.Added6Flags, e.f)
+	default:
+		panic(addr)
+	}
+	me.added[key] = struct{}{}
 }
 
 // Returns whether the entry was added (we can check if we're cancelling out another entry and so
 // won't hit the limit consuming this event).
 func (me *pexMsgFactory) drop(e pexEvent) {
-	key := me.addrKey(e.addr)
-	if _, ok := me.added[key]; ok {
-		delete(me.added, key)
+	addr, ok := nodeAddr(e.addr)
+	if !ok {
 		return
 	}
+	key := me.addrKey(e.addr)
 	if me.dropped == nil {
-		me.dropped = make(map[addrKey]pexEvent, pexMaxDelta)
+		me.dropped = make(map[addrKey]struct{}, pexMaxDelta)
+	}
+	if _, ok := me.dropped[key]; ok {
+		return
 	}
-	me.dropped[key] = e
+	m := &me.msg
+	switch {
+	case addr.IP.To4() != nil:
+		if _, ok := me.added[key]; ok {
+			if i := addrIndex(m.Added.NodeAddrs(), &addr); i >= 0 {
+				m.Added = append(m.Added[:i], m.Added[i+1:]...)
+				m.AddedFlags = append(m.AddedFlags[:i], m.AddedFlags[i+1:]...)
+			}
+			delete(me.added, key)
+			return
+		}
+		m.Dropped = append(m.Dropped, addr)
+	case len(addr.IP) == net.IPv6len:
+		if _, ok := me.added[key]; ok {
+			if i := addrIndex(m.Added6.NodeAddrs(), &addr); i >= 0 {
+				m.Added6 = append(m.Added6[:i], m.Added6[i+1:]...)
+				m.Added6Flags = append(m.Added6Flags[:i], m.Added6Flags[i+1:]...)
+			}
+			delete(me.added, key)
+			return
+		}
+		m.Dropped6 = append(m.Dropped6, addr)
+	}
+	me.dropped[key] = struct{}{}
 }
 
 func (me *pexMsgFactory) addEvent(event pexEvent) {
@@ -87,38 +155,8 @@ func (me *pexMsgFactory) addEvent(event pexEvent) {
 	}
 }
 
-func (me *pexMsgFactory) PexMsg() (ret pp.PexMsg) {
-	for key, added := range me.added {
-		addr, ok := nodeAddr(added.addr)
-		if !ok {
-			continue
-		}
-		switch len(addr.IP) {
-		case net.IPv4len:
-			ret.Added = append(ret.Added, addr)
-			ret.AddedFlags = append(ret.AddedFlags, added.f)
-		case net.IPv6len:
-			ret.Added6 = append(ret.Added6, addr)
-			ret.Added6Flags = append(ret.Added6Flags, added.f)
-		default:
-			panic(key)
-		}
-	}
-	for key, dropped := range me.dropped {
-		addr, ok := nodeAddr(dropped.addr)
-		if !ok {
-			continue
-		}
-		switch len(addr.IP) {
-		case net.IPv4len:
-			ret.Dropped = append(ret.Dropped, addr)
-		case net.IPv6len:
-			ret.Dropped6 = append(ret.Dropped6, addr)
-		default:
-			panic(key)
-		}
-	}
-	return
+func (me *pexMsgFactory) PexMsg() pp.PexMsg {
+	return me.msg
 }
 
 // Convert an arbitrary torrent peer Addr into one that can be represented by the compact addr
-- 
2.51.0