--- /dev/null
+package udp_tracker
+
+import (
+ "bitbucket.org/anacrolix/go.torrent/tracker"
+ "bytes"
+ "encoding/binary"
+ "io"
+ "math/rand"
+ "net"
+ "net/url"
+ "time"
+)
+
+type Action int32
+
+const (
+ Connect Action = iota
+ Announce
+ Scrape
+ Error
+)
+
+type ConnectionRequest struct {
+ ConnectionId int64
+ Action int32
+ TransctionId int32
+}
+
+type ConnectionResponse struct {
+ ConnectionId int64
+}
+
+type ResponseHeader struct {
+ Action Action
+ TransactionId int32
+}
+
+type RequestHeader struct {
+ ConnectionId int64
+ Action Action
+ TransactionId int32
+}
+
+type AnnounceResponseHeader struct {
+ Interval int32
+ Leechers int32
+ Seeders int32
+}
+
+type Peer struct {
+ IP [4]byte
+ Port uint16
+}
+
+func init() {
+ tracker.RegisterClientScheme("udp", newClient)
+}
+
+func newClient(url *url.URL) tracker.Client {
+ return &client{}
+}
+
+func newTransactionId() int32 {
+ return int32(rand.Uint32())
+}
+
+func timeout(contiguousTimeouts int) (d time.Duration) {
+ if contiguousTimeouts > 8 {
+ contiguousTimeouts = 8
+ }
+ d = 15 * time.Second
+ for ; contiguousTimeouts > 0; contiguousTimeouts-- {
+ d *= 2
+ }
+ return
+}
+
+type client struct {
+ contiguousTimeouts int
+ connectionIdReceived time.Time
+ connectionId int64
+ socket net.Conn
+}
+
+func (c *client) Announce(req *tracker.AnnounceRequest) (res tracker.AnnounceResponse, err error) {
+ err = c.connect()
+ if err != nil {
+ return
+ }
+ b, err := c.request(Announce, req)
+ if err != nil {
+ return
+ }
+ var (
+ h AnnounceResponseHeader
+ ps []Peer
+ )
+ err = readBody(b, &h, &ps)
+ if err != nil {
+ return
+ }
+ res.Interval = h.Interval
+ res.Leechers = h.Leechers
+ res.Seeders = h.Seeders
+ for _, p := range ps {
+ res.Peers = append(res.Peers, tracker.Peer{
+ IP: p.IP[:],
+ Port: int(p.Port),
+ })
+ }
+ return
+}
+
+func (c *client) write(h *RequestHeader, body interface{}) (err error) {
+ buf := &bytes.Buffer{}
+ err = binary.Write(buf, binary.BigEndian, h)
+ if err != nil {
+ panic(err)
+ }
+ err = binary.Write(buf, binary.BigEndian, body)
+ if err != nil {
+ panic(err)
+ }
+ n, err := c.socket.Write(buf.Bytes())
+ if err != nil {
+ return
+ }
+ if n != buf.Len() {
+ panic("write should send all or error")
+ }
+ return
+}
+
+func (c *client) request(action Action, args interface{}) (responseBody []byte, err error) {
+ tid := newTransactionId()
+ err = c.write(&RequestHeader{
+ ConnectionId: c.connectionId,
+ Action: action,
+ TransactionId: tid,
+ }, args)
+ if err != nil {
+ return
+ }
+ c.socket.SetDeadline(time.Now().Add(timeout(c.contiguousTimeouts)))
+ b := make([]byte, 0x10000) // IP limits packet size to 64KB
+ for {
+ var n int
+ n, err = c.socket.Read(b)
+ if opE, ok := err.(*net.OpError); ok {
+ if opE.Timeout() {
+ c.contiguousTimeouts++
+ return
+ }
+ }
+ if err != nil {
+ return
+ }
+ buf := bytes.NewBuffer(b[:n])
+ var h ResponseHeader
+ err = binary.Read(buf, binary.BigEndian, &h)
+ switch err {
+ case io.ErrUnexpectedEOF:
+ continue
+ case nil:
+ default:
+ return
+ }
+ if h.Action != action {
+ continue
+ }
+ if h.TransactionId != tid {
+ continue
+ }
+ c.contiguousTimeouts = 0
+ responseBody = buf.Bytes()
+ return
+ }
+}
+
+func readBody(b []byte, data ...interface{}) (err error) {
+ r := bytes.NewReader(b)
+ for _, datum := range data {
+ err = binary.Read(r, binary.BigEndian, datum)
+ if err != nil {
+ break
+ }
+ }
+ return
+}
+
+func (c *client) connect() (err error) {
+ if !c.connectionIdReceived.IsZero() && time.Now().Before(c.connectionIdReceived.Add(time.Minute)) {
+ return nil
+ }
+ c.connectionId = 0x41727101980
+ b, err := c.request(Connect, nil)
+ if err != nil {
+ return
+ }
+ var res ConnectionResponse
+ err = readBody(b, &res)
+ if err != nil {
+ return
+ }
+ c.connectionId = res.ConnectionId
+ c.connectionIdReceived = time.Now()
+ return
+}
--- /dev/null
+package udp_tracker
+
+import (
+ "bytes"
+ "encoding/binary"
+ "io"
+ "net"
+ "syscall"
+ "testing"
+)
+
+func TestNetIPv4Bytes(t *testing.T) {
+ ip := net.IP([]byte{127, 0, 0, 1})
+ if ip.String() != "127.0.0.1" {
+ t.FailNow()
+ }
+ if string(ip) != "\x7f\x00\x00\x01" {
+ t.Fatal([]byte(ip))
+ }
+}
+
+func TestMarshalAnnounceResponse(t *testing.T) {
+ w := bytes.NewBuffer(nil)
+ if err := binary.Write(w, binary.BigEndian, []Peer{{[4]byte{127, 0, 0, 1}, 2}, {[4]byte{255, 0, 0, 3}, 4}}); err != nil {
+ t.Fatalf("error writing udp announce response addrs: %s", err)
+ }
+ if w.String() != "\x7f\x00\x00\x01\x00\x02\xff\x00\x00\x03\x00\x04" {
+ t.FailNow()
+ }
+ if binary.Size(AnnounceResponseHeader{}) != 12 {
+ t.FailNow()
+ }
+}
+
+// Failure to write an entire packet to UDP is expected to given an error.
+func TestLongWriteUDP(t *testing.T) {
+ l, err := net.ListenUDP("udp", nil)
+ defer l.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := net.DialUDP("udp", nil, l.LocalAddr().(*net.UDPAddr))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ for msgLen := 1; ; msgLen *= 2 {
+ n, err := c.Write(make([]byte, msgLen))
+ if err != nil {
+ err := err.(*net.OpError).Err
+ if err != syscall.EMSGSIZE {
+ t.Fatalf("write error isn't EMSGSIZE: %s", err)
+ }
+ return
+ }
+ if n < msgLen {
+ t.FailNow()
+ }
+ }
+}
+
+func TestShortBinaryRead(t *testing.T) {
+ var data ResponseHeader
+ err := binary.Read(bytes.NewBufferString("\x00\x00\x00\x01"), binary.BigEndian, &data)
+ if data.Action != 0 {
+ t.Log("optimistic binary read now works?!")
+ }
+ switch err {
+ case io.ErrUnexpectedEOF:
+ default:
+ // TODO
+ }
+}
+
+func TestConvertInt16ToInt(t *testing.T) {
+ i := 50000
+ if int(uint16(int16(i))) != 50000 {
+ t.FailNow()
+ }
+}