From 44ec4d9bdb989b0280f68b921867e155440eafbb Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 16 Dec 2015 15:15:59 +1100 Subject: [PATCH] Add NodeIdHex config option --- dht/dht.go | 6 ++++++ dht/dht_test.go | 16 ++++++++++++++++ dht/server.go | 9 +++++++++ 3 files changed, 31 insertions(+) diff --git a/dht/dht.go b/dht/dht.go index 31bef132..ec00c1fa 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -40,6 +40,12 @@ type transactionKey struct { type ServerConfig struct { // Listen address. Used if Conn is nil. Addr string + + // Set NodeId Manually. Caller must ensure that, if NodeId does not + // conform to DHT Security Extensions, that NoSecurity is also set. This + // should be given as a HEX string. + NodeIdHex string + Conn net.PacketConn // Don't respond to queries from other nodes. Passive bool diff --git a/dht/dht_test.go b/dht/dht_test.go index c3afe6a5..08a9697a 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -184,6 +184,22 @@ func TestServerDefaultNodeIdSecure(t *testing.T) { } } +func TestServerCustomNodeId(t *testing.T) { + customId := "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256" + id, err := hex.DecodeString(customId) + assert.NoError(t, err) + // How to test custom *secure* Id when tester computers will have + // different Ids? Generate custom ids for local IPs and use + // mini-Id? + s, err := NewServer(&ServerConfig{ + NodeIdHex: customId, + NoDefaultBootstrap: true, + }) + require.NoError(t, err) + defer s.Close() + assert.Equal(t, string(id), s.ID()) +} + func TestAnnounceTimeout(t *testing.T) { s, err := NewServer(&ServerConfig{ BootstrapNodes: []string{"1.2.3.4:5"}, diff --git a/dht/server.go b/dht/server.go index b755289d..b25579c1 100644 --- a/dht/server.go +++ b/dht/server.go @@ -3,6 +3,7 @@ package dht import ( "crypto" "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -85,6 +86,14 @@ func NewServer(c *ServerConfig) (s *Server, err error) { } } s.bootstrapNodes = c.BootstrapNodes + if c.NodeIdHex != "" { + var rawID []byte + rawID, err = hex.DecodeString(c.NodeIdHex) + if err != nil { + return + } + s.id = string(rawID) + } err = s.init() if err != nil { return -- 2.48.1