]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add some synchronization on partial file renames
authorMatt Joiner <anacrolix@gmail.com>
Fri, 30 May 2025 03:50:32 +0000 (13:50 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 30 May 2025 03:50:32 +0000 (13:50 +1000)
storage/file-misc.go
storage/file-piece.go
storage/file-torrent-io.go
storage/file-torrent.go
torrent.go

index c298e51677b617e5dec2e030fee1bec5783acd8d..07e98e92ea0ab60b648b455b35d1f99506df29f2 100644 (file)
@@ -6,6 +6,7 @@ import (
        "io/fs"
        "os"
        "path/filepath"
+       "sync"
 
        "github.com/anacrolix/torrent/segments"
 )
@@ -51,13 +52,17 @@ func CreateNativeZeroLengthFile(name string) error {
 }
 
 type file struct {
+       // This protects high level OS file state like partial file name, permission mod, renaming etc.
+       mu sync.RWMutex
        // The safe, OS-local file path.
        safeOsPath      string
        beginPieceIndex int
        endPieceIndex   int
        length          int64
+       // Utility value to help the race detector find issues for us.
+       race byte
 }
 
-func (f file) partFilePath() string {
+func (f *file) partFilePath() string {
        return f.safeOsPath + ".part"
 }
index d52e70351cef59752774111070bb2fede09644b1..31bc94634f5c6648b99560417934bb708b01edc1 100644 (file)
@@ -44,10 +44,10 @@ func (me *filePieceImpl) extent() segments.Extent {
        }
 }
 
-func (me *filePieceImpl) pieceFiles() iter.Seq2[int, file] {
-       return func(yield func(int, file) bool) {
+func (me *filePieceImpl) pieceFiles() iter.Seq2[int, *file] {
+       return func(yield func(int, *file) bool) {
                for fileIndex := range me.t.segmentLocater.LocateIter(me.extent()) {
-                       if !yield(fileIndex, me.t.files[fileIndex]) {
+                       if !yield(fileIndex, &me.t.files[fileIndex]) {
                                return
                        }
                }
@@ -69,12 +69,14 @@ func (me *filePieceImpl) Completion() Completion {
                // If it's allegedly complete, check that its constituent files have the necessary length.
                for i, extent := range me.t.segmentLocater.LocateIter(me.extent()) {
                        noFiles = false
-                       file := me.t.files[i]
-                       s, err := os.Stat(file.partFilePath())
-                       if errors.Is(err, fs.ErrNotExist) {
+                       file := &me.t.files[i]
+                       file.mu.RLock()
+                       s, err := os.Stat(file.safeOsPath)
+                       if me.partFiles() && errors.Is(err, fs.ErrNotExist) {
                                // Can we use shared files for this? Is it faster?
-                               s, err = os.Stat(file.safeOsPath)
+                               s, err = os.Stat(file.partFilePath())
                        }
+                       file.mu.RUnlock()
                        if err != nil {
                                me.logger().Warn(
                                        "error checking file for piece marked as complete",
@@ -134,7 +136,7 @@ func (me *filePieceImpl) MarkComplete() (err error) {
        return
 }
 
-func (me *filePieceImpl) allFilePiecesComplete(f file) (ret g.Result[bool]) {
+func (me *filePieceImpl) allFilePiecesComplete(f *file) (ret g.Result[bool]) {
        next, stop := iter.Pull(GetPieceCompletionRange(
                me.t.pieceCompletion(),
                me.t.infoHash,
@@ -176,7 +178,10 @@ func (me *filePieceImpl) MarkNotComplete() (err error) {
 
 }
 
-func (me *filePieceImpl) promotePartFile(f file) (err error) {
+func (me *filePieceImpl) promotePartFile(f *file) (err error) {
+       f.mu.Lock()
+       defer f.mu.Unlock()
+       f.race++
        if me.partFiles() {
                err = me.exclRenameIfExists(f.partFilePath(), f.safeOsPath)
                if err != nil {
@@ -240,7 +245,10 @@ func (me *filePieceImpl) exclRenameIfExists(from, to string) error {
        return nil
 }
 
-func (me *filePieceImpl) onFileNotComplete(f file) (err error) {
+func (me *filePieceImpl) onFileNotComplete(f *file) (err error) {
+       f.mu.Lock()
+       defer f.mu.Unlock()
+       f.race++
        if me.partFiles() {
                err = me.exclRenameIfExists(f.safeOsPath, f.partFilePath())
                if err != nil {
@@ -265,7 +273,7 @@ func (me *filePieceImpl) onFileNotComplete(f file) (err error) {
        return
 }
 
-func (me *filePieceImpl) pathForWrite(f file) string {
+func (me *filePieceImpl) pathForWrite(f *file) string {
        return me.t.pathForWrite(f)
 }
 
index e57690e0a58c336e494a73a1b44335b87d587708..b7fdb70db10a148f42b9770b6a9e8ca86a02968f 100644 (file)
@@ -16,9 +16,10 @@ type fileTorrentImplIO struct {
 }
 
 // Returns EOF on short or missing file.
-func (fst fileTorrentImplIO) readFileAt(file file, b []byte, off int64) (n int, err error) {
+func (fst fileTorrentImplIO) readFileAt(file *file, b []byte, off int64) (n int, err error) {
        fst.fts.logger().Debug("readFileAt", "file.safeOsPath", file.safeOsPath)
        var f sharedFileIf
+       file.mu.RLock()
        // Fine to open once under each name on a unix system. We could make the shared file keys more
        // constrained but it shouldn't matter. TODO: Ensure at most one of the names exist.
        if fst.fts.partFiles() {
@@ -27,6 +28,7 @@ func (fst fileTorrentImplIO) readFileAt(file file, b []byte, off int64) (n int,
        if err == nil && f == nil || errors.Is(err, fs.ErrNotExist) {
                f, err = sharedFiles.Open(file.safeOsPath)
        }
+       file.mu.RUnlock()
        if errors.Is(err, fs.ErrNotExist) {
                // File missing is treated the same as a short file. Should we propagate this through the
                // interface now that fs.ErrNotExist is a thing?
@@ -57,7 +59,7 @@ func (fst fileTorrentImplIO) readFileAt(file file, b []byte, off int64) (n int,
 // Only returns EOF at the end of the torrent. Premature EOF is ErrUnexpectedEOF.
 func (fst fileTorrentImplIO) ReadAt(b []byte, off int64) (n int, err error) {
        fst.fts.segmentLocater.Locate(segments.Extent{off, int64(len(b))}, func(i int, e segments.Extent) bool {
-               n1, err1 := fst.readFileAt(fst.fts.files[i], b[:e.Length], e.Start)
+               n1, err1 := fst.readFileAt(&fst.fts.files[i], b[:e.Length], e.Start)
                n += n1
                b = b[n1:]
                err = err1
@@ -69,7 +71,7 @@ func (fst fileTorrentImplIO) ReadAt(b []byte, off int64) (n int, err error) {
        return
 }
 
-func (fst fileTorrentImplIO) openForWrite(file file) (f *os.File, err error) {
+func (fst fileTorrentImplIO) openForWrite(file *file) (f *os.File, err error) {
        // It might be possible to have a writable handle shared files cache if we need it.
        fst.fts.logger().Debug("openForWrite", "file.safeOsPath", file.safeOsPath)
        p := fst.fts.pathForWrite(file)
@@ -94,25 +96,27 @@ func (fst fileTorrentImplIO) openForWrite(file file) (f *os.File, err error) {
 
 func (fst fileTorrentImplIO) WriteAt(p []byte, off int64) (n int, err error) {
        // log.Printf("write at %v: %v bytes", off, len(p))
-       fst.fts.segmentLocater.Locate(segments.Extent{off, int64(len(p))}, func(i int, e segments.Extent) bool {
-               var f *os.File
-               f, err = fst.openForWrite(fst.fts.files[i])
-               if err != nil {
-                       return false
-               }
-               var n1 int
-               n1, err = f.WriteAt(p[:e.Length], e.Start)
-               // log.Printf("%v %v wrote %v: %v", i, e, n1, err)
-               closeErr := f.Close()
-               n += n1
-               p = p[n1:]
-               if err == nil {
-                       err = closeErr
-               }
-               if err == nil && int64(n1) != e.Length {
-                       err = io.ErrShortWrite
-               }
-               return err == nil
-       })
+       fst.fts.segmentLocater.Locate(
+               segments.Extent{off, int64(len(p))},
+               func(i int, e segments.Extent) bool {
+                       var f *os.File
+                       f, err = fst.openForWrite(&fst.fts.files[i])
+                       if err != nil {
+                               return false
+                       }
+                       var n1 int
+                       n1, err = f.WriteAt(p[:e.Length], e.Start)
+                       // log.Printf("%v %v wrote %v: %v", i, e, n1, err)
+                       closeErr := f.Close()
+                       n += n1
+                       p = p[n1:]
+                       if err == nil {
+                               err = closeErr
+                       }
+                       if err == nil && int64(n1) != e.Length {
+                               err = io.ErrShortWrite
+                       }
+                       return err == nil
+               })
        return
 }
index 09caa255323e4447962f31233e1a7c39886b9610..ad5a6b2a9d5aca53f8af678867ae34aae6133810 100644 (file)
@@ -76,7 +76,7 @@ func (fts *fileTorrentImpl) partFiles() bool {
        return fts.client.opts.partFiles()
 }
 
-func (fts *fileTorrentImpl) pathForWrite(f file) string {
+func (fts *fileTorrentImpl) pathForWrite(f *file) string {
        if fts.partFiles() {
                return f.partFilePath()
        }
@@ -106,7 +106,8 @@ func (fs *fileTorrentImpl) Close() error {
 }
 
 func (fts *fileTorrentImpl) Flush() error {
-       for _, f := range fts.files {
+       for i := range fts.files {
+               f := &fts.files[i]
                fts.logger().Debug("flushing", "file.safeOsPath", f.safeOsPath)
                if err := fsync(fts.pathForWrite(f)); err != nil {
                        return err
index 2f9d42e2b331fa6fc671ba52291322638c47ac22..09a9b81ba9d6ce2ec0d494ac75b8eb05bffbac11 100644 (file)
@@ -1117,7 +1117,7 @@ func (t *Torrent) offsetRequest(off int64) (req Request, ok bool) {
 }
 
 func (t *Torrent) writeChunk(piece int, begin int64, data []byte) (err error) {
-       n, err := t.pieces[piece].Storage().WriteAt(data, begin)
+       n, err := t.piece(piece).Storage().WriteAt(data, begin)
        if err == nil && n != len(data) {
                err = io.ErrShortWrite
        }