From 3d043ffc3984c9cfc62abb3e864ccc69052414b1 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sat, 26 Nov 2016 16:14:37 +1100 Subject: [PATCH] dht.Server: Return valid token from get_peers, and handle incoming announce_peer Addresses #133. --- dht/dht.go | 7 +++++- dht/expvar.go | 3 +++ dht/krpc/msg.go | 17 +++++++++------ dht/server.go | 38 ++++++++++++++++++++++++++++---- dht/tokens.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++ dht/tokens_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 dht/tokens.go create mode 100644 dht/tokens_test.go diff --git a/dht/dht.go b/dht/dht.go index 3b2eb858..edb19395 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -11,6 +11,7 @@ import ( "github.com/anacrolix/torrent/dht/krpc" "github.com/anacrolix/torrent/iplist" + "github.com/anacrolix/torrent/metainfo" ) const ( @@ -65,7 +66,11 @@ type ServerConfig struct { // Used to secure the server's ID. Defaults to the Conn's LocalAddr(). PublicIP net.IP - OnQuery func(*krpc.Msg, net.Addr) bool + // Hook received queries. Return true if you don't want to propagate to + // the default handlers. + OnQuery func(query *krpc.Msg, source net.Addr) (propagate bool) + // Called when a peer successfully announces to us. + OnAnnouncePeer func(infoHash metainfo.Hash, peer Peer) } // ServerStats instance is returned by Server.Stats() and stores Server metrics diff --git a/dht/expvar.go b/dht/expvar.go index 11a0a987..cf12ca68 100644 --- a/dht/expvar.go +++ b/dht/expvar.go @@ -10,7 +10,10 @@ var ( readNotKRPCDict = expvar.NewInt("dhtReadNotKRPCDict") readUnmarshalError = expvar.NewInt("dhtReadUnmarshalError") readQuery = expvar.NewInt("dhtReadQuery") + readQueryBad = expvar.NewInt("dhtQueryBad") + readAnnouncePeer = expvar.NewInt("dhtReadAnnouncePeer") announceErrors = expvar.NewInt("dhtAnnounceErrors") writeErrors = expvar.NewInt("dhtWriteErrors") writes = expvar.NewInt("dhtWrites") + readInvalidToken = expvar.NewInt("dhtReadInvalidToken") ) diff --git a/dht/krpc/msg.go b/dht/krpc/msg.go index b765b9f4..ff9d990b 100644 --- a/dht/krpc/msg.go +++ b/dht/krpc/msg.go @@ -29,16 +29,19 @@ type Msg struct { } type MsgArgs struct { - ID string `bencode:"id"` // ID of the quirying Node - InfoHash string `bencode:"info_hash"` // InfoHash of the torrent - Target string `bencode:"target"` // ID of the node sought + ID string `bencode:"id"` // ID of the quirying Node + InfoHash string `bencode:"info_hash"` // InfoHash of the torrent + Target string `bencode:"target"` // ID of the node sought + Token string `bencode:"token"` // Token received from an earlier get_peers query + Port int `bencode:"port"` // Senders torrent port + ImpliedPort int `bencode:"implied_port"` // Use senders apparent DHT port } type Return struct { - ID string `bencode:"id"` // ID of the querying node - Nodes CompactIPv4NodeInfo `bencode:"nodes,omitempty"` - Token string `bencode:"token,omitempty"` - Values []util.CompactPeer `bencode:"values,omitempty"` + ID string `bencode:"id"` // ID of the querying node + Nodes CompactIPv4NodeInfo `bencode:"nodes,omitempty"` // K closest nodes to the requested target + Token string `bencode:"token,omitempty"` // Token for future announce_peer + Values []util.CompactPeer `bencode:"values,omitempty"` // Torrent peers } var _ fmt.Stringer = Msg{} diff --git a/dht/server.go b/dht/server.go index bf269398..14fa5dcd 100644 --- a/dht/server.go +++ b/dht/server.go @@ -20,6 +20,7 @@ import ( "github.com/anacrolix/torrent/dht/krpc" "github.com/anacrolix/torrent/iplist" "github.com/anacrolix/torrent/logonce" + "github.com/anacrolix/torrent/metainfo" ) // A Server defines parameters for a DHT node server that is able to @@ -40,6 +41,7 @@ type Server struct { closed missinggo.Event ipBlockList iplist.Ranger badNodes *boom.BloomFilter + tokenServer tokenServer numConfirmedAnnounces int bootstrapNodes []string @@ -251,6 +253,8 @@ func (s *Server) nodeByID(id string) *node { return nil } +// TODO: Probably should write error messages back to senders if something is +// wrong. func (s *Server) handleQuery(source Addr, m krpc.Msg) { node := s.getNode(source, m.SenderID()) node.lastGotQuery = time.Now() @@ -280,8 +284,7 @@ func (s *Server) handleQuery(source Addr, m krpc.Msg) { } s.reply(source, m.T, krpc.Return{ Nodes: rNodes, - // TODO: Generate this dynamically, and store it for the source. - Token: "hi", + Token: s.createToken(source), }) case "find_node": // TODO: Extract common behaviour with get_peers. targetID := args.Target @@ -302,8 +305,27 @@ func (s *Server) handleQuery(source Addr, m krpc.Msg) { Nodes: rNodes, }) case "announce_peer": - // TODO(anacrolix): Implement this lolz. - // log.Print(m) + readAnnouncePeer.Add(1) + if !s.validToken(args.Token, source) { + readInvalidToken.Add(1) + return + } + if len(args.InfoHash) != 20 { + readQueryBad.Add(1) + return + } + if h := s.config.OnAnnouncePeer; h != nil { + var ih metainfo.Hash + copy(ih[:], args.InfoHash) + p := Peer{ + IP: source.UDPAddr().IP, + Port: args.Port, + } + if args.ImpliedPort != 0 { + p.Port = source.UDPAddr().Port + } + h(ih, p) + } case "vote": // TODO(anacrolix): Or reject, I don't think I want this. default: @@ -428,6 +450,14 @@ func (s *Server) ID() string { return s.id } +func (s *Server) createToken(addr Addr) string { + return s.tokenServer.CreateToken(addr) +} + +func (s *Server) validToken(token string, addr Addr) bool { + return s.tokenServer.ValidToken(token, addr) +} + func (s *Server) query(node Addr, q string, a map[string]interface{}, onResponse func(krpc.Msg)) (t *Transaction, err error) { tid := s.nextTransactionID() if a == nil { diff --git a/dht/tokens.go b/dht/tokens.go new file mode 100644 index 00000000..9ecb6e3e --- /dev/null +++ b/dht/tokens.go @@ -0,0 +1,54 @@ +package dht + +import ( + "crypto/sha1" + "encoding/binary" + "time" + + "github.com/bradfitz/iter" +) + +// Manages creation and validation of tokens issued to querying nodes. +type tokenServer struct { + secret []byte + interval time.Duration + maxIntervalDelta int + timeNow func() time.Time +} + +func (me tokenServer) CreateToken(addr Addr) string { + return me.createToken(addr, me.getTimeNow()) +} + +func (me tokenServer) createToken(addr Addr, t time.Time) string { + h := sha1.New() + ip := addr.UDPAddr().IP.To16() + if len(ip) != 16 { + panic(ip) + } + h.Write(ip) + ti := t.UnixNano() / int64(me.interval) + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(ti)) + h.Write(b[:]) + h.Write(me.secret) + return string(h.Sum(nil)) +} + +func (me *tokenServer) ValidToken(token string, addr Addr) bool { + t := me.getTimeNow() + for range iter.N(me.maxIntervalDelta + 1) { + if me.createToken(addr, t) == token { + return true + } + t = t.Add(-me.interval) + } + return false +} + +func (me *tokenServer) getTimeNow() time.Time { + if me.timeNow == nil { + return time.Now() + } + return me.timeNow() +} diff --git a/dht/tokens_test.go b/dht/tokens_test.go new file mode 100644 index 00000000..ae19e663 --- /dev/null +++ b/dht/tokens_test.go @@ -0,0 +1,54 @@ +package dht + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTokenServer(t *testing.T) { + addr1 := NewAddr(&net.UDPAddr{ + IP: []byte{1, 2, 3, 4}, + }) + addr2 := NewAddr(&net.UDPAddr{ + IP: []byte{1, 2, 3, 3}, + }) + ts := tokenServer{ + secret: []byte("42"), + interval: 5 * time.Minute, + maxIntervalDelta: 2, + } + tok := ts.CreateToken(addr1) + assert.Len(t, tok, 20) + assert.True(t, ts.ValidToken(tok, addr1)) + assert.False(t, ts.ValidToken(tok[1:], addr1)) + assert.False(t, ts.ValidToken(tok, addr2)) + func() { + ts0 := ts + ts0.secret = nil + assert.False(t, ts0.ValidToken(tok, addr1)) + }() + now := time.Now() + setTime := func(t time.Time) { + ts.timeNow = func() time.Time { + return t + } + } + setTime(now) + tok = ts.CreateToken(addr1) + assert.True(t, ts.ValidToken(tok, addr1)) + setTime(time.Time{}) + assert.False(t, ts.ValidToken(tok, addr1)) + setTime(now.Add(-5 * time.Minute)) + assert.False(t, ts.ValidToken(tok, addr1)) + setTime(now) + assert.True(t, ts.ValidToken(tok, addr1)) + setTime(now.Add(5 * time.Minute)) + assert.True(t, ts.ValidToken(tok, addr1)) + setTime(now.Add(2 * 5 * time.Minute)) + assert.True(t, ts.ValidToken(tok, addr1)) + setTime(now.Add(3 * 5 * time.Minute)) + assert.False(t, ts.ValidToken(tok, addr1)) +} -- 2.50.0