]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add OnQuery hook, thanks to Cathal Garvey
authorMatt Joiner <anacrolix@gmail.com>
Wed, 16 Dec 2015 04:20:37 +0000 (15:20 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 16 Dec 2015 04:20:37 +0000 (15:20 +1100)
dht/dht.go
dht/dht_test.go
dht/server.go

index ec00c1faa73e704e22cc7f2b4a0448597d98ac99..4fef734ddd7882056c1757de9812fc58475bd459 100644 (file)
@@ -65,6 +65,8 @@ type ServerConfig struct {
        IPBlocklist iplist.Ranger
        // Used to secure the server's ID. Defaults to the Conn's LocalAddr().
        PublicIP net.IP
+
+       OnQuery func(*Msg, net.Addr) bool
 }
 
 // ServerStats instance is returned by Server.Stats() and stores Server metrics
@@ -139,6 +141,8 @@ func (n *node) IsSecure() bool {
        if n.id.IsUnset() {
                return false
        }
+       // TODO (@onetruecathal): Exempt local peers from security
+       // check as per security extension recommendations
        return NodeIdSecure(n.id.ByteString(), n.addr.IP())
 }
 
index 082b15b33a6a756620d31a9b766ff45244f1ef2d..affeba837b474fea16b34d1614b2158b613f901a 100644 (file)
@@ -6,6 +6,7 @@ import (
        "math/rand"
        "net"
        "testing"
+       "time"
 
        "github.com/anacrolix/missinggo"
        "github.com/stretchr/testify/assert"
@@ -217,3 +218,53 @@ func TestAnnounceTimeout(t *testing.T) {
 func TestEqualPointers(t *testing.T) {
        assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}})
 }
+
+func TestHook(t *testing.T) {
+       t.Log("TestHook: Starting with Ping intercept/passthrough")
+       srv, err := NewServer(&ServerConfig{
+               Addr:               "127.0.0.1:5678",
+               NoDefaultBootstrap: true,
+       })
+       require.NoError(t, err)
+       defer srv.Close()
+       // Establish server with a hook attached to "ping"
+       hookCalled := make(chan bool)
+       srv0, err := NewServer(&ServerConfig{
+               Addr:           "127.0.0.1:5679",
+               BootstrapNodes: []string{"127.0.0.1:5678"},
+               OnQuery: func(m *Msg, addr net.Addr) bool {
+                       if m.Q == "ping" {
+                               hookCalled <- true
+                       }
+                       return true
+               },
+       })
+       require.NoError(t, err)
+       defer srv0.Close()
+       // Ping srv0 from srv to trigger hook. Should also receive a response.
+       t.Log("TestHook: Servers created, hook for ping established. Calling Ping.")
+       tn, err := srv.Ping(&net.UDPAddr{
+               IP:   []byte{127, 0, 0, 1},
+               Port: srv0.Addr().(*net.UDPAddr).Port,
+       })
+       assert.NoError(t, err)
+       defer tn.Close()
+       // Await response from hooked server
+       tn.SetResponseHandler(func(msg Msg, b bool) {
+               t.Log("TestHook: Sender received response from pinged hook server, so normal execution resumed.")
+       })
+       // Await signal that hook has been called.
+       select {
+       case <-hookCalled:
+               {
+                       // Success, hook was triggered. Todo: Ensure that "ok" channel
+                       // receives, also, indicating normal handling proceeded also.
+                       t.Log("TestHook: Received ping, hook called and returned to normal execution!")
+                       return
+               }
+       case <-time.After(time.Second * 1):
+               {
+                       t.Error("Failed to see evidence of ping hook being called after 2 seconds.")
+               }
+       }
+}
index b25579c1f5eed7849965f016e1fe488b1681106c..73a4cc3c1a2a947df32e62f37b9fe10af284cc86 100644 (file)
@@ -251,6 +251,12 @@ func (s *Server) nodeByID(id string) *node {
 func (s *Server) handleQuery(source dHTAddr, m Msg) {
        node := s.getNode(source, m.SenderID())
        node.lastGotQuery = time.Now()
+       if s.config.OnQuery != nil {
+               propagate := s.config.OnQuery(&m, source.UDPAddr())
+               if !propagate {
+                       return
+               }
+       }
        // Don't respond.
        if s.config.Passive {
                return
@@ -340,6 +346,7 @@ func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
        if len(s.nodes) >= maxNodes {
                return
        }
+       // Exclude insecure nodes from the node table.
        if !s.config.NoSecurity && !n.IsSecure() {
                return
        }