From 4939dd4e5764a44a4eb148ae1f430787835822db Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sun, 17 Jun 2018 16:21:57 +1000 Subject: [PATCH] Add some tests for net.Addr.Network values in various situations --- network_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 network_test.go diff --git a/network_test.go b/network_test.go new file mode 100644 index 00000000..32846e4a --- /dev/null +++ b/network_test.go @@ -0,0 +1,75 @@ +package torrent + +import ( + "log" + "net" + "testing" + + "github.com/anacrolix/missinggo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testListenerNetwork( + t *testing.T, + listenFunc func(net, addr string) (net.Listener, error), + expectedNet, givenNet, addr string, validIp4 bool, +) { + l, err := listenFunc(givenNet, addr) + require.NoError(t, err) + defer l.Close() + assert.EqualValues(t, expectedNet, l.Addr().Network()) + log.Print(missinggo.AddrIP(l.Addr())) + assert.Equal(t, validIp4, missinggo.AddrIP(l.Addr()).To4() != nil) +} + +func listenUtpListener(net, addr string) (l net.Listener, err error) { + l, err = NewUtpSocket(net, addr) + return +} + +func testAcceptedConnAddr( + t *testing.T, + network string, valid4 bool, + dial func(addr string) (net.Conn, error), + listen func() (net.Listener, error), +) { + l, err := listen() + require.NoError(t, err) + defer l.Close() + done := make(chan struct{}) + defer close(done) + go func() { + c, err := dial(l.Addr().String()) + require.NoError(t, err) + <-done + c.Close() + }() + c, err := l.Accept() + require.NoError(t, err) + defer c.Close() + assert.EqualValues(t, network, c.RemoteAddr().Network()) + assert.Equal(t, valid4, missinggo.AddrIP(c.RemoteAddr()).To4() != nil) +} + +func listenClosure(rawListenFunc func(string, string) (net.Listener, error), network, addr string) func() (net.Listener, error) { + return func() (net.Listener, error) { + return rawListenFunc(network, addr) + } +} + +func dialClosure(f func(net, addr string) (net.Conn, error), network string) func(addr string) (net.Conn, error) { + return func(addr string) (net.Conn, error) { + return f(network, addr) + } +} + +func TestListenLocalhostNetwork(t *testing.T) { + testListenerNetwork(t, net.Listen, "tcp", "tcp", "0.0.0.0:0", false) + testListenerNetwork(t, net.Listen, "tcp", "tcp", "[::1]:0", false) + testListenerNetwork(t, listenUtpListener, "udp", "udp6", "[::1]:0", false) + testListenerNetwork(t, listenUtpListener, "udp", "udp6", "[::]:0", false) + testListenerNetwork(t, listenUtpListener, "udp", "udp4", "localhost:0", true) + + testAcceptedConnAddr(t, "tcp", false, dialClosure(net.Dial, "tcp"), listenClosure(net.Listen, "tcp6", ":0")) +} -- 2.50.0