]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Rewrite UDP tracker client
authorMatt Joiner <anacrolix@gmail.com>
Tue, 22 Jun 2021 12:36:43 +0000 (22:36 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 22 Jun 2021 12:36:43 +0000 (22:36 +1000)
12 files changed:
tracker/server.go
tracker/tracker.go
tracker/udp.go
tracker/udp/announce.go [new file with mode: 0644]
tracker/udp/client.go [new file with mode: 0644]
tracker/udp/dispatcher.go [new file with mode: 0644]
tracker/udp/options.go [new file with mode: 0644]
tracker/udp/protocol.go [new file with mode: 0644]
tracker/udp/timeout.go [new file with mode: 0644]
tracker/udp/timeout_test.go [new file with mode: 0644]
tracker/udp/transaction.go [new file with mode: 0644]
tracker/udp_test.go

index 34417be6fc1225df79f32f0d1f6ef94ab3268cc0..59c64f1aaa8376fe542a48e9e8c62b28e38ab658 100644 (file)
@@ -10,6 +10,7 @@ import (
 
        "github.com/anacrolix/dht/v2/krpc"
        "github.com/anacrolix/missinggo"
+       "github.com/anacrolix/torrent/tracker/udp"
 )
 
 type torrent struct {
@@ -36,7 +37,7 @@ func marshal(parts ...interface{}) (ret []byte, err error) {
        return
 }
 
-func (s *server) respond(addr net.Addr, rh ResponseHeader, parts ...interface{}) (err error) {
+func (s *server) respond(addr net.Addr, rh udp.ResponseHeader, parts ...interface{}) (err error) {
        b, err := marshal(append([]interface{}{rh}, parts...)...)
        if err != nil {
                return
@@ -61,34 +62,34 @@ func (s *server) serveOne() (err error) {
                return
        }
        r := bytes.NewReader(b[:n])
-       var h RequestHeader
-       err = readBody(r, &h)
+       var h udp.RequestHeader
+       err = udp.Read(r, &h)
        if err != nil {
                return
        }
        switch h.Action {
-       case ActionConnect:
-               if h.ConnectionId != connectRequestConnectionId {
+       case udp.ActionConnect:
+               if h.ConnectionId != udp.ConnectRequestConnectionId {
                        return
                }
                connId := s.newConn()
-               err = s.respond(addr, ResponseHeader{
-                       ActionConnect,
+               err = s.respond(addr, udp.ResponseHeader{
+                       udp.ActionConnect,
                        h.TransactionId,
-               }, ConnectionResponse{
+               }, udp.ConnectionResponse{
                        connId,
                })
                return
-       case ActionAnnounce:
+       case udp.ActionAnnounce:
                if _, ok := s.conns[h.ConnectionId]; !ok {
-                       s.respond(addr, ResponseHeader{
+                       s.respond(addr, udp.ResponseHeader{
                                TransactionId: h.TransactionId,
-                               Action:        ActionError,
+                               Action:        udp.ActionError,
                        }, []byte("not connected"))
                        return
                }
                var ar AnnounceRequest
-               err = readBody(r, &ar)
+               err = udp.Read(r, &ar)
                if err != nil {
                        return
                }
@@ -104,10 +105,10 @@ func (s *server) serveOne() (err error) {
                if err != nil {
                        panic(err)
                }
-               err = s.respond(addr, ResponseHeader{
+               err = s.respond(addr, udp.ResponseHeader{
                        TransactionId: h.TransactionId,
-                       Action:        ActionAnnounce,
-               }, AnnounceResponseHeader{
+                       Action:        udp.ActionAnnounce,
+               }, udp.AnnounceResponseHeader{
                        Interval: 900,
                        Leechers: t.Leechers,
                        Seeders:  t.Seeders,
@@ -115,9 +116,9 @@ func (s *server) serveOne() (err error) {
                return
        default:
                err = fmt.Errorf("unhandled action: %d", h.Action)
-               s.respond(addr, ResponseHeader{
+               s.respond(addr, udp.ResponseHeader{
                        TransactionId: h.TransactionId,
-                       Action:        ActionError,
+                       Action:        udp.ActionError,
                }, []byte("unhandled action"))
                return
        }
index 1b6d1412bb203fcf7e900de5661f8b228c70148c..0a187574476beab4e70590c55fb7ddad26797bf1 100644 (file)
@@ -8,23 +8,10 @@ import (
        "time"
 
        "github.com/anacrolix/dht/v2/krpc"
+       "github.com/anacrolix/torrent/tracker/udp"
 )
 
-// Marshalled as binary by the UDP client, so be careful making changes.
-type AnnounceRequest struct {
-       InfoHash   [20]byte
-       PeerId     [20]byte
-       Downloaded int64
-       Left       int64 // If less than 0, math.MaxInt64 will be used for HTTP trackers instead.
-       Uploaded   int64
-       // Apparently this is optional. None can be used for announces done at
-       // regular intervals.
-       Event     AnnounceEvent
-       IPAddress uint32
-       Key       int32
-       NumWant   int32 // How many peer addresses are desired. -1 for default.
-       Port      uint16
-} // 82 bytes
+type AnnounceRequest = udp.AnnounceRequest
 
 type AnnounceResponse struct {
        Interval int32 // Minimum seconds the local peer should wait before next announce.
@@ -33,12 +20,7 @@ type AnnounceResponse struct {
        Peers    []Peer
 }
 
-type AnnounceEvent int32
-
-func (e AnnounceEvent) String() string {
-       // See BEP 3, "event", and https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001.
-       return []string{"", "completed", "started", "stopped"}[e]
-}
+type AnnounceEvent = udp.AnnounceEvent
 
 const (
        None      AnnounceEvent = iota
index 033598e50f5163ab489038f51515c601e8a757c7..4fb00b117381834d33a54526281cbd4e356ff0ca 100644 (file)
 package tracker
 
 import (
-       "bytes"
-       "context"
        "encoding"
        "encoding/binary"
-       "fmt"
-       "io"
-       "math/rand"
        "net"
        "net/url"
-       "time"
 
        "github.com/anacrolix/dht/v2/krpc"
        "github.com/anacrolix/missinggo"
-       "github.com/anacrolix/missinggo/pproffd"
-       "github.com/pkg/errors"
+       "github.com/anacrolix/torrent/tracker/udp"
 )
 
-type Action int32
-
-const (
-       ActionConnect Action = iota
-       ActionAnnounce
-       ActionScrape
-       ActionError
-
-       connectRequestConnectionId = 0x41727101980
-
-       // BEP 41
-       optionTypeEndOfOptions = 0
-       optionTypeNOP          = 1
-       optionTypeURLData      = 2
-)
-
-type ConnectionRequest struct {
-       ConnectionId int64
-       Action       int32
-       TransctionId int32
-}
-
-type ConnectionResponse struct {
-       ConnectionId int64
-}
-
-type ResponseHeader struct {
-       Action        Action
-       TransactionId int32
-}
-
-type RequestHeader struct {
-       ConnectionId  int64
-       Action        Action
-       TransactionId int32
-} // 16 bytes
-
-type AnnounceResponseHeader struct {
-       Interval int32
-       Leechers int32
-       Seeders  int32
-}
-
-func newTransactionId() int32 {
-       return int32(rand.Uint32())
-}
-
-func timeout(contiguousTimeouts int) (d time.Duration) {
-       if contiguousTimeouts > 8 {
-               contiguousTimeouts = 8
-       }
-       d = 15 * time.Second
-       for ; contiguousTimeouts > 0; contiguousTimeouts-- {
-               d *= 2
-       }
-       return
-}
-
 type udpAnnounce struct {
-       contiguousTimeouts   int
-       connectionIdReceived time.Time
-       connectionId         int64
-       socket               net.Conn
-       url                  url.URL
-       a                    *Announce
+       url url.URL
+       a   *Announce
 }
 
 func (c *udpAnnounce) Close() error {
-       if c.socket != nil {
-               return c.socket.Close()
-       }
        return nil
 }
 
-func (c *udpAnnounce) ipv6() bool {
+func (c *udpAnnounce) ipv6(conn net.Conn) bool {
        if c.a.UdpNetwork == "udp6" {
                return true
        }
-       rip := missinggo.AddrIP(c.socket.RemoteAddr())
+       rip := missinggo.AddrIP(conn.RemoteAddr())
        return rip.To16() != nil && rip.To4() == nil
 }
 
 func (c *udpAnnounce) Do(req AnnounceRequest) (res AnnounceResponse, err error) {
-       err = c.connect()
+       conn, err := net.Dial(c.dialNetwork(), c.url.Host)
        if err != nil {
                return
        }
-       reqURI := c.url.RequestURI()
-       if c.ipv6() {
+       defer conn.Close()
+       if c.ipv6(conn) {
                // BEP 15
                req.IPAddress = 0
        } else if req.IPAddress == 0 && c.a.ClientIp4.IP != nil {
                req.IPAddress = binary.BigEndian.Uint32(c.a.ClientIp4.IP.To4())
        }
-       // Clearly this limits the request URI to 255 bytes. BEP 41 supports
-       // longer but I'm not fussed.
-       options := append([]byte{optionTypeURLData, byte(len(reqURI))}, []byte(reqURI)...)
-       vars.Add("udp tracker announces", 1)
-       b, err := c.request(ActionAnnounce, req, options)
-       if err != nil {
-               return
-       }
-       var h AnnounceResponseHeader
-       err = readBody(b, &h)
-       if err != nil {
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
+       d := udp.Dispatcher{}
+       go func() {
+               for {
+                       b := make([]byte, 0x800)
+                       n, err := conn.Read(b)
+                       if err != nil {
+                               break
+                       }
+                       d.Dispatch(b[:n])
                }
-               err = fmt.Errorf("error parsing announce response: %s", err)
-               return
+       }()
+       cl := udp.Client{
+               Dispatcher: &d,
+               Writer:     conn,
        }
-       res.Interval = h.Interval
-       res.Leechers = h.Leechers
-       res.Seeders = h.Seeders
        nas := func() interface {
                encoding.BinaryUnmarshaler
                NodeAddrs() []krpc.NodeAddr
        } {
-               if c.ipv6() {
+               if c.ipv6(conn) {
                        return &krpc.CompactIPv6NodeAddrs{}
                } else {
                        return &krpc.CompactIPv4NodeAddrs{}
                }
        }()
-       err = nas.UnmarshalBinary(b.Bytes())
+       h, err := cl.Announce(c.a.Context, req, nas, udp.Options{RequestUri: c.url.RequestURI()})
        if err != nil {
                return
        }
+       res.Interval = h.Interval
+       res.Leechers = h.Leechers
+       res.Seeders = h.Seeders
        for _, cp := range nas.NodeAddrs() {
                res.Peers = append(res.Peers, Peer{}.FromNodeAddr(cp))
        }
        return
 }
 
-// body is the binary serializable request body. trailer is optional data
-// following it, such as for BEP 41.
-func (c *udpAnnounce) write(h *RequestHeader, body interface{}, trailer []byte) (err error) {
-       var buf bytes.Buffer
-       err = binary.Write(&buf, binary.BigEndian, h)
-       if err != nil {
-               panic(err)
-       }
-       if body != nil {
-               err = binary.Write(&buf, binary.BigEndian, body)
-               if err != nil {
-                       panic(err)
-               }
-       }
-       _, err = buf.Write(trailer)
-       if err != nil {
-               return
-       }
-       n, err := c.socket.Write(buf.Bytes())
-       if err != nil {
-               return
-       }
-       if n != buf.Len() {
-               panic("write should send all or error")
-       }
-       return
-}
-
-func read(r io.Reader, data interface{}) error {
-       return binary.Read(r, binary.BigEndian, data)
-}
-
-func write(w io.Writer, data interface{}) error {
-       return binary.Write(w, binary.BigEndian, data)
-}
-
-// args is the binary serializable request body. trailer is optional data
-// following it, such as for BEP 41.
-func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (*bytes.Buffer, error) {
-       tid := newTransactionId()
-       if err := errors.Wrap(
-               c.write(
-                       &RequestHeader{
-                               ConnectionId:  c.connectionId,
-                               Action:        action,
-                               TransactionId: tid,
-                       }, args, options),
-               "writing request",
-       ); err != nil {
-               return nil, err
-       }
-       c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts)))
-       b := make([]byte, 0x800) // 2KiB
-       for {
-               var (
-                       n        int
-                       readErr  error
-                       readDone = make(chan struct{})
-               )
-               go func() {
-                       defer close(readDone)
-                       n, readErr = c.socket.Read(b)
-               }()
-               ctx := c.a.Context
-               if ctx == nil {
-                       ctx = context.Background()
-               }
-               select {
-               case <-ctx.Done():
-                       return nil, ctx.Err()
-               case <-readDone:
-               }
-               if opE, ok := readErr.(*net.OpError); ok && opE.Timeout() {
-                       c.contiguousTimeouts++
-               }
-               if readErr != nil {
-                       return nil, errors.Wrap(readErr, "reading from socket")
-               }
-               buf := bytes.NewBuffer(b[:n])
-               var h ResponseHeader
-               err := binary.Read(buf, binary.BigEndian, &h)
-               switch err {
-               default:
-                       panic(err)
-               case io.ErrUnexpectedEOF, io.EOF:
-                       continue
-               case nil:
-               }
-               if h.TransactionId != tid {
-                       continue
-               }
-               c.contiguousTimeouts = 0
-               if h.Action == ActionError {
-                       err = errors.New(buf.String())
-               }
-               return buf, err
-       }
-}
-
-func readBody(r io.Reader, data ...interface{}) (err error) {
-       for _, datum := range data {
-               err = binary.Read(r, binary.BigEndian, datum)
-               if err != nil {
-                       break
-               }
-       }
-       return
-}
-
-func (c *udpAnnounce) connected() bool {
-       return !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute))
-}
-
 func (c *udpAnnounce) dialNetwork() string {
        if c.a.UdpNetwork != "" {
                return c.a.UdpNetwork
@@ -272,40 +85,7 @@ func (c *udpAnnounce) dialNetwork() string {
        return "udp"
 }
 
-func (c *udpAnnounce) connect() (err error) {
-       if c.connected() {
-               return nil
-       }
-       c.connectionId = connectRequestConnectionId
-       if c.socket == nil {
-               hmp := missinggo.SplitHostMaybePort(c.url.Host)
-               if hmp.NoPort {
-                       hmp.NoPort = false
-                       hmp.Port = 80
-               }
-               c.socket, err = net.Dial(c.dialNetwork(), hmp.String())
-               if err != nil {
-                       return
-               }
-               c.socket = pproffd.WrapNetConn(c.socket)
-       }
-       vars.Add("udp tracker connects", 1)
-       b, err := c.request(ActionConnect, nil, nil)
-       if err != nil {
-               return
-       }
-       var res ConnectionResponse
-       err = readBody(b, &res)
-       if err != nil {
-               return
-       }
-       c.connectionId = res.ConnectionId
-       c.connectionIdReceived = time.Now()
-       return
-}
-
-// TODO: Split on IPv6, as BEP 15 says response peer decoding depends on
-// network in use.
+// TODO: Split on IPv6, as BEP 15 says response peer decoding depends on network in use.
 func announceUDP(opt Announce, _url *url.URL) (AnnounceResponse, error) {
        ua := udpAnnounce{
                url: *_url,
diff --git a/tracker/udp/announce.go b/tracker/udp/announce.go
new file mode 100644 (file)
index 0000000..1573c27
--- /dev/null
@@ -0,0 +1,35 @@
+package udp
+
+import (
+       "encoding"
+
+       "github.com/anacrolix/dht/v2/krpc"
+)
+
+// Marshalled as binary by the UDP client, so be careful making changes.
+type AnnounceRequest struct {
+       InfoHash   [20]byte
+       PeerId     [20]byte
+       Downloaded int64
+       Left       int64 // If less than 0, math.MaxInt64 will be used for HTTP trackers instead.
+       Uploaded   int64
+       // Apparently this is optional. None can be used for announces done at
+       // regular intervals.
+       Event     AnnounceEvent
+       IPAddress uint32
+       Key       int32
+       NumWant   int32 // How many peer addresses are desired. -1 for default.
+       Port      uint16
+} // 82 bytes
+
+type AnnounceEvent int32
+
+func (e AnnounceEvent) String() string {
+       // See BEP 3, "event", and https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001.
+       return []string{"", "completed", "started", "stopped"}[e]
+}
+
+type AnnounceResponsePeers interface {
+       encoding.BinaryUnmarshaler
+       NodeAddrs() []krpc.NodeAddr
+}
diff --git a/tracker/udp/client.go b/tracker/udp/client.go
new file mode 100644 (file)
index 0000000..54099ff
--- /dev/null
@@ -0,0 +1,132 @@
+package udp
+
+import (
+       "bytes"
+       "context"
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "io"
+       "time"
+)
+
+type Client struct {
+       connId       ConnectionId
+       connIdIssued time.Time
+       Dispatcher   *Dispatcher
+       Writer       io.Writer
+}
+
+func (cl *Client) Announce(
+       ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
+) (
+       respHdr AnnounceResponseHeader, err error,
+) {
+       body, err := marshal(req)
+       if err != nil {
+               return
+       }
+       respBody, err := cl.request(ctx, ActionAnnounce, append(body, opts.Encode()...))
+       if err != nil {
+               return
+       }
+       r := bytes.NewBuffer(respBody)
+       err = Read(r, &respHdr)
+       if err != nil {
+               err = fmt.Errorf("reading response header: %w", err)
+               return
+       }
+       err = peers.UnmarshalBinary(r.Bytes())
+       if err != nil {
+               err = fmt.Errorf("reading response peers: %w", err)
+       }
+       return
+}
+
+func (cl *Client) connect(ctx context.Context) (err error) {
+       if time.Since(cl.connIdIssued) < time.Minute {
+               return nil
+       }
+       respBody, err := cl.request(ctx, ActionConnect, nil)
+       if err != nil {
+               return err
+       }
+       var connResp ConnectionResponse
+       err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
+       if err != nil {
+               return
+       }
+       cl.connId = connResp.ConnectionId
+       cl.connIdIssued = time.Now()
+       return
+}
+
+func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
+       if action == ActionConnect {
+               id = ConnectRequestConnectionId
+               return
+       }
+       err = cl.connect(ctx)
+       if err != nil {
+               return
+       }
+       id = cl.connId
+       return
+}
+
+func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
+       var buf bytes.Buffer
+       for n := 0; ; n++ {
+               var connId ConnectionId
+               connId, err = cl.connIdForRequest(ctx, action)
+               if err != nil {
+                       return
+               }
+               buf.Reset()
+               err = binary.Write(&buf, binary.BigEndian, RequestHeader{
+                       ConnectionId:  connId,
+                       Action:        action,
+                       TransactionId: tId,
+               })
+               if err != nil {
+                       panic(err)
+               }
+               buf.Write(body)
+               _, err = cl.Writer.Write(buf.Bytes())
+               if err != nil {
+                       return
+               }
+               select {
+               case <-ctx.Done():
+                       return ctx.Err()
+               case <-time.After(timeout(n)):
+               }
+       }
+}
+
+func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
+       respChan := make(chan DispatchedResponse, 1)
+       t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
+               respChan <- dr
+       })
+       defer t.End()
+       writeErr := make(chan error, 1)
+       go func() {
+               writeErr <- cl.requestWriter(ctx, action, body, t.Id())
+       }()
+       select {
+       case dr := <-respChan:
+               if dr.Header.Action == action {
+                       respBody = dr.Body
+               } else if dr.Header.Action == ActionError {
+                       err = errors.New(string(dr.Body))
+               } else {
+                       err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
+               }
+       case err = <-writeErr:
+               err = fmt.Errorf("write error: %w", err)
+       case <-ctx.Done():
+               err = ctx.Err()
+       }
+       return
+}
diff --git a/tracker/udp/dispatcher.go b/tracker/udp/dispatcher.go
new file mode 100644 (file)
index 0000000..907eb15
--- /dev/null
@@ -0,0 +1,64 @@
+package udp
+
+import (
+       "bytes"
+       "fmt"
+       "sync"
+)
+
+type Dispatcher struct {
+       mu           sync.RWMutex
+       transactions map[TransactionId]Transaction
+}
+
+func (me *Dispatcher) Dispatch(b []byte) error {
+       buf := bytes.NewBuffer(b)
+       var rh ResponseHeader
+       err := Read(buf, &rh)
+       if err != nil {
+               return err
+       }
+       me.mu.RLock()
+       defer me.mu.RUnlock()
+       if t, ok := me.transactions[rh.TransactionId]; ok {
+               t.h(DispatchedResponse{
+                       Header: rh,
+                       Body:   buf.Bytes(),
+               })
+               return nil
+       } else {
+               return fmt.Errorf("unknown transaction id %v", rh.TransactionId)
+       }
+}
+
+func (me *Dispatcher) forgetTransaction(id TransactionId) {
+       me.mu.Lock()
+       defer me.mu.Unlock()
+       delete(me.transactions, id)
+}
+
+func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction {
+       me.mu.Lock()
+       defer me.mu.Unlock()
+       for {
+               id := RandomTransactionId()
+               if _, ok := me.transactions[id]; ok {
+                       continue
+               }
+               t := Transaction{
+                       d:  me,
+                       h:  h,
+                       id: id,
+               }
+               if me.transactions == nil {
+                       me.transactions = make(map[TransactionId]Transaction)
+               }
+               me.transactions[id] = t
+               return t
+       }
+}
+
+type DispatchedResponse struct {
+       Header ResponseHeader
+       Body   []byte
+}
diff --git a/tracker/udp/options.go b/tracker/udp/options.go
new file mode 100644 (file)
index 0000000..a2c223d
--- /dev/null
@@ -0,0 +1,24 @@
+package udp
+
+import (
+       "math"
+)
+
+type Options struct {
+       RequestUri string
+}
+
+func (opts Options) Encode() (ret []byte) {
+       for {
+               l := len(opts.RequestUri)
+               if l == 0 {
+                       break
+               }
+               if l > math.MaxUint8 {
+                       l = math.MaxUint8
+               }
+               ret = append(append(ret, optionTypeURLData, byte(l)), opts.RequestUri[:l]...)
+               opts.RequestUri = opts.RequestUri[l:]
+       }
+       return
+}
diff --git a/tracker/udp/protocol.go b/tracker/udp/protocol.go
new file mode 100644 (file)
index 0000000..365d3c5
--- /dev/null
@@ -0,0 +1,69 @@
+package udp
+
+import (
+       "bytes"
+       "encoding/binary"
+       "io"
+)
+
+type Action int32
+
+const (
+       ActionConnect Action = iota
+       ActionAnnounce
+       ActionScrape
+       ActionError
+
+       ConnectRequestConnectionId = 0x41727101980
+
+       // BEP 41
+       optionTypeEndOfOptions = 0
+       optionTypeNOP          = 1
+       optionTypeURLData      = 2
+)
+
+type TransactionId = int32
+
+type ConnectionId = int64
+
+type ConnectionRequest struct {
+       ConnectionId  ConnectionId
+       Action        Action
+       TransactionId TransactionId
+}
+
+type ConnectionResponse struct {
+       ConnectionId ConnectionId
+}
+
+type ResponseHeader struct {
+       Action        Action
+       TransactionId TransactionId
+}
+
+type RequestHeader struct {
+       ConnectionId  ConnectionId
+       Action        Action
+       TransactionId TransactionId
+} // 16 bytes
+
+type AnnounceResponseHeader struct {
+       Interval int32
+       Leechers int32
+       Seeders  int32
+}
+
+func marshal(data interface{}) (b []byte, err error) {
+       var buf bytes.Buffer
+       err = binary.Write(&buf, binary.BigEndian, data)
+       b = buf.Bytes()
+       return
+}
+
+func Write(w io.Writer, data interface{}) error {
+       return binary.Write(w, binary.BigEndian, data)
+}
+
+func Read(r io.Reader, data interface{}) error {
+       return binary.Read(r, binary.BigEndian, data)
+}
diff --git a/tracker/udp/timeout.go b/tracker/udp/timeout.go
new file mode 100644 (file)
index 0000000..b5e1832
--- /dev/null
@@ -0,0 +1,18 @@
+package udp
+
+import (
+       "time"
+)
+
+const maxTimeout = 3840 * time.Second
+
+func timeout(contiguousTimeouts int) (d time.Duration) {
+       if contiguousTimeouts > 8 {
+               contiguousTimeouts = 8
+       }
+       d = 15 * time.Second
+       for ; contiguousTimeouts > 0; contiguousTimeouts-- {
+               d *= 2
+       }
+       return
+}
diff --git a/tracker/udp/timeout_test.go b/tracker/udp/timeout_test.go
new file mode 100644 (file)
index 0000000..4bb0dc8
--- /dev/null
@@ -0,0 +1,15 @@
+package udp
+
+import (
+       "math"
+       "testing"
+
+       qt "github.com/frankban/quicktest"
+)
+
+func TestTimeoutMax(t *testing.T) {
+       c := qt.New(t)
+       c.Check(timeout(8), qt.Equals, maxTimeout)
+       c.Check(timeout(9), qt.Equals, maxTimeout)
+       c.Check(timeout(math.MaxInt32), qt.Equals, maxTimeout)
+}
diff --git a/tracker/udp/transaction.go b/tracker/udp/transaction.go
new file mode 100644 (file)
index 0000000..2018b35
--- /dev/null
@@ -0,0 +1,23 @@
+package udp
+
+import "math/rand"
+
+func RandomTransactionId() TransactionId {
+       return TransactionId(rand.Uint32())
+}
+
+type TransactionResponseHandler func(dr DispatchedResponse)
+
+type Transaction struct {
+       id int32
+       d  *Dispatcher
+       h  TransactionResponseHandler
+}
+
+func (t *Transaction) Id() TransactionId {
+       return t.id
+}
+
+func (t *Transaction) End() {
+       t.d.forgetTransaction(t.id)
+}
index d33550f14e80506f439ccf1345818815d7f90084..39afa80c8cf9750a9ebdabc31d2cbd971b383f35 100644 (file)
@@ -5,6 +5,7 @@ import (
        "context"
        "crypto/rand"
        "encoding/binary"
+       "errors"
        "fmt"
        "io"
        "io/ioutil"
@@ -12,10 +13,11 @@ import (
        "net/url"
        "sync"
        "testing"
+       "time"
 
        "github.com/anacrolix/dht/v2/krpc"
        _ "github.com/anacrolix/envpprof"
-       "github.com/pkg/errors"
+       "github.com/anacrolix/torrent/tracker/udp"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
 )
@@ -47,7 +49,7 @@ func TestMarshalAnnounceResponse(t *testing.T) {
        require.EqualValues(t,
                "\x7f\x00\x00\x01\x00\x02\xff\x00\x00\x03\x00\x04",
                b)
-       require.EqualValues(t, 12, binary.Size(AnnounceResponseHeader{}))
+       require.EqualValues(t, 12, binary.Size(udp.AnnounceResponseHeader{}))
 }
 
 // Failure to write an entire packet to UDP is expected to given an error.
@@ -74,7 +76,7 @@ func TestLongWriteUDP(t *testing.T) {
 }
 
 func TestShortBinaryRead(t *testing.T) {
-       var data ResponseHeader
+       var data udp.ResponseHeader
        err := binary.Read(bytes.NewBufferString("\x00\x00\x00\x01"), binary.BigEndian, &data)
        if err != io.ErrUnexpectedEOF {
                t.FailNow()
@@ -137,12 +139,20 @@ func TestUDPTracker(t *testing.T) {
        }
        rand.Read(req.PeerId[:])
        copy(req.InfoHash[:], []uint8{0xa3, 0x56, 0x41, 0x43, 0x74, 0x23, 0xe6, 0x26, 0xd9, 0x38, 0x25, 0x4a, 0x6b, 0x80, 0x49, 0x10, 0xa6, 0x67, 0xa, 0xc1})
+       var ctx context.Context
+       if dl, ok := t.Deadline(); ok {
+               var cancel func()
+               ctx, cancel = context.WithDeadline(context.Background(), dl.Add(-time.Second))
+               defer cancel()
+       }
        ar, err := Announce{
                TrackerUrl: trackers[0],
                Request:    req,
+               Context:    ctx,
        }.Do()
        // Skip any net errors as we don't control the server.
-       if _, ok := errors.Cause(err).(net.Error); ok {
+       var ne net.Error
+       if errors.As(err, &ne) {
                t.Skip(err)
        }
        require.NoError(t, err)
@@ -163,6 +173,12 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
        rand.Read(req.InfoHash[:])
        wg := sync.WaitGroup{}
        ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       if dl, ok := t.Deadline(); ok {
+               var cancel func()
+               ctx, cancel = context.WithDeadline(ctx, dl.Add(-time.Second))
+               defer cancel()
+       }
        for _, url := range trackers {
                wg.Add(1)
                go func(url string) {
@@ -196,6 +212,7 @@ func TestURLPathOption(t *testing.T) {
                panic(err)
        }
        defer conn.Close()
+       announceErr := make(chan error)
        go func() {
                _, err := Announce{
                        TrackerUrl: (&url.URL{
@@ -204,34 +221,35 @@ func TestURLPathOption(t *testing.T) {
                                Path:   "/announce",
                        }).String(),
                }.Do()
-               if err != nil {
-                       defer conn.Close()
-               }
-               require.NoError(t, err)
+               defer conn.Close()
+               announceErr <- err
        }()
        var b [512]byte
        _, addr, _ := conn.ReadFrom(b[:])
        r := bytes.NewReader(b[:])
-       var h RequestHeader
-       read(r, &h)
+       var h udp.RequestHeader
+       udp.Read(r, &h)
        w := &bytes.Buffer{}
-       write(w, ResponseHeader{
+       udp.Write(w, udp.ResponseHeader{
+               Action:        udp.ActionConnect,
                TransactionId: h.TransactionId,
        })
-       write(w, ConnectionResponse{42})
+       udp.Write(w, udp.ConnectionResponse{42})
        conn.WriteTo(w.Bytes(), addr)
        n, _, _ := conn.ReadFrom(b[:])
        r = bytes.NewReader(b[:n])
-       read(r, &h)
-       read(r, &AnnounceRequest{})
+       udp.Read(r, &h)
+       udp.Read(r, &AnnounceRequest{})
        all, _ := ioutil.ReadAll(r)
        if string(all) != "\x02\x09/announce" {
                t.FailNow()
        }
        w = &bytes.Buffer{}
-       write(w, ResponseHeader{
+       udp.Write(w, udp.ResponseHeader{
+               Action:        udp.ActionAnnounce,
                TransactionId: h.TransactionId,
        })
-       write(w, AnnounceResponseHeader{})
+       udp.Write(w, udp.AnnounceResponseHeader{})
        conn.WriteTo(w.Bytes(), addr)
+       require.NoError(t, <-announceErr)
 }