From: Matt Joiner Date: Wed, 20 May 2015 12:23:50 +0000 (+1000) Subject: dht: Implement the DHT security extension X-Git-Tag: v1.0.0~1180 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=b7061506070cbe91d7116865d3d7d1cf8c66c8dc;p=btrtrc.git dht: Implement the DHT security extension --- diff --git a/dht/addr.go b/dht/addr.go new file mode 100644 index 00000000..ea6589d4 --- /dev/null +++ b/dht/addr.go @@ -0,0 +1,41 @@ +package dht + +import ( + "net" + + "github.com/anacrolix/torrent/util" +) + +// Used internally to refer to node network addresses. +type dHTAddr interface { + net.Addr + UDPAddr() *net.UDPAddr + IP() net.IP +} + +// Speeds up some of the commonly called Addr methods. +type cachedAddr struct { + a net.Addr + s string + ip net.IP +} + +func (ca cachedAddr) Network() string { + return ca.a.Network() +} + +func (ca cachedAddr) String() string { + return ca.s +} + +func (ca cachedAddr) UDPAddr() *net.UDPAddr { + return ca.a.(*net.UDPAddr) +} + +func (ca cachedAddr) IP() net.IP { + return ca.ip +} + +func newDHTAddr(addr net.Addr) dHTAddr { + return cachedAddr{addr, addr.String(), util.AddrIP(addr)} +} diff --git a/dht/dht.go b/dht/dht.go index deb45e79..c2432699 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -11,6 +11,7 @@ import ( "encoding/binary" "errors" "fmt" + "hash/crc32" "io" "log" "math/big" @@ -51,32 +52,7 @@ type Server struct { numConfirmedAnnounces int bootstrapNodes []string -} - -type dHTAddr interface { - net.Addr - UDPAddr() *net.UDPAddr -} - -type cachedAddr struct { - a net.Addr - s string -} - -func (ca cachedAddr) Network() string { - return ca.a.Network() -} - -func (ca cachedAddr) String() string { - return ca.s -} - -func (ca cachedAddr) UDPAddr() *net.UDPAddr { - return ca.a.(*net.UDPAddr) -} - -func newDHTAddr(addr net.Addr) dHTAddr { - return cachedAddr{addr, addr.String()} + config ServerConfig } type ServerConfig struct { @@ -86,6 +62,9 @@ type ServerConfig struct { Passive bool // DHT Bootstrap nodes BootstrapNodes []string + // Disable the DHT security extension: + // http://www.libtorrent.org/dht_sec.html. + NoSecurity bool } type ServerStats struct { @@ -135,7 +114,9 @@ func NewServer(c *ServerConfig) (s *Server, err error) { if c == nil { c = &ServerConfig{} } - s = &Server{} + s = &Server{ + config: *c, + } if c.Conn != nil { s.socket = c.Conn } else { @@ -202,8 +183,11 @@ func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) { return } -func (nid *nodeID) String() string { - return string(nid.i.Bytes()) +func (nid *nodeID) ByteString() string { + var buf [20]byte + b := nid.i.Bytes() + copy(buf[20-len(b):], b) + return string(buf[:]) } type node struct { @@ -216,17 +200,27 @@ type node struct { lastSentQuery time.Time } +func (n *node) IsSecure() bool { + if n.id.IsUnset() { + return false + } + return nodeIdSecure(n.id.ByteString(), n.addr.IP()) +} + func (n *node) idString() string { - return n.id.String() + return n.id.ByteString() } func (n *node) SetIDFromBytes(b []byte) { + if len(b) != 20 { + panic(b) + } n.id.i.SetBytes(b) n.id.set = true } func (n *node) SetIDFromString(s string) { - n.id.i.SetBytes([]byte(s)) + n.SetIDFromBytes([]byte(s)) } func (n *node) IDNotSet() bool { @@ -485,6 +479,60 @@ func (t *Transaction) handleResponse(m Msg) { t.tryHandleResponse() } +func maskForIP(ip net.IP) []byte { + switch { + case ip.To4() != nil: + return []byte{0x03, 0x0f, 0x3f, 0xff} + default: + return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff} + } +} + +// Generate the CRC used to make or validate secure node ID. +func crcIP(ip net.IP, rand uint8) uint32 { + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + // Copy IP so we can make changes. Go sux at this. + ip = append(make(net.IP, 0, len(ip)), ip...) + mask := maskForIP(ip) + for i := range mask { + ip[i] &= mask[i] + } + r := rand & 7 + ip[0] |= r << 5 + return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli)) +} + +// Makes a node ID valid, in-place. +func secureNodeId(id []byte, ip net.IP) { + crc := crcIP(ip, id[19]) + id[0] = byte(crc >> 24 & 0xff) + id[1] = byte(crc >> 16 & 0xff) + id[2] = byte(crc>>8&0xf8) | id[2]&7 +} + +// http://www.libtorrent.org/dht_sec.html +func nodeIdSecure(id string, ip net.IP) bool { + if len(id) != 20 { + panic(fmt.Sprintf("%q", id)) + } + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + crc := crcIP(ip, id[19]) + if id[0] != byte(crc>>24&0xff) { + return false + } + if id[1] != byte(crc>>16&0xff) { + return false + } + if id[2]&0xf8 != byte(crc>>8&0xf8) { + return false + } + return true +} + func (s *Server) setDefaults() (err error) { if s.id == "" { var id [20]byte @@ -501,6 +549,7 @@ func (s *Server) setDefaults() (err error) { if len(id) != 20 { panic(len(id)) } + secureNodeId(id[:], util.AddrIP(s.socket.LocalAddr())) s.id = string(id[:]) } s.nodes = make(map[string]*node, 10000) @@ -558,7 +607,7 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) { //log.Printf("unexpected message: %#v", d) return } - node := s.getNode(addr) + node := s.getNode(addr, d.ID()) node.lastGotResponse = time.Now() // TODO: Update node ID as this is an authoritative packet. go t.handleResponse(d) @@ -597,10 +646,7 @@ func (s *Server) AddNode(ni NodeInfo) { if s.nodes == nil { s.nodes = make(map[string]*node) } - n := s.getNode(ni.Addr) - if n.IDNotSet() { - n.SetIDFromBytes(ni.ID[:]) - } + s.getNode(ni.Addr, string(ni.ID[:])) } func (s *Server) nodeByID(id string) *node { @@ -614,7 +660,7 @@ func (s *Server) nodeByID(id string) *node { func (s *Server) handleQuery(source dHTAddr, m Msg) { args := m["a"].(map[string]interface{}) - node := s.getNode(source) + node := s.getNode(source, m.ID()) node.SetIDFromString(args["id"].(string)) node.lastGotQuery = time.Now() // Don't respond. @@ -705,19 +751,33 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) { } } -func (s *Server) getNode(addr dHTAddr) (n *node) { +// Returns a node struct for the addr. It is taken from the table or created +// and possibly added if required and meets validity constraints. +func (s *Server) getNode(addr dHTAddr, id string) (n *node) { addrStr := addr.String() n = s.nodes[addrStr] - if n == nil { - n = &node{ - addr: addr, - } - if len(s.nodes) < maxNodes { - s.nodes[addrStr] = n + if n != nil { + if id != "" { + n.SetIDFromString(id) } + return + } + n = &node{ + addr: addr, } + if id != "" { + n.SetIDFromString(id) + } + if len(s.nodes) >= maxNodes { + return + } + if !s.config.NoSecurity && !n.IsSecure() { + return + } + s.nodes[addrStr] = n return } + func (s *Server) nodeTimedOut(addr dHTAddr) { node, ok := s.nodes[addr.String()] if !ok { @@ -813,7 +873,7 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onRespo if err != nil { return } - s.getNode(node).lastSentQuery = time.Now() + s.getNode(node, "").lastSentQuery = time.Now() t.startTimer() s.addTransaction(t) return @@ -901,7 +961,7 @@ func (s *Server) liftNodes(d Msg) { if s.ipBlocked(util.AddrIP(cni.Addr)) { continue } - n := s.getNode(cni.Addr) + n := s.getNode(cni.Addr, string(cni.ID[:])) n.SetIDFromBytes(cni.ID[:]) } } @@ -965,7 +1025,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err er s.liftNodes(m) at, ok := m.AnnounceToken() if ok { - s.getNode(addr).announceToken = at + s.getNode(addr, m.ID()).announceToken = at } }) return @@ -1122,7 +1182,7 @@ func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []* ids := sel.IDs() ret := make([]*node, 0, len(ids)) for _, id := range ids { - ret = append(ret, idNodes[id.String()]) + ret = append(ret, idNodes[id.ByteString()]) } return ret } diff --git a/dht/dht_test.go b/dht/dht_test.go index f5e1c63c..8091242a 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,10 +1,13 @@ package dht import ( + "encoding/hex" "math/big" "math/rand" "net" "testing" + + "github.com/anacrolix/torrent/util" ) func TestSetNilBigInt(t *testing.T) { @@ -94,9 +97,9 @@ func TestClosestNodes(t *testing.T) { } m := map[string]bool{} for _, id := range cn.IDs() { - m[id.String()] = true + m[id.ByteString()] = true } - if !m[testIDs[3].String()] || !m[testIDs[4].String()] { + if !m[testIDs[3].ByteString()] || !m[testIDs[4].ByteString()] { t.FailNow() } } @@ -156,3 +159,55 @@ func TestPing(t *testing.T) { t.FailNow() } } + +func TestDHTSec(t *testing.T) { + for _, case_ := range []struct { + ipStr string + nodeIDHex string + valid bool + }{ + // These 5 are from the spec example. They are all valid. + {"124.31.75.21", "5fbfbff10c5d6a4ec8a88e4c6ab4c28b95eee401", true}, + {"21.75.31.124", "5a3ce9c14e7a08645677bbd1cfe7d8f956d53256", true}, + {"65.23.51.170", "a5d43220bc8f112a3d426c84764f8c2a1150e616", true}, + {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01ff41", true}, + {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51305a", true}, + // spec[0] with one of the rand() bytes changed. Valid. + {"124.31.75.21", "5fbfbff10c5d7a4ec8a88e4c6ab4c28b95eee401", true}, + // spec[1] with the 21st leading bit changed. Not Valid. + {"21.75.31.124", "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256", false}, + // spec[2] with the 22nd leading bit changed. Valid. + {"65.23.51.170", "a5d43620bc8f112a3d426c84764f8c2a1150e616", true}, + // spec[3] with the 4th last bit changed. Valid. + {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01fe01", true}, + // spec[4] with the 3rd last bit changed. Not valid. + {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51303e", false}, + } { + ip := net.ParseIP(case_.ipStr) + id, err := hex.DecodeString(case_.nodeIDHex) + if err != nil { + t.Fatal(err) + } + secure := nodeIdSecure(string(id), ip) + if secure != case_.valid { + t.Fatalf("case failed: %v", case_) + } + if !secure { + secureNodeId(id, ip) + if !nodeIdSecure(string(id), ip) { + t.Fatal("failed to secure node id") + } + } + } +} + +func TestServerDefaultNodeIdSecure(t *testing.T) { + s, err := NewServer(nil) + if err != nil { + t.Fatal(err) + } + defer s.Close() + if !nodeIdSecure(s.ID(), util.AddrIP(s.Addr())) { + t.Fatal("not secure") + } +}