]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: During get_peers, put discovered nodes in a backlog for later querying
authorMatt Joiner <anacrolix@gmail.com>
Tue, 9 Dec 2014 02:00:42 +0000 (20:00 -0600)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 9 Dec 2014 02:00:42 +0000 (20:00 -0600)
dht/dht.go
dht/getpeers.go

index 89b9ee24ccd37b3bb4277c0412f54d3e3dc01d21..836f387fb988725ac584375fb1b8ab5ec2c90fd6 100644 (file)
@@ -232,7 +232,7 @@ type transaction struct {
        remoteAddr  dHTAddr
        t           string
        Response    chan Msg
-       onResponse  func(Msg)
+       onResponse  func(Msg) // Called with the server locked.
        done        chan struct{}
        queryPacket []byte
        timer       *time.Timer
@@ -326,7 +326,9 @@ func (t *transaction) handleResponse(m Msg) {
        close(t.done)
        t.mu.Unlock()
        if t.onResponse != nil {
+               t.s.mu.Lock()
                t.onResponse(m)
+               t.s.mu.Unlock()
        }
        t.queryPacket = nil
        select {
index 382952104b3361cb30075f07478744ff1a5ac7e6..6654636b18662065e47aa420a7fd6bbb8fc39460 100644 (file)
@@ -4,24 +4,23 @@ import (
        "log"
        "net"
        "sync"
-       "time"
 
        "bitbucket.org/anacrolix/go.torrent/util"
 )
 
 type peerDiscovery struct {
        *peerStream
-       triedAddrs        map[string]struct{}
-       contactAddrs      chan net.Addr
-       pending           int
-       transactionClosed chan struct{}
-       server            *Server
-       infoHash          string
+       triedAddrs map[string]struct{}
+       backlog    map[string]net.Addr
+       pending    int
+       server     *Server
+       infoHash   string
 }
 
+const parallelQueries = 100
+
 func (me *peerDiscovery) Close() {
        me.peerStream.Close()
-       close(me.contactAddrs)
 }
 
 func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
@@ -45,65 +44,69 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
                        Values: make(chan peerStreamValue),
                        stop:   make(chan struct{}),
                },
-               triedAddrs:        make(map[string]struct{}, 500),
-               contactAddrs:      make(chan net.Addr),
-               transactionClosed: make(chan struct{}),
-               server:            s,
-               infoHash:          infoHash,
+               triedAddrs: make(map[string]struct{}, 500),
+               backlog:    make(map[string]net.Addr, parallelQueries),
+               server:     s,
+               infoHash:   infoHash,
        }
-       go disc.loop()
+       disc.mu.Lock()
        for _, addr := range startAddrs {
                disc.contact(addr)
        }
+       disc.mu.Unlock()
        return disc.peerStream, nil
 }
 
-func (me *peerDiscovery) contact(addr net.Addr) {
-       select {
-       case me.contactAddrs <- addr:
-       case <-me.closingCh():
+func (me *peerDiscovery) gotNodeAddr(addr net.Addr) {
+       if util.AddrPort(addr) == 0 {
+               // Not a contactable address.
+               return
+       }
+       if me.server.ipBlocked(util.AddrIP(addr)) {
+               return
+       }
+       if _, ok := me.triedAddrs[addr.String()]; ok {
+               return
+       }
+       if _, ok := me.backlog[addr.String()]; ok {
+               return
+       }
+       if me.pending >= parallelQueries {
+               me.backlog[addr.String()] = addr
+       } else {
+               me.contact(addr)
        }
 }
 
-func (me *peerDiscovery) responseNode(node NodeInfo) {
-       if util.AddrPort(node.Addr) == 0 {
-               // Not a contactable address.
+func (me *peerDiscovery) contact(addr net.Addr) {
+       me.triedAddrs[addr.String()] = struct{}{}
+       if err := me.getPeers(addr); err != nil {
+               log.Printf("error sending get_peers request to %s: %s", addr, err)
                return
        }
-       me.contact(node.Addr)
+       me.pending++
 }
 
-func (me *peerDiscovery) loop() {
-       for {
-               select {
-               case addr := <-me.contactAddrs:
-                       if me.pending >= 1000 {
-                               break
-                       }
-                       if _, ok := me.triedAddrs[addr.String()]; ok {
-                               break
-                       }
-                       me.triedAddrs[addr.String()] = struct{}{}
-                       if me.server.ipBlocked(util.AddrIP(addr)) {
-                               break
-                       }
-                       if err := me.getPeers(addr); err != nil {
-                               log.Printf("error sending get_peers request to %s: %s", addr, err)
-                               break
-                       }
-                       // log.Printf("contacting %s", addr)
-                       me.pending++
-               case <-me.transactionClosed:
-                       me.pending--
-                       // log.Printf("pending: %d", me.pending)
-                       if me.pending == 0 {
-                               me.Close()
-                               return
-                       }
+func (me *peerDiscovery) transactionClosed() {
+       me.pending--
+       // log.Printf("pending: %d", me.pending)
+       for key, addr := range me.backlog {
+               if me.pending >= parallelQueries {
+                       break
                }
+               delete(me.backlog, key)
+               me.contact(addr)
+       }
+       if me.pending == 0 {
+               me.Close()
+               return
        }
 }
 
+func (me *peerDiscovery) responseNode(node NodeInfo) {
+       me.gotNodeAddr(node.Addr)
+}
+
 func (me *peerDiscovery) closingCh() chan struct{} {
        return me.peerStream.stop
 }
@@ -118,11 +121,13 @@ func (me *peerDiscovery) getPeers(addr net.Addr) error {
        go func() {
                select {
                case m := <-t.Response:
+                       me.mu.Lock()
                        if nodes := m.Nodes(); len(nodes) != 0 {
                                for _, n := range nodes {
                                        me.responseNode(n)
                                }
                        }
+                       me.mu.Unlock()
                        if vs := extractValues(m); vs != nil {
                                nodeInfo := NodeInfo{
                                        Addr: t.remoteAddr,
@@ -145,7 +150,9 @@ func (me *peerDiscovery) getPeers(addr net.Addr) error {
                case <-me.closingCh():
                }
                t.Close()
-               me.transactionClosed <- struct{}{}
+               me.mu.Lock()
+               me.transactionClosed()
+               me.mu.Unlock()
        }()
        return nil
 }