]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add support for non-IP-based networks
authorMatt Joiner <anacrolix@gmail.com>
Thu, 20 Feb 2020 05:47:37 +0000 (16:47 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 20 Feb 2020 05:47:37 +0000 (16:47 +1100)
Includes a test with unix sockets. Exposes AddDialer, AddListener, and reworks Peer.

20 files changed:
Peer.go
Peers.go
client.go
client_test.go
cmd/torrent-pick/main.go
cmd/torrent/main.go
cmd/torrentfs/main.go
config.go
connection.go
connection_test.go
dialer.go [new file with mode: 0644]
go.mod
ipport.go [new file with mode: 0644]
misc.go
prioritized_peers_test.go
socket.go
test/init_test.go [new file with mode: 0644]
test/transfer_test.go
test/unix_test.go [new file with mode: 0644]
torrent.go

diff --git a/Peer.go b/Peer.go
index 02b2c31aec76ef98af52bc8aa4cc6aa095d2b4d2..c71aff211bd79431cf519845b1a2b71b0ccad05f 100644 (file)
--- a/Peer.go
+++ b/Peer.go
@@ -11,8 +11,7 @@ import (
 // Peer connection info, handed about publicly.
 type Peer struct {
        Id     [20]byte
-       IP     net.IP
-       Port   int
+       Addr   net.Addr
        Source peerSource
        // Peer is known to support encryption.
        SupportsEncryption bool
@@ -23,8 +22,7 @@ type Peer struct {
 
 // FromPex generate Peer from peer exchange
 func (me *Peer) FromPex(na krpc.NodeAddr, fs peer_protocol.PexPeerFlags) {
-       me.IP = append([]byte(nil), na.IP...)
-       me.Port = na.Port
+       me.Addr = ipPortAddr{append([]byte(nil), na.IP...), na.Port}
        me.Source = peerSourcePex
        // If they prefer encryption, they must support it.
        if fs.Get(peer_protocol.PexPrefersEncryption) {
@@ -34,5 +32,5 @@ func (me *Peer) FromPex(na krpc.NodeAddr, fs peer_protocol.PexPeerFlags) {
 }
 
 func (me Peer) addr() IpPort {
-       return IpPort{IP: me.IP, Port: uint16(me.Port)}
+       return IpPort{IP: addrIpOrNil(me.Addr), Port: uint16(addrPortOrZero(me.Addr))}
 }
index be70203268675cb6cd417c5ae55ee4724808a096..0c2d726f70bb05f9cb5b4bb13fbd9083e3f8c59c 100644 (file)
--- a/Peers.go
+++ b/Peers.go
@@ -24,8 +24,7 @@ func (me *Peers) AppendFromPex(nas []krpc.NodeAddr, fs []peer_protocol.PexPeerFl
 func (ret Peers) AppendFromTracker(ps []tracker.Peer) Peers {
        for _, p := range ps {
                _p := Peer{
-                       IP:     p.IP,
-                       Port:   p.Port,
+                       Addr:   ipPortAddr{p.IP, p.Port},
                        Source: peerSourceTracker,
                }
                copy(_p.Id[:], p.ID)
index 60bd27e88b535189147ec5997a281c7b051d89ed..451788c55289f7f5fbc2379f2681e136c693e422 100644 (file)
--- a/client.go
+++ b/client.go
@@ -58,8 +58,8 @@ type Client struct {
        peerID         PeerID
        defaultStorage *storage.Client
        onClose        []func()
-       dialers        []dialer
-       listeners      []listener
+       dialers        []Dialer
+       listeners      []Listener
        dhtServers     []*dht.Server
        ipBlockList    iplist.Ranger
        // Our BitTorrent protocol extension bytes, sent in our BT handshakes.
@@ -92,18 +92,13 @@ func (cl *Client) PeerID() PeerID {
        return cl.peerID
 }
 
+// Returns the port number for the first listener that has one. No longer assumes that all port
+// numbers are the same, due to support for custom listeners. Returns zero if no port number is
+// found.
 func (cl *Client) LocalPort() (port int) {
-       cl.eachListener(func(l listener) bool {
-               _port := missinggo.AddrPort(l.Addr())
-               if _port == 0 {
-                       panic(l)
-               }
-               if port == 0 {
-                       port = _port
-               } else if port != _port {
-                       panic("mismatched ports")
-               }
-               return true
+       cl.eachListener(func(l Listener) bool {
+               port = addrPortOrZero(l.Addr())
+               return port == 0
        })
        return
 }
@@ -263,6 +258,19 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
        return
 }
 
+// Adds a Dialer for outgoing connections. All Dialers are used when attempting to connect to a
+// given address for any Torrent.
+func (cl *Client) AddDialer(d Dialer) {
+       cl.dialers = append(cl.dialers, d)
+}
+
+// Registers a Listener, and starts Accepting on it. You must Close Listeners provided this way
+// yourself.
+func (cl *Client) AddListener(l Listener) {
+       cl.listeners = append(cl.listeners, l)
+       go cl.acceptConnections(l)
+}
+
 func (cl *Client) firewallCallback(net.Addr) bool {
        cl.rLock()
        block := !cl.wantConns()
@@ -389,24 +397,26 @@ func (cl *Client) waitAccept() {
        }
 }
 
+// TODO: Apply filters for non-standard networks, particularly rate-limiting.
 func (cl *Client) rejectAccepted(conn net.Conn) error {
        ra := conn.RemoteAddr()
-       rip := missinggo.AddrIP(ra)
-       if cl.config.DisableIPv4Peers && rip.To4() != nil {
-               return errors.New("ipv4 peers disabled")
-       }
-       if cl.config.DisableIPv4 && len(rip) == net.IPv4len {
-               return errors.New("ipv4 disabled")
+       if rip := addrIpOrNil(ra); rip != nil {
+               if cl.config.DisableIPv4Peers && rip.To4() != nil {
+                       return errors.New("ipv4 peers disabled")
+               }
+               if cl.config.DisableIPv4 && len(rip) == net.IPv4len {
+                       return errors.New("ipv4 disabled")
 
-       }
-       if cl.config.DisableIPv6 && len(rip) == net.IPv6len && rip.To4() == nil {
-               return errors.New("ipv6 disabled")
-       }
-       if cl.rateLimitAccept(rip) {
-               return errors.New("source IP accepted rate limited")
-       }
-       if cl.badPeerIPPort(rip, missinggo.AddrPort(ra)) {
-               return errors.New("bad source addr")
+               }
+               if cl.config.DisableIPv6 && len(rip) == net.IPv6len && rip.To4() == nil {
+                       return errors.New("ipv6 disabled")
+               }
+               if cl.rateLimitAccept(rip) {
+                       return errors.New("source IP accepted rate limited")
+               }
+               if cl.badPeerIPPort(rip, missinggo.AddrPort(ra)) {
+                       return errors.New("bad source addr")
+               }
        }
        return nil
 }
@@ -441,8 +451,12 @@ func (cl *Client) acceptConnections(l net.Listener) {
                        } else {
                                go cl.incomingConnection(conn)
                        }
-                       log.Fmsg("accepted %s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr()).AddValue(debugLogValue).Log(cl.logger)
-                       torrent.Add(fmt.Sprintf("accepted conn remote IP len=%d", len(missinggo.AddrIP(conn.RemoteAddr()))), 1)
+                       log.Fmsg("accepted %q connection at %q from %q",
+                               l.Addr().Network(),
+                               conn.LocalAddr(),
+                               conn.RemoteAddr(),
+                       ).AddValue(debugLogValue).Log(cl.logger)
+                       torrent.Add(fmt.Sprintf("accepted conn remote IP len=%d", len(addrIpOrNil(conn.RemoteAddr()))), 1)
                        torrent.Add(fmt.Sprintf("accepted conn network=%s", conn.RemoteAddr().Network()), 1)
                        torrent.Add(fmt.Sprintf("accepted on %s listener", l.Addr().Network()), 1)
                }()
@@ -454,7 +468,7 @@ func (cl *Client) incomingConnection(nc net.Conn) {
        if tc, ok := nc.(*net.TCPConn); ok {
                tc.SetLinger(0)
        }
-       c := cl.newConnection(nc, false, missinggo.IpPortFromNetAddr(nc.RemoteAddr()), nc.RemoteAddr().Network())
+       c := cl.newConnection(nc, false, nc.RemoteAddr(), nc.RemoteAddr().Network())
        c.Discovery = peerSourceIncoming
        cl.runReceivedConn(c)
 }
@@ -518,7 +532,7 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) (res dialResult) {
        func() {
                cl.lock()
                defer cl.unlock()
-               cl.eachDialer(func(s dialer) bool {
+               cl.eachDialer(func(s Dialer) bool {
                        func() {
                                left++
                                //cl.logger.Printf("dialing %s on %s/%s", addr, s.Addr().Network(), s.Addr())
@@ -559,7 +573,7 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) (res dialResult) {
        return res
 }
 
-func (cl *Client) dialFromSocket(ctx context.Context, s dialer, addr string) net.Conn {
+func (cl *Client) dialFromSocket(ctx context.Context, s Dialer, addr string) net.Conn {
        network := s.LocalAddr().Network()
        cte := cl.config.ConnTracker.Wait(
                ctx,
@@ -575,7 +589,7 @@ func (cl *Client) dialFromSocket(ctx context.Context, s dialer, addr string) net
                }
                return nil
        }
-       c, err := s.dial(ctx, addr)
+       c, err := s.Dial(ctx, addr)
        // This is a bit optimistic, but it looks non-trivial to thread this through the proxy code. Set
        // it now in case we close the connection forthwith.
        if tc, ok := c.(*net.TCPConn); ok {
@@ -611,7 +625,7 @@ func (cl *Client) noLongerHalfOpen(t *Torrent, addr string) {
 
 // Performs initiator handshakes and returns a connection. Returns nil
 // *connection if no connection for valid reasons.
-func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr IpPort, network string) (c *connection, err error) {
+func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool, remoteAddr net.Addr, network string) (c *connection, err error) {
        c = cl.newConnection(nc, true, remoteAddr, network)
        c.headerEncrypted = encryptHeader
        ctx, cancel := context.WithTimeout(ctx, cl.config.HandshakesTimeout)
@@ -630,7 +644,7 @@ func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torr
 
 // Returns nil connection and nil error if no connection could be established
 // for valid reasons.
-func (cl *Client) establishOutgoingConnEx(t *Torrent, addr IpPort, obfuscatedHeader bool) (*connection, error) {
+func (cl *Client) establishOutgoingConnEx(t *Torrent, addr net.Addr, obfuscatedHeader bool) (*connection, error) {
        dialCtx, cancel := context.WithTimeout(context.Background(), func() time.Duration {
                cl.rLock()
                defer cl.rUnlock()
@@ -654,7 +668,7 @@ func (cl *Client) establishOutgoingConnEx(t *Torrent, addr IpPort, obfuscatedHea
 
 // Returns nil connection and nil error if no connection could be established
 // for valid reasons.
-func (cl *Client) establishOutgoingConn(t *Torrent, addr IpPort) (c *connection, err error) {
+func (cl *Client) establishOutgoingConn(t *Torrent, addr net.Addr) (c *connection, err error) {
        torrent.Add("establish outgoing connection", 1)
        obfuscatedHeaderFirst := cl.config.HeaderObfuscationPolicy.Preferred
        c, err = cl.establishOutgoingConnEx(t, addr, obfuscatedHeaderFirst)
@@ -679,7 +693,7 @@ func (cl *Client) establishOutgoingConn(t *Torrent, addr IpPort) (c *connection,
 
 // Called to dial out and run a connection. The addr we're given is already
 // considered half-open.
-func (cl *Client) outgoingConnection(t *Torrent, addr IpPort, ps peerSource, trusted bool) {
+func (cl *Client) outgoingConnection(t *Torrent, addr net.Addr, ps peerSource, trusted bool) {
        cl.dialRateLimiter.Wait(context.Background())
        c, err := cl.establishOutgoingConn(t, addr)
        cl.lock()
@@ -699,8 +713,7 @@ func (cl *Client) outgoingConnection(t *Torrent, addr IpPort, ps peerSource, tru
        cl.runHandshookConn(c, t)
 }
 
-// The port number for incoming peer connections. 0 if the client isn't
-// listening.
+// The port number for incoming peer connections. 0 if the client isn't listening.
 func (cl *Client) incomingPeerPort() int {
        return cl.LocalPort()
 }
@@ -885,7 +898,7 @@ func (cl *Client) sendInitialMessages(conn *connection, torrent *Torrent) {
                                        },
                                        V:            cl.config.ExtendedHandshakeClientVersion,
                                        Reqq:         64, // TODO: Really?
-                                       YourIp:       pp.CompactIp(conn.remoteAddr.IP),
+                                       YourIp:       pp.CompactIp(addrIpOrNil(conn.remoteAddr)),
                                        Encryption:   cl.config.HeaderObfuscationPolicy.Preferred || !cl.config.HeaderObfuscationPolicy.RequirePreferred,
                                        Port:         cl.incomingPeerPort(),
                                        MetadataSize: torrent.metadataSize(),
@@ -980,6 +993,13 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *Torrent, c *connect
        }
 }
 
+func (cl *Client) badPeerAddr(addr net.Addr) bool {
+       if ipa, ok := tryIpPortFromNetAddr(addr); ok {
+               return cl.badPeerIPPort(ipa.IP, ipa.Port)
+       }
+       return false
+}
+
 func (cl *Client) badPeerIPPort(ip net.IP, port int) bool {
        if port == 0 {
                return true
@@ -1010,7 +1030,7 @@ func (cl *Client) newTorrent(ih metainfo.Hash, specStorage storage.ClientImpl) (
                peers: prioritizedPeers{
                        om: btree.New(32),
                        getPrio: func(p Peer) peerPriority {
-                               return bep40PriorityIgnoreError(cl.publicAddr(p.IP), p.addr())
+                               return bep40PriorityIgnoreError(cl.publicAddr(addrIpOrNil(p.Addr)), p.addr())
                        },
                },
                conns: make(map[*connection]struct{}, 2*cl.config.EstablishedConnsPerTorrent),
@@ -1208,7 +1228,7 @@ func (cl *Client) banPeerIP(ip net.IP) {
        cl.badPeerIPs[ip.String()] = struct{}{}
 }
 
-func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr IpPort, network string) (c *connection) {
+func (cl *Client) newConnection(nc net.Conn, outgoing bool, remoteAddr net.Addr, network string) (c *connection) {
        c = &connection{
                conn:            nc,
                outgoing:        outgoing,
@@ -1242,8 +1262,7 @@ func (cl *Client) onDHTAnnouncePeer(ih metainfo.Hash, ip net.IP, port int, portO
                return
        }
        t.addPeers([]Peer{{
-               IP:     ip,
-               Port:   port,
+               Addr:   ipPortAddr{ip, port},
                Source: peerSourceDhtAnnouncePeer,
        }})
 }
@@ -1257,7 +1276,7 @@ func firstNotNil(ips ...net.IP) net.IP {
        return nil
 }
 
-func (cl *Client) eachDialer(f func(dialer) bool) {
+func (cl *Client) eachDialer(f func(Dialer) bool) {
        for _, s := range cl.dialers {
                if !f(s) {
                        break
@@ -1265,7 +1284,7 @@ func (cl *Client) eachDialer(f func(dialer) bool) {
        }
 }
 
-func (cl *Client) eachListener(f func(listener) bool) {
+func (cl *Client) eachListener(f func(Listener) bool) {
        for _, s := range cl.listeners {
                if !f(s) {
                        break
@@ -1274,7 +1293,7 @@ func (cl *Client) eachListener(f func(listener) bool) {
 }
 
 func (cl *Client) findListener(f func(net.Listener) bool) (ret net.Listener) {
-       cl.eachListener(func(l listener) bool {
+       cl.eachListener(func(l Listener) bool {
                ret = l
                return !f(l)
        })
@@ -1297,9 +1316,13 @@ func (cl *Client) publicIp(peer net.IP) net.IP {
 }
 
 func (cl *Client) findListenerIp(f func(net.IP) bool) net.IP {
-       return missinggo.AddrIP(cl.findListener(func(l net.Listener) bool {
-               return f(missinggo.AddrIP(l.Addr()))
-       }).Addr())
+       return addrIpOrNil(
+               cl.findListener(
+                       func(l net.Listener) bool {
+                               return f(addrIpOrNil(l.Addr()))
+                       },
+               ).Addr(),
+       )
 }
 
 // Our IP as a peer should see it.
@@ -1311,15 +1334,19 @@ func (cl *Client) publicAddr(peer net.IP) IpPort {
 func (cl *Client) ListenAddrs() (ret []net.Addr) {
        cl.lock()
        defer cl.unlock()
-       cl.eachListener(func(l listener) bool {
+       cl.eachListener(func(l Listener) bool {
                ret = append(ret, l.Addr())
                return true
        })
        return
 }
 
-func (cl *Client) onBadAccept(addr IpPort) {
-       ip := maskIpForAcceptLimiting(addr.IP)
+func (cl *Client) onBadAccept(addr net.Addr) {
+       ipa, ok := tryIpPortFromNetAddr(addr)
+       if !ok {
+               return
+       }
+       ip := maskIpForAcceptLimiting(ipa.IP)
        if cl.acceptLimiter == nil {
                cl.acceptLimiter = make(map[ipStr]int)
        }
index 42c10b572d3affaf32fc538718a2512d360b051f..09d9637abedbece0853cc20e63a2a56e9a03803f 100644 (file)
@@ -574,7 +574,7 @@ func TestClientDynamicListenPortAllProtocols(t *testing.T) {
        defer cl.Close()
        port := cl.LocalPort()
        assert.NotEqual(t, 0, port)
-       cl.eachListener(func(s listener) bool {
+       cl.eachListener(func(s Listener) bool {
                assert.Equal(t, port, missinggo.AddrPort(s.Addr()))
                return true
        })
index eada56876454a24bbe8db65da74623037eaa933c..4c01baee457f16499a88b8f95bd23292e68fb9fe 100644 (file)
@@ -32,8 +32,7 @@ func resolvedPeerAddrs(ss []string) (ret []torrent.Peer, err error) {
                        return
                }
                ret = append(ret, torrent.Peer{
-                       IP:   addr.IP,
-                       Port: addr.Port,
+                       Addr: addr,
                })
        }
        return
index e21c3d2bff1b4b159914dc2b73e08f9551baba47..1e7feba0eec6c325e0ef93f75498effe4740c2c9 100644 (file)
@@ -114,8 +114,7 @@ func addTorrents(client *torrent.Client) error {
                t.AddPeers(func() (ret []torrent.Peer) {
                        for _, ta := range flags.TestPeer {
                                ret = append(ret, torrent.Peer{
-                                       IP:   ta.IP,
-                                       Port: ta.Port,
+                                       Addr: ta,
                                })
                        }
                        return
index b2d34efe53b3ed45a365655d8a02744fb5fe1b29..172f61029259e78e977d5893381f31d2639f4a5f 100644 (file)
@@ -62,8 +62,7 @@ func exitSignalHandlers(fs *torrentfs.TorrentFS) {
 func addTestPeer(client *torrent.Client) {
        for _, t := range client.Torrents() {
                t.AddPeers([]torrent.Peer{{
-                       IP:   args.TestPeer.IP,
-                       Port: args.TestPeer.Port,
+                       Addr: args.TestPeer,
                }})
        }
 }
index 569b4b2ae778b19ea16663f021828f0154246a72..692b0751fc6b75804e5c7661c33305f8f578a4a0 100644 (file)
--- a/config.go
+++ b/config.go
@@ -26,9 +26,8 @@ type ClientConfig struct {
        // Store torrent file data in this directory unless .DefaultStorage is
        // specified.
        DataDir string `long:"data-dir" description:"directory to store downloaded torrent data"`
-       // The address to listen for new uTP and TCP bittorrent protocol
-       // connections. DHT shares a UDP socket with uTP unless configured
-       // otherwise.
+       // The address to listen for new uTP and TCP BitTorrent protocol connections. DHT shares a UDP
+       // socket with uTP unless configured otherwise.
        ListenHost              func(network string) string
        ListenPort              int
        NoDefaultPortForwarding bool
index f5a586275ff158edb6eb78a07c8e96ad99f192c7..a7f8edf2141824a92e1176500f46323cef4a8660 100644 (file)
@@ -46,7 +46,7 @@ type connection struct {
        conn       net.Conn
        outgoing   bool
        network    string
-       remoteAddr IpPort
+       remoteAddr net.Addr
        // The Reader and Writer for this Conn, with hooks installed for stats,
        // limiting, deadlines etc.
        w io.Writer
@@ -139,14 +139,14 @@ func (cn *connection) expectingChunks() bool {
 
 // Returns true if the connection is over IPv6.
 func (cn *connection) ipv6() bool {
-       ip := cn.remoteAddr.IP
+       ip := addrIpOrNil(cn.remoteAddr)
        if ip.To4() != nil {
                return false
        }
        return len(ip) == net.IPv6len
 }
 
-// Returns true the dialer has the lower client peer ID. TODO: Find the
+// Returns true the if the dialer/initiator has the lower client peer ID. TODO: Find the
 // specification for this.
 func (cn *connection) isPreferredDirection() bool {
        return bytes.Compare(cn.t.cl.peerID[:], cn.PeerID[:]) < 0 == cn.outgoing
@@ -1049,9 +1049,13 @@ func (c *connection) mainReadLoop() (err error) {
                        req := newRequestFromMessage(&msg)
                        c.onPeerSentCancel(req)
                case pp.Port:
+                       ipa, ok := tryIpPortFromNetAddr(c.remoteAddr)
+                       if !ok {
+                               break
+                       }
                        pingAddr := net.UDPAddr{
-                               IP:   c.remoteAddr.IP,
-                               Port: int(c.remoteAddr.Port),
+                               IP:   ipa.IP,
+                               Port: ipa.Port,
                        }
                        if msg.Port != 0 {
                                pingAddr.Port = int(msg.Port)
@@ -1458,11 +1462,12 @@ func (c *connection) peerPriority() peerPriority {
 }
 
 func (c *connection) remoteIp() net.IP {
-       return c.remoteAddr.IP
+       return addrIpOrNil(c.remoteAddr)
 }
 
 func (c *connection) remoteIpPort() IpPort {
-       return c.remoteAddr
+       ipa, _ := tryIpPortFromNetAddr(c.remoteAddr)
+       return IpPort{ipa.IP, uint16(ipa.Port)}
 }
 
 func (c *connection) String() string {
index 58153fa5481c0330663da0110f810c6a486f3438..2f11fc2b7e5745367e452a1abf8e5601d0a2122c 100644 (file)
@@ -23,7 +23,7 @@ func TestSendBitfieldThenHave(t *testing.T) {
                config: TestingConfig(),
        }
        cl.initLogger()
-       c := cl.newConnection(nil, false, IpPort{}, "")
+       c := cl.newConnection(nil, false, nil, "")
        c.setTorrent(cl.newTorrent(metainfo.Hash{}, nil))
        c.t.setInfo(&metainfo.Info{
                Pieces: make([]byte, metainfo.HashSize*3),
@@ -107,7 +107,7 @@ func BenchmarkConnectionMainReadLoop(b *testing.B) {
        t.setChunkSize(defaultChunkSize)
        t._pendingPieces.Set(0, PiecePriorityNormal.BitmapPriority())
        r, w := net.Pipe()
-       cn := cl.newConnection(r, true, IpPort{}, "")
+       cn := cl.newConnection(r, true, nil, "")
        cn.setTorrent(t)
        mrlErr := make(chan error)
        msg := pp.Message{
diff --git a/dialer.go b/dialer.go
new file mode 100644 (file)
index 0000000..32ab91f
--- /dev/null
+++ b/dialer.go
@@ -0,0 +1,45 @@
+package torrent
+
+import (
+       "context"
+       "net"
+
+       "github.com/anacrolix/missinggo/perf"
+)
+
+type Dialer interface {
+       // The network is implied by the instance.
+       Dial(_ context.Context, addr string) (net.Conn, error)
+       // This is required for registering with the connection tracker (router connection table
+       // emulating rate-limiter) before dialing. TODO: What about connections that wouldn't infringe
+       // on routers, like localhost or unix sockets.
+       LocalAddr() net.Addr
+}
+
+type NetDialer struct {
+       Network string
+       Dialer  net.Dialer
+}
+
+func (me NetDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) {
+       defer perf.ScopeTimerErr(&err)()
+       return me.Dialer.DialContext(ctx, me.Network, addr)
+}
+
+func (me NetDialer) LocalAddr() net.Addr {
+       return netDialerLocalAddr{me.Network, me.Dialer.LocalAddr}
+}
+
+type netDialerLocalAddr struct {
+       network string
+       addr    net.Addr
+}
+
+func (me netDialerLocalAddr) Network() string { return me.network }
+
+func (me netDialerLocalAddr) String() string {
+       if me.addr == nil {
+               return ""
+       }
+       return me.addr.String()
+}
diff --git a/go.mod b/go.mod
index 42a7164326f57a2fd94ed1bbbf2773e5942db083..612cff5f76d8ec78c3b25815e42a623c91daadf9 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -11,7 +11,6 @@ require (
        github.com/anacrolix/missinggo/perf v1.0.0
        github.com/anacrolix/missinggo/v2 v2.3.2-0.20200110051601-fc3212fb3984
        github.com/anacrolix/multiless v0.0.0-20191223025854-070b7994e841
-       github.com/anacrolix/stm v0.2.0
        github.com/anacrolix/sync v0.2.0
        github.com/anacrolix/tagflag v1.0.1
        github.com/anacrolix/upnp v0.1.1
@@ -29,7 +28,7 @@ require (
        github.com/pkg/errors v0.9.1
        github.com/stretchr/testify v1.4.0
        github.com/tinylib/msgp v1.1.1 // indirect
-       golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa
+       golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa // indirect
        golang.org/x/time v0.0.0-20191024005414-555d28b269f0
        golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
 )
diff --git a/ipport.go b/ipport.go
new file mode 100644 (file)
index 0000000..d65624c
--- /dev/null
+++ b/ipport.go
@@ -0,0 +1,67 @@
+package torrent
+
+import (
+       "net"
+       "strconv"
+)
+
+// Extracts the port as an integer from an address string.
+func addrPortOrZero(addr net.Addr) int {
+       switch raw := addr.(type) {
+       case *net.UDPAddr:
+               return raw.Port
+       case *net.TCPAddr:
+               return raw.Port
+       default:
+               _, port, err := net.SplitHostPort(addr.String())
+               if err != nil {
+                       return 0
+               }
+               i64, err := strconv.ParseInt(port, 0, 0)
+               if err != nil {
+                       panic(err)
+               }
+               return int(i64)
+       }
+}
+
+func addrIpOrNil(addr net.Addr) net.IP {
+       if addr == nil {
+               return nil
+       }
+       switch raw := addr.(type) {
+       case *net.UDPAddr:
+               return raw.IP
+       case *net.TCPAddr:
+               return raw.IP
+       default:
+               host, _, err := net.SplitHostPort(addr.String())
+               if err != nil {
+                       return nil
+               }
+               return net.ParseIP(host)
+       }
+}
+
+type ipPortAddr struct {
+       IP   net.IP
+       Port int
+}
+
+func (ipPortAddr) Network() string {
+       return ""
+}
+
+func (me ipPortAddr) String() string {
+       return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10))
+}
+
+func tryIpPortFromNetAddr(na net.Addr) (ret ipPortAddr, ok bool) {
+       ret.IP = addrIpOrNil(na)
+       if ret.IP == nil {
+               return
+       }
+       ret.Port = addrPortOrZero(na)
+       ok = true
+       return
+}
diff --git a/misc.go b/misc.go
index 9e1628a9c99babed8c29065e5422eb1e1c3a152a..d1f179a3fe0a158b308620d93da72aeea1c24215 100644 (file)
--- a/misc.go
+++ b/misc.go
@@ -114,7 +114,7 @@ func connIsIpv6(nc interface {
        LocalAddr() net.Addr
 }) bool {
        ra := nc.LocalAddr()
-       rip := missinggo.AddrIP(ra)
+       rip := addrIpOrNil(ra)
        return rip.To4() == nil && rip.To16() != nil
 }
 
index 364178c177e64cee735322b88791f6e2f23bfa9c..c08c5a33e83d8cebe7a367dcc3cb26b2703398d4 100644 (file)
@@ -19,10 +19,10 @@ func TestPrioritizedPeers(t *testing.T) {
        assert.Panics(t, func() { pp.PopMax() })
        assert.False(t, ok)
        ps := []Peer{
-               {IP: net.ParseIP("1.2.3.4")},
-               {IP: net.ParseIP("1::2")},
-               {IP: net.ParseIP("")},
-               {IP: net.ParseIP(""), Trusted: true},
+               {Addr: ipPortAddr{IP: net.ParseIP("1.2.3.4")}},
+               {Addr: ipPortAddr{IP: net.ParseIP("1::2")}},
+               {Addr: ipPortAddr{IP: net.ParseIP("")}},
+               {Addr: ipPortAddr{IP: net.ParseIP("")}, Trusted: true},
        }
        for i, p := range ps {
                t.Logf("peer %d priority: %08x trusted: %t\n", i, pp.getPrio(p), p.Trusted)
index c5f7dbcd96209a124f38a473a0c3e7f5168b38a8..1ae4b42f3af72780ad6769f709602644193a6629 100644 (file)
--- a/socket.go
+++ b/socket.go
@@ -10,18 +10,13 @@ import (
        "github.com/pkg/errors"
 )
 
-type dialer interface {
-       dial(_ context.Context, addr string) (net.Conn, error)
-       LocalAddr() net.Addr
-}
-
-type listener interface {
+type Listener interface {
        net.Listener
 }
 
 type socket interface {
-       listener
-       dialer
+       Listener
+       Dialer
 }
 
 func listen(n network, addr string, f firewallCallback) (socket, error) {
@@ -39,34 +34,17 @@ func listenTcp(network, address string) (s socket, err error) {
        l, err := net.Listen(network, address)
        return tcpSocket{
                Listener: l,
-               network:  network,
+               NetDialer: NetDialer{
+                       Network: network,
+               },
        }, err
 }
 
 type tcpSocket struct {
        net.Listener
-       network string
-       dialer  net.Dialer
+       NetDialer
 }
 
-func (me tcpSocket) dial(ctx context.Context, addr string) (_ net.Conn, err error) {
-       defer perf.ScopeTimerErr(&err)()
-       return me.dialer.DialContext(ctx, me.network, addr)
-}
-
-func (me tcpSocket) LocalAddr() net.Addr {
-       return tcpSocketLocalAddr{me.network, me.Listener.Addr().String()}
-}
-
-type tcpSocketLocalAddr struct {
-       network string
-       s       string
-}
-
-func (me tcpSocketLocalAddr) Network() string { return me.network }
-
-func (me tcpSocketLocalAddr) String() string { return "" }
-
 func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) {
        if len(networks) == 0 {
                return nil, nil
@@ -128,7 +106,7 @@ type utpSocketSocket struct {
        network string
 }
 
-func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
+func (me utpSocketSocket) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
        defer perf.ScopeTimerErr(&err)()
        return me.utpSocket.DialContext(ctx, me.network, addr)
 }
diff --git a/test/init_test.go b/test/init_test.go
new file mode 100644 (file)
index 0000000..3aa4069
--- /dev/null
@@ -0,0 +1,9 @@
+package test
+
+import (
+       "github.com/anacrolix/torrent"
+)
+
+func init() {
+       torrent.TestingTempDir.Init("torrent-test.test")
+}
index c36ccc98ce695b0dd893a23626cdfef630874452..8bd604c3b4a4c4f77e40fad3e7d72c390d1c0257 100644 (file)
@@ -28,6 +28,8 @@ type testClientTransferParams struct {
        SeederStorage              func(string) storage.ClientImpl
        SeederUploadRateLimiter    *rate.Limiter
        LeecherDownloadRateLimiter *rate.Limiter
+       ConfigureSeeder            ConfigureClient
+       ConfigureLeecher           ConfigureClient
 }
 
 func assertReadAllGreeting(t *testing.T, r io.ReadSeeker) {
@@ -57,8 +59,14 @@ func testClientTransfer(t *testing.T, ps testClientTransferParams) {
        } else {
                cfg.DataDir = greetingTempDir
        }
+       if ps.ConfigureSeeder.Config != nil {
+               ps.ConfigureSeeder.Config(cfg)
+       }
        seeder, err := torrent.NewClient(cfg)
        require.NoError(t, err)
+       if ps.ConfigureSeeder.Client != nil {
+               ps.ConfigureSeeder.Client(seeder)
+       }
        if ps.ExportClientStatus {
                defer testutil.ExportStatusWriter(seeder, "s")()
        }
@@ -83,9 +91,15 @@ func testClientTransfer(t *testing.T, ps testClientTransferParams) {
        }
        cfg.Seed = false
        //cfg.Debug = true
+       if ps.ConfigureLeecher.Config != nil {
+               ps.ConfigureLeecher.Config(cfg)
+       }
        leecher, err := torrent.NewClient(cfg)
        require.NoError(t, err)
        defer leecher.Close()
+       if ps.ConfigureLeecher.Client != nil {
+               ps.ConfigureLeecher.Client(leecher)
+       }
        if ps.ExportClientStatus {
                defer testutil.ExportStatusWriter(leecher, "l")()
        }
@@ -335,3 +349,8 @@ func TestSeedAfterDownloading(t *testing.T) {
        }()
        wg.Wait()
 }
+
+type ConfigureClient struct {
+       Config func(*torrent.ClientConfig)
+       Client func(*torrent.Client)
+}
diff --git a/test/unix_test.go b/test/unix_test.go
new file mode 100644 (file)
index 0000000..d021b98
--- /dev/null
@@ -0,0 +1,41 @@
+package test
+
+import (
+       "io"
+       "log"
+       "net"
+       "path/filepath"
+       "testing"
+
+       "github.com/anacrolix/torrent"
+)
+
+func TestUnixConns(t *testing.T) {
+       var closers []io.Closer
+       defer func() {
+               for _, c := range closers {
+                       c.Close()
+               }
+       }()
+       configure := ConfigureClient{
+               Config: func(cfg *torrent.ClientConfig) {
+                       cfg.DisableUTP = true
+                       cfg.DisableTCP = true
+                       cfg.Debug = true
+               },
+               Client: func(cl *torrent.Client) {
+                       cl.AddDialer(torrent.NetDialer{Network: "unix"})
+                       l, err := net.Listen("unix", filepath.Join(torrent.TestingTempDir.NewSub(), "socket"))
+                       if err != nil {
+                               panic(err)
+                       }
+                       log.Printf("created listener %q", l)
+                       closers = append(closers, l)
+                       cl.AddListener(l)
+               },
+       }
+       testClientTransfer(t, testClientTransferParams{
+               ConfigureSeeder:  configure,
+               ConfigureLeecher: configure,
+       })
+}
index bb10264f78ecb0968c672c0e02b5231b774db323..474acfd1d1d334ace0d89287fccaa8a4d3bf22fc 100644 (file)
@@ -189,8 +189,7 @@ func (t *Torrent) KnownSwarm() (ks []Peer) {
 
                ks = append(ks, Peer{
                        Id:     conn.PeerID,
-                       IP:     conn.remoteAddr.IP,
-                       Port:   int(conn.remoteAddr.Port),
+                       Addr:   conn.remoteAddr,
                        Source: conn.Discovery,
                        // > If the connection is encrypted, that's certainly enough to set SupportsEncryption.
                        // > But if we're not connected to them with an encrypted connection, I couldn't say
@@ -253,9 +252,11 @@ func (t *Torrent) addPeer(p Peer) {
        if t.closed.IsSet() {
                return
        }
-       if cl.badPeerIPPort(p.IP, p.Port) {
-               torrent.Add("peers not added because of bad addr", 1)
-               return
+       if ipAddr, ok := tryIpPortFromNetAddr(p.Addr); ok {
+               if cl.badPeerIPPort(ipAddr.IP, ipAddr.Port) {
+                       torrent.Add("peers not added because of bad addr", 1)
+                       return
+               }
        }
        if t.peers.Add(p) {
                torrent.Add("peers replaced", 1)
@@ -1350,8 +1351,7 @@ func (t *Torrent) consumeDhtAnnouncePeers(pvs <-chan dht.PeersValues) {
                                continue
                        }
                        t.addPeer(Peer{
-                               IP:     cp.IP[:],
-                               Port:   cp.Port,
+                               Addr:   ipPortAddr{cp.IP, cp.Port},
                                Source: peerSourceDhtGetPeers,
                        })
                }
@@ -1433,7 +1433,7 @@ func (t *Torrent) numTotalPeers() int {
                peers[addr] = struct{}{}
        }
        t.peers.Each(func(peer Peer) {
-               peers[fmt.Sprintf("%s:%d", peer.IP, peer.Port)] = struct{}{}
+               peers[peer.Addr.String()] = struct{}{}
        })
        return len(peers)
 }
@@ -1592,7 +1592,7 @@ func (t *Torrent) pieceHashed(piece pieceIndex, passed bool, hashIoErr error) {
 
                        if len(bannableTouchers) >= 1 {
                                c := bannableTouchers[0]
-                               t.cl.banPeerIP(c.remoteAddr.IP)
+                               t.cl.banPeerIP(c.remoteIp())
                                c.Drop()
                        }
                }
@@ -1738,10 +1738,11 @@ func (t *Torrent) initiateConn(peer Peer) {
        if peer.Id == t.cl.peerID {
                return
        }
-       if t.cl.badPeerIPPort(peer.IP, peer.Port) && !peer.Trusted {
+
+       if t.cl.badPeerAddr(peer.Addr) && !peer.Trusted {
                return
        }
-       addr := IpPort{peer.IP, uint16(peer.Port)}
+       addr := peer.Addr
        if t.addrActive(addr.String()) {
                return
        }
@@ -1754,8 +1755,7 @@ func (t *Torrent) AddClientPeer(cl *Client) {
        t.AddPeers(func() (ps []Peer) {
                for _, la := range cl.ListenAddrs() {
                        ps = append(ps, Peer{
-                               IP:      missinggo.AddrIP(la),
-                               Port:    missinggo.AddrPort(la),
+                               Addr:    la,
                                Trusted: true,
                        })
                }