]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Implement the DHT security extension
authorMatt Joiner <anacrolix@gmail.com>
Wed, 20 May 2015 12:23:50 +0000 (22:23 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 20 May 2015 12:23:50 +0000 (22:23 +1000)
dht/addr.go [new file with mode: 0644]
dht/dht.go
dht/dht_test.go

diff --git a/dht/addr.go b/dht/addr.go
new file mode 100644 (file)
index 0000000..ea6589d
--- /dev/null
@@ -0,0 +1,41 @@
+package dht
+
+import (
+       "net"
+
+       "github.com/anacrolix/torrent/util"
+)
+
+// Used internally to refer to node network addresses.
+type dHTAddr interface {
+       net.Addr
+       UDPAddr() *net.UDPAddr
+       IP() net.IP
+}
+
+// Speeds up some of the commonly called Addr methods.
+type cachedAddr struct {
+       a  net.Addr
+       s  string
+       ip net.IP
+}
+
+func (ca cachedAddr) Network() string {
+       return ca.a.Network()
+}
+
+func (ca cachedAddr) String() string {
+       return ca.s
+}
+
+func (ca cachedAddr) UDPAddr() *net.UDPAddr {
+       return ca.a.(*net.UDPAddr)
+}
+
+func (ca cachedAddr) IP() net.IP {
+       return ca.ip
+}
+
+func newDHTAddr(addr net.Addr) dHTAddr {
+       return cachedAddr{addr, addr.String(), util.AddrIP(addr)}
+}
index deb45e79449afb6cbb22bee2b96cd37b350600a0..c2432699dbba0ef5298430d5e5e2f71e4612b0e0 100644 (file)
@@ -11,6 +11,7 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
+       "hash/crc32"
        "io"
        "log"
        "math/big"
@@ -51,32 +52,7 @@ type Server struct {
 
        numConfirmedAnnounces int
        bootstrapNodes        []string
-}
-
-type dHTAddr interface {
-       net.Addr
-       UDPAddr() *net.UDPAddr
-}
-
-type cachedAddr struct {
-       a net.Addr
-       s string
-}
-
-func (ca cachedAddr) Network() string {
-       return ca.a.Network()
-}
-
-func (ca cachedAddr) String() string {
-       return ca.s
-}
-
-func (ca cachedAddr) UDPAddr() *net.UDPAddr {
-       return ca.a.(*net.UDPAddr)
-}
-
-func newDHTAddr(addr net.Addr) dHTAddr {
-       return cachedAddr{addr, addr.String()}
+       config                ServerConfig
 }
 
 type ServerConfig struct {
@@ -86,6 +62,9 @@ type ServerConfig struct {
        Passive bool
        // DHT Bootstrap nodes
        BootstrapNodes []string
+       // Disable the DHT security extension:
+       // http://www.libtorrent.org/dht_sec.html.
+       NoSecurity bool
 }
 
 type ServerStats struct {
@@ -135,7 +114,9 @@ func NewServer(c *ServerConfig) (s *Server, err error) {
        if c == nil {
                c = &ServerConfig{}
        }
-       s = &Server{}
+       s = &Server{
+               config: *c,
+       }
        if c.Conn != nil {
                s.socket = c.Conn
        } else {
@@ -202,8 +183,11 @@ func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
        return
 }
 
-func (nid *nodeID) String() string {
-       return string(nid.i.Bytes())
+func (nid *nodeID) ByteString() string {
+       var buf [20]byte
+       b := nid.i.Bytes()
+       copy(buf[20-len(b):], b)
+       return string(buf[:])
 }
 
 type node struct {
@@ -216,17 +200,27 @@ type node struct {
        lastSentQuery   time.Time
 }
 
+func (n *node) IsSecure() bool {
+       if n.id.IsUnset() {
+               return false
+       }
+       return nodeIdSecure(n.id.ByteString(), n.addr.IP())
+}
+
 func (n *node) idString() string {
-       return n.id.String()
+       return n.id.ByteString()
 }
 
 func (n *node) SetIDFromBytes(b []byte) {
+       if len(b) != 20 {
+               panic(b)
+       }
        n.id.i.SetBytes(b)
        n.id.set = true
 }
 
 func (n *node) SetIDFromString(s string) {
-       n.id.i.SetBytes([]byte(s))
+       n.SetIDFromBytes([]byte(s))
 }
 
 func (n *node) IDNotSet() bool {
@@ -485,6 +479,60 @@ func (t *Transaction) handleResponse(m Msg) {
        t.tryHandleResponse()
 }
 
+func maskForIP(ip net.IP) []byte {
+       switch {
+       case ip.To4() != nil:
+               return []byte{0x03, 0x0f, 0x3f, 0xff}
+       default:
+               return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
+       }
+}
+
+// Generate the CRC used to make or validate secure node ID.
+func crcIP(ip net.IP, rand uint8) uint32 {
+       if ip4 := ip.To4(); ip4 != nil {
+               ip = ip4
+       }
+       // Copy IP so we can make changes. Go sux at this.
+       ip = append(make(net.IP, 0, len(ip)), ip...)
+       mask := maskForIP(ip)
+       for i := range mask {
+               ip[i] &= mask[i]
+       }
+       r := rand & 7
+       ip[0] |= r << 5
+       return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli))
+}
+
+// Makes a node ID valid, in-place.
+func secureNodeId(id []byte, ip net.IP) {
+       crc := crcIP(ip, id[19])
+       id[0] = byte(crc >> 24 & 0xff)
+       id[1] = byte(crc >> 16 & 0xff)
+       id[2] = byte(crc>>8&0xf8) | id[2]&7
+}
+
+// http://www.libtorrent.org/dht_sec.html
+func nodeIdSecure(id string, ip net.IP) bool {
+       if len(id) != 20 {
+               panic(fmt.Sprintf("%q", id))
+       }
+       if ip4 := ip.To4(); ip4 != nil {
+               ip = ip4
+       }
+       crc := crcIP(ip, id[19])
+       if id[0] != byte(crc>>24&0xff) {
+               return false
+       }
+       if id[1] != byte(crc>>16&0xff) {
+               return false
+       }
+       if id[2]&0xf8 != byte(crc>>8&0xf8) {
+               return false
+       }
+       return true
+}
+
 func (s *Server) setDefaults() (err error) {
        if s.id == "" {
                var id [20]byte
@@ -501,6 +549,7 @@ func (s *Server) setDefaults() (err error) {
                if len(id) != 20 {
                        panic(len(id))
                }
+               secureNodeId(id[:], util.AddrIP(s.socket.LocalAddr()))
                s.id = string(id[:])
        }
        s.nodes = make(map[string]*node, 10000)
@@ -558,7 +607,7 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) {
                //log.Printf("unexpected message: %#v", d)
                return
        }
-       node := s.getNode(addr)
+       node := s.getNode(addr, d.ID())
        node.lastGotResponse = time.Now()
        // TODO: Update node ID as this is an authoritative packet.
        go t.handleResponse(d)
@@ -597,10 +646,7 @@ func (s *Server) AddNode(ni NodeInfo) {
        if s.nodes == nil {
                s.nodes = make(map[string]*node)
        }
-       n := s.getNode(ni.Addr)
-       if n.IDNotSet() {
-               n.SetIDFromBytes(ni.ID[:])
-       }
+       s.getNode(ni.Addr, string(ni.ID[:]))
 }
 
 func (s *Server) nodeByID(id string) *node {
@@ -614,7 +660,7 @@ func (s *Server) nodeByID(id string) *node {
 
 func (s *Server) handleQuery(source dHTAddr, m Msg) {
        args := m["a"].(map[string]interface{})
-       node := s.getNode(source)
+       node := s.getNode(source, m.ID())
        node.SetIDFromString(args["id"].(string))
        node.lastGotQuery = time.Now()
        // Don't respond.
@@ -705,19 +751,33 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
        }
 }
 
-func (s *Server) getNode(addr dHTAddr) (n *node) {
+// Returns a node struct for the addr. It is taken from the table or created
+// and possibly added if required and meets validity constraints.
+func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
        addrStr := addr.String()
        n = s.nodes[addrStr]
-       if n == nil {
-               n = &node{
-                       addr: addr,
-               }
-               if len(s.nodes) < maxNodes {
-                       s.nodes[addrStr] = n
+       if n != nil {
+               if id != "" {
+                       n.SetIDFromString(id)
                }
+               return
+       }
+       n = &node{
+               addr: addr,
        }
+       if id != "" {
+               n.SetIDFromString(id)
+       }
+       if len(s.nodes) >= maxNodes {
+               return
+       }
+       if !s.config.NoSecurity && !n.IsSecure() {
+               return
+       }
+       s.nodes[addrStr] = n
        return
 }
+
 func (s *Server) nodeTimedOut(addr dHTAddr) {
        node, ok := s.nodes[addr.String()]
        if !ok {
@@ -813,7 +873,7 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onRespo
        if err != nil {
                return
        }
-       s.getNode(node).lastSentQuery = time.Now()
+       s.getNode(node, "").lastSentQuery = time.Now()
        t.startTimer()
        s.addTransaction(t)
        return
@@ -901,7 +961,7 @@ func (s *Server) liftNodes(d Msg) {
                if s.ipBlocked(util.AddrIP(cni.Addr)) {
                        continue
                }
-               n := s.getNode(cni.Addr)
+               n := s.getNode(cni.Addr, string(cni.ID[:]))
                n.SetIDFromBytes(cni.ID[:])
        }
 }
@@ -965,7 +1025,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err er
                s.liftNodes(m)
                at, ok := m.AnnounceToken()
                if ok {
-                       s.getNode(addr).announceToken = at
+                       s.getNode(addr, m.ID()).announceToken = at
                }
        })
        return
@@ -1122,7 +1182,7 @@ func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*
        ids := sel.IDs()
        ret := make([]*node, 0, len(ids))
        for _, id := range ids {
-               ret = append(ret, idNodes[id.String()])
+               ret = append(ret, idNodes[id.ByteString()])
        }
        return ret
 }
index f5e1c63c41a9448b79ae902d08bce6aaeac3e22a..8091242a446ceb0f90fcfb67089113a4220362e1 100644 (file)
@@ -1,10 +1,13 @@
 package dht
 
 import (
+       "encoding/hex"
        "math/big"
        "math/rand"
        "net"
        "testing"
+
+       "github.com/anacrolix/torrent/util"
 )
 
 func TestSetNilBigInt(t *testing.T) {
@@ -94,9 +97,9 @@ func TestClosestNodes(t *testing.T) {
        }
        m := map[string]bool{}
        for _, id := range cn.IDs() {
-               m[id.String()] = true
+               m[id.ByteString()] = true
        }
-       if !m[testIDs[3].String()] || !m[testIDs[4].String()] {
+       if !m[testIDs[3].ByteString()] || !m[testIDs[4].ByteString()] {
                t.FailNow()
        }
 }
@@ -156,3 +159,55 @@ func TestPing(t *testing.T) {
                t.FailNow()
        }
 }
+
+func TestDHTSec(t *testing.T) {
+       for _, case_ := range []struct {
+               ipStr     string
+               nodeIDHex string
+               valid     bool
+       }{
+               // These 5 are from the spec example. They are all valid.
+               {"124.31.75.21", "5fbfbff10c5d6a4ec8a88e4c6ab4c28b95eee401", true},
+               {"21.75.31.124", "5a3ce9c14e7a08645677bbd1cfe7d8f956d53256", true},
+               {"65.23.51.170", "a5d43220bc8f112a3d426c84764f8c2a1150e616", true},
+               {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01ff41", true},
+               {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51305a", true},
+               // spec[0] with one of the rand() bytes changed. Valid.
+               {"124.31.75.21", "5fbfbff10c5d7a4ec8a88e4c6ab4c28b95eee401", true},
+               // spec[1] with the 21st leading bit changed. Not Valid.
+               {"21.75.31.124", "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256", false},
+               // spec[2] with the 22nd leading bit changed. Valid.
+               {"65.23.51.170", "a5d43620bc8f112a3d426c84764f8c2a1150e616", true},
+               // spec[3] with the 4th last bit changed. Valid.
+               {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01fe01", true},
+               // spec[4] with the 3rd last bit changed. Not valid.
+               {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51303e", false},
+       } {
+               ip := net.ParseIP(case_.ipStr)
+               id, err := hex.DecodeString(case_.nodeIDHex)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               secure := nodeIdSecure(string(id), ip)
+               if secure != case_.valid {
+                       t.Fatalf("case failed: %v", case_)
+               }
+               if !secure {
+                       secureNodeId(id, ip)
+                       if !nodeIdSecure(string(id), ip) {
+                               t.Fatal("failed to secure node id")
+                       }
+               }
+       }
+}
+
+func TestServerDefaultNodeIdSecure(t *testing.T) {
+       s, err := NewServer(nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer s.Close()
+       if !nodeIdSecure(s.ID(), util.AddrIP(s.Addr())) {
+               t.Fatal("not secure")
+       }
+}