]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Handle responses through a callback
authorMatt Joiner <anacrolix@gmail.com>
Sat, 21 Feb 2015 04:00:48 +0000 (15:00 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 21 Feb 2015 04:00:48 +0000 (15:00 +1100)
Trying to reduce number of goroutines

cmd/dht-ping/main.go
dht/dht.go
dht/dht_test.go
dht/getpeers.go

index 80cd2cc103d922350df326e55e4672728cb40045..d1d98ad7f0be7164e7b18cfcef4491f01cec2b6a 100644 (file)
@@ -47,14 +47,15 @@ func main() {
                                log.Fatal(err)
                        }
                        start := time.Now()
-                       go func(addr string) {
-                               resp := <-t.Response
-                               pingResponses <- pingResponse{
-                                       addr: addr,
-                                       krpc: resp,
-                                       rtt:  time.Now().Sub(start),
+                       t.SetResponseHandler(func(addr string) func(dht.Msg) {
+                               return func(resp dht.Msg) {
+                                       pingResponses <- pingResponse{
+                                               addr: addr,
+                                               krpc: resp,
+                                               rtt:  time.Now().Sub(start),
+                                       }
                                }
-                       }(netloc)
+                       }(netloc))
                }
                if *timeout >= 0 {
                        time.Sleep(*timeout)
index cac9fbd061ebc3c604023b5b5b1f6f81ca85d777..5b426045af02e64d4b45828f2bc333961266d1b6 100644 (file)
@@ -21,7 +21,10 @@ import (
        "github.com/anacrolix/libtorgo/bencode"
 )
 
-const maxNodes = 1000
+const (
+       maxNodes         = 320
+       queryResendEvery = 5 * time.Second
+)
 
 // Uniquely identifies a transaction to us.
 type transactionKey struct {
@@ -313,16 +316,38 @@ func (m Msg) AnnounceToken() (token string, ok bool) {
 }
 
 type transaction struct {
-       mu          sync.Mutex
-       remoteAddr  dHTAddr
-       t           string
-       Response    chan Msg
-       onResponse  func(Msg) // Called with the server locked.
-       done        chan struct{}
-       queryPacket []byte
-       timer       *time.Timer
-       s           *Server
-       retries     int
+       mu             sync.Mutex
+       remoteAddr     dHTAddr
+       t              string
+       response       chan Msg
+       onResponse     func(Msg) // Called with the server locked.
+       done           chan struct{}
+       queryPacket    []byte
+       timer          *time.Timer
+       s              *Server
+       retries        int
+       lastSend       time.Time
+       userOnResponse func(Msg)
+}
+
+func (t *transaction) SetResponseHandler(f func(Msg)) {
+       t.mu.Lock()
+       defer t.mu.Unlock()
+       t.userOnResponse = f
+       t.tryHandleResponse()
+}
+
+func (t *transaction) tryHandleResponse() {
+       if t.userOnResponse == nil {
+               return
+       }
+       select {
+       case r := <-t.response:
+               t.userOnResponse(r)
+               // Shouldn't be called more than once.
+               t.userOnResponse = nil
+       default:
+       }
 }
 
 func (t *transaction) Key() transactionKey {
@@ -337,7 +362,7 @@ func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duratio
 }
 
 func (t *transaction) startTimer() {
-       t.timer = time.AfterFunc(jitterDuration(20*time.Second, time.Second), t.timerCallback)
+       t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
 }
 
 func (t *transaction) timerCallback() {
@@ -354,13 +379,18 @@ func (t *transaction) timerCallback() {
        }
        t.retries++
        t.sendQuery()
-       if t.timer.Reset(jitterDuration(20*time.Second, time.Second)) {
+       if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
                panic("timer should have fired to get here")
        }
 }
 
 func (t *transaction) sendQuery() error {
-       return t.s.writeToNode(t.queryPacket, t.remoteAddr)
+       err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
+       if err != nil {
+               return err
+       }
+       t.lastSend = time.Now()
+       return nil
 }
 
 func (t *transaction) timeout() {
@@ -377,7 +407,8 @@ func (t *transaction) close() {
                return
        }
        t.queryPacket = nil
-       close(t.Response)
+       close(t.response)
+       t.tryHandleResponse()
        close(t.done)
        t.timer.Stop()
        go func() {
@@ -417,11 +448,12 @@ func (t *transaction) handleResponse(m Msg) {
        }
        t.queryPacket = nil
        select {
-       case t.Response <- m:
+       case t.response <- m:
        default:
                panic("blocked handling response")
        }
-       close(t.Response)
+       close(t.response)
+       t.tryHandleResponse()
 }
 
 func (s *Server) setDefaults() (err error) {
@@ -675,7 +707,7 @@ func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
        }
        n, err := s.socket.WriteTo(b, node.UDPAddr())
        if err != nil {
-               err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
+               err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
                return
        }
        if n != len(b) {
@@ -716,7 +748,7 @@ func (s *Server) IDString() string {
        return s.id
 }
 
-func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) {
+func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *transaction, err error) {
        tid := s.nextTransactionID()
        if a == nil {
                a = make(map[string]interface{}, 1)
@@ -735,10 +767,11 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *tra
        t = &transaction{
                remoteAddr:  node,
                t:           tid,
-               Response:    make(chan Msg, 1),
+               response:    make(chan Msg, 1),
                done:        make(chan struct{}),
                queryPacket: b,
                s:           s,
+               onResponse:  onResponse,
        }
        err = t.sendQuery()
        if err != nil {
@@ -787,7 +820,7 @@ func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
 func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
        s.mu.Lock()
        defer s.mu.Unlock()
-       return s.query(newDHTAddr(node), "ping", nil)
+       return s.query(newDHTAddr(node), "ping", nil, nil)
 }
 
 // Announce a local peer. This can only be done to nodes that gave us an
@@ -807,11 +840,11 @@ func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err
        return
 }
 
-func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) error {
+func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
        if port == 0 && !impliedPort {
                return errors.New("nothing to announce")
        }
-       t, err := s.query(node, "announce_peer", map[string]interface{}{
+       _, err = s.query(node, "announce_peer", map[string]interface{}{
                "implied_port": func() int {
                        if impliedPort {
                                return 1
@@ -822,22 +855,14 @@ func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token str
                "info_hash": infoHash,
                "port":      port,
                "token":     token,
-       })
-       t.setOnResponse(func(m Msg) {
+       }, func(m Msg) {
                if err := m.Error(); err != nil {
                        logonce.Stderr.Printf("announce_peer response: %s", err)
                        return
                }
                s.NumConfirmedAnnounces++
        })
-       return err
-}
-
-func (t *transaction) setOnResponse(f func(m Msg)) {
-       if t.onResponse != nil {
-               panic(t.onResponse)
-       }
-       t.onResponse = f
+       return
 }
 
 // Add response nodes to node table.
@@ -860,15 +885,14 @@ func (s *Server) liftNodes(d Msg) {
 
 // Sends a find_node query to addr. targetID is the node we're looking for.
 func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err error) {
-       t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID})
+       t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
+               // Scrape peers from the response to put in the server's table before
+               // handing the response back to the caller.
+               s.liftNodes(d)
+       })
        if err != nil {
                return
        }
-       // Scrape peers from the response to put in the server's table before
-       // handing the response back to the caller.
-       t.setOnResponse(func(d Msg) {
-               s.liftNodes(d)
-       })
        return
 }
 
@@ -914,11 +938,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err er
                err = fmt.Errorf("infohash has bad length")
                return
        }
-       t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash})
-       if err != nil {
-               return
-       }
-       t.setOnResponse(func(m Msg) {
+       t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
                s.liftNodes(m)
                at, ok := m.AnnounceToken()
                if ok {
@@ -978,10 +998,9 @@ func (s *Server) bootstrap() (err error) {
                                return
                        }
                        outstanding.Add(1)
-                       go func() {
-                               <-t.Response
+                       t.SetResponseHandler(func(Msg) {
                                outstanding.Done()
-                       }()
+                       })
                }
                noOutstanding := make(chan struct{})
                go func() {
index 8a8a1f7f46d93b4914a7daea10e528ab1cf3c9b5..a2e5d11ec597cfaafff61cca19ebd3d8ce17c6b2 100644 (file)
@@ -148,8 +148,11 @@ func TestPing(t *testing.T) {
                t.Fatal(err)
        }
        defer tn.Close()
-       msg := <-tn.Response
-       if msg.ID() != srv0.IDString() {
+       ok := make(chan bool)
+       tn.SetResponseHandler(func(msg Msg) {
+               ok <- msg.ID() == srv0.IDString()
+       })
+       if !<-ok {
                t.FailNow()
        }
 }
index 0435b59b7c1053ac8656211476d3f96bfe3e2e0d..f25603b48b82d563e04ca2c7cc8cda9bee9cb190 100644 (file)
@@ -105,7 +105,7 @@ func (me *peerDiscovery) contact(addr dHTAddr) {
        me.numContacted++
        me.triedAddrs.Add([]byte(addr.String()))
        if err := me.getPeers(addr); err != nil {
-               log.Printf("error sending get_peers request to %s: %s", addr, err)
+               log.Printf("error sending get_peers request to %s: %#v", addr, err)
                return
        }
        me.pending++
@@ -143,40 +143,36 @@ func (me *peerDiscovery) getPeers(addr dHTAddr) error {
        if err != nil {
                return err
        }
-       go func() {
-               select {
-               case m := <-t.Response:
-                       // Register suggested nodes closer to the target info-hash.
-                       me.mu.Lock()
-                       for _, n := range m.Nodes() {
-                               me.responseNode(n)
-                       }
-                       me.mu.Unlock()
+       t.SetResponseHandler(func(m Msg) {
+               // Register suggested nodes closer to the target info-hash.
+               me.mu.Lock()
+               for _, n := range m.Nodes() {
+                       me.responseNode(n)
+               }
+               me.mu.Unlock()
 
-                       if vs := m.Values(); vs != nil {
-                               nodeInfo := NodeInfo{
-                                       Addr: t.remoteAddr,
-                               }
-                               copy(nodeInfo.ID[:], m.ID())
-                               select {
-                               case me.peerStream.values <- peerStreamValue{
-                                       Peers:    vs,
-                                       NodeInfo: nodeInfo,
-                               }:
-                               case <-me.peerStream.stop:
-                               }
+               if vs := m.Values(); vs != nil {
+                       nodeInfo := NodeInfo{
+                               Addr: t.remoteAddr,
                        }
-
-                       if at, ok := m.AnnounceToken(); ok {
-                               me.announcePeer(addr, at)
+                       copy(nodeInfo.ID[:], m.ID())
+                       select {
+                       case me.peerStream.values <- peerStreamValue{
+                               Peers:    vs,
+                               NodeInfo: nodeInfo,
+                       }:
+                       case <-me.peerStream.stop:
                        }
-               case <-me.closingCh():
                }
-               t.Close()
+
+               if at, ok := m.AnnounceToken(); ok {
+                       me.announcePeer(addr, at)
+               }
+
                me.mu.Lock()
                me.transactionClosed()
                me.mu.Unlock()
-       }()
+       })
        return nil
 }