transactionIDInt uint64
nodes map[string]*Node
mu sync.Mutex
+ closed chan struct{}
+}
+
+func (s *Server) String() string {
+ return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr())
}
type Node struct {
remoteAddr net.Addr
t string
Response chan Msg
- response chan Msg
+ onResponse func(Msg)
+}
+
+func (t *transaction) handleResponse(m Msg) {
+ if t.onResponse != nil {
+ t.onResponse(m)
+ }
+ t.Response <- m
+ close(t.Response)
}
func (s *Server) setDefaults() (err error) {
return
}
-func (s *Server) Init() error {
- return s.setDefaults()
- //s.nodes = make(map[string]*Node)
+func (s *Server) Init() (err error) {
+ err = s.setDefaults()
+ if err != nil {
+ return
+ }
+ s.closed = make(chan struct{})
+ return
}
func (s *Server) Serve() error {
var d map[string]interface{}
err = bencode.Unmarshal(b[:n], &d)
if err != nil {
- log.Printf("bad krpc message: %s: %q", err, b[:n])
+ log.Printf("%s: received bad krpc message: %s: %q", s, err, b[:n])
continue
}
s.mu.Lock()
s.mu.Unlock()
continue
}
- t.response <- d
+ t.handleResponse(d)
s.removeTransaction(t)
id := ""
if d["y"] == "r" {
}
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
- log.Print(m["q"])
if m["q"] != "ping" {
+ log.Printf("%s: not handling received query: q=%s", s, m["q"])
return
}
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
t: tid,
Response: make(chan Msg, 1),
}
- t.response = t.Response
s.addTransaction(t)
err = s.writeToNode(b, node)
if err != nil {
return nil
}
-func (t *transaction) onResponse(f func(m Msg)) {
- ch := make(chan Msg)
- t.response = ch
- go func() {
- d, ok := <-t.response
- if !ok {
- close(t.Response)
+func (t *transaction) setOnResponse(f func(m Msg)) {
+ if t.onResponse != nil {
+ panic(t.onResponse)
+ }
+ t.onResponse = f
+}
+
+func unmarshalNodeInfoBinary(b []byte) (ret []NodeInfo, err error) {
+ if len(b)%26 != 0 {
+ err = errors.New("bad buffer length")
+ return
+ }
+ ret = make([]NodeInfo, 0, len(b)/26)
+ for i := 0; i < len(b); i += 26 {
+ var ni NodeInfo
+ err = ni.UnmarshalCompact(b[i : i+26])
+ if err != nil {
return
}
- f(d)
- t.Response <- d
- }()
+ ret = append(ret, ni)
+ }
+ return
+}
+
+func extractNodes(d Msg) (nodes []NodeInfo, err error) {
+ if d["y"] != "r" {
+ return
+ }
+ r, ok := d["r"]
+ if !ok {
+ err = errors.New("missing r dict")
+ return
+ }
+ rd, ok := r.(map[string]interface{})
+ if !ok {
+ err = errors.New("bad r value type")
+ return
+ }
+ n, ok := rd["nodes"]
+ if !ok {
+ return
+ }
+ ns, ok := n.(string)
+ if !ok {
+ err = errors.New("bad nodes value type")
+ return
+ }
+ return unmarshalNodeInfoBinary([]byte(ns))
}
func (s *Server) liftNodes(d Msg) {
if err != nil {
// log.Print(err)
} else {
- s.mu.Lock()
for _, cni := range r.Nodes {
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
- s.mu.Unlock()
// log.Printf("lifted %d nodes", len(r.Nodes))
}
}
// Sends a find_node query to addr. targetID is the node we're looking for.
-func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
+func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
if err != nil {
return
}
// Scrape peers from the response to put in the server's table before
// handing the response back to the caller.
- t.onResponse(func(d Msg) {
+ t.setOnResponse(func(d Msg) {
s.liftNodes(d)
})
return
case m := <-t.Response:
vs := extractValues(m)
if vs != nil {
- ps.Values <- vs
- // } else {
- // log.Print("get_peers response had no values")
+ select {
+ case ps.Values <- vs:
+ case <-ps.stop:
+ }
}
case <-ps.stop:
}
s.mu.Unlock()
go func() {
for ; pending > 0; pending-- {
- <-done
+ select {
+ case <-done:
+ case <-s.closed:
+ }
}
ps.Close()
}()
if err != nil {
return
}
- t.onResponse(func(m Msg) {
+ t.setOnResponse(func(m Msg) {
s.liftNodes(m)
})
return
defer s.mu.Unlock()
if len(s.nodes) == 0 {
err = s.addRootNode()
- if err != nil {
- return
- }
+ }
+ if err != nil {
+ return
}
for {
+ var outstanding sync.WaitGroup
for _, node := range s.nodes {
var t *transaction
- s.mu.Unlock()
- t, err = s.FindNode(node.addr, s.ID)
- s.mu.Lock()
+ t, err = s.findNode(node.addr, s.ID)
if err != nil {
return
}
+ outstanding.Add(1)
go func() {
<-t.Response
+ outstanding.Done()
}()
}
- time.Sleep(5 * time.Second)
+ noOutstanding := make(chan struct{})
+ go func() {
+ outstanding.Wait()
+ close(noOutstanding)
+ }()
+ s.mu.Unlock()
+ select {
+ case <-s.closed:
+ s.mu.Lock()
+ return
+ case <-time.After(15 * time.Second):
+ case <-noOutstanding:
+ }
+ s.mu.Lock()
log.Printf("now have %d nodes", len(s.nodes))
if len(s.nodes) >= 8*160 {
break
func (s *Server) StopServing() {
s.Socket.Close()
+ s.mu.Lock()
+ select {
+ case <-s.closed:
+ default:
+ close(s.closed)
+ }
+ s.mu.Unlock()
}
func idDistance(a, b string) (ret int) {