import (
"bitbucket.org/anacrolix/go.torrent/dht"
+ "flag"
+ "fmt"
+ "io"
"log"
"net"
"os"
+ "os/signal"
)
type pingResponse struct {
krpc dht.Msg
}
-func main() {
+var (
+ tableFileName = flag.String("tableFile", "", "name of file for storing node info")
+ serveAddr = flag.String("serveAddr", ":0", "local UDP address")
+
+ s dht.Server
+)
+
+func loadTable() error {
+ if *tableFileName == "" {
+ return nil
+ }
+ f, err := os.Open(*tableFileName)
+ if os.IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return fmt.Errorf("error opening table file: %s", err)
+ }
+ defer f.Close()
+ added := 0
+ for {
+ b := make([]byte, dht.CompactNodeInfoLen)
+ _, err := io.ReadFull(f, b)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("error reading table file: %s", err)
+ }
+ var ni dht.NodeInfo
+ err = ni.UnmarshalCompact(b)
+ if err != nil {
+ return fmt.Errorf("error unmarshaling compact node info: %s", err)
+ }
+ s.AddNode(ni)
+ added++
+ }
+ log.Printf("loaded %d nodes from table file", added)
+ return nil
+}
+
+func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
- s := dht.Server{}
- var err error
- s.Socket, err = net.ListenUDP("udp4", nil)
+ flag.Parse()
+ err := loadTable()
+ if err != nil {
+ log.Fatalf("error loading table: %s", err)
+ }
+ s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
+ addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
+ if err != nil {
+ log.Fatalf("error resolving serve addr: %s", err)
+ }
+ return addr
+ }())
if err != nil {
log.Fatal(err)
}
+ log.Printf("dht server on %s", s.Socket.LocalAddr())
s.Init()
- func() {
- f, err := os.Open("nodes")
- if os.IsNotExist(err) {
- return
+ setupSignals()
+}
+
+func saveTable() error {
+ goodNodes := s.GoodNodes()
+ if *tableFileName == "" {
+ if len(goodNodes) != 0 {
+ log.Printf("discarding %d good nodes!", len(goodNodes))
}
+ return nil
+ }
+ f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+ if err != nil {
+ return fmt.Errorf("error opening table file: %s", err)
+ }
+ defer f.Close()
+ for _, nodeInfo := range goodNodes {
+ var b [dht.CompactNodeInfoLen]byte
+ err := nodeInfo.PutCompact(b[:])
if err != nil {
- log.Fatal(err)
+ return fmt.Errorf("error compacting node info: %s", err)
}
- defer f.Close()
- err = s.ReadNodes(f)
+ _, err = f.Write(b[:])
if err != nil {
- log.Fatal(err)
+ return fmt.Errorf("error writing compact node info: %s", err)
}
- }()
- log.Printf("dht server on %s", s.Socket.LocalAddr())
+ }
+ log.Printf("saved %d nodes to table file", len(goodNodes))
+ return nil
+}
+
+func setupSignals() {
+ ch := make(chan os.Signal)
+ signal.Notify(ch)
go func() {
- err := s.Serve()
- if err != nil {
- log.Fatal(err)
- }
+ <-ch
+ s.StopServing()
}()
- err = s.Bootstrap()
- func() {
- f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+}
+
+func main() {
+ go func() {
+ err := s.Bootstrap()
if err != nil {
- log.Print(err)
- return
+ log.Printf("error bootstrapping: %s", err)
+ s.StopServing()
}
- defer f.Close()
- s.WriteNodes(f)
}()
+ err := s.Serve()
+ if err := saveTable(); err != nil {
+ log.Printf("error saving node table: %s", err)
+ }
if err != nil {
- log.Fatal(err)
+ log.Fatalf("error serving dht: %s", err)
}
}
"io"
"log"
"net"
+ "sync"
"time"
)
transactions []*transaction
transactionIDInt uint64
nodes map[string]*Node
+ mu sync.Mutex
}
type Node struct {
lastSentTo time.Time
}
+func (n *Node) Good() bool {
+ if len(n.id) != 20 {
+ return false
+ }
+ if time.Now().Sub(n.lastHeardFrom) >= 15*time.Minute {
+ return false
+ }
+ return true
+}
+
type Msg map[string]interface{}
var _ fmt.Stringer = Msg{}
response chan Msg
}
-func (s *Server) ReadNodes(r io.Reader) error {
- for {
- var b [compactNodeInfoLen]byte
- _, err := io.ReadFull(r, b[:])
- if err == io.EOF {
- return nil
- }
- if err != nil {
- return err
- }
- var cni compactNodeInfo
- err = cni.UnmarshalBinary(b[:])
- if err != nil {
- return err
- }
- n := s.getNode(cni.Addr)
- n.id = string(cni.ID[:])
- }
-}
-
-func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
- for _, node := range s.nodes {
- cni := compactNodeInfo{
- Addr: node.addr,
- }
- if n := copy(cni.ID[:], node.id); n != 20 {
- panic(n)
- }
- var b [26]byte
- cni.PutBinary(b[:])
- var nn int
- nn, err = w.Write(b[:])
- if err != nil {
- return
- }
- n += nn
- }
- return
-}
-
func (s *Server) setDefaults() {
if s.ID == "" {
var id [20]byte
func (s *Server) Init() {
s.setDefaults()
- s.nodes = make(map[string]*Node, 1000)
}
func (s *Server) Serve() error {
log.Printf("bad krpc message: %s", err)
continue
}
+ s.mu.Lock()
if d["y"] == "q" {
s.handleQuery(addr, d)
+ s.mu.Unlock()
continue
}
t := s.findResponseTransaction(d["t"].(string), addr)
if t == nil {
log.Printf("unexpected message: %#v", d)
+ s.mu.Unlock()
continue
}
t.response <- d
id = d["r"].(map[string]interface{})["id"].(string)
}
s.heardFromNode(addr, id)
+ s.mu.Unlock()
+ }
+}
+
+func (s *Server) AddNode(ni NodeInfo) {
+ if s.nodes == nil {
+ s.nodes = make(map[string]*Node)
+ }
+ n := s.getNode(ni.Addr)
+ if n.id == "" {
+ n.id = string(ni.ID[:])
}
}
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
+ log.Print(m["q"])
if m["q"] != "ping" {
return
}
- s.heardFromNode(source, m["a"].(map[string]string)["id"])
+ s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
s.reply(source, m["t"].(string))
}
return
}
-const compactNodeInfoLen = 26
-
-type compactAddrInfo *net.UDPAddr
+const CompactNodeInfoLen = 26
-type compactNodeInfo struct {
+type NodeInfo struct {
ID [20]byte
- Addr compactAddrInfo
+ Addr *net.UDPAddr
}
-func (cni *compactNodeInfo) PutBinary(b []byte) {
- if n := copy(b[:], cni.ID[:]); n != 20 {
+func (ni *NodeInfo) PutCompact(b []byte) error {
+ if n := copy(b[:], ni.ID[:]); n != 20 {
panic(n)
}
- ip := cni.Addr.IP.To4()
+ ip := ni.Addr.IP.To4()
if len(ip) != 4 {
panic(ip)
}
if n := copy(b[20:], ip); n != 4 {
panic(n)
}
- binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port))
+ binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port))
+ return nil
}
-func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error {
+func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
if len(b) != 26 {
return errors.New("expected 26 bytes")
}
}
type findNodeResponse struct {
- Nodes []compactNodeInfo
+ Nodes []NodeInfo
}
func getResponseNodes(m Msg) (s string, err error) {
return err
}
for i := 0; i < len(b); i += 26 {
- var n compactNodeInfo
- err := n.UnmarshalBinary([]byte(b[i : i+26]))
+ var n NodeInfo
+ err := n.UnmarshalCompact([]byte(b[i : i+26]))
if err != nil {
return err
}
}
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
- // log.Print(addr)
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
if err != nil {
return
if err != nil {
log.Print(err)
} else {
+ s.mu.Lock()
for _, cni := range r.Nodes {
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
+ s.mu.Unlock()
}
}
t.Response <- d
return
}
-func (s *Server) Bootstrap() error {
+func (s *Server) addRootNode() error {
+ addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
+ if err != nil {
+ return err
+ }
+ s.nodes[addr.String()] = &Node{
+ addr: addr,
+ }
+ return nil
+}
+
+// Populates the node table.
+func (s *Server) Bootstrap() (err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
if len(s.nodes) == 0 {
- addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
+ err = s.addRootNode()
if err != nil {
- return err
+ return
}
- s.nodes[addr.String()] = &Node{
- addr: addr,
+ }
+ for _, node := range s.nodes {
+ var t *transaction
+ s.mu.Unlock()
+ t, err = s.FindNode(node.addr, s.ID)
+ s.mu.Lock()
+ if err != nil {
+ return
}
+ go func() {
+ <-t.Response
+ }()
}
- queriedNodes := make(map[string]bool, 1000)
- for i := 0; i < 3; i++ {
- log.Printf("node table length: %d", len(s.nodes))
- for _, node := range s.nodes {
- if queriedNodes[node.addr.String()] {
- continue
- }
- t, err := s.FindNode(node.addr, s.ID)
- if err != nil {
- return err
- }
- queriedNodes[node.addr.String()] = true
- go func() {
- <-t.Response
- }()
+ return
+}
+
+func (s *Server) GoodNodes() (nis []NodeInfo) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for _, node := range s.nodes {
+ if !node.Good() {
+ continue
}
- time.Sleep(3 * time.Second)
+ ni := NodeInfo{
+ Addr: node.addr,
+ }
+ if n := copy(ni.ID[:], node.id); n != 20 {
+ panic(n)
+ }
+ nis = append(nis, ni)
}
- return nil
+ return
+}
+
+func (s *Server) StopServing() {
+ s.Socket.Close()
}