From 7c3d919cfb03f5e59012eaadfaecaf97ce8bd2e9 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sat, 24 May 2014 16:51:56 +1000 Subject: [PATCH] Begin implementing DHT --- cmd/dht-server/main.go | 53 ++++++++++++ dht/dht.go | 187 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 cmd/dht-server/main.go create mode 100644 dht/dht.go diff --git a/cmd/dht-server/main.go b/cmd/dht-server/main.go new file mode 100644 index 00000000..c539b2ee --- /dev/null +++ b/cmd/dht-server/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "bitbucket.org/anacrolix/go.torrent/dht" + "log" + "net" +) + +type pingResponse struct { + addr string + krpc dht.Msg +} + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + s := dht.Server{} + var err error + s.Socket, err = net.ListenPacket("udp4", "") + if err != nil { + log.Fatal(err) + } + log.Printf("dht server on %s", s.Socket.LocalAddr()) + go func() { + err := s.Serve() + if err != nil { + log.Fatal(err) + } + }() + pingResponses := make(chan pingResponse) + pingStrAddrs := []string{ + "router.utorrent.com:6881", + "router.bittorrent.com:6881", + } + for _, netloc := range pingStrAddrs { + addr, err := net.ResolveUDPAddr("udp4", netloc) + if err != nil { + log.Fatal(err) + } + t, err := s.Ping(addr) + if err != nil { + log.Fatal(err) + } + go func(addr string) { + pingResponses <- pingResponse{ + addr: addr, + krpc: <-t.Response, + } + }(netloc) + } + for _ = range pingStrAddrs { + log.Print(<-pingResponses) + } +} diff --git a/dht/dht.go b/dht/dht.go new file mode 100644 index 00000000..c6b5634f --- /dev/null +++ b/dht/dht.go @@ -0,0 +1,187 @@ +package dht + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "github.com/nsf/libtorgo/bencode" + "io" + "log" + "net" + "time" +) + +type Server struct { + ID string + Socket net.PacketConn + transactions []*transaction + transactionIDInt uint64 + nodes map[string]*Node +} + +type Node struct { + addr net.Addr + id string + lastHeardFrom time.Time + lastSentTo time.Time +} + +type Msg map[string]interface{} + +var _ fmt.Stringer = Msg{} + +func (m Msg) String() string { + return fmt.Sprintf("%#v", m) +} + +type transaction struct { + remoteAddr net.Addr + t string + Response chan Msg +} + +func (s *Server) setDefaults() { + if s.ID == "" { + var id [20]byte + _, err := rand.Read(id[:]) + if err != nil { + panic(err) + } + s.ID = string(id[:]) + } +} + +func (s *Server) init() { + s.nodes = make(map[string]*Node, 1000) +} + +func (s *Server) Serve() error { + s.setDefaults() + s.init() + for { + var b [1500]byte + n, addr, err := s.Socket.ReadFrom(b[:]) + if err != nil { + return err + } + var d map[string]interface{} + err = bencode.Unmarshal(b[:n], &d) + if err != nil { + log.Printf("bad krpc message: %s", err) + continue + } + t := s.findResponseTransaction(d["t"].(string), addr) + t.Response <- d + s.removeTransaction(t) + id := "" + if d["y"] == "r" { + id = d["r"].(map[string]interface{})["id"].(string) + } + s.heardFromNode(addr, id) + } +} + +func (s *Server) heardFromNode(addr net.Addr, id string) { + n := s.getNode(addr) + n.id = id + n.lastHeardFrom = time.Now() +} + +func (s *Server) getNode(addr net.Addr) (n *Node) { + n = s.nodes[addr.String()] + if n == nil { + n = &Node{ + addr: addr, + } + s.nodes[addr.String()] = n + } + return +} + +func (s *Server) sentToNode(addr net.Addr) { + n := s.getNode(addr) + n.lastSentTo = time.Now() +} + +func (s *Server) findResponseTransaction(transactionID string, sourceNode net.Addr) *transaction { + for _, t := range s.transactions { + if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() { + return t + } + } + return nil +} + +func (s *Server) nextTransactionID() string { + var b [binary.MaxVarintLen64]byte + n := binary.PutUvarint(b[:], s.transactionIDInt) + s.transactionIDInt++ + return string(b[:n]) +} + +func (s *Server) removeTransaction(t *transaction) { + for i, tt := range s.transactions { + if t == tt { + last := len(s.transactions) - 1 + s.transactions[i] = s.transactions[last] + s.transactions = s.transactions[:last] + return + } + } + panic("transaction not found") +} + +func (s *Server) addTransaction(t *transaction) { + s.transactions = append(s.transactions, t) +} + +func (s *Server) IDString() string { + if len(s.ID) != 20 { + panic("bad node id") + } + return s.ID +} + +func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transaction, err error) { + tid := s.nextTransactionID() + if a == nil { + a = make(map[string]string, 1) + } + a["id"] = s.IDString() + d := map[string]interface{}{ + "t": tid, + "y": "q", + "q": q, + "a": a, + } + b, err := bencode.Marshal(d) + if err != nil { + return + } + t = &transaction{ + remoteAddr: node, + t: tid, + Response: make(chan Msg, 1), + } + s.addTransaction(t) + n, err := s.Socket.WriteTo(b, node) + if err != nil { + s.removeTransaction(t) + return + } + if n != len(b) { + err = io.ErrShortWrite + s.removeTransaction(t) + return + } + s.sentToNode(node) + return +} + +func (s *Server) GetPeers(node *net.UDPAddr, targetInfoHash [20]byte) { + +} + +func (s *Server) Ping(node net.Addr) (*transaction, error) { + return s.query(node, "ping", nil) +} -- 2.48.1