]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Rewrite cmd/dht-ping
authorMatt Joiner <anacrolix@gmail.com>
Mon, 7 Dec 2015 13:45:42 +0000 (00:45 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 7 Dec 2015 13:45:42 +0000 (00:45 +1100)
cmd/dht-ping/main.go

index 94d93368d3332f1d0e2f571b7ea0abaf782a915d..b23aa0cfe2301462a461dc0158aac0d70c10001a 100644 (file)
@@ -2,83 +2,92 @@
 package main
 
 import (
-       "flag"
        "fmt"
        "log"
+       "math"
        "net"
        "os"
        "time"
 
+       "github.com/anacrolix/tagflag"
+       "github.com/bradfitz/iter"
+
        "github.com/anacrolix/torrent/dht"
 )
 
-type pingResponse struct {
-       addr  string
-       krpc  dht.Msg
-       msgOk bool
-       rtt   time.Duration
-}
-
 func main() {
        log.SetFlags(log.LstdFlags | log.Lshortfile)
-       timeout := flag.Duration("timeout", -1, "maximum timeout")
-       flag.Parse()
-       pingStrAddrs := flag.Args()
-       if len(pingStrAddrs) == 0 {
-               os.Stderr.WriteString("u must specify addrs of nodes to ping e.g. router.bittorrent.com:6881\n")
-               os.Exit(2)
+       var args = struct {
+               Timeout time.Duration
+               Nodes   []string `type:"pos" arity:"+" help:"nodes to ping e.g. router.bittorrent.com:6881"`
+       }{
+               Timeout: math.MaxInt64,
        }
+       tagflag.Parse(&args)
        s, err := dht.NewServer(nil)
        if err != nil {
                log.Fatal(err)
        }
        log.Printf("dht server on %s", s.Addr())
-       pingResponsesChan := make(chan pingResponse)
-       timeoutChan := make(chan struct{})
-       go func() {
-               for i, netloc := range pingStrAddrs {
-                       if i != 0 {
-                               time.Sleep(1 * time.Millisecond)
-                       }
-                       addr, err := net.ResolveUDPAddr("udp4", netloc)
-                       if err != nil {
-                               log.Fatal(err)
-                       }
-                       t, err := s.Ping(addr)
-                       if err != nil {
-                               log.Fatal(err)
-                       }
-                       start := time.Now()
-                       t.SetResponseHandler(func(addr string) func(dht.Msg, bool) {
-                               return func(resp dht.Msg, ok bool) {
-                                       pingResponsesChan <- pingResponse{
-                                               addr:  addr,
-                                               krpc:  resp,
-                                               rtt:   time.Now().Sub(start),
-                                               msgOk: ok,
-                                       }
-                               }
-                       }(netloc))
-               }
-               if *timeout >= 0 {
-                       time.Sleep(*timeout)
-                       close(timeoutChan)
-               }
-       }()
-       responses := 0
-pingResponsesLoop:
-       for _ = range pingStrAddrs {
+       timeout := time.After(args.Timeout)
+       pongChan := make(chan pong)
+       startPings(s, pongChan, args.Nodes)
+       numResp := receivePongs(pongChan, timeout, len(args.Nodes))
+       fmt.Printf("%d/%d responses (%f%%)\n", numResp, len(args.Nodes), 100*float64(numResp)/float64(len(args.Nodes)))
+}
+
+func receivePongs(pongChan chan pong, timeout <-chan time.Time, maxPongs int) (numResp int) {
+       for range iter.N(maxPongs) {
                select {
-               case resp := <-pingResponsesChan:
-                       if !resp.msgOk {
+               case pong := <-pongChan:
+                       if !pong.msgOk {
                                break
                        }
-                       responses++
-                       fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc.SenderID(), resp.addr), resp.rtt)
-               case <-timeoutChan:
-                       break pingResponsesLoop
+                       numResp++
+                       fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", pong.krpc.SenderID(), pong.addr), pong.rtt)
+               case <-timeout:
+                       fmt.Fprintf(os.Stderr, "timed out\n")
+                       return
+               }
+       }
+       return
+}
+
+func startPings(s *dht.Server, pongChan chan pong, nodes []string) {
+       for i, addr := range nodes {
+               if i != 0 {
+                       // Put a small sleep between pings to avoid network issues.
+                       time.Sleep(1 * time.Millisecond)
                }
+               ping(addr, pongChan, s)
        }
-       // timeouts := len(pingStrAddrs) - responses
-       fmt.Printf("%d/%d responses (%f%%)\n", responses, len(pingStrAddrs), 100*float64(responses)/float64(len(pingStrAddrs)))
+}
+
+type pong struct {
+       addr  string
+       krpc  dht.Msg
+       msgOk bool
+       rtt   time.Duration
+}
+
+func ping(netloc string, pongChan chan pong, s *dht.Server) {
+       addr, err := net.ResolveUDPAddr("udp4", netloc)
+       if err != nil {
+               log.Fatal(err)
+       }
+       t, err := s.Ping(addr)
+       if err != nil {
+               log.Fatal(err)
+       }
+       start := time.Now()
+       t.SetResponseHandler(func(addr string) func(dht.Msg, bool) {
+               return func(resp dht.Msg, ok bool) {
+                       pongChan <- pong{
+                               addr:  addr,
+                               krpc:  resp,
+                               rtt:   time.Now().Sub(start),
+                               msgOk: ok,
+                       }
+               }
+       }(netloc))
 }