]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Improve Server.GetPeers so new nodes are fed directly back into the current...
authorMatt Joiner <anacrolix@gmail.com>
Tue, 18 Nov 2014 18:38:13 +0000 (12:38 -0600)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 18 Nov 2014 18:38:13 +0000 (12:38 -0600)
dht/dht.go
dht/getpeers.go [new file with mode: 0644]

index 14a0af0ffa4e89552b4ed390acef9d816bc23bfe..fc703085f202839516f08257c308095111bbbdf7 100644 (file)
@@ -148,6 +148,14 @@ func (m Msg) T() (t string) {
        return
 }
 
+func (m Msg) Nodes() []NodeInfo {
+       var r findNodeResponse
+       if err := r.UnmarshalKRPCMsg(m); err != nil {
+               return nil
+       }
+       return r.Nodes
+}
+
 type KRPCError struct {
        Code int
        Msg  string
@@ -159,12 +167,21 @@ func (me KRPCError) Error() string {
 
 var _ error = KRPCError{}
 
-func (m Msg) Error() *KRPCError {
+func (m Msg) Error() (ret *KRPCError) {
        if m["y"] != "e" {
-               return nil
+               return
        }
-       l := m["e"].([]interface{})
-       return &KRPCError{int(l[0].(int64)), l[1].(string)}
+       ret = &KRPCError{}
+       switch e := m["e"].(type) {
+       case []interface{}:
+               ret.Code = int(e[0].(int64))
+               ret.Msg = e[1].(string)
+       case string:
+               ret.Msg = e
+       default:
+               logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
+       }
+       return
 }
 
 // Returns the token given in response to a get_peers request for future
@@ -175,6 +192,7 @@ func (m Msg) AnnounceToken() string {
 }
 
 type transaction struct {
+       mu         sync.Mutex
        remoteAddr dHTAddr
        t          string
        Response   chan Msg
@@ -183,12 +201,36 @@ type transaction struct {
 }
 
 func (t *transaction) timeout() {
+       t.Close()
+}
+
+func (t *transaction) closing() bool {
+       select {
+       case <-t.done:
+               return true
+       default:
+               return false
+       }
+}
+
+func (t *transaction) Close() {
+       t.mu.Lock()
+       defer t.mu.Unlock()
+       if t.closing() {
+               return
+       }
        close(t.Response)
        close(t.done)
 }
 
 func (t *transaction) handleResponse(m Msg) {
+       t.mu.Lock()
+       if t.closing() {
+               t.mu.Unlock()
+               return
+       }
        close(t.done)
+       t.mu.Unlock()
        if t.onResponse != nil {
                t.onResponse(m)
        }
@@ -272,6 +314,8 @@ func (s *Server) serve() error {
 }
 
 func (s *Server) AddNode(ni NodeInfo) {
+       s.mu.Lock()
+       defer s.mu.Unlock()
        if s.nodes == nil {
                s.nodes = make(map[string]*Node)
        }
@@ -697,27 +741,6 @@ func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err er
        return
 }
 
-type peerStreamValue struct {
-       Peers    []util.CompactPeer // Peers given in get_peers response.
-       NodeInfo                    // The node that gave the response.
-}
-
-type peerStream struct {
-       mu     sync.Mutex
-       Values chan peerStreamValue
-       stop   chan struct{}
-}
-
-func (ps *peerStream) Close() {
-       ps.mu.Lock()
-       select {
-       case <-ps.stop:
-       default:
-               close(ps.stop)
-       }
-       ps.mu.Unlock()
-}
-
 func extractValues(m Msg) (vs []util.CompactPeer) {
        r, ok := m["r"]
        if !ok {
@@ -752,63 +775,6 @@ func extractValues(m Msg) (vs []util.CompactPeer) {
        return
 }
 
-func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
-       ps = &peerStream{
-               Values: make(chan peerStreamValue),
-               stop:   make(chan struct{}),
-       }
-       done := make(chan struct{})
-       pending := 0
-       s.mu.Lock()
-       for _, n := range s.closestGoodNodes(160, infoHash) {
-               var t *transaction
-               t, err = s.getPeers(n.addr, infoHash)
-               if err != nil {
-                       ps.Close()
-                       break
-               }
-               go func() {
-                       select {
-                       case m := <-t.Response:
-                               vs := extractValues(m)
-                               if vs != nil {
-                                       nodeInfo := NodeInfo{
-                                               Addr: t.remoteAddr,
-                                       }
-                                       id := func() string {
-                                               defer func() {
-                                                       recover()
-                                               }()
-                                               return m["r"].(map[string]interface{})["id"].(string)
-                                       }()
-                                       copy(nodeInfo.ID[:], id)
-                                       select {
-                                       case ps.Values <- peerStreamValue{
-                                               Peers:    vs,
-                                               NodeInfo: nodeInfo,
-                                       }:
-                                       case <-ps.stop:
-                                       }
-                               }
-                       case <-ps.stop:
-                       }
-                       done <- struct{}{}
-               }()
-               pending++
-       }
-       s.mu.Unlock()
-       go func() {
-               for ; pending > 0; pending-- {
-                       select {
-                       case <-done:
-                       case <-s.closed:
-                       }
-               }
-               close(ps.Values)
-       }()
-       return
-}
-
 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err error) {
        if len(infoHash) != 20 {
                err = fmt.Errorf("infohash has bad length")
@@ -825,6 +791,10 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err er
        return
 }
 
+func bootstrapAddr() (net.Addr, error) {
+       return net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
+}
+
 func (s *Server) addRootNode() error {
        addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
        if err != nil {
diff --git a/dht/getpeers.go b/dht/getpeers.go
new file mode 100644 (file)
index 0000000..1f136c9
--- /dev/null
@@ -0,0 +1,170 @@
+package dht
+
+import (
+       "log"
+       "net"
+       "sync"
+
+       "bitbucket.org/anacrolix/go.torrent/util"
+)
+
+type peerDiscovery struct {
+       *peerStream
+       triedAddrs        map[string]struct{}
+       contactAddrs      chan net.Addr
+       pending           int
+       transactionClosed chan struct{}
+       server            *Server
+       infoHash          string
+}
+
+func (me *peerDiscovery) Close() {
+       me.peerStream.Close()
+       close(me.contactAddrs)
+}
+
+func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
+       disc := &peerDiscovery{
+               peerStream: &peerStream{
+                       Values: make(chan peerStreamValue),
+                       stop:   make(chan struct{}),
+               },
+               triedAddrs:        make(map[string]struct{}, 500),
+               contactAddrs:      make(chan net.Addr),
+               transactionClosed: make(chan struct{}),
+               server:            s,
+               infoHash:          infoHash,
+       }
+       go disc.loop()
+       s.mu.Lock()
+       startAddrs := func() (ret []net.Addr) {
+               for _, n := range s.closestGoodNodes(160, infoHash) {
+                       ret = append(ret, n.addr)
+               }
+               return
+       }()
+       s.mu.Unlock()
+       for _, addr := range startAddrs {
+               disc.contact(addr)
+       }
+       if len(startAddrs) == 0 {
+               addr, err := bootstrapAddr()
+               if err != nil {
+                       disc.Close()
+                       return nil, err
+               }
+               disc.contact(addr)
+       }
+       return disc.peerStream, nil
+}
+
+func (me *peerDiscovery) contact(addr net.Addr) {
+       select {
+       case me.contactAddrs <- addr:
+       case <-me.closingCh():
+       }
+}
+
+func (me *peerDiscovery) responseNode(node NodeInfo) {
+       me.contact(node.Addr)
+}
+
+func (me *peerDiscovery) loop() {
+       for {
+               select {
+               case addr := <-me.contactAddrs:
+                       if me.pending >= 160 {
+                               break
+                       }
+                       if _, ok := me.triedAddrs[addr.String()]; ok {
+                               break
+                       }
+                       me.triedAddrs[addr.String()] = struct{}{}
+                       if err := me.getPeers(addr); err != nil {
+                               log.Printf("error sending get_peers request to %s: %s", addr, err)
+                               break
+                       }
+                       // log.Printf("contacting %s", addr)
+                       me.pending++
+               case <-me.transactionClosed:
+                       me.pending--
+                       // log.Printf("pending: %d", me.pending)
+                       if me.pending == 0 {
+                               me.Close()
+                               return
+                       }
+               }
+       }
+}
+
+func (me *peerDiscovery) closingCh() chan struct{} {
+       return me.peerStream.stop
+}
+
+func (me *peerDiscovery) getPeers(addr net.Addr) error {
+       me.server.mu.Lock()
+       defer me.server.mu.Unlock()
+       t, err := me.server.getPeers(addr, me.infoHash)
+       if err != nil {
+               return err
+       }
+       go func() {
+               select {
+               case m := <-t.Response:
+                       if nodes := m.Nodes(); len(nodes) != 0 {
+                               for _, n := range nodes {
+                                       me.responseNode(n)
+                               }
+                       }
+                       if vs := extractValues(m); vs != nil {
+                               nodeInfo := NodeInfo{
+                                       Addr: t.remoteAddr,
+                               }
+                               id := func() string {
+                                       defer func() {
+                                               recover()
+                                       }()
+                                       return m["r"].(map[string]interface{})["id"].(string)
+                               }()
+                               copy(nodeInfo.ID[:], id)
+                               select {
+                               case me.peerStream.Values <- peerStreamValue{
+                                       Peers:    vs,
+                                       NodeInfo: nodeInfo,
+                               }:
+                               case <-me.peerStream.stop:
+                               }
+                       }
+               case <-me.closingCh():
+               }
+               t.Close()
+               me.transactionClosed <- struct{}{}
+       }()
+       return nil
+}
+
+func (me *peerDiscovery) streamValue(psv peerStreamValue) {
+       me.peerStream.Values <- psv
+}
+
+type peerStreamValue struct {
+       Peers    []util.CompactPeer // Peers given in get_peers response.
+       NodeInfo                    // The node that gave the response.
+}
+
+type peerStream struct {
+       mu     sync.Mutex
+       Values chan peerStreamValue
+       stop   chan struct{}
+}
+
+func (ps *peerStream) Close() {
+       ps.mu.Lock()
+       select {
+       case <-ps.stop:
+       default:
+               close(ps.stop)
+               close(ps.Values)
+       }
+       ps.mu.Unlock()
+}