From: Matt Joiner Date: Sat, 14 Dec 2013 11:21:45 +0000 (+1100) Subject: Fixes and tests for UDP tracker protocol X-Git-Tag: v1.0.0~1780 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=5f093c38033e8d21d116020661be8868cc0b9f98;p=btrtrc.git Fixes and tests for UDP tracker protocol --- diff --git a/tracker/tracker.go b/tracker/tracker.go index 576353b5..5ce7500c 100644 --- a/tracker/tracker.go +++ b/tracker/tracker.go @@ -1,6 +1,7 @@ package tracker import ( + "errors" "net" "net/url" ) @@ -34,17 +35,33 @@ type Peer struct { const ( None AnnounceEvent = iota + Completed + Started + Stopped ) type Client interface { Announce(*AnnounceRequest) (AnnounceResponse, error) + Connect() error } -var schemes = make(map[string]func(*url.URL) Client) +var ( + ErrNotConnected = errors.New("not connected") + ErrBadScheme = errors.New("unknown scheme") + + schemes = make(map[string]func(*url.URL) Client) +) func RegisterClientScheme(scheme string, newFunc func(*url.URL) Client) { + schemes[scheme] = newFunc } -func New(url *url.URL) Client { - return schemes[url.Scheme](url) +func New(url *url.URL) (cl Client, err error) { + newFunc, ok := schemes[url.Scheme] + if !ok { + err = ErrBadScheme + return + } + cl = newFunc(url) + return } diff --git a/tracker/udp/udp_tracker.go b/tracker/udp/udp_tracker.go index 1bb57b37..71d79830 100644 --- a/tracker/udp/udp_tracker.go +++ b/tracker/udp/udp_tracker.go @@ -4,6 +4,7 @@ import ( "bitbucket.org/anacrolix/go.torrent/tracker" "bytes" "encoding/binary" + "errors" "io" "math/rand" "net" @@ -57,7 +58,9 @@ func init() { } func newClient(url *url.URL) tracker.Client { - return &client{} + return &client{ + url: url, + } } func newTransactionId() int32 { @@ -80,11 +83,12 @@ type client struct { connectionIdReceived time.Time connectionId int64 socket net.Conn + url *url.URL } func (c *client) Announce(req *tracker.AnnounceRequest) (res tracker.AnnounceResponse, err error) { - err = c.connect() - if err != nil { + if !c.connected() { + err = tracker.ErrNotConnected return } b, err := c.request(Announce, req) @@ -124,9 +128,11 @@ func (c *client) write(h *RequestHeader, body interface{}) (err error) { if err != nil { panic(err) } - err = binary.Write(buf, binary.BigEndian, body) - if err != nil { - panic(err) + if body != nil { + err = binary.Write(buf, binary.BigEndian, body) + if err != nil { + panic(err) + } } n, err := c.socket.Write(buf.Bytes()) if err != nil { @@ -172,9 +178,6 @@ func (c *client) request(action Action, args interface{}) (responseBody *bytes.R default: return } - if h.Action != action { - continue - } if h.TransactionId != tid { continue } @@ -197,11 +200,21 @@ func readBody(r *bytes.Reader, data ...interface{}) (err error) { return } -func (c *client) connect() (err error) { - if !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute)) { +func (c *client) connected() bool { + return !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute)) +} + +func (c *client) Connect() (err error) { + if c.connected() { return nil } c.connectionId = 0x41727101980 + if c.socket == nil { + c.socket, err = net.Dial("udp", c.url.Host) + if err != nil { + return + } + } b, err := c.request(Connect, nil) if err != nil { return diff --git a/tracker/udp/udp_tracker_test.go b/tracker/udp/udp_tracker_test.go index 4a451ddb..c9ba04c2 100644 --- a/tracker/udp/udp_tracker_test.go +++ b/tracker/udp/udp_tracker_test.go @@ -1,10 +1,14 @@ package udp_tracker import ( + "bitbucket.org/anacrolix/go.torrent/tracker" "bytes" + "crypto/rand" "encoding/binary" + "encoding/hex" "io" "net" + "net/url" "syscall" "testing" ) @@ -78,3 +82,36 @@ func TestConvertInt16ToInt(t *testing.T) { t.FailNow() } } + +func TestUDPTracker(t *testing.T) { + tr, err := tracker.New(func() *url.URL { + u, err := url.Parse("udp://tracker.openbittorrent.com:80/announce") + if err != nil { + t.Fatal(err) + } + return u + }()) + if err != nil { + t.Fatal(err) + } + if err := tr.Connect(); err != nil { + t.Fatal(err) + } + req := tracker.AnnounceRequest{ + NumWant: -1, + Event: tracker.Started, + } + rand.Read(req.PeerId[:]) + n, err := hex.Decode(req.InfoHash[:], []byte("c833bb2b5e7bcb9c07f4c020b4be430c28ba7cdb")) + if err != nil { + t.Fatal(err) + } + if n != len(req.InfoHash) { + panic("nope") + } + resp, err := tr.Announce(&req) + if err != nil { + t.Fatal(err) + } + t.Log(resp) +}