]> Sergey Matveev's repositories - btrtrc.git/blobdiff - iplist/packed.go
Buffer metainfo loads from files
[btrtrc.git] / iplist / packed.go
index b5c0a5e3429eb61ee4e5edb0f3667e053731e2ed..5ae1faea2b88a4cde89804d4e58b4dbbfe691442 100644 (file)
@@ -1,7 +1,11 @@
+//go:build !wasm
+// +build !wasm
+
 package iplist
 
 import (
        "encoding/binary"
+       "fmt"
        "io"
        "net"
        "os"
@@ -18,7 +22,7 @@ import (
 
 const (
        packedRangesOffset = 8
-       packedRangeLen     = 20
+       packedRangeLen     = 44
 )
 
 func (ipl *IPList) WritePacked(w io.Writer) (err error) {
@@ -43,8 +47,8 @@ func (ipl *IPList) WritePacked(w io.Writer) (err error) {
        binary.LittleEndian.PutUint64(b[:], uint64(len(ipl.ranges)))
        write(b[:], 8)
        for _, r := range ipl.ranges {
-               write(r.First.To4(), 4)
-               write(r.Last.To4(), 4)
+               write(r.First.To16(), 16)
+               write(r.Last.To16(), 16)
                descOff, ok := descOffsets[r.Description]
                if !ok {
                        descOff = nextOffset
@@ -64,7 +68,12 @@ func (ipl *IPList) WritePacked(w io.Writer) (err error) {
 }
 
 func NewFromPacked(b []byte) PackedIPList {
-       return PackedIPList(b)
+       ret := PackedIPList(b)
+       minLen := packedRangesOffset + ret.len()*packedRangeLen
+       if len(b) < minLen {
+               panic(fmt.Sprintf("packed len %d < %d", len(b), minLen))
+       }
+       return ret
 }
 
 type PackedIPList []byte
@@ -81,14 +90,14 @@ func (pil PackedIPList) NumRanges() int {
 
 func (pil PackedIPList) getFirst(i int) net.IP {
        off := packedRangesOffset + packedRangeLen*i
-       return net.IP(pil[off : off+4])
+       return net.IP(pil[off : off+16])
 }
 
 func (pil PackedIPList) getRange(i int) (ret Range) {
        rOff := packedRangesOffset + packedRangeLen*i
-       last := pil[rOff+4 : rOff+8]
-       descOff := int(binary.LittleEndian.Uint64(pil[rOff+8:]))
-       descLen := int(binary.LittleEndian.Uint32(pil[rOff+16:]))
+       last := pil[rOff+16 : rOff+32]
+       descOff := int(binary.LittleEndian.Uint64(pil[rOff+32:]))
+       descLen := int(binary.LittleEndian.Uint32(pil[rOff+40:]))
        descOff += packedRangesOffset + packedRangeLen*pil.len()
        ret = Range{
                pil.getFirst(i),
@@ -99,27 +108,27 @@ func (pil PackedIPList) getRange(i int) (ret Range) {
 }
 
 func (pil PackedIPList) Lookup(ip net.IP) (r Range, ok bool) {
-       ip4 := ip.To4()
-       if ip4 == nil {
-               // If the IP list was built successfully, then it only contained IPv4
-               // ranges. Therefore no IPv6 ranges are blocked.
-               if ip.To16() == nil {
-                       r = Range{
-                               Description: "bad IP",
-                       }
-                       ok = true
-               }
-               return
+       ip16 := ip.To16()
+       if ip16 == nil {
+               panic(ip)
        }
-       return lookup(pil.getFirst, pil.getRange, pil.len(), ip4)
+       return lookup(pil.getFirst, pil.getRange, pil.len(), ip16)
 }
 
-func MMapPacked(filename string) (ret Ranger, err error) {
+type closerFunc func() error
+
+func (me closerFunc) Close() error {
+       return me()
+}
+
+func MMapPackedFile(filename string) (
+       ret interface {
+               Ranger
+               io.Closer
+       },
+       err error,
+) {
        f, err := os.Open(filename)
-       if os.IsNotExist(err) {
-               err = nil
-               return
-       }
        if err != nil {
                return
        }
@@ -128,6 +137,9 @@ func MMapPacked(filename string) (ret Ranger, err error) {
        if err != nil {
                return
        }
-       ret = NewFromPacked(mm)
+       ret = struct {
+               Ranger
+               io.Closer
+       }{NewFromPacked(mm), closerFunc(mm.Unmap)}
        return
 }