From a4049e179cdfbd596090d8c47b8434783d0b1b62 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Mon, 19 Oct 2015 00:00:26 +1100 Subject: [PATCH] Avoid allocation in iplist.Ranger.Lookup It was very expensive for DHT packets. --- client.go | 28 +++++++++++++--------------- dht/dht.go | 7 ++++--- iplist/iplist.go | 43 +++++++++++++++++++++++++------------------ iplist/iplist_test.go | 23 ++++++++++++----------- iplist/packed.go | 19 ++++++++++++++----- 5 files changed, 68 insertions(+), 52 deletions(-) diff --git a/client.go b/client.go index 121b0849..0385b4fa 100644 --- a/client.go +++ b/client.go @@ -581,18 +581,20 @@ func (me *Client) Close() { var ipv6BlockRange = iplist.Range{Description: "non-IPv4 address"} -func (cl *Client) ipBlockRange(ip net.IP) (r *iplist.Range) { +func (cl *Client) ipBlockRange(ip net.IP) (r iplist.Range, blocked bool) { if cl.ipBlockList == nil { return } ip4 := ip.To4() + // If blocklists are enabled, then block non-IPv4 addresses, because + // blocklists do not yet support IPv6. if ip4 == nil { log.Printf("blocking non-IPv4 address: %s", ip) - r = &ipv6BlockRange + r = ipv6BlockRange + blocked = true return } - r = cl.ipBlockList.Lookup(ip4) - return + return cl.ipBlockList.Lookup(ip4) } func (cl *Client) waitAccept() { @@ -638,9 +640,9 @@ func (cl *Client) acceptConnections(l net.Listener, utp bool) { } cl.mu.RLock() doppleganger := cl.dopplegangerAddr(conn.RemoteAddr().String()) - blockRange := cl.ipBlockRange(AddrIP(conn.RemoteAddr())) + _, blocked := cl.ipBlockRange(AddrIP(conn.RemoteAddr())) cl.mu.RUnlock() - if blockRange != nil || doppleganger { + if blocked || doppleganger { acceptReject.Add(1) // log.Printf("inbound connection from %s blocked by %s", conn.RemoteAddr(), blockRange) conn.Close() @@ -728,7 +730,7 @@ func (me *Client) initiateConn(peer Peer, t *torrent) { duplicateConnsAvoided.Add(1) return } - if r := me.ipBlockRange(peer.IP); r != nil { + if r, ok := me.ipBlockRange(peer.IP); ok { log.Printf("outbound connect to %s blocked by IP blocklist rule %s", peer.IP, r) return } @@ -1846,7 +1848,7 @@ func (me *Client) addPeers(t *torrent, peers []Peer) { )) { continue } - if me.ipBlockRange(p.IP) != nil { + if _, ok := me.ipBlockRange(p.IP); ok { continue } if p.Port == 0 { @@ -2337,13 +2339,9 @@ func (cl *Client) trackerBlockedUnlocked(tr tracker.Client) (blocked bool, err e if err != nil { return } - cl.mu.Lock() - if cl.ipBlockList != nil { - if cl.ipBlockRange(addr.IP) != nil { - blocked = true - } - } - cl.mu.Unlock() + cl.mu.RLock() + _, blocked = cl.ipBlockRange(addr.IP) + cl.mu.RUnlock() return } diff --git a/dht/dht.go b/dht/dht.go index d7b6cc17..5d3789bf 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -688,11 +688,12 @@ func (s *Server) serve() error { } } -func (s *Server) ipBlocked(ip net.IP) bool { +func (s *Server) ipBlocked(ip net.IP) (blocked bool) { if s.ipBlockList == nil { - return false + return } - return s.ipBlockList.Lookup(ip) != nil + _, blocked = s.ipBlockList.Lookup(ip) + return } // Adds directly to the node table. diff --git a/iplist/iplist.go b/iplist/iplist.go index cd97b080..9e402dae 100644 --- a/iplist/iplist.go +++ b/iplist/iplist.go @@ -15,7 +15,7 @@ import ( // An abstraction of IP list implementations. type Ranger interface { // Return a Range containing the IP. - Lookup(net.IP) *Range + Lookup(net.IP) (r Range, ok bool) // If your ranges hurt, use this. NumRanges() int } @@ -50,17 +50,17 @@ func (me *IPList) NumRanges() int { } // Return the range the given IP is in. Returns nil if no range is found. -func (me *IPList) Lookup(ip net.IP) (r *Range) { +func (me *IPList) Lookup(ip net.IP) (r Range, ok bool) { if me == nil { - return nil + return } // TODO: Perhaps all addresses should be converted to IPv6, if the future // of IP is to always be backwards compatible. But this will cost 4x the // memory for IPv4 addresses? v4 := ip.To4() if v4 != nil { - r = me.lookup(v4) - if r != nil { + r, ok = me.lookup(v4) + if ok { return } } @@ -69,37 +69,44 @@ func (me *IPList) Lookup(ip net.IP) (r *Range) { return me.lookup(v6) } if v4 == nil && v6 == nil { - return &Range{ + r = Range{ Description: "bad IP", } + ok = true } - return nil + return } // Return a range that contains ip, or nil. -func lookup(f func(i int) Range, n int, ip net.IP) *Range { +func lookup( + first func(i int) net.IP, + full func(i int) Range, + n int, + ip net.IP, +) ( + r Range, ok bool, +) { // Find the index of the first range for which the following range exceeds // it. i := sort.Search(n, func(i int) bool { if i+1 >= n { return true } - r := f(i + 1) - return bytes.Compare(ip, r.First) < 0 + return bytes.Compare(ip, first(i+1)) < 0 }) if i == n { - return nil - } - r := f(i) - if bytes.Compare(ip, r.First) < 0 || bytes.Compare(ip, r.Last) > 0 { - return nil + return } - return &r + r = full(i) + ok = bytes.Compare(r.First, ip) <= 0 && bytes.Compare(ip, r.Last) <= 0 + return } // Return the range the given IP is in. Returns nil if no range is found. -func (me *IPList) lookup(ip net.IP) (r *Range) { - return lookup(func(i int) Range { +func (me *IPList) lookup(ip net.IP) (Range, bool) { + return lookup(func(i int) net.IP { + return me.ranges[i].First + }, func(i int) Range { return me.ranges[i] }, len(me.ranges), ip) } diff --git a/iplist/iplist_test.go b/iplist/iplist_test.go index 2b24e7e6..e8d2cbe4 100644 --- a/iplist/iplist_test.go +++ b/iplist/iplist_test.go @@ -95,15 +95,21 @@ func connRemoteAddrIP(network, laddr string, dialHost string) net.IP { return ret } +func lookupOk(r Range, ok bool) bool { + return ok +} + func TestBadIP(t *testing.T) { for _, iplist := range []Ranger{ New(nil), NewFromPacked([]byte("\x00\x00\x00\x00\x00\x00\x00\x00")), } { - assert.Nil(t, iplist.Lookup(net.IP(make([]byte, 4))), "%v", iplist) - assert.Nil(t, iplist.Lookup(net.IP(make([]byte, 16)))) - assert.Equal(t, iplist.Lookup(nil).Description, "bad IP") - assert.NotNil(t, iplist.Lookup(net.IP(make([]byte, 5)))) + assert.False(t, lookupOk(iplist.Lookup(net.IP(make([]byte, 4)))), "%v", iplist) + assert.False(t, lookupOk(iplist.Lookup(net.IP(make([]byte, 16))))) + r, ok := iplist.Lookup(nil) + assert.True(t, ok) + assert.Equal(t, r.Description, "bad IP") + assert.True(t, lookupOk(iplist.Lookup(net.IP(make([]byte, 5))))) } } @@ -123,16 +129,11 @@ func testLookuperSimple(t *testing.T, iplist Ranger) { {"1.2.8.2", true, "eff"}, } { ip := net.ParseIP(_case.IP) - r := iplist.Lookup(ip) + r, ok := iplist.Lookup(ip) + assert.Equal(t, _case.Hit, ok, "%s", _case) if !_case.Hit { - if r != nil { - t.Fatalf("got hit when none was expected: %s", ip) - } continue } - if r == nil { - t.Fatalf("expected hit for %q", _case.IP) - } assert.Equal(t, _case.Desc, r.Description, "%T", iplist) } } diff --git a/iplist/packed.go b/iplist/packed.go index 152c680a..e6ab41ad 100644 --- a/iplist/packed.go +++ b/iplist/packed.go @@ -76,28 +76,37 @@ func (me PackedIPList) NumRanges() int { return me.len() } +func (me PackedIPList) getFirst(i int) net.IP { + off := packedRangesOffset + packedRangeLen*i + return net.IP(me[off : off+4]) +} + func (me PackedIPList) getRange(i int) (ret Range) { rOff := packedRangesOffset + packedRangeLen*i - first := me[rOff : rOff+4] last := me[rOff+4 : rOff+8] descOff := int(binary.LittleEndian.Uint64(me[rOff+8:])) descLen := int(binary.LittleEndian.Uint32(me[rOff+16:])) descOff += packedRangesOffset + packedRangeLen*me.len() - ret = Range{net.IP(first), net.IP(last), string(me[descOff : descOff+descLen])} + ret = Range{ + me.getFirst(i), + net.IP(last), + string(me[descOff : descOff+descLen]), + } return } -func (me PackedIPList) Lookup(ip net.IP) (r *Range) { +func (me PackedIPList) Lookup(ip net.IP) (r Range, ok bool) { ip4 := ip.To4() if ip4 == nil { // If the IP list was built successfully, then it only contained IPv4 // ranges. Therefore no IPv6 ranges are blocked. if ip.To16() == nil { - r = &Range{ + r = Range{ Description: "bad IP", } + ok = true } return } - return lookup(me.getRange, me.len(), ip4) + return lookup(me.getFirst, me.getRange, me.len(), ip4) } -- 2.48.1