]> Sergey Matveev's repositories - btrtrc.git/commitdiff
tracker: Add Announce.Context
authorMatt Joiner <anacrolix@gmail.com>
Wed, 28 Nov 2018 01:02:12 +0000 (12:02 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 28 Nov 2018 01:02:12 +0000 (12:02 +1100)
Use it to rewrite a test that fails with recent go versions due to logging after test completion.

tracker/http.go
tracker/tracker.go
tracker/udp.go
tracker/udp_test.go

index 69764040d8996a51587872ff6eabe145143f6b12..87927e773214ba348cadb1b09cf98e6c956f766b 100644 (file)
@@ -98,6 +98,9 @@ func announceHTTP(opt Announce, _url *url.URL) (ret AnnounceResponse, err error)
        req, err := http.NewRequest("GET", _url.String(), nil)
        req.Header.Set("User-Agent", opt.UserAgent)
        req.Host = opt.HostHeader
+       if opt.Context != nil {
+               req = req.WithContext(opt.Context)
+       }
        resp, err := (&http.Client{
                Timeout: time.Second * 15,
                Transport: &http.Transport{
index a56a99b74adb221ffa48f830f3eb374b5efde7df..e7260503df41e9458e14e10be9826b6a03dd4ec4 100644 (file)
@@ -1,6 +1,7 @@
 package tracker
 
 import (
+       "context"
        "errors"
        "net/http"
        "net/url"
@@ -61,6 +62,7 @@ type Announce struct {
        ClientIp4 krpc.NodeAddr
        // If the port is zero, it's assumed to be the same as the Request.Port
        ClientIp6 krpc.NodeAddr
+       Context   context.Context
 }
 
 // In an FP language with currying, what order what you put these params?
index 656cc7df5c551797d413ca0a1a491095bd495d09..c694285dc9ca5f2b5629f095d99ac9169cbc7633 100644 (file)
@@ -2,9 +2,9 @@ package tracker
 
 import (
        "bytes"
+       "context"
        "encoding"
        "encoding/binary"
-       "errors"
        "fmt"
        "io"
        "math/rand"
@@ -15,6 +15,7 @@ import (
        "github.com/anacrolix/dht/krpc"
        "github.com/anacrolix/missinggo"
        "github.com/anacrolix/missinggo/pproffd"
+       "github.com/pkg/errors"
 )
 
 type Action int32
@@ -188,39 +189,55 @@ func write(w io.Writer, data interface{}) error {
 
 // args is the binary serializable request body. trailer is optional data
 // following it, such as for BEP 41.
-func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (responseBody *bytes.Buffer, err error) {
+func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (*bytes.Buffer, error) {
        tid := newTransactionId()
-       err = c.write(&RequestHeader{
-               ConnectionId:  c.connectionId,
-               Action:        action,
-               TransactionId: tid,
-       }, args, options)
-       if err != nil {
-               return
+       if err := errors.Wrap(
+               c.write(
+                       &RequestHeader{
+                               ConnectionId:  c.connectionId,
+                               Action:        action,
+                               TransactionId: tid,
+                       }, args, options),
+               "writing request",
+       ); err != nil {
+               return nil, err
        }
        c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts)))
        b := make([]byte, 0x800) // 2KiB
        for {
-               var n int
-               n, err = c.socket.Read(b)
-               if opE, ok := err.(*net.OpError); ok {
-                       if opE.Timeout() {
-                               c.contiguousTimeouts++
-                               return
-                       }
+               var (
+                       n        int
+                       readErr  error
+                       readDone = make(chan struct{})
+               )
+               go func() {
+                       defer close(readDone)
+                       n, readErr = c.socket.Read(b)
+               }()
+               ctx := c.a.Context
+               if ctx == nil {
+                       ctx = context.Background()
                }
-               if err != nil {
-                       return
+               select {
+               case <-ctx.Done():
+                       return nil, ctx.Err()
+               case <-readDone:
+               }
+               if opE, ok := readErr.(*net.OpError); ok && opE.Timeout() {
+                       c.contiguousTimeouts++
+               }
+               if readErr != nil {
+                       return nil, errors.Wrap(readErr, "reading from socket")
                }
                buf := bytes.NewBuffer(b[:n])
                var h ResponseHeader
-               err = binary.Read(buf, binary.BigEndian, &h)
+               err := binary.Read(buf, binary.BigEndian, &h)
                switch err {
-               case io.ErrUnexpectedEOF:
+               default:
+                       panic(err)
+               case io.ErrUnexpectedEOF, io.EOF:
                        continue
                case nil:
-               default:
-                       return
                }
                if h.TransactionId != tid {
                        continue
@@ -229,8 +246,7 @@ func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (
                if h.Action == ActionError {
                        err = errors.New(buf.String())
                }
-               responseBody = buf
-               return
+               return buf, err
        }
 }
 
index a6d4cb1531435860e4c9dc972a4ad9de361c97dd..ea500f18e6dc88c8e59095062348a068f3de29b2 100644 (file)
@@ -2,6 +2,7 @@ package tracker
 
 import (
        "bytes"
+       "context"
        "crypto/rand"
        "encoding/binary"
        "fmt"
@@ -160,8 +161,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
        rand.Read(req.PeerId[:])
        rand.Read(req.InfoHash[:])
        wg := sync.WaitGroup{}
-       success := make(chan bool)
-       fail := make(chan struct{})
+       ctx, cancel := context.WithCancel(context.Background())
        for _, url := range trackers {
                wg.Add(1)
                go func(url string) {
@@ -169,6 +169,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
                        resp, err := Announce{
                                TrackerUrl: url,
                                Request:    req,
+                               Context:    ctx,
                        }.Do()
                        if err != nil {
                                t.Logf("error announcing to %s: %s", url, err)
@@ -180,21 +181,10 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
                                t.Fatal(resp)
                        }
                        t.Logf("announced to %s", url)
-                       // TODO: Can probably get stuck here, but it's just a throwaway
-                       // test.
-                       success <- true
+                       cancel()
                }(url)
        }
-       go func() {
-               wg.Wait()
-               close(fail)
-       }()
-       select {
-       case <-fail:
-               // It doesn't matter if they all fail, the servers could just be down.
-       case <-success:
-               // Bail as quickly as we can. One success is enough.
-       }
+       wg.Wait()
 }
 
 // Check that URLPath option is done correctly.