util/addr.go | 34 ++++++++++++++++++++++------------ diff --git a/util/addr.go b/util/addr.go index a91dec3fd6d95f90529544b0af17b42f12ca04fe..b6ec49e13bf1bb8cc3e6b85a184a88f634d5f622 100644 --- a/util/addr.go +++ b/util/addr.go @@ -7,21 +7,31 @@ ) // Extracts the port as an integer from an address string. func AddrPort(addr net.Addr) int { - _, port, err := net.SplitHostPort(addr.String()) - if err != nil { - panic(err) - } - i64, err := strconv.ParseInt(port, 0, 0) - if err != nil { - panic(err) + switch raw := addr.(type) { + case *net.UDPAddr: + return raw.Port + default: + _, port, err := net.SplitHostPort(addr.String()) + if err != nil { + panic(err) + } + i64, err := strconv.ParseInt(port, 0, 0) + if err != nil { + panic(err) + } + return int(i64) } - return int(i64) } func AddrIP(addr net.Addr) net.IP { - host, _, err := net.SplitHostPort(addr.String()) - if err != nil { - panic(err) + switch raw := addr.(type) { + case *net.UDPAddr: + return raw.IP + default: + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + panic(err) + } + return net.ParseIP(host) } - return net.ParseIP(host) }