]> Sergey Matveev's repositories - btrtrc.git/commitdiff
cmd/torrent: Add download --save-metainfos and fix up signal notification
authorMatt Joiner <anacrolix@gmail.com>
Thu, 17 Mar 2022 04:08:06 +0000 (15:08 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 17 Mar 2022 04:08:06 +0000 (15:08 +1100)
cmd/torrent/download.go

index 0b5d4d584b62d923cbe407bf9d40e1da5e3fec21..123b3659a483c54ee69def1d3b530e5f66abda1d 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "context"
        "errors"
        "expvar"
        "fmt"
@@ -15,7 +16,6 @@ import (
        "time"
 
        "github.com/anacrolix/log"
-       "github.com/anacrolix/missinggo/v2"
        "github.com/anacrolix/tagflag"
        "github.com/anacrolix/torrent"
        "github.com/anacrolix/torrent/iplist"
@@ -89,7 +89,7 @@ func resolveTestPeers(addrs []string) (ret []torrent.PeerInfo) {
        return
 }
 
-func addTorrents(client *torrent.Client, flags downloadFlags) error {
+func addTorrents(ctx context.Context, client *torrent.Client, flags downloadFlags, wg *sync.WaitGroup) error {
        testPeers := resolveTestPeers(flags.TestPeer)
        for _, arg := range flags.Torrent {
                t, err := func() (*torrent.Torrent, error) {
@@ -137,10 +137,30 @@ func addTorrents(client *torrent.Client, flags downloadFlags) error {
                        torrentBar(t, flags.PieceStates)
                }
                t.AddPeers(testPeers)
+               wg.Add(1)
                go func() {
-                       <-t.GotInfo()
+                       defer wg.Done()
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case <-t.GotInfo():
+                       }
+                       if flags.SaveMetainfos {
+                               path := fmt.Sprintf("%v.torrent", t.InfoHash().HexString())
+                               err := writeMetainfoToFile(t.Metainfo(), path)
+                               if err == nil {
+                                       log.Printf("wrote %q", path)
+                               } else {
+                                       log.Printf("error writing %q: %v", path, err)
+                               }
+                       }
                        if len(flags.File) == 0 {
                                t.DownloadAll()
+                               wg.Add(1)
+                               go func() {
+                                       defer wg.Done()
+                                       waitForPieces(ctx, t, 0, t.NumPieces())
+                               }()
                                if flags.LinearDiscard {
                                        r := t.NewReader()
                                        io.Copy(io.Discard, r)
@@ -150,6 +170,11 @@ func addTorrents(client *torrent.Client, flags downloadFlags) error {
                                for _, f := range t.Files() {
                                        for _, fileArg := range flags.File {
                                                if f.DisplayPath() == fileArg {
+                                                       wg.Add(1)
+                                                       go func() {
+                                                               defer wg.Done()
+                                                               waitForPieces(ctx, t, f.BeginPieceIndex(), f.EndPieceIndex())
+                                                       }()
                                                        f.Download()
                                                        if flags.LinearDiscard {
                                                                r := f.NewReader()
@@ -167,12 +192,52 @@ func addTorrents(client *torrent.Client, flags downloadFlags) error {
        return nil
 }
 
+func waitForPieces(ctx context.Context, t *torrent.Torrent, beginIndex, endIndex int) {
+       sub := t.SubscribePieceStateChanges()
+       defer sub.Close()
+       pending := make(map[int]struct{})
+       for i := beginIndex; i < endIndex; i++ {
+               pending[i] = struct{}{}
+       }
+       expected := storage.Completion{
+               Complete: true,
+               Ok:       true,
+       }
+       for {
+               select {
+               case ev := <-sub.Values:
+                       if ev.Completion == expected {
+                               delete(pending, ev.Index)
+                       }
+                       if len(pending) == 0 {
+                               return
+                       }
+               case <-ctx.Done():
+                       return
+               }
+       }
+}
+
+func writeMetainfoToFile(mi metainfo.MetaInfo, path string) error {
+       f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0640)
+       if err != nil {
+               return err
+       }
+       defer f.Close()
+       err = mi.Write(f)
+       if err != nil {
+               return err
+       }
+       return f.Close()
+}
+
 type downloadFlags struct {
        Debug bool
        DownloadCmd
 }
 
 type DownloadCmd struct {
+       SaveMetainfos      bool
        Mmap               bool           `help:"memory-map torrent data"`
        Seed               bool           `help:"seed after download is complete"`
        Addr               string         `help:"network listen addr"`
@@ -211,15 +276,6 @@ func statsEnabled(flags downloadFlags) bool {
        return flags.Stats
 }
 
-func exitSignalHandlers(notify *missinggo.SynchronizedEvent) {
-       c := make(chan os.Signal, 1)
-       signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
-       for {
-               log.Printf("close signal received: %+v", <-c)
-               notify.Set()
-       }
-}
-
 func downloadErr(flags downloadFlags) error {
        clientConfig := torrent.NewDefaultClientConfig()
        clientConfig.DisableWebseeds = flags.DisableWebseeds
@@ -269,35 +325,29 @@ func downloadErr(flags downloadFlags) error {
        }
        clientConfig.MaxUnverifiedBytes = flags.MaxUnverifiedBytes.Int64()
 
-       var stop missinggo.SynchronizedEvent
-       defer func() {
-               stop.Set()
-       }()
+       ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+       defer cancel()
 
        client, err := torrent.NewClient(clientConfig)
        if err != nil {
                return fmt.Errorf("creating client: %w", err)
        }
-       var clientClose sync.Once // In certain situations, close was being called more than once.
-       defer clientClose.Do(func() { client.Close() })
-       go exitSignalHandlers(&stop)
-       go func() {
-               <-stop.C()
-               clientClose.Do(func() { client.Close() })
-       }()
+       defer client.Close()
 
        // Write status on the root path on the default HTTP muxer. This will be bound to localhost
        // somewhere if GOPPROF is set, thanks to the envpprof import.
        http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
                client.WriteStatus(w)
        })
-       err = addTorrents(client, flags)
-       started := time.Now()
+       var wg sync.WaitGroup
+       err = addTorrents(ctx, client, flags, &wg)
        if err != nil {
                return fmt.Errorf("adding torrents: %w", err)
        }
+       started := time.Now()
        defer outputStats(client, flags)
-       if client.WaitAll() {
+       wg.Wait()
+       if ctx.Err() == nil {
                log.Print("downloaded ALL the torrents")
        } else {
                err = errors.New("y u no complete torrents?!")
@@ -314,7 +364,7 @@ func downloadErr(flags downloadFlags) error {
                        log.Print("no torrents to seed")
                } else {
                        outputStats(client, flags)
-                       <-stop.C()
+                       <-ctx.Done()
                }
        }
        spew.Dump(expvar.Get("torrent").(*expvar.Map).Get("chunks received"))