]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Implement the DHT Port message
authorMatt Joiner <anacrolix@gmail.com>
Mon, 25 Aug 2014 12:12:16 +0000 (22:12 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 25 Aug 2014 12:12:16 +0000 (22:12 +1000)
client.go
peer_protocol/protocol.go
peer_protocol/protocol_test.go

index ecbbcc4ce55ceb5d0b8415a68e4cf9b87ca2fb6a..d1ae457f6ed02720a556bbd7ffb7c0867dc9c570 100644 (file)
--- a/client.go
+++ b/client.go
@@ -54,7 +54,11 @@ var (
        postedCancels               = expvar.NewInt("postedCancels")
 )
 
-const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x00"
+// Justification for set bits follows.
+//
+// Extension protocol: http://www.bittorrent.org/beps/bep_0010.html
+// DHT: http://www.bittorrent.org/beps/bep_0005.html
+const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x01"
 
 // Currently doesn't really queue, but should in the future.
 func (cl *Client) queuePieceCheck(t *torrent, pieceIndex pp.Integer) {
@@ -531,6 +535,13 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS
                        Bitfield: torrent.bitfield(),
                })
        }
+       if conn.PeerExtensionBytes[7]&0x01 != 0 && me.dHT != nil {
+               addr, _ := me.dHT.LocalAddr().(*net.UDPAddr)
+               conn.Post(pp.Message{
+                       Type: pp.Port,
+                       Port: uint16(addr.Port),
+               })
+       }
        err = me.connectionLoop(torrent, conn)
        if err != nil {
                err = fmt.Errorf("during Connection loop with peer %q: %s", conn.PeerID, err)
@@ -860,6 +871,16 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error {
                        if err != nil {
                                log.Printf("peer extension map: %#v", c.PeerExtensionIDs)
                        }
+               case pp.Port:
+                       if me.dHT == nil {
+                               break
+                       }
+                       addr, _ := c.Socket.RemoteAddr().(*net.TCPAddr)
+                       _, err = me.dHT.Ping(&net.UDPAddr{
+                               IP:   addr.IP,
+                               Zone: addr.Zone,
+                               Port: int(msg.Port),
+                       })
                default:
                        err = fmt.Errorf("received unknown message type: %#v", msg.Type)
                }
index db32ac59ce99f31e86dd281b59acda40ba9535ab..607edf5979ab90cd5e38b89f0ff3796ace0d5f0c 100644 (file)
@@ -33,6 +33,7 @@ const (
        Request                   // 6
        Piece                     // 7
        Cancel                    // 8
+       Port                      // 9
        Extended      = 20
 
        HandshakeExtendedID = 0
@@ -50,6 +51,7 @@ type Message struct {
        Bitfield             []bool
        ExtendedID           byte
        ExtendedPayload      []byte
+       Port                 uint16
 }
 
 func (msg Message) MarshalBinary() (data []byte, err error) {
@@ -92,6 +94,8 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                                return
                        }
                        _, err = buf.Write(msg.ExtendedPayload)
+               case Port:
+                       err = binary.Write(buf, binary.BigEndian, msg.Port)
                default:
                        err = fmt.Errorf("unknown message type: %v", msg.Type)
                }
@@ -187,6 +191,8 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                        break
                }
                msg.ExtendedPayload, err = ioutil.ReadAll(r)
+       case Port:
+               err = binary.Read(r, binary.BigEndian, &msg.Port)
        default:
                err = fmt.Errorf("unknown message type %#v", c)
        }
index d2fc8d4b0a5c4bc3c8be0999d8391e3422f331c0..580d26b2e695a4217d34a6246e07e041f19164db 100644 (file)
@@ -121,3 +121,34 @@ func TestMarshalKeepalive(t *testing.T) {
                t.Fatalf("marshalled keepalive is %q, expected %q", bs, expected)
        }
 }
+
+func TestMarshalPortMsg(t *testing.T) {
+       b, err := (Message{
+               Type: Port,
+               Port: 0xaabb,
+       }).MarshalBinary()
+       if err != nil {
+               t.Fatal(err)
+       }
+       if string(b) != "\x00\x00\x00\x03\x09\xaa\xbb" {
+               t.FailNow()
+       }
+}
+
+func TestUnmarshalPortMsg(t *testing.T) {
+       var m Message
+       d := Decoder{
+               R:         bufio.NewReader(bytes.NewBufferString("\x00\x00\x00\x03\x09\xaa\xbb")),
+               MaxLength: 8,
+       }
+       err := d.Decode(&m)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if m.Type != Port {
+               t.FailNow()
+       }
+       if m.Port != 0xaabb {
+               t.FailNow()
+       }
+}