]> Sergey Matveev's repositories - btrtrc.git/blobdiff - iplist/iplist.go
Update to multiple-blobs-per-value squirrel
[btrtrc.git] / iplist / iplist.go
index 18d6b0a383c4a8455161fa355b2bcf6d28652acc..d6d70a96256653224e9227f1fbfe6684a7bab28d 100644 (file)
@@ -12,6 +12,14 @@ import (
        "sort"
 )
 
+// An abstraction of IP list implementations.
+type Ranger interface {
+       // Return a Range containing the IP.
+       Lookup(net.IP) (r Range, ok bool)
+       // If your ranges hurt, use this.
+       NumRanges() int
+}
+
 type IPList struct {
        ranges []Range
 }
@@ -21,8 +29,8 @@ type Range struct {
        Description string
 }
 
-func (r *Range) String() string {
-       return fmt.Sprintf("%s-%s (%s)", r.First, r.Last, r.Description)
+func (r Range) String() string {
+       return fmt.Sprintf("%s-%s: %s", r.First, r.Last, r.Description)
 }
 
 // Create a new IP list. The given ranges must already sorted by the lower
@@ -34,60 +42,75 @@ func New(initSorted []Range) *IPList {
        }
 }
 
-func (me *IPList) NumRanges() int {
-       if me == nil {
+func (ipl *IPList) NumRanges() int {
+       if ipl == nil {
                return 0
        }
-       return len(me.ranges)
+       return len(ipl.ranges)
 }
 
-// Return the range the given IP is in. Returns nil if no range is found.
-func (me *IPList) Lookup(ip net.IP) (r *Range) {
-       if me == nil {
-               return nil
+// Return the range the given IP is in. ok if false if no range is found.
+func (ipl *IPList) Lookup(ip net.IP) (r Range, ok bool) {
+       if ipl == nil {
+               return
        }
        // TODO: Perhaps all addresses should be converted to IPv6, if the future
        // of IP is to always be backwards compatible. But this will cost 4x the
        // memory for IPv4 addresses?
        v4 := ip.To4()
        if v4 != nil {
-               r = me.lookup(v4)
-               if r != nil {
+               r, ok = ipl.lookup(v4)
+               if ok {
                        return
                }
        }
        v6 := ip.To16()
        if v6 != nil {
-               return me.lookup(v6)
+               return ipl.lookup(v6)
        }
        if v4 == nil && v6 == nil {
-               return &Range{
-                       Description: fmt.Sprintf("unsupported IP: %s", ip),
+               r = Range{
+                       Description: "bad IP",
                }
+               ok = true
        }
-       return nil
+       return
 }
 
-// Return the range the given IP is in. Returns nil if no range is found.
-func (me *IPList) lookup(ip net.IP) (r *Range) {
+// Return a range that contains ip, or nil.
+func lookup(
+       first func(i int) net.IP,
+       full func(i int) Range,
+       n int,
+       ip net.IP,
+) (
+       r Range, ok bool,
+) {
        // Find the index of the first range for which the following range exceeds
        // it.
-       i := sort.Search(len(me.ranges), func(i int) bool {
-               if i+1 >= len(me.ranges) {
+       i := sort.Search(n, func(i int) bool {
+               if i+1 >= n {
                        return true
                }
-               return bytes.Compare(ip, me.ranges[i+1].First) < 0
+               return bytes.Compare(ip, first(i+1)) < 0
        })
-       if i == len(me.ranges) {
+       if i == n {
                return
        }
-       r = &me.ranges[i]
-       if bytes.Compare(ip, r.First) < 0 || bytes.Compare(ip, r.Last) > 0 {
-               r = nil
-       }
+       r = full(i)
+       ok = bytes.Compare(r.First, ip) <= 0 && bytes.Compare(ip, r.Last) <= 0
        return
 }
 
+// Return the range the given IP is in. Returns nil if no range is found.
+func (ipl *IPList) lookup(ip net.IP) (Range, bool) {
+       return lookup(func(i int) net.IP {
+               return ipl.ranges[i].First
+       }, func(i int) Range {
+               return ipl.ranges[i]
+       }, len(ipl.ranges), ip)
+}
+
 func minifyIP(ip *net.IP) {
        v4 := ip.To4()
        if v4 != nil {