17 "github.com/anacrolix/torrent/iplist"
18 "github.com/anacrolix/torrent/logonce"
19 "github.com/anacrolix/torrent/util"
20 "github.com/anacrolix/sync"
21 "github.com/anacrolix/libtorgo/bencode"
26 queryResendEvery = 5 * time.Second
29 // Uniquely identifies a transaction to us.
30 type transactionKey struct {
31 RemoteAddr string // host:port
32 T string // The KRPC transaction ID.
38 transactions map[transactionKey]*transaction
39 transactionIDInt uint64
40 nodes map[string]*Node // Keyed by dHTAddr.String().
43 passive bool // Don't respond to queries.
44 ipBlockList *iplist.IPList
46 NumConfirmedAnnounces int
49 type dHTAddr interface {
51 UDPAddr() *net.UDPAddr
54 type cachedAddr struct {
59 func (ca cachedAddr) Network() string {
63 func (ca cachedAddr) String() string {
67 func (ca cachedAddr) UDPAddr() *net.UDPAddr {
68 return ca.a.(*net.UDPAddr)
71 func newDHTAddr(addr net.Addr) dHTAddr {
72 return cachedAddr{addr, addr.String()}
75 type ServerConfig struct {
78 Passive bool // Don't respond to queries.
81 type serverStats struct {
84 NumOutstandingTransactions int
87 func (s *Server) Stats() (ss serverStats) {
90 for _, n := range s.nodes {
91 if n.DefinitelyGood() {
95 ss.NumNodes = len(s.nodes)
96 ss.NumOutstandingTransactions = len(s.transactions)
100 func (s *Server) LocalAddr() net.Addr {
101 return s.socket.LocalAddr()
104 func makeSocket(addr string) (socket *net.UDPConn, err error) {
105 addr_, err := net.ResolveUDPAddr("", addr)
109 socket, err = net.ListenUDP("udp", addr_)
113 func NewServer(c *ServerConfig) (s *Server, err error) {
121 s.socket, err = makeSocket(c.Addr)
126 s.passive = c.Passive
145 log.Printf("error bootstrapping DHT: %s", err)
151 func (s *Server) String() string {
152 return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
160 func (nid *nodeID) IsUnset() bool {
164 func nodeIDFromString(s string) (ret nodeID) {
168 ret.i.SetBytes([]byte(s))
173 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
174 if nid0.IsUnset() != nid1.IsUnset() {
178 ret.Xor(&nid0.i, &nid1.i)
182 func (nid *nodeID) String() string {
183 return string(nid.i.Bytes())
191 lastGotQuery time.Time
192 lastGotResponse time.Time
193 lastSentQuery time.Time
196 func (n *Node) idString() string {
200 func (n *Node) SetIDFromBytes(b []byte) {
205 func (n *Node) SetIDFromString(s string) {
206 n.id.i.SetBytes([]byte(s))
209 func (n *Node) IDNotSet() bool {
210 return n.id.i.Int64() == 0
213 func (n *Node) NodeInfo() (ret NodeInfo) {
215 if n := copy(ret.ID[:], n.idString()); n != 20 {
221 func (n *Node) DefinitelyGood() bool {
222 if len(n.idString()) != 20 {
225 // No reason to think ill of them if they've never been queried.
226 if n.lastSentQuery.IsZero() {
229 // They answered our last query.
230 if n.lastSentQuery.Before(n.lastGotResponse) {
236 type Msg map[string]interface{}
238 var _ fmt.Stringer = Msg{}
240 func (m Msg) String() string {
241 return fmt.Sprintf("%#v", m)
244 func (m Msg) T() (t string) {
253 func (m Msg) ID() string {
257 return m[m["y"].(string)].(map[string]interface{})["id"].(string)
260 // Suggested nodes in a response.
261 func (m Msg) Nodes() (nodes []NodeInfo) {
266 return m["r"].(map[string]interface{})["nodes"].(string)
271 for i := 0; i < len(b); i += 26 {
273 err := n.UnmarshalCompact([]byte(b[i : i+26]))
277 nodes = append(nodes, n)
282 type KRPCError struct {
287 func (me KRPCError) Error() string {
288 return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
291 var _ error = KRPCError{}
293 func (m Msg) Error() (ret *KRPCError) {
298 switch e := m["e"].(type) {
300 ret.Code = int(e[0].(int64))
301 ret.Msg = e[1].(string)
305 logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
310 // Returns the token given in response to a get_peers request for future
311 // announce_peer requests to that node.
312 func (m Msg) AnnounceToken() (token string, ok bool) {
313 defer func() { recover() }()
314 token, ok = m["r"].(map[string]interface{})["token"].(string)
318 type transaction struct {
323 onResponse func(Msg) // Called with the server locked.
330 userOnResponse func(Msg)
333 func (t *transaction) SetResponseHandler(f func(Msg)) {
337 t.tryHandleResponse()
340 func (t *transaction) tryHandleResponse() {
341 if t.userOnResponse == nil {
345 case r := <-t.response:
347 // Shouldn't be called more than once.
348 t.userOnResponse = nil
353 func (t *transaction) Key() transactionKey {
354 return transactionKey{
355 t.remoteAddr.String(),
360 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
361 return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
364 func (t *transaction) startTimer() {
365 t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
368 func (t *transaction) timerCallback() {
382 if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
383 panic("timer should have fired to get here")
387 func (t *transaction) sendQuery() error {
388 err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
392 t.lastSend = time.Now()
396 func (t *transaction) timeout() {
399 defer t.s.mu.Unlock()
400 t.s.nodeTimedOut(t.remoteAddr)
405 func (t *transaction) close() {
411 t.tryHandleResponse()
416 defer t.s.mu.Unlock()
417 t.s.deleteTransaction(t)
421 func (t *transaction) closing() bool {
430 func (t *transaction) Close() {
436 func (t *transaction) handleResponse(m Msg) {
444 if t.onResponse != nil {
451 case t.response <- m:
453 panic("blocked handling response")
456 t.tryHandleResponse()
459 func (s *Server) setDefaults() (err error) {
462 h := crypto.SHA1.New()
463 ss, err := os.Hostname()
467 ss += s.socket.LocalAddr().String()
469 if b := h.Sum(id[:0:20]); len(b) != 20 {
477 s.nodes = make(map[string]*Node, 10000)
481 func (s *Server) SetIPBlockList(list *iplist.IPList) {
487 func (s *Server) init() (err error) {
488 err = s.setDefaults()
492 s.closed = make(chan struct{})
493 s.transactions = make(map[transactionKey]*transaction)
497 func (s *Server) processPacket(b []byte, addr dHTAddr) {
499 err := bencode.Unmarshal(b, &d)
502 if se, ok := err.(*bencode.SyntaxError); ok {
503 // The message was truncated.
504 if int(se.Offset) == len(b) {
507 // Some messages seem to drop to nul chars abrubtly.
508 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
511 // The message isn't bencode from the first.
516 log.Printf("%s: received bad krpc message from %s: %s: %q", s, addr, err, b)
523 s.handleQuery(addr, d)
526 t := s.findResponseTransaction(d.T(), addr)
528 //log.Printf("unexpected message: %#v", d)
531 node := s.getNode(addr)
532 node.lastGotResponse = time.Now()
533 // TODO: Update node ID as this is an authoritative packet.
534 go t.handleResponse(d)
535 s.deleteTransaction(t)
538 func (s *Server) serve() error {
541 n, addr, err := s.socket.ReadFrom(b[:])
546 logonce.Stderr.Printf("received dht packet exceeds buffer size")
549 s.processPacket(b[:n], newDHTAddr(addr))
553 func (s *Server) ipBlocked(ip net.IP) bool {
554 if s.ipBlockList == nil {
557 return s.ipBlockList.Lookup(ip) != nil
560 func (s *Server) AddNode(ni NodeInfo) {
564 s.nodes = make(map[string]*Node)
566 n := s.getNode(ni.Addr)
568 n.SetIDFromBytes(ni.ID[:])
572 func (s *Server) nodeByID(id string) *Node {
573 for _, node := range s.nodes {
574 if node.idString() == id {
581 func (s *Server) handleQuery(source dHTAddr, m Msg) {
582 args := m["a"].(map[string]interface{})
583 node := s.getNode(source)
584 node.SetIDFromString(args["id"].(string))
585 node.lastGotQuery = time.Now()
592 s.reply(source, m["t"].(string), nil)
593 case "get_peers": // TODO: Extract common behaviour with find_node.
594 targetID := args["info_hash"].(string)
595 if len(targetID) != 20 {
598 var rNodes []NodeInfo
599 // TODO: Reply with "values" list if we have peers instead.
600 for _, node := range s.closestGoodNodes(8, targetID) {
601 rNodes = append(rNodes, node.NodeInfo())
603 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
604 for i, ni := range rNodes {
605 err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
610 s.reply(source, m["t"].(string), map[string]interface{}{
611 "nodes": string(nodesBytes),
614 case "find_node": // TODO: Extract common behaviour with get_peers.
615 targetID := args["target"].(string)
616 if len(targetID) != 20 {
617 log.Printf("bad DHT query: %v", m)
620 var rNodes []NodeInfo
621 if node := s.nodeByID(targetID); node != nil {
622 rNodes = append(rNodes, node.NodeInfo())
624 for _, node := range s.closestGoodNodes(8, targetID) {
625 rNodes = append(rNodes, node.NodeInfo())
628 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
629 for i, ni := range rNodes {
630 // TODO: Put IPv6 nodes into the correct dict element.
631 if ni.Addr.UDPAddr().IP.To4() == nil {
634 err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
636 log.Printf("error compacting %#v: %s", ni, err)
640 s.reply(source, m["t"].(string), map[string]interface{}{
641 "nodes": string(nodesBytes),
643 case "announce_peer":
644 // TODO(anacrolix): Implement this lolz.
647 // TODO(anacrolix): Or reject, I don't think I want this.
649 log.Printf("%s: not handling received query: q=%s", s, m["q"])
654 func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
656 r = make(map[string]interface{}, 1)
658 r["id"] = s.IDString()
659 m := map[string]interface{}{
664 b, err := bencode.Marshal(m)
668 err = s.writeToNode(b, addr)
670 log.Printf("error replying to %s: %s", addr, err)
674 func (s *Server) getNode(addr dHTAddr) (n *Node) {
675 addrStr := addr.String()
681 if len(s.nodes) < maxNodes {
687 func (s *Server) nodeTimedOut(addr dHTAddr) {
688 node, ok := s.nodes[addr.String()]
692 if node.DefinitelyGood() {
695 if len(s.nodes) < maxNodes {
698 delete(s.nodes, addr.String())
701 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
702 if list := s.ipBlockList; list != nil {
703 if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
704 err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
708 n, err := s.socket.WriteTo(b, node.UDPAddr())
710 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
714 err = io.ErrShortWrite
720 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction {
721 return s.transactions[transactionKey{
726 func (s *Server) nextTransactionID() string {
727 var b [binary.MaxVarintLen64]byte
728 n := binary.PutUvarint(b[:], s.transactionIDInt)
733 func (s *Server) deleteTransaction(t *transaction) {
734 delete(s.transactions, t.Key())
737 func (s *Server) addTransaction(t *transaction) {
738 if _, ok := s.transactions[t.Key()]; ok {
739 panic("transaction not unique")
741 s.transactions[t.Key()] = t
744 func (s *Server) IDString() string {
751 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *transaction, err error) {
752 tid := s.nextTransactionID()
754 a = make(map[string]interface{}, 1)
756 a["id"] = s.IDString()
757 d := map[string]interface{}{
763 b, err := bencode.Marshal(d)
770 response: make(chan Msg, 1),
771 done: make(chan struct{}),
774 onResponse: onResponse,
780 s.getNode(node).lastSentQuery = time.Now()
786 const CompactNodeInfoLen = 26
788 type NodeInfo struct {
793 func (ni *NodeInfo) PutCompact(b []byte) error {
794 if n := copy(b[:], ni.ID[:]); n != 20 {
797 ip := util.AddrIP(ni.Addr).To4()
799 return errors.New("expected ipv4 address")
801 if n := copy(b[20:], ip); n != 4 {
804 binary.BigEndian.PutUint16(b[24:], uint16(util.AddrPort(ni.Addr)))
808 func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
810 return errors.New("expected 26 bytes")
812 util.CopyExact(cni.ID[:], b[:20])
813 cni.Addr = newDHTAddr(&net.UDPAddr{
814 IP: net.IPv4(b[20], b[21], b[22], b[23]),
815 Port: int(binary.BigEndian.Uint16(b[24:26])),
820 func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
823 return s.query(newDHTAddr(node), "ping", nil, nil)
826 // Announce a local peer. This can only be done to nodes that gave us an
827 // announce token, which is received in responses during GetPeers. It's
828 // recommended then that GetPeers is called before this method.
829 func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
832 for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
833 return n.announceToken != ""
835 err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
843 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
844 if port == 0 && !impliedPort {
845 return errors.New("nothing to announce")
847 _, err = s.query(node, "announce_peer", map[string]interface{}{
848 "implied_port": func() int {
855 "info_hash": infoHash,
859 if err := m.Error(); err != nil {
860 logonce.Stderr.Printf("announce_peer response: %s", err)
863 s.NumConfirmedAnnounces++
868 // Add response nodes to node table.
869 func (s *Server) liftNodes(d Msg) {
873 for _, cni := range d.Nodes() {
874 if util.AddrPort(cni.Addr) == 0 {
875 // TODO: Why would people even do this?
878 if s.ipBlocked(util.AddrIP(cni.Addr)) {
881 n := s.getNode(cni.Addr)
882 n.SetIDFromBytes(cni.ID[:])
886 // Sends a find_node query to addr. targetID is the node we're looking for.
887 func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err error) {
888 t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
889 // Scrape peers from the response to put in the server's table before
890 // handing the response back to the caller.
899 // In a get_peers response, the addresses of torrent clients involved with the
900 // queried info-hash.
901 func (m Msg) Values() (vs []util.CompactPeer) {
906 rd, ok := r.(map[string]interface{})
910 v, ok := rd["values"]
914 vl, ok := v.([]interface{})
916 log.Printf("unexpected krpc values type: %T", v)
919 vs = make([]util.CompactPeer, 0, len(vl))
920 for _, i := range vl {
925 var cp util.CompactPeer
926 err := cp.UnmarshalBinary([]byte(s))
928 log.Printf("error decoding values list element: %s", err)
936 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err error) {
937 if len(infoHash) != 20 {
938 err = fmt.Errorf("infohash has bad length")
941 t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
943 at, ok := m.AnnounceToken()
945 s.getNode(addr).announceToken = at
951 func bootstrapAddrs() (addrs []*net.UDPAddr, err error) {
952 for _, addrStr := range []string{
953 "router.utorrent.com:6881",
954 "router.bittorrent.com:6881",
956 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
960 addrs = append(addrs, udpAddr)
963 err = errors.New("nothing resolved")
968 func (s *Server) addRootNodes() error {
969 addrs, err := bootstrapAddrs()
973 for _, addr := range addrs {
974 s.nodes[addr.String()] = &Node{
975 addr: newDHTAddr(addr),
981 // Populates the node table.
982 func (s *Server) bootstrap() (err error) {
985 if len(s.nodes) == 0 {
986 err = s.addRootNodes()
992 var outstanding sync.WaitGroup
993 for _, node := range s.nodes {
995 t, err = s.findNode(node.addr, s.id)
997 err = fmt.Errorf("error sending find_node: %s", err)
1001 t.SetResponseHandler(func(Msg) {
1005 noOutstanding := make(chan struct{})
1008 close(noOutstanding)
1015 case <-time.After(15 * time.Second):
1016 case <-noOutstanding:
1019 // log.Printf("now have %d nodes", len(s.nodes))
1020 if s.numGoodNodes() >= 160 {
1027 func (s *Server) numGoodNodes() (num int) {
1028 for _, n := range s.nodes {
1029 if n.DefinitelyGood() {
1036 func (s *Server) NumNodes() int {
1042 func (s *Server) Nodes() (nis []NodeInfo) {
1045 for _, node := range s.nodes {
1046 // if !node.Good() {
1052 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1055 nis = append(nis, ni)
1060 func (s *Server) Close() {
1071 var maxDistance big.Int
1075 maxDistance.SetBit(&zero, 160, 1)
1078 func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
1079 return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
1082 func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
1083 sel := newKClosestNodesSelector(k, target)
1084 idNodes := make(map[string]*Node, len(s.nodes))
1085 for _, node := range s.nodes {
1090 idNodes[node.idString()] = node
1093 ret := make([]*Node, 0, len(ids))
1094 for _, id := range ids {
1095 ret = append(ret, idNodes[id.String()])