]> Sergey Matveev's repositories - btrtrc.git/blobdiff - iplist/iplist.go
Avoid allocation in iplist.Ranger.Lookup
[btrtrc.git] / iplist / iplist.go
index cd97b080312f77dfa7e372a214edbf38a994fff5..9e402daed3e6571ce85fbf28db454da0b6c6b2f8 100644 (file)
@@ -15,7 +15,7 @@ import (
 // An abstraction of IP list implementations.
 type Ranger interface {
        // Return a Range containing the IP.
-       Lookup(net.IP) *Range
+       Lookup(net.IP) (r Range, ok bool)
        // If your ranges hurt, use this.
        NumRanges() int
 }
@@ -50,17 +50,17 @@ func (me *IPList) NumRanges() int {
 }
 
 // Return the range the given IP is in. Returns nil if no range is found.
-func (me *IPList) Lookup(ip net.IP) (r *Range) {
+func (me *IPList) Lookup(ip net.IP) (r Range, ok bool) {
        if me == nil {
-               return nil
+               return
        }
        // TODO: Perhaps all addresses should be converted to IPv6, if the future
        // of IP is to always be backwards compatible. But this will cost 4x the
        // memory for IPv4 addresses?
        v4 := ip.To4()
        if v4 != nil {
-               r = me.lookup(v4)
-               if r != nil {
+               r, ok = me.lookup(v4)
+               if ok {
                        return
                }
        }
@@ -69,37 +69,44 @@ func (me *IPList) Lookup(ip net.IP) (r *Range) {
                return me.lookup(v6)
        }
        if v4 == nil && v6 == nil {
-               return &Range{
+               r = Range{
                        Description: "bad IP",
                }
+               ok = true
        }
-       return nil
+       return
 }
 
 // Return a range that contains ip, or nil.
-func lookup(f func(i int) Range, n int, ip net.IP) *Range {
+func lookup(
+       first func(i int) net.IP,
+       full func(i int) Range,
+       n int,
+       ip net.IP,
+) (
+       r Range, ok bool,
+) {
        // Find the index of the first range for which the following range exceeds
        // it.
        i := sort.Search(n, func(i int) bool {
                if i+1 >= n {
                        return true
                }
-               r := f(i + 1)
-               return bytes.Compare(ip, r.First) < 0
+               return bytes.Compare(ip, first(i+1)) < 0
        })
        if i == n {
-               return nil
-       }
-       r := f(i)
-       if bytes.Compare(ip, r.First) < 0 || bytes.Compare(ip, r.Last) > 0 {
-               return nil
+               return
        }
-       return &r
+       r = full(i)
+       ok = bytes.Compare(r.First, ip) <= 0 && bytes.Compare(ip, r.Last) <= 0
+       return
 }
 
 // Return the range the given IP is in. Returns nil if no range is found.
-func (me *IPList) lookup(ip net.IP) (r *Range) {
-       return lookup(func(i int) Range {
+func (me *IPList) lookup(ip net.IP) (Range, bool) {
+       return lookup(func(i int) net.IP {
+               return me.ranges[i].First
+       }, func(i int) Range {
                return me.ranges[i]
        }, len(me.ranges), ip)
 }