From de81778ecc8524b9fe1ae8fa02992b0088fa5d68 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Fri, 30 May 2025 13:50:32 +1000 Subject: [PATCH] Add some synchronization on partial file renames --- storage/file-misc.go | 7 +++++- storage/file-piece.go | 30 ++++++++++++++--------- storage/file-torrent-io.go | 50 ++++++++++++++++++++------------------ storage/file-torrent.go | 5 ++-- torrent.go | 2 +- 5 files changed, 56 insertions(+), 38 deletions(-) diff --git a/storage/file-misc.go b/storage/file-misc.go index c298e516..07e98e92 100644 --- a/storage/file-misc.go +++ b/storage/file-misc.go @@ -6,6 +6,7 @@ import ( "io/fs" "os" "path/filepath" + "sync" "github.com/anacrolix/torrent/segments" ) @@ -51,13 +52,17 @@ func CreateNativeZeroLengthFile(name string) error { } type file struct { + // This protects high level OS file state like partial file name, permission mod, renaming etc. + mu sync.RWMutex // The safe, OS-local file path. safeOsPath string beginPieceIndex int endPieceIndex int length int64 + // Utility value to help the race detector find issues for us. + race byte } -func (f file) partFilePath() string { +func (f *file) partFilePath() string { return f.safeOsPath + ".part" } diff --git a/storage/file-piece.go b/storage/file-piece.go index d52e7035..31bc9463 100644 --- a/storage/file-piece.go +++ b/storage/file-piece.go @@ -44,10 +44,10 @@ func (me *filePieceImpl) extent() segments.Extent { } } -func (me *filePieceImpl) pieceFiles() iter.Seq2[int, file] { - return func(yield func(int, file) bool) { +func (me *filePieceImpl) pieceFiles() iter.Seq2[int, *file] { + return func(yield func(int, *file) bool) { for fileIndex := range me.t.segmentLocater.LocateIter(me.extent()) { - if !yield(fileIndex, me.t.files[fileIndex]) { + if !yield(fileIndex, &me.t.files[fileIndex]) { return } } @@ -69,12 +69,14 @@ func (me *filePieceImpl) Completion() Completion { // If it's allegedly complete, check that its constituent files have the necessary length. for i, extent := range me.t.segmentLocater.LocateIter(me.extent()) { noFiles = false - file := me.t.files[i] - s, err := os.Stat(file.partFilePath()) - if errors.Is(err, fs.ErrNotExist) { + file := &me.t.files[i] + file.mu.RLock() + s, err := os.Stat(file.safeOsPath) + if me.partFiles() && errors.Is(err, fs.ErrNotExist) { // Can we use shared files for this? Is it faster? - s, err = os.Stat(file.safeOsPath) + s, err = os.Stat(file.partFilePath()) } + file.mu.RUnlock() if err != nil { me.logger().Warn( "error checking file for piece marked as complete", @@ -134,7 +136,7 @@ func (me *filePieceImpl) MarkComplete() (err error) { return } -func (me *filePieceImpl) allFilePiecesComplete(f file) (ret g.Result[bool]) { +func (me *filePieceImpl) allFilePiecesComplete(f *file) (ret g.Result[bool]) { next, stop := iter.Pull(GetPieceCompletionRange( me.t.pieceCompletion(), me.t.infoHash, @@ -176,7 +178,10 @@ func (me *filePieceImpl) MarkNotComplete() (err error) { } -func (me *filePieceImpl) promotePartFile(f file) (err error) { +func (me *filePieceImpl) promotePartFile(f *file) (err error) { + f.mu.Lock() + defer f.mu.Unlock() + f.race++ if me.partFiles() { err = me.exclRenameIfExists(f.partFilePath(), f.safeOsPath) if err != nil { @@ -240,7 +245,10 @@ func (me *filePieceImpl) exclRenameIfExists(from, to string) error { return nil } -func (me *filePieceImpl) onFileNotComplete(f file) (err error) { +func (me *filePieceImpl) onFileNotComplete(f *file) (err error) { + f.mu.Lock() + defer f.mu.Unlock() + f.race++ if me.partFiles() { err = me.exclRenameIfExists(f.safeOsPath, f.partFilePath()) if err != nil { @@ -265,7 +273,7 @@ func (me *filePieceImpl) onFileNotComplete(f file) (err error) { return } -func (me *filePieceImpl) pathForWrite(f file) string { +func (me *filePieceImpl) pathForWrite(f *file) string { return me.t.pathForWrite(f) } diff --git a/storage/file-torrent-io.go b/storage/file-torrent-io.go index e57690e0..b7fdb70d 100644 --- a/storage/file-torrent-io.go +++ b/storage/file-torrent-io.go @@ -16,9 +16,10 @@ type fileTorrentImplIO struct { } // Returns EOF on short or missing file. -func (fst fileTorrentImplIO) readFileAt(file file, b []byte, off int64) (n int, err error) { +func (fst fileTorrentImplIO) readFileAt(file *file, b []byte, off int64) (n int, err error) { fst.fts.logger().Debug("readFileAt", "file.safeOsPath", file.safeOsPath) var f sharedFileIf + file.mu.RLock() // Fine to open once under each name on a unix system. We could make the shared file keys more // constrained but it shouldn't matter. TODO: Ensure at most one of the names exist. if fst.fts.partFiles() { @@ -27,6 +28,7 @@ func (fst fileTorrentImplIO) readFileAt(file file, b []byte, off int64) (n int, if err == nil && f == nil || errors.Is(err, fs.ErrNotExist) { f, err = sharedFiles.Open(file.safeOsPath) } + file.mu.RUnlock() if errors.Is(err, fs.ErrNotExist) { // File missing is treated the same as a short file. Should we propagate this through the // interface now that fs.ErrNotExist is a thing? @@ -57,7 +59,7 @@ func (fst fileTorrentImplIO) readFileAt(file file, b []byte, off int64) (n int, // Only returns EOF at the end of the torrent. Premature EOF is ErrUnexpectedEOF. func (fst fileTorrentImplIO) ReadAt(b []byte, off int64) (n int, err error) { fst.fts.segmentLocater.Locate(segments.Extent{off, int64(len(b))}, func(i int, e segments.Extent) bool { - n1, err1 := fst.readFileAt(fst.fts.files[i], b[:e.Length], e.Start) + n1, err1 := fst.readFileAt(&fst.fts.files[i], b[:e.Length], e.Start) n += n1 b = b[n1:] err = err1 @@ -69,7 +71,7 @@ func (fst fileTorrentImplIO) ReadAt(b []byte, off int64) (n int, err error) { return } -func (fst fileTorrentImplIO) openForWrite(file file) (f *os.File, err error) { +func (fst fileTorrentImplIO) openForWrite(file *file) (f *os.File, err error) { // It might be possible to have a writable handle shared files cache if we need it. fst.fts.logger().Debug("openForWrite", "file.safeOsPath", file.safeOsPath) p := fst.fts.pathForWrite(file) @@ -94,25 +96,27 @@ func (fst fileTorrentImplIO) openForWrite(file file) (f *os.File, err error) { func (fst fileTorrentImplIO) WriteAt(p []byte, off int64) (n int, err error) { // log.Printf("write at %v: %v bytes", off, len(p)) - fst.fts.segmentLocater.Locate(segments.Extent{off, int64(len(p))}, func(i int, e segments.Extent) bool { - var f *os.File - f, err = fst.openForWrite(fst.fts.files[i]) - if err != nil { - return false - } - var n1 int - n1, err = f.WriteAt(p[:e.Length], e.Start) - // log.Printf("%v %v wrote %v: %v", i, e, n1, err) - closeErr := f.Close() - n += n1 - p = p[n1:] - if err == nil { - err = closeErr - } - if err == nil && int64(n1) != e.Length { - err = io.ErrShortWrite - } - return err == nil - }) + fst.fts.segmentLocater.Locate( + segments.Extent{off, int64(len(p))}, + func(i int, e segments.Extent) bool { + var f *os.File + f, err = fst.openForWrite(&fst.fts.files[i]) + if err != nil { + return false + } + var n1 int + n1, err = f.WriteAt(p[:e.Length], e.Start) + // log.Printf("%v %v wrote %v: %v", i, e, n1, err) + closeErr := f.Close() + n += n1 + p = p[n1:] + if err == nil { + err = closeErr + } + if err == nil && int64(n1) != e.Length { + err = io.ErrShortWrite + } + return err == nil + }) return } diff --git a/storage/file-torrent.go b/storage/file-torrent.go index 09caa255..ad5a6b2a 100644 --- a/storage/file-torrent.go +++ b/storage/file-torrent.go @@ -76,7 +76,7 @@ func (fts *fileTorrentImpl) partFiles() bool { return fts.client.opts.partFiles() } -func (fts *fileTorrentImpl) pathForWrite(f file) string { +func (fts *fileTorrentImpl) pathForWrite(f *file) string { if fts.partFiles() { return f.partFilePath() } @@ -106,7 +106,8 @@ func (fs *fileTorrentImpl) Close() error { } func (fts *fileTorrentImpl) Flush() error { - for _, f := range fts.files { + for i := range fts.files { + f := &fts.files[i] fts.logger().Debug("flushing", "file.safeOsPath", f.safeOsPath) if err := fsync(fts.pathForWrite(f)); err != nil { return err diff --git a/torrent.go b/torrent.go index 2f9d42e2..09a9b81b 100644 --- a/torrent.go +++ b/torrent.go @@ -1117,7 +1117,7 @@ func (t *Torrent) offsetRequest(off int64) (req Request, ok bool) { } func (t *Torrent) writeChunk(piece int, begin int64, data []byte) (err error) { - n, err := t.pieces[piece].Storage().WriteAt(data, begin) + n, err := t.piece(piece).Storage().WriteAt(data, begin) if err == nil && n != len(data) { err = io.ErrShortWrite } -- 2.51.0