]> Sergey Matveev's repositories - btrtrc.git/commitdiff
iplist: Fail invalid IPs, they were always passing
authorMatt Joiner <anacrolix@gmail.com>
Fri, 27 Mar 2015 15:54:17 +0000 (02:54 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 27 Mar 2015 15:54:17 +0000 (02:54 +1100)
iplist/iplist.go
iplist/iplist_test.go

index 958af5a94d5fe00897373748e3164e73075cc9f8..93fb9b31ad8bfb7a63e5bcbf90eb18d4d2b04c5b 100644 (file)
@@ -45,15 +45,22 @@ func (me *IPList) Lookup(ip net.IP) (r *Range) {
        // TODO: Perhaps all addresses should be converted to IPv6, if the future
        // of IP is to always be backwards compatible. But this will cost 4x the
        // memory for IPv4 addresses?
-       if v4 := ip.To4(); v4 != nil {
+       v4 := ip.To4()
+       if v4 != nil {
                r = me.lookup(v4)
                if r != nil {
                        return
                }
        }
-       if v6 := ip.To16(); v6 != nil {
+       v6 := ip.To16()
+       if v6 != nil {
                return me.lookup(v6)
        }
+       if v4 == nil && v6 == nil {
+               return &Range{
+                       Description: fmt.Sprintf("unsupported IP: %s", ip),
+               }
+       }
        return nil
 }
 
index 09d216fc2169d66bf2d58f472d5e333958683d50..f21ecc3dbad632730b79adc69e5d4c05e27b1bd7 100644 (file)
@@ -73,6 +73,22 @@ func connRemoteAddrIP(network, laddr string, dialHost string) net.IP {
        return ret
 }
 
+func TestBadIP(t *testing.T) {
+       iplist := New(nil)
+       if iplist.Lookup(net.IP(make([]byte, 4))) != nil {
+               t.FailNow()
+       }
+       if iplist.Lookup(net.IP(make([]byte, 16))) != nil {
+               t.FailNow()
+       }
+       if iplist.Lookup(nil) == nil {
+               t.FailNow()
+       }
+       if iplist.Lookup(net.IP(make([]byte, 5))) == nil {
+               t.FailNow()
+       }
+}
+
 func TestSimple(t *testing.T) {
        ranges, err := sampleRanges(t)
        if err != nil {
@@ -90,14 +106,16 @@ func TestSimple(t *testing.T) {
                {"1.2.3.255", false, ""},
                {"1.2.8.0", true, "b"},
                {"1.2.4.255", true, "a"},
-               // Try to roll over to the next octet on the parse.
-               {"1.2.7.256", false, ""},
+               // Try to roll over to the next octet on the parse. Note the final
+               // octet is overbounds. In the next case.
+               {"1.2.7.256", true, "unsupported IP: <nil>"},
                {"1.2.8.254", true, "b"},
        } {
-               r := iplist.Lookup(net.ParseIP(_case.IP))
+               ip := net.ParseIP(_case.IP)
+               r := iplist.Lookup(ip)
                if !_case.Hit {
                        if r != nil {
-                               t.Fatalf("got hit when none was expected")
+                               t.Fatalf("got hit when none was expected: %s", ip)
                        }
                        continue
                }