]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Tidy up the torrent and DHT APIs
authorMatt Joiner <anacrolix@gmail.com>
Thu, 21 Aug 2014 08:07:06 +0000 (18:07 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 21 Aug 2014 08:07:06 +0000 (18:07 +1000)
client.go
client_test.go
cmd/dht-get-peers/main.go
cmd/dht-ping/main.go
cmd/dht-server/main.go
cmd/torrent/main.go
cmd/torrentfs/main.go
config.go [new file with mode: 0644]
dht/dht.go
fs/torrentfs_test.go

index 529d31d03f77ca73b513a46ba8a36b6362862283..8e44663a12f08e0a68cbc6aba46bf5a32f0f4659 100644 (file)
--- a/client.go
+++ b/client.go
@@ -16,8 +16,6 @@ Simple example:
 package torrent
 
 import (
-       "bitbucket.org/anacrolix/go.torrent/dht"
-       "bitbucket.org/anacrolix/go.torrent/util"
        "bufio"
        "crypto/rand"
        "crypto/sha1"
@@ -32,6 +30,9 @@ import (
        "syscall"
        "time"
 
+       "bitbucket.org/anacrolix/go.torrent/dht"
+       . "bitbucket.org/anacrolix/go.torrent/util"
+
        "github.com/anacrolix/libtorgo/metainfo"
        "github.com/nsf/libtorgo/bencode"
 
@@ -63,7 +64,7 @@ func (me *Client) PrioritizeDataRegion(ih InfoHash, off, len_ int64) error {
        if !t.haveInfo() {
                return errors.New("missing metadata")
        }
-       me.DownloadStrategy.TorrentPrioritize(t, off, len_)
+       me.downloadStrategy.TorrentPrioritize(t, off, len_)
        for _, cn := range t.Conns {
                me.replenishConnRequests(t, cn)
        }
@@ -76,13 +77,13 @@ type dataSpec struct {
 }
 
 type Client struct {
-       DataDir          string
-       HalfOpenLimit    int
-       PeerId           [20]byte
-       Listener         net.Listener
-       DisableTrackers  bool
-       DownloadStrategy DownloadStrategy
-       DHT              *dht.Server
+       dataDir          string
+       halfOpenLimit    int
+       peerID           [20]byte
+       listener         net.Listener
+       disableTrackers  bool
+       downloadStrategy DownloadStrategy
+       dHT              *dht.Server
 
        mu    sync.Mutex
        event sync.Cond
@@ -93,18 +94,23 @@ type Client struct {
        dataWaiter chan struct{}
 }
 
+func (me *Client) ListenAddr() net.Addr {
+       return me.listener.Addr()
+}
+
 func (cl *Client) WriteStatus(w io.Writer) {
        cl.mu.Lock()
        defer cl.mu.Unlock()
-       if cl.Listener != nil {
-               fmt.Fprintf(w, "Listening on %s\n", cl.Listener.Addr())
+       if cl.listener != nil {
+               fmt.Fprintf(w, "Listening on %s\n", cl.listener.Addr())
        } else {
                fmt.Fprintf(w, "No listening torrent port!\n")
        }
-       fmt.Fprintf(w, "Peer ID: %q\n", cl.PeerId)
+       fmt.Fprintf(w, "Peer ID: %q\n", cl.peerID)
        fmt.Fprintf(w, "Half open outgoing connections: %d\n", cl.halfOpen)
-       if cl.DHT != nil {
-               fmt.Fprintf(w, "DHT nodes: %d\n", cl.DHT.NumNodes())
+       if cl.dHT != nil {
+               fmt.Fprintf(w, "DHT nodes: %d\n", cl.dHT.NumNodes())
+               fmt.Fprintf(w, "DHT Server ID: %x\n", cl.dHT.IDString())
        }
        fmt.Fprintln(w)
        for _, t := range cl.torrents {
@@ -164,26 +170,50 @@ func (cl *Client) TorrentReadAt(ih InfoHash, off int64, p []byte) (n int, err er
        return t.Data.ReadAt(p, off)
 }
 
-// Starts the client. Defaults are applied. The client will begin accepting
-// connections and tracking.
-func (c *Client) Start() {
-       c.event.L = &c.mu
-       c.torrents = make(map[InfoHash]*torrent)
-       if c.HalfOpenLimit == 0 {
-               c.HalfOpenLimit = 10
+func NewClient(cfg *Config) (cl *Client, err error) {
+       if cfg == nil {
+               cfg = &Config{}
        }
-       o := copy(c.PeerId[:], BEP20)
-       _, err := rand.Read(c.PeerId[o:])
+
+       cl = &Client{
+               disableTrackers:  cfg.DisableTrackers,
+               downloadStrategy: cfg.DownloadStrategy,
+               halfOpenLimit:    100,
+               dataDir:          cfg.DataDir,
+
+               quit:     make(chan struct{}),
+               torrents: make(map[InfoHash]*torrent),
+       }
+       cl.event.L = &cl.mu
+
+       o := copy(cl.peerID[:], BEP20)
+       _, err = rand.Read(cl.peerID[o:])
        if err != nil {
                panic("error generating peer id")
        }
-       c.quit = make(chan struct{})
-       if c.DownloadStrategy == nil {
-               c.DownloadStrategy = &DefaultDownloadStrategy{}
+
+       if cl.downloadStrategy == nil {
+               cl.downloadStrategy = &DefaultDownloadStrategy{}
+       }
+
+       cl.listener, err = net.Listen("tcp", cfg.ListenAddr)
+       if err != nil {
+               return
        }
-       if c.Listener != nil {
-               go c.acceptConnections()
+       if cl.listener != nil {
+               go cl.acceptConnections()
        }
+
+       if !cfg.NoDHT {
+               cl.dHT, err = dht.NewServer(&dht.ServerConfig{
+                       Addr: cfg.ListenAddr,
+               })
+               if err != nil {
+                       return
+               }
+       }
+
+       return
 }
 
 func (cl *Client) stopped() bool {
@@ -211,7 +241,7 @@ func (me *Client) Stop() {
 
 func (cl *Client) acceptConnections() {
        for {
-               conn, err := cl.Listener.Accept()
+               conn, err := cl.listener.Accept()
                select {
                case <-cl.quit:
                        if conn != nil {
@@ -245,7 +275,7 @@ func (me *Client) torrent(ih InfoHash) *torrent {
 // Start the process of connecting to the given peer for the given torrent if
 // appropriate.
 func (me *Client) initiateConn(peer Peer, torrent *torrent) {
-       if peer.Id == me.PeerId {
+       if peer.Id == me.peerID {
                return
        }
        me.halfOpen++
@@ -291,10 +321,10 @@ func (me *Client) initiateConn(peer Peer, torrent *torrent) {
 }
 
 func (cl *Client) incomingPeerPort() int {
-       if cl.Listener == nil {
+       if cl.listener == nil {
                return 0
        }
-       _, p, err := net.SplitHostPort(cl.Listener.Addr().String())
+       _, p, err := net.SplitHostPort(cl.listener.Addr().String())
        if err != nil {
                panic(err)
        }
@@ -452,7 +482,7 @@ func (me *Client) peerUnchoked(torrent *torrent, conn *connection) {
 func (cl *Client) connCancel(t *torrent, cn *connection, r request) (ok bool) {
        ok = cn.Cancel(r)
        if ok {
-               cl.DownloadStrategy.DeleteRequest(t, r)
+               cl.downloadStrategy.DeleteRequest(t, r)
        }
        return
 }
@@ -461,7 +491,7 @@ func (cl *Client) connDeleteRequest(t *torrent, cn *connection, r request) {
        if !cn.RequestPending(r) {
                return
        }
-       cl.DownloadStrategy.DeleteRequest(t, r)
+       cl.downloadStrategy.DeleteRequest(t, r)
        delete(cn.Requests, r)
 }
 
@@ -788,7 +818,7 @@ func (me *Client) addConnection(t *torrent, c *connection) bool {
 func (me *Client) openNewConns() {
        for _, t := range me.torrents {
                for len(t.Peers) != 0 {
-                       if me.halfOpen >= me.HalfOpenLimit {
+                       if me.halfOpen >= me.halfOpenLimit {
                                return
                        }
                        p := t.Peers[0]
@@ -812,7 +842,7 @@ func (me *Client) AddPeers(infoHash InfoHash, peers []Peer) error {
 }
 
 func (cl *Client) setMetaData(t *torrent, md metainfo.Info, bytes []byte) (err error) {
-       err = t.setMetadata(md, cl.DataDir, bytes)
+       err = t.setMetadata(md, cl.dataDir, bytes)
        if err != nil {
                return
        }
@@ -827,7 +857,7 @@ func (cl *Client) setMetaData(t *torrent, md metainfo.Info, bytes []byte) (err e
                }
        }()
 
-       cl.DownloadStrategy.TorrentStarted(t)
+       cl.downloadStrategy.TorrentStarted(t)
        return
 }
 
@@ -902,10 +932,10 @@ func (me *Client) addTorrent(t *torrent) (err error) {
                return
        }
        me.torrents[t.InfoHash] = t
-       if !me.DisableTrackers {
+       if !me.disableTrackers {
                go me.announceTorrent(t)
        }
-       if me.DHT != nil {
+       if me.dHT != nil {
                go me.announceTorrentDHT(t)
        }
        return
@@ -940,7 +970,7 @@ func (me *Client) AddTorrentFromFile(name string) (err error) {
 }
 
 func (cl *Client) listenerAnnouncePort() (port int16) {
-       l := cl.Listener
+       l := cl.listener
        if l == nil {
                return
        }
@@ -958,7 +988,7 @@ func (cl *Client) listenerAnnouncePort() (port int16) {
 
 func (cl *Client) announceTorrentDHT(t *torrent) {
        for {
-               ps, err := cl.DHT.GetPeers(string(t.InfoHash[:]))
+               ps, err := cl.dHT.GetPeers(string(t.InfoHash[:]))
                if err != nil {
                        log.Printf("error getting peers from dht: %s", err)
                        return
@@ -1003,7 +1033,7 @@ func (cl *Client) announceTorrent(t *torrent) {
                Event:    tracker.Started,
                NumWant:  -1,
                Port:     cl.listenerAnnouncePort(),
-               PeerId:   cl.PeerId,
+               PeerId:   cl.peerID,
                InfoHash: t.InfoHash,
        }
 newAnnounce:
@@ -1072,7 +1102,7 @@ func (me *Client) WaitAll() bool {
 }
 
 func (cl *Client) assertRequestHeat() {
-       dds, ok := cl.DownloadStrategy.(*DefaultDownloadStrategy)
+       dds, ok := cl.downloadStrategy.(*DefaultDownloadStrategy)
        if !ok {
                return
        }
@@ -1095,7 +1125,7 @@ func (me *Client) replenishConnRequests(t *torrent, c *connection) {
        if !t.haveInfo() {
                return
        }
-       me.DownloadStrategy.FillRequests(t, c)
+       me.downloadStrategy.FillRequests(t, c)
        //me.assertRequestHeat()
        if len(c.Requests) == 0 && !c.PeerChoked {
                c.SetInterested(false)
@@ -1130,7 +1160,7 @@ func (me *Client) downloadedChunk(t *torrent, c *connection, msg *pp.Message) er
        }
 
        // Unprioritize the chunk.
-       me.DownloadStrategy.TorrentGotChunk(t, req)
+       me.downloadStrategy.TorrentGotChunk(t, req)
 
        // Cancel pending requests for this chunk.
        cancelled := false
@@ -1171,7 +1201,7 @@ func (me *Client) pieceHashed(t *torrent, piece pp.Integer, correct bool) {
        p.EverHashed = true
        if correct {
                p.PendingChunkSpecs = nil
-               me.DownloadStrategy.TorrentGotPiece(t, int(piece))
+               me.downloadStrategy.TorrentGotPiece(t, int(piece))
                me.dataReady(dataSpec{
                        t.InfoHash,
                        request{
index b6a6439b330256389ffb688ac2bb1da8199b3f5a..1d1e9c30d7440b5e60ae49bdaad60e34d750dc34 100644 (file)
@@ -7,6 +7,14 @@ import (
        "testing"
 )
 
+func TestClientDefault(t *testing.T) {
+       cl, err := NewClient(nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       cl.Stop()
+}
+
 func TestAddTorrentNoSupportedTrackerSchemes(t *testing.T) {
        t.SkipNow()
 }
index b9138a4ddf79cf5c825a53e9356a80854a66e9cf..779888b67d094fb8932b3d03082bfe78eeb4e445 100644 (file)
@@ -1,9 +1,6 @@
 package main
 
 import (
-       "bitbucket.org/anacrolix/go.torrent/dht"
-       "bitbucket.org/anacrolix/go.torrent/tracker"
-       _ "bitbucket.org/anacrolix/go.torrent/util/profile"
        "flag"
        "fmt"
        "io"
@@ -11,6 +8,11 @@ import (
        "net"
        "os"
        "os/signal"
+       "time"
+
+       "bitbucket.org/anacrolix/go.torrent/dht"
+       "bitbucket.org/anacrolix/go.torrent/util"
+       _ "bitbucket.org/anacrolix/go.torrent/util/profile"
 )
 
 type pingResponse struct {
@@ -23,7 +25,7 @@ var (
        serveAddr     = flag.String("serveAddr", ":0", "local UDP address")
        infoHash      = flag.String("infoHash", "", "torrent infohash")
 
-       s dht.Server
+       s *dht.Server
 )
 
 func loadTable() error {
@@ -74,22 +76,17 @@ func init() {
                log.Fatal("require 20 byte infohash")
        }
        var err error
-       s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
-               addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
-               if err != nil {
-                       log.Fatalf("error resolving serve addr: %s", err)
-               }
-               return addr
-       }())
+       s, err = dht.NewServer(&dht.ServerConfig{
+               Addr: *serveAddr,
+       })
        if err != nil {
                log.Fatal(err)
        }
-       s.Init()
        err = loadTable()
        if err != nil {
                log.Fatalf("error loading table: %s", err)
        }
-       log.Printf("dht server on %s, ID is %q", s.Socket.LocalAddr(), s.IDString())
+       log.Printf("dht server on %s, ID is %q", s.LocalAddr(), s.IDString())
        setupSignals()
 }
 
@@ -131,36 +128,29 @@ func setupSignals() {
 }
 
 func main() {
-       go func() {
-               defer s.StopServing()
-               if err := s.Bootstrap(); err != nil {
-                       log.Printf("error bootstrapping: %s", err)
-                       return
-               }
-               saveTable()
+       seen := make(map[util.CompactPeer]struct{})
+       for {
                ps, err := s.GetPeers(*infoHash)
                if err != nil {
                        log.Fatal(err)
                }
-               seen := make(map[tracker.CompactPeer]struct{})
-               for sl := range ps.Values {
-                       for _, p := range sl {
-                               if _, ok := seen[p]; ok {
-                                       continue
+               go func() {
+                       for sl := range ps.Values {
+                               for _, p := range sl {
+                                       if _, ok := seen[p]; ok {
+                                               continue
+                                       }
+                                       seen[p] = struct{}{}
+                                       fmt.Println((&net.UDPAddr{
+                                               IP:   p.IP[:],
+                                               Port: int(p.Port),
+                                       }).String())
                                }
-                               seen[p] = struct{}{}
-                               fmt.Println((&net.UDPAddr{
-                                       IP:   p.IP[:],
-                                       Port: int(p.Port),
-                               }).String())
                        }
-               }
-       }()
-       err := s.Serve()
+               }()
+               time.Sleep(15 * time.Second)
+       }
        if err := saveTable(); err != nil {
                log.Printf("error saving node table: %s", err)
        }
-       if err != nil {
-               log.Fatalf("error serving dht: %s", err)
-       }
 }
index 6c686492666d031efd54a6ae47d241bd2513af90..b5cf5236ee925a4d44d934fa4c5dabf44861c9b9 100644 (file)
@@ -1,11 +1,12 @@
 package main
 
 import (
-       "bitbucket.org/anacrolix/go.torrent/dht"
        "flag"
        "log"
        "net"
        "os"
+
+       "bitbucket.org/anacrolix/go.torrent/dht"
 )
 
 type pingResponse struct {
@@ -21,20 +22,11 @@ func main() {
                os.Stderr.WriteString("u must specify addrs of nodes to ping e.g. router.bittorrent.com:6881\n")
                os.Exit(2)
        }
-       s := dht.Server{}
-       var err error
-       s.Socket, err = net.ListenUDP("udp4", nil)
+       s, err := dht.NewServer(nil)
        if err != nil {
                log.Fatal(err)
        }
-       log.Printf("dht server on %s", s.Socket.LocalAddr())
-       s.Init()
-       go func() {
-               err := s.Serve()
-               if err != nil {
-                       log.Fatal(err)
-               }
-       }()
+       log.Printf("dht server on %s", s.LocalAddr())
        pingResponses := make(chan pingResponse)
        for _, netloc := range pingStrAddrs {
                addr, err := net.ResolveUDPAddr("udp4", netloc)
index 667a8c8a4b95903968095b94e0e6c514842fd271..d7140ffc33e42ff34eaf40525de698fccd6f067d 100644 (file)
@@ -1,14 +1,14 @@
 package main
 
 import (
-       "bitbucket.org/anacrolix/go.torrent/dht"
        "flag"
        "fmt"
        "io"
        "log"
-       "net"
        "os"
        "os/signal"
+
+       "bitbucket.org/anacrolix/go.torrent/dht"
 )
 
 type pingResponse struct {
@@ -20,7 +20,7 @@ var (
        tableFileName = flag.String("tableFile", "", "name of file for storing node info")
        serveAddr     = flag.String("serveAddr", ":0", "local UDP address")
 
-       s dht.Server
+       s *dht.Server
 )
 
 func loadTable() error {
@@ -61,22 +61,17 @@ func init() {
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        flag.Parse()
        var err error
-       s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
-               addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
-               if err != nil {
-                       log.Fatalf("error resolving serve addr: %s", err)
-               }
-               return addr
-       }())
+       s, err = dht.NewServer(&dht.ServerConfig{
+               Addr: *serveAddr,
+       })
        if err != nil {
                log.Fatal(err)
        }
-       s.Init()
        err = loadTable()
        if err != nil {
                log.Fatalf("error loading table: %s", err)
        }
-       log.Printf("dht server on %s, ID is %q", s.Socket.LocalAddr(), s.IDString())
+       log.Printf("dht server on %s, ID is %q", s.LocalAddr(), s.IDString())
        setupSignals()
 }
 
@@ -118,20 +113,8 @@ func setupSignals() {
 }
 
 func main() {
-       go func() {
-               err := s.Bootstrap()
-               if err != nil {
-                       log.Printf("error bootstrapping: %s", err)
-                       s.StopServing()
-               } else {
-                       log.Print("bootstrapping complete")
-               }
-       }()
-       err := s.Serve()
+       select {}
        if err := saveTable(); err != nil {
                log.Printf("error saving node table: %s", err)
        }
-       if err != nil {
-               log.Fatalf("error serving dht: %s", err)
-       }
 }
index 140c35845065d7e12c24e43f0225e66092b3a41f..8442f5f72f9c8f4b060278169a1ba74a51d4b765 100644 (file)
@@ -1,8 +1,6 @@
 package main
 
 import (
-       "bitbucket.org/anacrolix/go.torrent/dht"
-       "bitbucket.org/anacrolix/go.torrent/util"
        "flag"
        "fmt"
        "log"
@@ -12,6 +10,8 @@ import (
        "os"
        "strings"
 
+       "bitbucket.org/anacrolix/go.torrent/util"
+
        "github.com/anacrolix/libtorgo/metainfo"
 
        "bitbucket.org/anacrolix/go.torrent"
@@ -32,57 +32,21 @@ func init() {
        flag.Parse()
 }
 
-func makeListener() net.Listener {
-       l, err := net.Listen("tcp", *listenAddr)
-       if err != nil {
-               log.Fatal(err)
-       }
-       return l
-}
-
 func main() {
        if *httpAddr != "" {
                util.LoggedHTTPServe(*httpAddr)
        }
-       dhtServer := &dht.Server{
-               Socket: func() *net.UDPConn {
-                       addr, err := net.ResolveUDPAddr("udp4", *listenAddr)
-                       if err != nil {
-                               log.Fatalf("error resolving dht listen addr: %s", err)
-                       }
-                       s, err := net.ListenUDP("udp4", addr)
-                       if err != nil {
-                               log.Fatalf("error creating dht socket: %s", err)
-                       }
-                       return s
-               }(),
-       }
-       err := dhtServer.Init()
-       if err != nil {
-               log.Fatalf("error initing dht server: %s", err)
-       }
-       go func() {
-               err := dhtServer.Serve()
-               if err != nil {
-                       log.Fatalf("error serving dht: %s", err)
-               }
-       }()
-       go func() {
-               err := dhtServer.Bootstrap()
-               if err != nil {
-                       log.Printf("error bootstrapping dht server: %s", err)
-               }
-       }()
-       client := torrent.Client{
+       client, err := torrent.NewClient(&torrent.Config{
                DataDir:         *downloadDir,
-               Listener:        makeListener(),
                DisableTrackers: *disableTrackers,
-               DHT:             dhtServer,
+               ListenAddr:      *listenAddr,
+       })
+       if err != nil {
+               log.Fatalf("error creating client: %s", err)
        }
        http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
                client.WriteStatus(w)
        })
-       client.Start()
        defer client.Stop()
        if flag.NArg() == 0 {
                fmt.Fprintln(os.Stderr, "no torrents specified")
index 16fa3bf2d5e81e835bb700d980a0c5be6e3f371d..34cdbf8a7dddbf0145c829aa000bf495dc2a5863 100644 (file)
@@ -45,14 +45,6 @@ func init() {
        flag.StringVar(&mountDir, "mountDir", "", "location the torrent contents are made available")
 }
 
-func makeListener() net.Listener {
-       l, err := net.Listen("tcp", *listenAddr)
-       if err != nil {
-               log.Fatal(err)
-       }
-       return l
-}
-
 func resolveTestPeerAddr() {
        if *testPeer == "" {
                return
@@ -113,16 +105,15 @@ func main() {
        // TODO: Think about the ramifications of exiting not due to a signal.
        setSignalHandlers()
        defer conn.Close()
-       client := &torrent.Client{
+       client, err := torrent.NewClient(&torrent.Config{
                DataDir:          downloadDir,
                DisableTrackers:  *disableTrackers,
                DownloadStrategy: torrent.NewResponsiveDownloadStrategy(*readaheadBytes),
-               Listener:         makeListener(),
-       }
+               ListenAddr:       *listenAddr,
+       })
        http.DefaultServeMux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
                client.WriteStatus(w)
        })
-       client.Start()
        dw, err := dirwatch.New(torrentPath)
        if err != nil {
                log.Fatal(err)
diff --git a/config.go b/config.go
new file mode 100644 (file)
index 0000000..40e1fd0
--- /dev/null
+++ b/config.go
@@ -0,0 +1,9 @@
+package torrent
+
+type Config struct {
+       DataDir          string
+       ListenAddr       string
+       DisableTrackers  bool
+       DownloadStrategy DownloadStrategy
+       NoDHT            bool
+}
index 3cb1f0f47e81c8cf3ea776555efee0f0e5497a3d..c30ab942ae8d433953a1fede62c42df86a016956 100644 (file)
@@ -1,25 +1,25 @@
 package dht
 
 import (
-       "bitbucket.org/anacrolix/go.torrent/tracker"
-       "bitbucket.org/anacrolix/go.torrent/util"
        "crypto"
        _ "crypto/sha1"
        "encoding/binary"
        "errors"
        "fmt"
-       "github.com/nsf/libtorgo/bencode"
        "io"
        "log"
        "net"
        "os"
        "sync"
        "time"
+
+       "bitbucket.org/anacrolix/go.torrent/util"
+       "github.com/nsf/libtorgo/bencode"
 )
 
 type Server struct {
-       ID               string
-       Socket           *net.UDPConn
+       id               string
+       socket           *net.UDPConn
        transactions     []*transaction
        transactionIDInt uint64
        nodes            map[string]*Node // Keyed by *net.UDPAddr.String().
@@ -27,8 +27,47 @@ type Server struct {
        closed           chan struct{}
 }
 
+type ServerConfig struct {
+       Addr string
+}
+
+func (s *Server) LocalAddr() net.Addr {
+       return s.socket.LocalAddr()
+}
+
+func makeSocket(addr string) (socket *net.UDPConn, err error) {
+       addr_, err := net.ResolveUDPAddr("", addr)
+       if err != nil {
+               return
+       }
+       socket, err = net.ListenUDP("udp", addr_)
+       return
+}
+
+func NewServer(c *ServerConfig) (s *Server, err error) {
+       s = &Server{}
+       s.socket, err = makeSocket(c.Addr)
+       if err != nil {
+               return
+       }
+       err = s.init()
+       if err != nil {
+               return
+       }
+       go func() {
+               panic(s.serve())
+       }()
+       go func() {
+               err := s.bootstrap()
+               if err != nil {
+                       panic(err)
+               }
+       }()
+       return
+}
+
 func (s *Server) String() string {
-       return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr())
+       return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
 }
 
 type Node struct {
@@ -96,25 +135,14 @@ func (t *transaction) handleResponse(m Msg) {
 }
 
 func (s *Server) setDefaults() (err error) {
-       if s.Socket == nil {
-               var addr *net.UDPAddr
-               addr, err = net.ResolveUDPAddr("", ":6882")
-               if err != nil {
-                       return
-               }
-               s.Socket, err = net.ListenUDP("udp", addr)
-               if err != nil {
-                       return
-               }
-       }
-       if s.ID == "" {
+       if s.id == "" {
                var id [20]byte
                h := crypto.SHA1.New()
                ss, err := os.Hostname()
                if err != nil {
                        log.Print(err)
                }
-               ss += s.Socket.LocalAddr().String()
+               ss += s.socket.LocalAddr().String()
                h.Write([]byte(ss))
                if b := h.Sum(id[:0:20]); len(b) != 20 {
                        panic(len(b))
@@ -122,13 +150,13 @@ func (s *Server) setDefaults() (err error) {
                if len(id) != 20 {
                        panic(len(id))
                }
-               s.ID = string(id[:])
+               s.id = string(id[:])
        }
        s.nodes = make(map[string]*Node, 10000)
        return
 }
 
-func (s *Server) Init() (err error) {
+func (s *Server) init() (err error) {
        err = s.setDefaults()
        if err != nil {
                return
@@ -137,10 +165,10 @@ func (s *Server) Init() (err error) {
        return
 }
 
-func (s *Server) Serve() error {
+func (s *Server) serve() error {
        for {
                var b [0x10000]byte
-               n, addr, err := s.Socket.ReadFromUDP(b[:])
+               n, addr, err := s.socket.ReadFromUDP(b[:])
                if err != nil {
                        return err
                }
@@ -296,7 +324,7 @@ func (s *Server) getNode(addr *net.UDPAddr) (n *Node) {
 }
 
 func (s *Server) writeToNode(b []byte, node *net.UDPAddr) (err error) {
-       n, err := s.Socket.WriteTo(b, node)
+       n, err := s.socket.WriteTo(b, node)
        if err != nil {
                err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
                return
@@ -347,10 +375,10 @@ func (s *Server) addTransaction(t *transaction) {
 }
 
 func (s *Server) IDString() string {
-       if len(s.ID) != 20 {
+       if len(s.id) != 20 {
                panic("bad node id")
        }
-       return s.ID
+       return s.id
 }
 
 func (s *Server) timeoutTransaction(t *transaction) {
@@ -560,14 +588,9 @@ func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, e
        return
 }
 
-type getPeersResponse struct {
-       Values []tracker.CompactPeer `bencode:"values"`
-       Nodes  util.CompactPeers     `bencode:"nodes"`
-}
-
 type peerStream struct {
        mu     sync.Mutex
-       Values chan []tracker.CompactPeer
+       Values chan []util.CompactPeer
        stop   chan struct{}
 }
 
@@ -577,12 +600,11 @@ func (ps *peerStream) Close() {
        case <-ps.stop:
        default:
                close(ps.stop)
-               close(ps.Values)
        }
        ps.mu.Unlock()
 }
 
-func extractValues(m Msg) (vs []tracker.CompactPeer) {
+func extractValues(m Msg) (vs []util.CompactPeer) {
        r, ok := m["r"]
        if !ok {
                return
@@ -595,19 +617,17 @@ func extractValues(m Msg) (vs []tracker.CompactPeer) {
        if !ok {
                return
        }
-       // log.Fatal(m)
        vl, ok := v.([]interface{})
        if !ok {
                panic(v)
        }
-       vs = make([]tracker.CompactPeer, 0, len(vl))
+       vs = make([]util.CompactPeer, 0, len(vl))
        for _, i := range vl {
-               // log.Printf("%T", i)
                s, ok := i.(string)
                if !ok {
                        panic(i)
                }
-               var cp tracker.CompactPeer
+               var cp util.CompactPeer
                err := cp.UnmarshalBinary([]byte(s))
                if err != nil {
                        log.Printf("error decoding values list element: %s", err)
@@ -620,7 +640,7 @@ func extractValues(m Msg) (vs []tracker.CompactPeer) {
 
 func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
        ps = &peerStream{
-               Values: make(chan []tracker.CompactPeer),
+               Values: make(chan []util.CompactPeer),
                stop:   make(chan struct{}),
        }
        done := make(chan struct{})
@@ -657,7 +677,7 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
                        case <-s.closed:
                        }
                }
-               ps.Close()
+               close(ps.Values)
        }()
        return
 }
@@ -689,7 +709,7 @@ func (s *Server) addRootNode() error {
 }
 
 // Populates the node table.
-func (s *Server) Bootstrap() (err error) {
+func (s *Server) bootstrap() (err error) {
        s.mu.Lock()
        defer s.mu.Unlock()
        if len(s.nodes) == 0 {
@@ -702,7 +722,7 @@ func (s *Server) Bootstrap() (err error) {
                var outstanding sync.WaitGroup
                for _, node := range s.nodes {
                        var t *transaction
-                       t, err = s.findNode(node.addr, s.ID)
+                       t, err = s.findNode(node.addr, s.id)
                        if err != nil {
                                return
                        }
@@ -768,7 +788,7 @@ func (s *Server) Nodes() (nis []NodeInfo) {
 }
 
 func (s *Server) StopServing() {
-       s.Socket.Close()
+       s.socket.Close()
        s.mu.Lock()
        select {
        case <-s.closed:
index 9b26eb88b245a397f45ec0ea0026b4f4deeb7812..fb214bdf04c1bf60072cc878850f502468de7aab 100644 (file)
@@ -6,19 +6,26 @@ import (
        "io/ioutil"
        "log"
        "net"
+       "net/http"
+       _ "net/http/pprof"
        "os"
        "path/filepath"
        "testing"
        "time"
 
+       "bitbucket.org/anacrolix/go.torrent"
        "bitbucket.org/anacrolix/go.torrent/testutil"
+       "bitbucket.org/anacrolix/go.torrent/util"
+       "github.com/anacrolix/libtorgo/metainfo"
 
        "bazil.org/fuse"
        fusefs "bazil.org/fuse/fs"
-       "bitbucket.org/anacrolix/go.torrent"
-       "github.com/anacrolix/libtorgo/metainfo"
 )
 
+func init() {
+       go http.ListenAndServe(":6061", nil)
+}
+
 func TestTCPAddrString(t *testing.T) {
        ta := &net.TCPAddr{
                IP:   net.IPv4(127, 0, 0, 1),
@@ -65,6 +72,7 @@ func newGreetingLayout() (tl testLayout, err error) {
        metaInfoBuf := &bytes.Buffer{}
        testutil.CreateMetaInfo(name, metaInfoBuf)
        tl.Metainfo, err = metainfo.Load(metaInfoBuf)
+       log.Printf("%x", tl.Metainfo.Info.Pieces)
        return
 }
 
@@ -79,14 +87,14 @@ func TestUnmountWedged(t *testing.T) {
                        t.Log(err)
                }
        }()
-       client := torrent.Client{
+       client, err := torrent.NewClient(&torrent.Config{
                DataDir:         filepath.Join(layout.BaseDir, "incomplete"),
                DisableTrackers: true,
-       }
-       client.Start()
+               NoDHT:           true,
+       })
        log.Printf("%+v", *layout.Metainfo)
        client.AddTorrent(layout.Metainfo)
-       fs := New(&client)
+       fs := New(client)
        fuseConn, err := fuse.Mount(layout.MountDir)
        if err != nil {
                t.Fatal(err)
@@ -125,40 +133,40 @@ func TestDownloadOnDemand(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       seeder := torrent.Client{
-               DataDir: layout.Completed,
-               Listener: func() net.Listener {
-                       conn, err := net.Listen("tcp", ":0")
-                       if err != nil {
-                               panic(err)
-                       }
-                       return conn
-               }(),
+       seeder, err := torrent.NewClient(&torrent.Config{
+               DataDir:         layout.Completed,
                DisableTrackers: true,
-       }
-       defer seeder.Listener.Close()
-       seeder.Start()
+               NoDHT:           true,
+       })
+       http.HandleFunc("/seeder", func(w http.ResponseWriter, req *http.Request) {
+               seeder.WriteStatus(w)
+       })
        defer seeder.Stop()
        err = seeder.AddMagnet(fmt.Sprintf("magnet:?xt=urn:btih:%x", layout.Metainfo.Info.Hash))
        if err != nil {
                t.Fatal(err)
        }
-       leecher := torrent.Client{
+       leecher, err := torrent.NewClient(&torrent.Config{
                DataDir:          filepath.Join(layout.BaseDir, "download"),
-               DownloadStrategy: &torrent.ResponsiveDownloadStrategy{},
+               DownloadStrategy: torrent.NewResponsiveDownloadStrategy(0),
                DisableTrackers:  true,
-       }
-       leecher.Start()
+               NoDHT:            true,
+       })
+       http.HandleFunc("/leecher", func(w http.ResponseWriter, req *http.Request) {
+               leecher.WriteStatus(w)
+       })
        defer leecher.Stop()
        leecher.AddTorrent(layout.Metainfo)
-       leecher.AddPeers(torrent.BytesInfoHash(layout.Metainfo.Info.Hash), []torrent.Peer{func() torrent.Peer {
-               tcpAddr := seeder.Listener.Addr().(*net.TCPAddr)
+       var ih torrent.InfoHash
+       util.CopyExact(ih[:], layout.Metainfo.Info.Hash)
+       leecher.AddPeers(ih, []torrent.Peer{func() torrent.Peer {
+               tcpAddr := seeder.ListenAddr().(*net.TCPAddr)
                return torrent.Peer{
                        IP:   tcpAddr.IP,
                        Port: tcpAddr.Port,
                }
        }()})
-       fs := New(&leecher)
+       fs := New(leecher)
        mountDir := layout.MountDir
        fuseConn, err := fuse.Mount(layout.MountDir)
        if err != nil {