From 3204e276f2f4127f619e6eb5d078987935695af8 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Tue, 18 Aug 2015 02:11:09 +1000 Subject: [PATCH] dht: Improve on on Msg methods --- dht/announce.go | 2 +- dht/dht.go | 21 +++++++++++++++++---- dht/dht_test.go | 2 +- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/dht/announce.go b/dht/announce.go index 84b3c7bc..a7c55868 100644 --- a/dht/announce.go +++ b/dht/announce.go @@ -178,7 +178,7 @@ func (me *Announce) getPeers(addr dHTAddr) error { nodeInfo := NodeInfo{ Addr: t.remoteAddr, } - copy(nodeInfo.ID[:], m.ID()) + copy(nodeInfo.ID[:], m.SenderID()) select { case me.values <- PeersValues{ Peers: vs, diff --git a/dht/dht.go b/dht/dht.go index ed45395c..bca1c778 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -289,11 +289,24 @@ func (m Msg) T() (t string) { return } -func (m Msg) ID() string { +func (m Msg) Args() map[string]interface{} { defer func() { recover() }() - return m[m["y"].(string)].(map[string]interface{})["id"].(string) + return m["a"].(map[string]interface{}) +} + +func (m Msg) SenderID() string { + defer func() { + recover() + }() + switch m["y"].(string) { + case "q": + return m.Args()["id"].(string) + case "r": + return m["r"].(map[string]interface{})["id"].(string) + } + return "" } // Suggested nodes in a response. @@ -647,7 +660,7 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) { //log.Printf("unexpected message: %#v", d) return } - node := s.getNode(addr, d.ID()) + node := s.getNode(addr, d.SenderID()) node.lastGotResponse = time.Now() // TODO: Update node ID as this is an authoritative packet. go t.handleResponse(d) @@ -1090,7 +1103,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err er s.liftNodes(m) at, ok := m.AnnounceToken() if ok { - s.getNode(addr, m.ID()).announceToken = at + s.getNode(addr, m.SenderID()).announceToken = at } }) return diff --git a/dht/dht_test.go b/dht/dht_test.go index 6636ad6f..f84ccb68 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -145,7 +145,7 @@ func TestPing(t *testing.T) { defer tn.Close() ok := make(chan bool) tn.SetResponseHandler(func(msg Msg) { - ok <- msg.ID() == srv0.ID() + ok <- msg.SenderID() == srv0.ID() }) if !<-ok { t.FailNow() -- 2.48.1