"encoding/binary"
"errors"
"fmt"
+ "hash/crc32"
"io"
"log"
"math/big"
numConfirmedAnnounces int
bootstrapNodes []string
-}
-
-type dHTAddr interface {
- net.Addr
- UDPAddr() *net.UDPAddr
-}
-
-type cachedAddr struct {
- a net.Addr
- s string
-}
-
-func (ca cachedAddr) Network() string {
- return ca.a.Network()
-}
-
-func (ca cachedAddr) String() string {
- return ca.s
-}
-
-func (ca cachedAddr) UDPAddr() *net.UDPAddr {
- return ca.a.(*net.UDPAddr)
-}
-
-func newDHTAddr(addr net.Addr) dHTAddr {
- return cachedAddr{addr, addr.String()}
+ config ServerConfig
}
type ServerConfig struct {
Passive bool
// DHT Bootstrap nodes
BootstrapNodes []string
+ // Disable the DHT security extension:
+ // http://www.libtorrent.org/dht_sec.html.
+ NoSecurity bool
}
type ServerStats struct {
if c == nil {
c = &ServerConfig{}
}
- s = &Server{}
+ s = &Server{
+ config: *c,
+ }
if c.Conn != nil {
s.socket = c.Conn
} else {
return
}
-func (nid *nodeID) String() string {
- return string(nid.i.Bytes())
+func (nid *nodeID) ByteString() string {
+ var buf [20]byte
+ b := nid.i.Bytes()
+ copy(buf[20-len(b):], b)
+ return string(buf[:])
}
type node struct {
lastSentQuery time.Time
}
+func (n *node) IsSecure() bool {
+ if n.id.IsUnset() {
+ return false
+ }
+ return nodeIdSecure(n.id.ByteString(), n.addr.IP())
+}
+
func (n *node) idString() string {
- return n.id.String()
+ return n.id.ByteString()
}
func (n *node) SetIDFromBytes(b []byte) {
+ if len(b) != 20 {
+ panic(b)
+ }
n.id.i.SetBytes(b)
n.id.set = true
}
func (n *node) SetIDFromString(s string) {
- n.id.i.SetBytes([]byte(s))
+ n.SetIDFromBytes([]byte(s))
}
func (n *node) IDNotSet() bool {
t.tryHandleResponse()
}
+func maskForIP(ip net.IP) []byte {
+ switch {
+ case ip.To4() != nil:
+ return []byte{0x03, 0x0f, 0x3f, 0xff}
+ default:
+ return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
+ }
+}
+
+// Generate the CRC used to make or validate secure node ID.
+func crcIP(ip net.IP, rand uint8) uint32 {
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ // Copy IP so we can make changes. Go sux at this.
+ ip = append(make(net.IP, 0, len(ip)), ip...)
+ mask := maskForIP(ip)
+ for i := range mask {
+ ip[i] &= mask[i]
+ }
+ r := rand & 7
+ ip[0] |= r << 5
+ return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli))
+}
+
+// Makes a node ID valid, in-place.
+func secureNodeId(id []byte, ip net.IP) {
+ crc := crcIP(ip, id[19])
+ id[0] = byte(crc >> 24 & 0xff)
+ id[1] = byte(crc >> 16 & 0xff)
+ id[2] = byte(crc>>8&0xf8) | id[2]&7
+}
+
+// http://www.libtorrent.org/dht_sec.html
+func nodeIdSecure(id string, ip net.IP) bool {
+ if len(id) != 20 {
+ panic(fmt.Sprintf("%q", id))
+ }
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ crc := crcIP(ip, id[19])
+ if id[0] != byte(crc>>24&0xff) {
+ return false
+ }
+ if id[1] != byte(crc>>16&0xff) {
+ return false
+ }
+ if id[2]&0xf8 != byte(crc>>8&0xf8) {
+ return false
+ }
+ return true
+}
+
func (s *Server) setDefaults() (err error) {
if s.id == "" {
var id [20]byte
if len(id) != 20 {
panic(len(id))
}
+ secureNodeId(id[:], util.AddrIP(s.socket.LocalAddr()))
s.id = string(id[:])
}
s.nodes = make(map[string]*node, 10000)
//log.Printf("unexpected message: %#v", d)
return
}
- node := s.getNode(addr)
+ node := s.getNode(addr, d.ID())
node.lastGotResponse = time.Now()
// TODO: Update node ID as this is an authoritative packet.
go t.handleResponse(d)
if s.nodes == nil {
s.nodes = make(map[string]*node)
}
- n := s.getNode(ni.Addr)
- if n.IDNotSet() {
- n.SetIDFromBytes(ni.ID[:])
- }
+ s.getNode(ni.Addr, string(ni.ID[:]))
}
func (s *Server) nodeByID(id string) *node {
func (s *Server) handleQuery(source dHTAddr, m Msg) {
args := m["a"].(map[string]interface{})
- node := s.getNode(source)
+ node := s.getNode(source, m.ID())
node.SetIDFromString(args["id"].(string))
node.lastGotQuery = time.Now()
// Don't respond.
}
}
-func (s *Server) getNode(addr dHTAddr) (n *node) {
+// Returns a node struct for the addr. It is taken from the table or created
+// and possibly added if required and meets validity constraints.
+func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
addrStr := addr.String()
n = s.nodes[addrStr]
- if n == nil {
- n = &node{
- addr: addr,
- }
- if len(s.nodes) < maxNodes {
- s.nodes[addrStr] = n
+ if n != nil {
+ if id != "" {
+ n.SetIDFromString(id)
}
+ return
+ }
+ n = &node{
+ addr: addr,
}
+ if id != "" {
+ n.SetIDFromString(id)
+ }
+ if len(s.nodes) >= maxNodes {
+ return
+ }
+ if !s.config.NoSecurity && !n.IsSecure() {
+ return
+ }
+ s.nodes[addrStr] = n
return
}
+
func (s *Server) nodeTimedOut(addr dHTAddr) {
node, ok := s.nodes[addr.String()]
if !ok {
if err != nil {
return
}
- s.getNode(node).lastSentQuery = time.Now()
+ s.getNode(node, "").lastSentQuery = time.Now()
t.startTimer()
s.addTransaction(t)
return
if s.ipBlocked(util.AddrIP(cni.Addr)) {
continue
}
- n := s.getNode(cni.Addr)
+ n := s.getNode(cni.Addr, string(cni.ID[:]))
n.SetIDFromBytes(cni.ID[:])
}
}
s.liftNodes(m)
at, ok := m.AnnounceToken()
if ok {
- s.getNode(addr).announceToken = at
+ s.getNode(addr, m.ID()).announceToken = at
}
})
return
ids := sel.IDs()
ret := make([]*node, 0, len(ids))
for _, id := range ids {
- ret = append(ret, idNodes[id.String()])
+ ret = append(ret, idNodes[id.ByteString()])
}
return ret
}
package dht
import (
+ "encoding/hex"
"math/big"
"math/rand"
"net"
"testing"
+
+ "github.com/anacrolix/torrent/util"
)
func TestSetNilBigInt(t *testing.T) {
}
m := map[string]bool{}
for _, id := range cn.IDs() {
- m[id.String()] = true
+ m[id.ByteString()] = true
}
- if !m[testIDs[3].String()] || !m[testIDs[4].String()] {
+ if !m[testIDs[3].ByteString()] || !m[testIDs[4].ByteString()] {
t.FailNow()
}
}
t.FailNow()
}
}
+
+func TestDHTSec(t *testing.T) {
+ for _, case_ := range []struct {
+ ipStr string
+ nodeIDHex string
+ valid bool
+ }{
+ // These 5 are from the spec example. They are all valid.
+ {"124.31.75.21", "5fbfbff10c5d6a4ec8a88e4c6ab4c28b95eee401", true},
+ {"21.75.31.124", "5a3ce9c14e7a08645677bbd1cfe7d8f956d53256", true},
+ {"65.23.51.170", "a5d43220bc8f112a3d426c84764f8c2a1150e616", true},
+ {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01ff41", true},
+ {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51305a", true},
+ // spec[0] with one of the rand() bytes changed. Valid.
+ {"124.31.75.21", "5fbfbff10c5d7a4ec8a88e4c6ab4c28b95eee401", true},
+ // spec[1] with the 21st leading bit changed. Not Valid.
+ {"21.75.31.124", "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256", false},
+ // spec[2] with the 22nd leading bit changed. Valid.
+ {"65.23.51.170", "a5d43620bc8f112a3d426c84764f8c2a1150e616", true},
+ // spec[3] with the 4th last bit changed. Valid.
+ {"84.124.73.14", "1b0321dd1bb1fe518101ceef99462b947a01fe01", true},
+ // spec[4] with the 3rd last bit changed. Not valid.
+ {"43.213.53.83", "e56f6cbf5b7c4be0237986d5243b87aa6d51303e", false},
+ } {
+ ip := net.ParseIP(case_.ipStr)
+ id, err := hex.DecodeString(case_.nodeIDHex)
+ if err != nil {
+ t.Fatal(err)
+ }
+ secure := nodeIdSecure(string(id), ip)
+ if secure != case_.valid {
+ t.Fatalf("case failed: %v", case_)
+ }
+ if !secure {
+ secureNodeId(id, ip)
+ if !nodeIdSecure(string(id), ip) {
+ t.Fatal("failed to secure node id")
+ }
+ }
+ }
+}
+
+func TestServerDefaultNodeIdSecure(t *testing.T) {
+ s, err := NewServer(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer s.Close()
+ if !nodeIdSecure(s.ID(), util.AddrIP(s.Addr())) {
+ t.Fatal("not secure")
+ }
+}