]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Avoid allocation in iplist.Ranger.Lookup
authorMatt Joiner <anacrolix@gmail.com>
Sun, 18 Oct 2015 13:00:26 +0000 (00:00 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Sun, 18 Oct 2015 13:00:26 +0000 (00:00 +1100)
It was very expensive for DHT packets.

client.go
dht/dht.go
iplist/iplist.go
iplist/iplist_test.go
iplist/packed.go

index 121b0849ab65e542d8d07ff6069bd5652445e722..0385b4fa5b8d0d4b1ac0f751a075cb32fe97eab5 100644 (file)
--- a/client.go
+++ b/client.go
@@ -581,18 +581,20 @@ func (me *Client) Close() {
 
 var ipv6BlockRange = iplist.Range{Description: "non-IPv4 address"}
 
-func (cl *Client) ipBlockRange(ip net.IP) (r *iplist.Range) {
+func (cl *Client) ipBlockRange(ip net.IP) (r iplist.Range, blocked bool) {
        if cl.ipBlockList == nil {
                return
        }
        ip4 := ip.To4()
+       // If blocklists are enabled, then block non-IPv4 addresses, because
+       // blocklists do not yet support IPv6.
        if ip4 == nil {
                log.Printf("blocking non-IPv4 address: %s", ip)
-               r = &ipv6BlockRange
+               r = ipv6BlockRange
+               blocked = true
                return
        }
-       r = cl.ipBlockList.Lookup(ip4)
-       return
+       return cl.ipBlockList.Lookup(ip4)
 }
 
 func (cl *Client) waitAccept() {
@@ -638,9 +640,9 @@ func (cl *Client) acceptConnections(l net.Listener, utp bool) {
                }
                cl.mu.RLock()
                doppleganger := cl.dopplegangerAddr(conn.RemoteAddr().String())
-               blockRange := cl.ipBlockRange(AddrIP(conn.RemoteAddr()))
+               _, blocked := cl.ipBlockRange(AddrIP(conn.RemoteAddr()))
                cl.mu.RUnlock()
-               if blockRange != nil || doppleganger {
+               if blocked || doppleganger {
                        acceptReject.Add(1)
                        // log.Printf("inbound connection from %s blocked by %s", conn.RemoteAddr(), blockRange)
                        conn.Close()
@@ -728,7 +730,7 @@ func (me *Client) initiateConn(peer Peer, t *torrent) {
                duplicateConnsAvoided.Add(1)
                return
        }
-       if r := me.ipBlockRange(peer.IP); r != nil {
+       if r, ok := me.ipBlockRange(peer.IP); ok {
                log.Printf("outbound connect to %s blocked by IP blocklist rule %s", peer.IP, r)
                return
        }
@@ -1846,7 +1848,7 @@ func (me *Client) addPeers(t *torrent, peers []Peer) {
                )) {
                        continue
                }
-               if me.ipBlockRange(p.IP) != nil {
+               if _, ok := me.ipBlockRange(p.IP); ok {
                        continue
                }
                if p.Port == 0 {
@@ -2337,13 +2339,9 @@ func (cl *Client) trackerBlockedUnlocked(tr tracker.Client) (blocked bool, err e
        if err != nil {
                return
        }
-       cl.mu.Lock()
-       if cl.ipBlockList != nil {
-               if cl.ipBlockRange(addr.IP) != nil {
-                       blocked = true
-               }
-       }
-       cl.mu.Unlock()
+       cl.mu.RLock()
+       _, blocked = cl.ipBlockRange(addr.IP)
+       cl.mu.RUnlock()
        return
 }
 
index d7b6cc171cd414b3172e20c1db45f9d3b885707e..5d3789bfc60d5e0dd5e77d2580ceec9079318605 100644 (file)
@@ -688,11 +688,12 @@ func (s *Server) serve() error {
        }
 }
 
-func (s *Server) ipBlocked(ip net.IP) bool {
+func (s *Server) ipBlocked(ip net.IP) (blocked bool) {
        if s.ipBlockList == nil {
-               return false
+               return
        }
-       return s.ipBlockList.Lookup(ip) != nil
+       _, blocked = s.ipBlockList.Lookup(ip)
+       return
 }
 
 // Adds directly to the node table.
index cd97b080312f77dfa7e372a214edbf38a994fff5..9e402daed3e6571ce85fbf28db454da0b6c6b2f8 100644 (file)
@@ -15,7 +15,7 @@ import (
 // An abstraction of IP list implementations.
 type Ranger interface {
        // Return a Range containing the IP.
-       Lookup(net.IP) *Range
+       Lookup(net.IP) (r Range, ok bool)
        // If your ranges hurt, use this.
        NumRanges() int
 }
@@ -50,17 +50,17 @@ func (me *IPList) NumRanges() int {
 }
 
 // Return the range the given IP is in. Returns nil if no range is found.
-func (me *IPList) Lookup(ip net.IP) (r *Range) {
+func (me *IPList) Lookup(ip net.IP) (r Range, ok bool) {
        if me == nil {
-               return nil
+               return
        }
        // TODO: Perhaps all addresses should be converted to IPv6, if the future
        // of IP is to always be backwards compatible. But this will cost 4x the
        // memory for IPv4 addresses?
        v4 := ip.To4()
        if v4 != nil {
-               r = me.lookup(v4)
-               if r != nil {
+               r, ok = me.lookup(v4)
+               if ok {
                        return
                }
        }
@@ -69,37 +69,44 @@ func (me *IPList) Lookup(ip net.IP) (r *Range) {
                return me.lookup(v6)
        }
        if v4 == nil && v6 == nil {
-               return &Range{
+               r = Range{
                        Description: "bad IP",
                }
+               ok = true
        }
-       return nil
+       return
 }
 
 // Return a range that contains ip, or nil.
-func lookup(f func(i int) Range, n int, ip net.IP) *Range {
+func lookup(
+       first func(i int) net.IP,
+       full func(i int) Range,
+       n int,
+       ip net.IP,
+) (
+       r Range, ok bool,
+) {
        // Find the index of the first range for which the following range exceeds
        // it.
        i := sort.Search(n, func(i int) bool {
                if i+1 >= n {
                        return true
                }
-               r := f(i + 1)
-               return bytes.Compare(ip, r.First) < 0
+               return bytes.Compare(ip, first(i+1)) < 0
        })
        if i == n {
-               return nil
-       }
-       r := f(i)
-       if bytes.Compare(ip, r.First) < 0 || bytes.Compare(ip, r.Last) > 0 {
-               return nil
+               return
        }
-       return &r
+       r = full(i)
+       ok = bytes.Compare(r.First, ip) <= 0 && bytes.Compare(ip, r.Last) <= 0
+       return
 }
 
 // Return the range the given IP is in. Returns nil if no range is found.
-func (me *IPList) lookup(ip net.IP) (r *Range) {
-       return lookup(func(i int) Range {
+func (me *IPList) lookup(ip net.IP) (Range, bool) {
+       return lookup(func(i int) net.IP {
+               return me.ranges[i].First
+       }, func(i int) Range {
                return me.ranges[i]
        }, len(me.ranges), ip)
 }
index 2b24e7e644939cabe570930778a4a066bf337ed4..e8d2cbe46fcd93e651ef425f42e0ca1cedcb2443 100644 (file)
@@ -95,15 +95,21 @@ func connRemoteAddrIP(network, laddr string, dialHost string) net.IP {
        return ret
 }
 
+func lookupOk(r Range, ok bool) bool {
+       return ok
+}
+
 func TestBadIP(t *testing.T) {
        for _, iplist := range []Ranger{
                New(nil),
                NewFromPacked([]byte("\x00\x00\x00\x00\x00\x00\x00\x00")),
        } {
-               assert.Nil(t, iplist.Lookup(net.IP(make([]byte, 4))), "%v", iplist)
-               assert.Nil(t, iplist.Lookup(net.IP(make([]byte, 16))))
-               assert.Equal(t, iplist.Lookup(nil).Description, "bad IP")
-               assert.NotNil(t, iplist.Lookup(net.IP(make([]byte, 5))))
+               assert.False(t, lookupOk(iplist.Lookup(net.IP(make([]byte, 4)))), "%v", iplist)
+               assert.False(t, lookupOk(iplist.Lookup(net.IP(make([]byte, 16)))))
+               r, ok := iplist.Lookup(nil)
+               assert.True(t, ok)
+               assert.Equal(t, r.Description, "bad IP")
+               assert.True(t, lookupOk(iplist.Lookup(net.IP(make([]byte, 5)))))
        }
 }
 
@@ -123,16 +129,11 @@ func testLookuperSimple(t *testing.T, iplist Ranger) {
                {"1.2.8.2", true, "eff"},
        } {
                ip := net.ParseIP(_case.IP)
-               r := iplist.Lookup(ip)
+               r, ok := iplist.Lookup(ip)
+               assert.Equal(t, _case.Hit, ok, "%s", _case)
                if !_case.Hit {
-                       if r != nil {
-                               t.Fatalf("got hit when none was expected: %s", ip)
-                       }
                        continue
                }
-               if r == nil {
-                       t.Fatalf("expected hit for %q", _case.IP)
-               }
                assert.Equal(t, _case.Desc, r.Description, "%T", iplist)
        }
 }
index 152c680a18590ff9cc7abac251b5743cdd482b05..e6ab41ad23506afb95a9b270df6693e465a30753 100644 (file)
@@ -76,28 +76,37 @@ func (me PackedIPList) NumRanges() int {
        return me.len()
 }
 
+func (me PackedIPList) getFirst(i int) net.IP {
+       off := packedRangesOffset + packedRangeLen*i
+       return net.IP(me[off : off+4])
+}
+
 func (me PackedIPList) getRange(i int) (ret Range) {
        rOff := packedRangesOffset + packedRangeLen*i
-       first := me[rOff : rOff+4]
        last := me[rOff+4 : rOff+8]
        descOff := int(binary.LittleEndian.Uint64(me[rOff+8:]))
        descLen := int(binary.LittleEndian.Uint32(me[rOff+16:]))
        descOff += packedRangesOffset + packedRangeLen*me.len()
-       ret = Range{net.IP(first), net.IP(last), string(me[descOff : descOff+descLen])}
+       ret = Range{
+               me.getFirst(i),
+               net.IP(last),
+               string(me[descOff : descOff+descLen]),
+       }
        return
 }
 
-func (me PackedIPList) Lookup(ip net.IP) (r *Range) {
+func (me PackedIPList) Lookup(ip net.IP) (r Range, ok bool) {
        ip4 := ip.To4()
        if ip4 == nil {
                // If the IP list was built successfully, then it only contained IPv4
                // ranges. Therefore no IPv6 ranges are blocked.
                if ip.To16() == nil {
-                       r = &Range{
+                       r = Range{
                                Description: "bad IP",
                        }
+                       ok = true
                }
                return
        }
-       return lookup(me.getRange, me.len(), ip4)
+       return lookup(me.getFirst, me.getRange, me.len(), ip4)
 }