From: Matt Joiner Date: Sat, 21 Feb 2015 04:00:48 +0000 (+1100) Subject: dht: Handle responses through a callback X-Git-Tag: v1.0.0~1328 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=a372b68abbb1d1690c8a3387b8c113b979bef0b4;p=btrtrc.git dht: Handle responses through a callback Trying to reduce number of goroutines --- diff --git a/cmd/dht-ping/main.go b/cmd/dht-ping/main.go index 80cd2cc1..d1d98ad7 100644 --- a/cmd/dht-ping/main.go +++ b/cmd/dht-ping/main.go @@ -47,14 +47,15 @@ func main() { log.Fatal(err) } start := time.Now() - go func(addr string) { - resp := <-t.Response - pingResponses <- pingResponse{ - addr: addr, - krpc: resp, - rtt: time.Now().Sub(start), + t.SetResponseHandler(func(addr string) func(dht.Msg) { + return func(resp dht.Msg) { + pingResponses <- pingResponse{ + addr: addr, + krpc: resp, + rtt: time.Now().Sub(start), + } } - }(netloc) + }(netloc)) } if *timeout >= 0 { time.Sleep(*timeout) diff --git a/dht/dht.go b/dht/dht.go index cac9fbd0..5b426045 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -21,7 +21,10 @@ import ( "github.com/anacrolix/libtorgo/bencode" ) -const maxNodes = 1000 +const ( + maxNodes = 320 + queryResendEvery = 5 * time.Second +) // Uniquely identifies a transaction to us. type transactionKey struct { @@ -313,16 +316,38 @@ func (m Msg) AnnounceToken() (token string, ok bool) { } type transaction struct { - mu sync.Mutex - remoteAddr dHTAddr - t string - Response chan Msg - onResponse func(Msg) // Called with the server locked. - done chan struct{} - queryPacket []byte - timer *time.Timer - s *Server - retries int + mu sync.Mutex + remoteAddr dHTAddr + t string + response chan Msg + onResponse func(Msg) // Called with the server locked. + done chan struct{} + queryPacket []byte + timer *time.Timer + s *Server + retries int + lastSend time.Time + userOnResponse func(Msg) +} + +func (t *transaction) SetResponseHandler(f func(Msg)) { + t.mu.Lock() + defer t.mu.Unlock() + t.userOnResponse = f + t.tryHandleResponse() +} + +func (t *transaction) tryHandleResponse() { + if t.userOnResponse == nil { + return + } + select { + case r := <-t.response: + t.userOnResponse(r) + // Shouldn't be called more than once. + t.userOnResponse = nil + default: + } } func (t *transaction) Key() transactionKey { @@ -337,7 +362,7 @@ func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duratio } func (t *transaction) startTimer() { - t.timer = time.AfterFunc(jitterDuration(20*time.Second, time.Second), t.timerCallback) + t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback) } func (t *transaction) timerCallback() { @@ -354,13 +379,18 @@ func (t *transaction) timerCallback() { } t.retries++ t.sendQuery() - if t.timer.Reset(jitterDuration(20*time.Second, time.Second)) { + if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) { panic("timer should have fired to get here") } } func (t *transaction) sendQuery() error { - return t.s.writeToNode(t.queryPacket, t.remoteAddr) + err := t.s.writeToNode(t.queryPacket, t.remoteAddr) + if err != nil { + return err + } + t.lastSend = time.Now() + return nil } func (t *transaction) timeout() { @@ -377,7 +407,8 @@ func (t *transaction) close() { return } t.queryPacket = nil - close(t.Response) + close(t.response) + t.tryHandleResponse() close(t.done) t.timer.Stop() go func() { @@ -417,11 +448,12 @@ func (t *transaction) handleResponse(m Msg) { } t.queryPacket = nil select { - case t.Response <- m: + case t.response <- m: default: panic("blocked handling response") } - close(t.Response) + close(t.response) + t.tryHandleResponse() } func (s *Server) setDefaults() (err error) { @@ -675,7 +707,7 @@ func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) { } n, err := s.socket.WriteTo(b, node.UDPAddr()) if err != nil { - err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err) + err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err) return } if n != len(b) { @@ -716,7 +748,7 @@ func (s *Server) IDString() string { return s.id } -func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) { +func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *transaction, err error) { tid := s.nextTransactionID() if a == nil { a = make(map[string]interface{}, 1) @@ -735,10 +767,11 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *tra t = &transaction{ remoteAddr: node, t: tid, - Response: make(chan Msg, 1), + response: make(chan Msg, 1), done: make(chan struct{}), queryPacket: b, s: s, + onResponse: onResponse, } err = t.sendQuery() if err != nil { @@ -787,7 +820,7 @@ func (cni *NodeInfo) UnmarshalCompact(b []byte) error { func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) { s.mu.Lock() defer s.mu.Unlock() - return s.query(newDHTAddr(node), "ping", nil) + return s.query(newDHTAddr(node), "ping", nil, nil) } // Announce a local peer. This can only be done to nodes that gave us an @@ -807,11 +840,11 @@ func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err return } -func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) error { +func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) { if port == 0 && !impliedPort { return errors.New("nothing to announce") } - t, err := s.query(node, "announce_peer", map[string]interface{}{ + _, err = s.query(node, "announce_peer", map[string]interface{}{ "implied_port": func() int { if impliedPort { return 1 @@ -822,22 +855,14 @@ func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token str "info_hash": infoHash, "port": port, "token": token, - }) - t.setOnResponse(func(m Msg) { + }, func(m Msg) { if err := m.Error(); err != nil { logonce.Stderr.Printf("announce_peer response: %s", err) return } s.NumConfirmedAnnounces++ }) - return err -} - -func (t *transaction) setOnResponse(f func(m Msg)) { - if t.onResponse != nil { - panic(t.onResponse) - } - t.onResponse = f + return } // Add response nodes to node table. @@ -860,15 +885,14 @@ func (s *Server) liftNodes(d Msg) { // Sends a find_node query to addr. targetID is the node we're looking for. func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err error) { - t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}) + t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) { + // Scrape peers from the response to put in the server's table before + // handing the response back to the caller. + s.liftNodes(d) + }) if err != nil { return } - // Scrape peers from the response to put in the server's table before - // handing the response back to the caller. - t.setOnResponse(func(d Msg) { - s.liftNodes(d) - }) return } @@ -914,11 +938,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err er err = fmt.Errorf("infohash has bad length") return } - t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}) - if err != nil { - return - } - t.setOnResponse(func(m Msg) { + t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) { s.liftNodes(m) at, ok := m.AnnounceToken() if ok { @@ -978,10 +998,9 @@ func (s *Server) bootstrap() (err error) { return } outstanding.Add(1) - go func() { - <-t.Response + t.SetResponseHandler(func(Msg) { outstanding.Done() - }() + }) } noOutstanding := make(chan struct{}) go func() { diff --git a/dht/dht_test.go b/dht/dht_test.go index 8a8a1f7f..a2e5d11e 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -148,8 +148,11 @@ func TestPing(t *testing.T) { t.Fatal(err) } defer tn.Close() - msg := <-tn.Response - if msg.ID() != srv0.IDString() { + ok := make(chan bool) + tn.SetResponseHandler(func(msg Msg) { + ok <- msg.ID() == srv0.IDString() + }) + if !<-ok { t.FailNow() } } diff --git a/dht/getpeers.go b/dht/getpeers.go index 0435b59b..f25603b4 100644 --- a/dht/getpeers.go +++ b/dht/getpeers.go @@ -105,7 +105,7 @@ func (me *peerDiscovery) contact(addr dHTAddr) { me.numContacted++ me.triedAddrs.Add([]byte(addr.String())) if err := me.getPeers(addr); err != nil { - log.Printf("error sending get_peers request to %s: %s", addr, err) + log.Printf("error sending get_peers request to %s: %#v", addr, err) return } me.pending++ @@ -143,40 +143,36 @@ func (me *peerDiscovery) getPeers(addr dHTAddr) error { if err != nil { return err } - go func() { - select { - case m := <-t.Response: - // Register suggested nodes closer to the target info-hash. - me.mu.Lock() - for _, n := range m.Nodes() { - me.responseNode(n) - } - me.mu.Unlock() + t.SetResponseHandler(func(m Msg) { + // Register suggested nodes closer to the target info-hash. + me.mu.Lock() + for _, n := range m.Nodes() { + me.responseNode(n) + } + me.mu.Unlock() - if vs := m.Values(); vs != nil { - nodeInfo := NodeInfo{ - Addr: t.remoteAddr, - } - copy(nodeInfo.ID[:], m.ID()) - select { - case me.peerStream.values <- peerStreamValue{ - Peers: vs, - NodeInfo: nodeInfo, - }: - case <-me.peerStream.stop: - } + if vs := m.Values(); vs != nil { + nodeInfo := NodeInfo{ + Addr: t.remoteAddr, } - - if at, ok := m.AnnounceToken(); ok { - me.announcePeer(addr, at) + copy(nodeInfo.ID[:], m.ID()) + select { + case me.peerStream.values <- peerStreamValue{ + Peers: vs, + NodeInfo: nodeInfo, + }: + case <-me.peerStream.stop: } - case <-me.closingCh(): } - t.Close() + + if at, ok := m.AnnounceToken(); ok { + me.announcePeer(addr, at) + } + me.mu.Lock() me.transactionClosed() me.mu.Unlock() - }() + }) return nil }