]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add Reader.ReadContext
authorMatt Joiner <anacrolix@gmail.com>
Sat, 30 Apr 2016 01:08:29 +0000 (11:08 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 30 Apr 2016 01:08:29 +0000 (11:08 +1000)
Allows cancelling reads etc. Torrents that get stuck can result in Reads that won't return until the torrent is dropped.

reader.go
reader_test.go [new file with mode: 0644]

index 46e41dff06b40e4d53a54d8b34eb8ff613f3fb6a..709b69c7a056496b65bf43040d91c304b602e2db 100644 (file)
--- a/reader.go
+++ b/reader.go
@@ -1,6 +1,7 @@
 package torrent
 
 import (
+       "context"
        "errors"
        "io"
        "os"
@@ -91,11 +92,31 @@ func (r *Reader) waitReadable(off int64) {
 }
 
 func (r *Reader) Read(b []byte) (n int, err error) {
+       return r.ReadContext(b, context.Background())
+}
+
+func (r *Reader) ReadContext(b []byte, ctx context.Context) (n int, err error) {
+       // This is set under the Client lock if the Context is canceled.
+       var ctxErr error
+       if ctx.Done() != nil {
+               ctx, cancel := context.WithCancel(ctx)
+               // Abort the goroutine when the function returns.
+               defer cancel()
+               go func() {
+                       <-ctx.Done()
+                       r.t.cl.mu.Lock()
+                       ctxErr = ctx.Err()
+                       r.t.cl.event.Broadcast()
+                       r.t.cl.mu.Unlock()
+               }()
+       }
+       // Hmmm, if a Read gets stuck, this means you can't change position for
+       // other purposes. That seems reasonable, but unusual.
        r.opMu.Lock()
        defer r.opMu.Unlock()
        for len(b) != 0 {
                var n1 int
-               n1, err = r.readOnceAt(b, r.pos)
+               n1, err = r.readOnceAt(b, r.pos, &ctxErr)
                if n1 == 0 {
                        if err == nil {
                                panic("expected error")
@@ -123,28 +144,32 @@ func (r *Reader) torrentClosed() bool {
 
 // Wait until some data should be available to read. Tickles the client if it
 // isn't. Returns how much should be readable without blocking.
-func (r *Reader) waitAvailable(pos, wanted int64) (avail int64) {
+func (r *Reader) waitAvailable(pos, wanted int64, ctxErr *error) (avail int64) {
        r.t.cl.mu.Lock()
        defer r.t.cl.mu.Unlock()
-       for !r.readable(pos) {
+       for !r.readable(pos) && *ctxErr == nil {
                r.waitReadable(pos)
        }
        return r.available(pos, wanted)
 }
 
 // Performs at most one successful read to torrent storage.
-func (r *Reader) readOnceAt(b []byte, pos int64) (n int, err error) {
+func (r *Reader) readOnceAt(b []byte, pos int64, ctxErr *error) (n int, err error) {
        if pos >= r.t.length {
                err = io.EOF
                return
        }
        for {
-               avail := r.waitAvailable(pos, int64(len(b)))
+               avail := r.waitAvailable(pos, int64(len(b)), ctxErr)
                if avail == 0 {
                        if r.torrentClosed() {
                                err = errors.New("torrent closed")
                                return
                        }
+                       if *ctxErr != nil {
+                               err = *ctxErr
+                               return
+                       }
                }
                b1 := b[:avail]
                pi := int(pos / r.t.Info().PieceLength)
diff --git a/reader_test.go b/reader_test.go
new file mode 100644 (file)
index 0000000..4206976
--- /dev/null
@@ -0,0 +1,25 @@
+package torrent
+
+import (
+       "context"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/require"
+
+       "github.com/anacrolix/torrent/internal/testutil"
+)
+
+func TestReaderReadContext(t *testing.T) {
+       cl, err := NewClient(&TestingConfig)
+       require.NoError(t, err)
+       defer cl.Close()
+       tt, err := cl.AddTorrent(testutil.GreetingMetaInfo())
+       require.NoError(t, err)
+       defer tt.Drop()
+       ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond))
+       r := tt.NewReader()
+       defer r.Close()
+       _, err = r.ReadContext(make([]byte, 1), ctx)
+       require.EqualValues(t, context.DeadlineExceeded, err)
+}