]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Fixes and tests for UDP tracker protocol
authorMatt Joiner <anacrolix@gmail.com>
Sat, 14 Dec 2013 11:21:45 +0000 (22:21 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 14 Dec 2013 11:21:45 +0000 (22:21 +1100)
tracker/tracker.go
tracker/udp/udp_tracker.go
tracker/udp/udp_tracker_test.go

index 576353b51cad1b9512994e2e213717180879efaf..5ce7500c3da0ed1f564414d6a09e395142190620 100644 (file)
@@ -1,6 +1,7 @@
 package tracker
 
 import (
+       "errors"
        "net"
        "net/url"
 )
@@ -34,17 +35,33 @@ type Peer struct {
 
 const (
        None AnnounceEvent = iota
+       Completed
+       Started
+       Stopped
 )
 
 type Client interface {
        Announce(*AnnounceRequest) (AnnounceResponse, error)
+       Connect() error
 }
 
-var schemes = make(map[string]func(*url.URL) Client)
+var (
+       ErrNotConnected = errors.New("not connected")
+       ErrBadScheme    = errors.New("unknown scheme")
+
+       schemes = make(map[string]func(*url.URL) Client)
+)
 
 func RegisterClientScheme(scheme string, newFunc func(*url.URL) Client) {
+       schemes[scheme] = newFunc
 }
 
-func New(url *url.URL) Client {
-       return schemes[url.Scheme](url)
+func New(url *url.URL) (cl Client, err error) {
+       newFunc, ok := schemes[url.Scheme]
+       if !ok {
+               err = ErrBadScheme
+               return
+       }
+       cl = newFunc(url)
+       return
 }
index 1bb57b37c33e7ec6b2c615732f2b5e37d3208ad7..71d798300ebe8e15ff6c1533c92e4dba47e5d672 100644 (file)
@@ -4,6 +4,7 @@ import (
        "bitbucket.org/anacrolix/go.torrent/tracker"
        "bytes"
        "encoding/binary"
+       "errors"
        "io"
        "math/rand"
        "net"
@@ -57,7 +58,9 @@ func init() {
 }
 
 func newClient(url *url.URL) tracker.Client {
-       return &client{}
+       return &client{
+               url: url,
+       }
 }
 
 func newTransactionId() int32 {
@@ -80,11 +83,12 @@ type client struct {
        connectionIdReceived time.Time
        connectionId         int64
        socket               net.Conn
+       url                  *url.URL
 }
 
 func (c *client) Announce(req *tracker.AnnounceRequest) (res tracker.AnnounceResponse, err error) {
-       err = c.connect()
-       if err != nil {
+       if !c.connected() {
+               err = tracker.ErrNotConnected
                return
        }
        b, err := c.request(Announce, req)
@@ -124,9 +128,11 @@ func (c *client) write(h *RequestHeader, body interface{}) (err error) {
        if err != nil {
                panic(err)
        }
-       err = binary.Write(buf, binary.BigEndian, body)
-       if err != nil {
-               panic(err)
+       if body != nil {
+               err = binary.Write(buf, binary.BigEndian, body)
+               if err != nil {
+                       panic(err)
+               }
        }
        n, err := c.socket.Write(buf.Bytes())
        if err != nil {
@@ -172,9 +178,6 @@ func (c *client) request(action Action, args interface{}) (responseBody *bytes.R
                default:
                        return
                }
-               if h.Action != action {
-                       continue
-               }
                if h.TransactionId != tid {
                        continue
                }
@@ -197,11 +200,21 @@ func readBody(r *bytes.Reader, data ...interface{}) (err error) {
        return
 }
 
-func (c *client) connect() (err error) {
-       if !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute)) {
+func (c *client) connected() bool {
+       return !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute))
+}
+
+func (c *client) Connect() (err error) {
+       if c.connected() {
                return nil
        }
        c.connectionId = 0x41727101980
+       if c.socket == nil {
+               c.socket, err = net.Dial("udp", c.url.Host)
+               if err != nil {
+                       return
+               }
+       }
        b, err := c.request(Connect, nil)
        if err != nil {
                return
index 4a451ddb2b6df667f50eccc8742d04a3182a9d69..c9ba04c2ed464f6f1587af6d536b6896d0f332ca 100644 (file)
@@ -1,10 +1,14 @@
 package udp_tracker
 
 import (
+       "bitbucket.org/anacrolix/go.torrent/tracker"
        "bytes"
+       "crypto/rand"
        "encoding/binary"
+       "encoding/hex"
        "io"
        "net"
+       "net/url"
        "syscall"
        "testing"
 )
@@ -78,3 +82,36 @@ func TestConvertInt16ToInt(t *testing.T) {
                t.FailNow()
        }
 }
+
+func TestUDPTracker(t *testing.T) {
+       tr, err := tracker.New(func() *url.URL {
+               u, err := url.Parse("udp://tracker.openbittorrent.com:80/announce")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               return u
+       }())
+       if err != nil {
+               t.Fatal(err)
+       }
+       if err := tr.Connect(); err != nil {
+               t.Fatal(err)
+       }
+       req := tracker.AnnounceRequest{
+               NumWant: -1,
+               Event:   tracker.Started,
+       }
+       rand.Read(req.PeerId[:])
+       n, err := hex.Decode(req.InfoHash[:], []byte("c833bb2b5e7bcb9c07f4c020b4be430c28ba7cdb"))
+       if err != nil {
+               t.Fatal(err)
+       }
+       if n != len(req.InfoHash) {
+               panic("nope")
+       }
+       resp, err := tr.Announce(&req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       t.Log(resp)
+}