]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Make Msg a struct with bencode tags
authorMatt Joiner <anacrolix@gmail.com>
Fri, 23 Oct 2015 01:41:45 +0000 (12:41 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 23 Oct 2015 01:41:45 +0000 (12:41 +1100)
cmd/dht-get-peers/main.go
cmd/dht-ping/main.go
cmd/dht-server/main.go
dht/announce.go
dht/compactNodeInfo.go [new file with mode: 0644]
dht/dht.go
dht/dht_test.go
dht/krpcError.go [new file with mode: 0644]
dht/msg.go [new file with mode: 0644]
util/types.go

index 646104d9b4f3a65d7a850f5ce32bd5e4ab2468ba..f03e9fd1c721bbf1a92fd7143e03d9b427b3f9e0 100644 (file)
@@ -39,7 +39,7 @@ func loadTable() error {
        defer f.Close()
        added := 0
        for {
-               b := make([]byte, dht.CompactNodeInfoLen)
+               b := make([]byte, dht.CompactIPv4NodeInfoLen)
                _, err := io.ReadFull(f, b)
                if err == io.EOF {
                        break
@@ -48,7 +48,7 @@ func loadTable() error {
                        return fmt.Errorf("error reading table file: %s", err)
                }
                var ni dht.NodeInfo
-               err = ni.UnmarshalCompact(b)
+               err = ni.UnmarshalCompactIPv4(b)
                if err != nil {
                        return fmt.Errorf("error unmarshaling compact node info: %s", err)
                }
@@ -101,7 +101,7 @@ func saveTable() error {
        }
        defer f.Close()
        for _, nodeInfo := range goodNodes {
-               var b [dht.CompactNodeInfoLen]byte
+               var b [dht.CompactIPv4NodeInfoLen]byte
                err := nodeInfo.PutCompact(b[:])
                if err != nil {
                        return fmt.Errorf("error compacting node info: %s", err)
index 0588a84819ad91b18beee03b9b010b1ac702c66a..bae5843171a81264b858bb4860b525328f05c3a2 100644 (file)
@@ -68,11 +68,8 @@ pingResponses:
        for _ = range pingStrAddrs {
                select {
                case resp := <-pingResponses:
-                       if resp.krpc == nil {
-                               break
-                       }
                        responses++
-                       fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc["r"].(map[string]interface{})["id"].(string), resp.addr), resp.rtt)
+                       fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc.R.ID, resp.addr), resp.rtt)
                case <-timeoutChan:
                        break pingResponses
                }
index 19799ce8006b57c394cda338db4d779bafd6db17..f0792b8d03c09c19dead682ae7c959717cd47559 100644 (file)
@@ -32,7 +32,7 @@ func loadTable() error {
        defer f.Close()
        added := 0
        for {
-               b := make([]byte, dht.CompactNodeInfoLen)
+               b := make([]byte, dht.CompactIPv4NodeInfoLen)
                _, err := io.ReadFull(f, b)
                if err == io.EOF {
                        break
@@ -41,7 +41,7 @@ func loadTable() error {
                        return fmt.Errorf("error reading table file: %s", err)
                }
                var ni dht.NodeInfo
-               err = ni.UnmarshalCompact(b)
+               err = ni.UnmarshalCompactIPv4(b)
                if err != nil {
                        return fmt.Errorf("error unmarshaling compact node info: %s", err)
                }
@@ -84,7 +84,7 @@ func saveTable() error {
        }
        defer f.Close()
        for _, nodeInfo := range goodNodes {
-               var b [dht.CompactNodeInfoLen]byte
+               var b [dht.CompactIPv4NodeInfoLen]byte
                err := nodeInfo.PutCompact(b[:])
                if err != nil {
                        return fmt.Errorf("error compacting node info: %s", err)
index a7c55868d618a4e881c69350ebd728f260a3be50..f87b76f200f29851f8cf2a40f5c3238ae456fcb6 100644 (file)
@@ -160,36 +160,35 @@ func (me *Announce) getPeers(addr dHTAddr) error {
        }
        t.SetResponseHandler(func(m Msg) {
                // Register suggested nodes closer to the target info-hash.
-               me.mu.Lock()
-               for _, n := range m.Nodes() {
-                       me.responseNode(n)
-               }
-               me.mu.Unlock()
+               if m.R != nil {
+                       me.mu.Lock()
+                       for _, n := range m.R.Nodes {
+                               me.responseNode(n)
+                       }
+                       me.mu.Unlock()
 
-               if vs := m.Values(); vs != nil {
-                       for _, cp := range vs {
-                               if cp.Port == 0 {
-                                       me.server.mu.Lock()
-                                       me.server.badNode(addr)
-                                       me.server.mu.Unlock()
-                                       return
+                       if vs := m.R.Values; len(vs) != 0 {
+                               nodeInfo := NodeInfo{
+                                       Addr: t.remoteAddr,
+                               }
+                               copy(nodeInfo.ID[:], m.SenderID())
+                               select {
+                               case me.values <- PeersValues{
+                                       Peers: func() (ret []Peer) {
+                                               for _, cp := range vs {
+                                                       ret = append(ret, Peer(cp))
+                                               }
+                                               return
+                                       }(),
+                                       NodeInfo: nodeInfo,
+                               }:
+                               case <-me.stop:
                                }
                        }
-                       nodeInfo := NodeInfo{
-                               Addr: t.remoteAddr,
-                       }
-                       copy(nodeInfo.ID[:], m.SenderID())
-                       select {
-                       case me.values <- PeersValues{
-                               Peers:    vs,
-                               NodeInfo: nodeInfo,
-                       }:
-                       case <-me.stop:
-                       }
-               }
 
-               if at, ok := m.AnnounceToken(); ok {
-                       me.announcePeer(addr, at)
+                       if at := m.R.Token; at != "" {
+                               me.announcePeer(addr, at)
+                       }
                }
 
                me.mu.Lock()
diff --git a/dht/compactNodeInfo.go b/dht/compactNodeInfo.go
new file mode 100644 (file)
index 0000000..e12bb0f
--- /dev/null
@@ -0,0 +1,49 @@
+package dht
+
+import (
+       "bytes"
+       "encoding/binary"
+       "errors"
+       "fmt"
+
+       "github.com/anacrolix/torrent/bencode"
+)
+
+type CompactIPv4NodeInfo []NodeInfo
+
+var _ bencode.Unmarshaler = &CompactIPv4NodeInfo{}
+
+func (me *CompactIPv4NodeInfo) UnmarshalBencode(_b []byte) (err error) {
+       var b []byte
+       err = bencode.Unmarshal(_b, &b)
+       if err != nil {
+               return
+       }
+       if len(b)%CompactIPv4NodeInfoLen != 0 {
+               err = fmt.Errorf("bad length: %d", len(b))
+               return
+       }
+       for i := 0; i < len(b); i += CompactIPv4NodeInfoLen {
+               var ni NodeInfo
+               err = ni.UnmarshalCompactIPv4(b[i : i+CompactIPv4NodeInfoLen])
+               if err != nil {
+                       return
+               }
+               *me = append(*me, ni)
+       }
+       return
+}
+
+func (me CompactIPv4NodeInfo) MarshalBencode() (ret []byte, err error) {
+       var buf bytes.Buffer
+       for _, ni := range me {
+               buf.Write(ni.ID[:])
+               if ni.Addr == nil {
+                       err = errors.New("nil addr in node info")
+                       return
+               }
+               buf.Write(ni.Addr.IP().To4())
+               binary.Write(&buf, binary.BigEndian, uint16(ni.Addr.UDPAddr().Port))
+       }
+       return bencode.Marshal(buf.Bytes())
+}
index 518457557b5fe06a1c4d2eed08cbd343d660b7a2..bea3f8ddfe9f1e3fa0c0401e47068a480e51d676 100644 (file)
@@ -28,11 +28,13 @@ import (
        "github.com/anacrolix/torrent/bencode"
        "github.com/anacrolix/torrent/iplist"
        "github.com/anacrolix/torrent/logonce"
-       "github.com/anacrolix/torrent/util"
 )
 
 const (
-       maxNodes         = 320
+       maxNodes = 320
+)
+
+var (
        queryResendEvery = 5 * time.Second
 )
 
@@ -266,105 +268,6 @@ func (n *node) DefinitelyGood() bool {
        return true
 }
 
-// A wrapper around the unmarshalled KRPC dict that constitutes messages in
-// the DHT. There are various helpers for extracting common data from the
-// message. In normal use, Msg is abstracted away for you, but it can be of
-// interest.
-type Msg map[string]interface{}
-
-var _ fmt.Stringer = Msg{}
-
-func (m Msg) String() string {
-       return fmt.Sprintf("%#v", m)
-}
-
-func (m Msg) T() (t string) {
-       tif, ok := m["t"]
-       if !ok {
-               return
-       }
-       t, _ = tif.(string)
-       return
-}
-
-func (m Msg) Args() map[string]interface{} {
-       defer func() {
-               recover()
-       }()
-       return m["a"].(map[string]interface{})
-}
-
-func (m Msg) SenderID() string {
-       defer func() {
-               recover()
-       }()
-       switch m["y"].(string) {
-       case "q":
-               return m.Args()["id"].(string)
-       case "r":
-               return m["r"].(map[string]interface{})["id"].(string)
-       }
-       return ""
-}
-
-// Suggested nodes in a response.
-func (m Msg) Nodes() (nodes []NodeInfo) {
-       b := func() string {
-               defer func() {
-                       recover()
-               }()
-               return m["r"].(map[string]interface{})["nodes"].(string)
-       }()
-       if len(b)%26 != 0 {
-               return
-       }
-       for i := 0; i < len(b); i += 26 {
-               var n NodeInfo
-               err := n.UnmarshalCompact([]byte(b[i : i+26]))
-               if err != nil {
-                       continue
-               }
-               nodes = append(nodes, n)
-       }
-       return
-}
-
-type KRPCError struct {
-       Code int
-       Msg  string
-}
-
-func (me KRPCError) Error() string {
-       return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
-}
-
-var _ error = KRPCError{}
-
-func (m Msg) Error() (ret *KRPCError) {
-       if m["y"] != "e" {
-               return
-       }
-       ret = &KRPCError{}
-       switch e := m["e"].(type) {
-       case []interface{}:
-               ret.Code = int(e[0].(int64))
-               ret.Msg = e[1].(string)
-       case string:
-               ret.Msg = e
-       default:
-               logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
-       }
-       return
-}
-
-// Returns the token given in response to a get_peers request for future
-// announce_peer requests to that node.
-func (m Msg) AnnounceToken() (token string, ok bool) {
-       defer func() { recover() }()
-       token, ok = m["r"].(map[string]interface{})["token"].(string)
-       return
-}
-
 type Transaction struct {
        mu             sync.Mutex
        remoteAddr     dHTAddr
@@ -648,12 +551,12 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) {
        }
        s.mu.Lock()
        defer s.mu.Unlock()
-       if d["y"] == "q" {
+       if d.Y == "q" {
                readQuery.Add(1)
                s.handleQuery(addr, d)
                return
        }
-       t := s.findResponseTransaction(d.T(), addr)
+       t := s.findResponseTransaction(d.T, addr)
        if t == nil {
                //log.Printf("unexpected message: %#v", d)
                return
@@ -722,15 +625,12 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
        if s.config.Passive {
                return
        }
-       args := m.Args()
-       if args == nil {
-               return
-       }
-       switch m["q"] {
+       args := m.A
+       switch m.Q {
        case "ping":
-               s.reply(source, m["t"].(string), nil)
+               s.reply(source, m.T, Return{})
        case "get_peers": // TODO: Extract common behaviour with find_node.
-               targetID := args["info_hash"].(string)
+               targetID := args.InfoHash
                if len(targetID) != 20 {
                        break
                }
@@ -739,19 +639,13 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
                for _, node := range s.closestGoodNodes(8, targetID) {
                        rNodes = append(rNodes, node.NodeInfo())
                }
-               nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
-               for i, ni := range rNodes {
-                       err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
-                       if err != nil {
-                               panic(err)
-                       }
-               }
-               s.reply(source, m["t"].(string), map[string]interface{}{
-                       "nodes": string(nodesBytes),
-                       "token": "hi",
+               s.reply(source, m.T, Return{
+                       Nodes: rNodes,
+                       // TODO: Generate this dynamically, and store it for the source.
+                       Token: "hi",
                })
        case "find_node": // TODO: Extract common behaviour with get_peers.
-               targetID := args["target"].(string)
+               targetID := args.Target
                if len(targetID) != 20 {
                        log.Printf("bad DHT query: %v", m)
                        return
@@ -760,24 +654,13 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
                if node := s.nodeByID(targetID); node != nil {
                        rNodes = append(rNodes, node.NodeInfo())
                } else {
+                       // This will probably cause a crash for IPv6, but meh.
                        for _, node := range s.closestGoodNodes(8, targetID) {
                                rNodes = append(rNodes, node.NodeInfo())
                        }
                }
-               nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
-               for i, ni := range rNodes {
-                       // TODO: Put IPv6 nodes into the correct dict element.
-                       if ni.Addr.UDPAddr().IP.To4() == nil {
-                               continue
-                       }
-                       err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
-                       if err != nil {
-                               log.Printf("error compacting %#v: %s", ni, err)
-                               continue
-                       }
-               }
-               s.reply(source, m["t"].(string), map[string]interface{}{
-                       "nodes": string(nodesBytes),
+               s.reply(source, m.T, Return{
+                       Nodes: rNodes,
                })
        case "announce_peer":
                // TODO(anacrolix): Implement this lolz.
@@ -785,20 +668,17 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
        case "vote":
                // TODO(anacrolix): Or reject, I don't think I want this.
        default:
-               log.Printf("%s: not handling received query: q=%s", s, m["q"])
+               log.Printf("%s: not handling received query: q=%s", s, m.Q)
                return
        }
 }
 
-func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
-       if r == nil {
-               r = make(map[string]interface{}, 1)
-       }
-       r["id"] = s.ID()
-       m := map[string]interface{}{
-               "t": t,
-               "y": "r",
-               "r": r,
+func (s *Server) reply(addr dHTAddr, t string, r Return) {
+       r.ID = s.ID()
+       m := Msg{
+               T: t,
+               Y: "r",
+               R: &r,
        }
        b, err := bencode.Marshal(m)
        if err != nil {
@@ -947,7 +827,7 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onRespo
 }
 
 // The size in bytes of a NodeInfo in its compact binary representation.
-const CompactNodeInfoLen = 26
+const CompactIPv4NodeInfoLen = 26
 
 type NodeInfo struct {
        ID   [20]byte
@@ -971,7 +851,7 @@ func (ni *NodeInfo) PutCompact(b []byte) error {
        return nil
 }
 
-func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
+func (cni *NodeInfo) UnmarshalCompactIPv4(b []byte) error {
        if len(b) != 26 {
                return errors.New("expected 26 bytes")
        }
@@ -1019,10 +899,10 @@ func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token str
 
 // Add response nodes to node table.
 func (s *Server) liftNodes(d Msg) {
-       if d["y"] != "r" {
+       if d.Y != "r" {
                return
        }
-       for _, cni := range d.Nodes() {
+       for _, cni := range d.R.Nodes {
                if missinggo.AddrPort(cni.Addr) == 0 {
                        // TODO: Why would people even do this?
                        continue
@@ -1057,44 +937,6 @@ func (me *Peer) String() string {
        return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10))
 }
 
-// In a get_peers response, the addresses of torrent clients involved with the
-// queried info-hash.
-func (m Msg) Values() (vs []Peer) {
-       v := func() interface{} {
-               defer func() {
-                       recover()
-               }()
-               return m["r"].(map[string]interface{})["values"]
-       }()
-       if v == nil {
-               return
-       }
-       vl, ok := v.([]interface{})
-       if !ok {
-               if missinggo.CryHeard() {
-                       log.Printf(`unexpected krpc "values" field: %#v`, v)
-               }
-               return
-       }
-       vs = make([]Peer, 0, len(vl))
-       for _, i := range vl {
-               s, ok := i.(string)
-               if !ok {
-                       panic(i)
-               }
-               // Because it's a list of strings, we can let the length of the string
-               // determine the IP version of the compact peer.
-               var cp util.CompactPeer
-               err := cp.UnmarshalBinary([]byte(s))
-               if err != nil {
-                       log.Printf("error decoding values list element: %s", err)
-                       continue
-               }
-               vs = append(vs, Peer{cp.IP[:], int(cp.Port)})
-       }
-       return
-}
-
 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) {
        if len(infoHash) != 20 {
                err = fmt.Errorf("infohash has bad length")
@@ -1102,10 +944,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err er
        }
        t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
                s.liftNodes(m)
-               at, ok := m.AnnounceToken()
-               if ok {
-                       s.getNode(addr, m.SenderID()).announceToken = at
-               }
+               s.getNode(addr, m.SenderID()).announceToken = m.R.Token
        })
        return
 }
index f84ccb68f8a9788518fe07acc6b26ffe9ca3ac8c..b8a8897c5b849ec58d8a8b14ed6c17fc452c6029 100644 (file)
@@ -9,6 +9,10 @@ import (
 
        "github.com/anacrolix/missinggo"
        "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+
+       "github.com/anacrolix/torrent/bencode"
+       "github.com/anacrolix/torrent/util"
 )
 
 func TestSetNilBigInt(t *testing.T) {
@@ -25,7 +29,7 @@ func TestMarshalCompactNodeInfo(t *testing.T) {
                t.Fatal(err)
        }
        cni.Addr = newDHTAddr(addr)
-       var b [CompactNodeInfoLen]byte
+       var b [CompactIPv4NodeInfoLen]byte
        cni.PutCompact(b[:])
        if err != nil {
                t.Fatal(err)
@@ -106,14 +110,12 @@ func TestClosestNodes(t *testing.T) {
 }
 
 func TestUnmarshalGetPeersResponse(t *testing.T) {
-       gpr := Msg{
-               "r": map[string]interface{}{
-                       "values": []interface{}{"\x01\x02\x03\x04\x05\x06", "\x07\x08\x09\x0a\x0b\x0c"},
-                       "nodes":  "\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07" + "\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07",
-               },
-       }
-       assert.EqualValues(t, 2, len(gpr.Values()))
-       assert.EqualValues(t, 2, len(gpr.Nodes()))
+       var msg Msg
+       err := bencode.Unmarshal([]byte("d1:rd6:valuesl6:\x01\x02\x03\x04\x05\x066:\x07\x08\x09\x0a\x0b\x0ce5:nodes52:\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07ee"), &msg)
+       require.NoError(t, err)
+       assert.Len(t, msg.R.Values, 2)
+       assert.Len(t, msg.R.Nodes, 2)
+       assert.Nil(t, msg.E)
 }
 
 func TestDHTDefaultConfig(t *testing.T) {
@@ -203,3 +205,74 @@ func TestServerDefaultNodeIdSecure(t *testing.T) {
                t.Fatal("not secure")
        }
 }
+
+func testMarshalUnmarshalMsg(t *testing.T, m Msg, expected string) {
+       b, err := bencode.Marshal(m)
+       require.NoError(t, err)
+       assert.Equal(t, expected, string(b))
+       var _m Msg
+       err = bencode.Unmarshal([]byte(expected), &_m)
+       assert.NoError(t, err)
+       assert.EqualValues(t, m, _m)
+       assert.EqualValues(t, m.R, _m.R)
+}
+
+func TestMarshalUnmarshalMsg(t *testing.T) {
+       testMarshalUnmarshalMsg(t, Msg{}, "d1:t0:1:y0:e")
+       testMarshalUnmarshalMsg(t, Msg{
+               Y: "q",
+               Q: "ping",
+               T: "hi",
+       }, "d1:q4:ping1:t2:hi1:y1:qe")
+       testMarshalUnmarshalMsg(t, Msg{
+               Y: "e",
+               T: "42",
+               E: &KRPCError{Code: 200, Msg: "fuck"},
+       }, "d1:eli200e4:fucke1:t2:421:y1:ee")
+       testMarshalUnmarshalMsg(t, Msg{
+               Y: "r",
+               T: "\x8c%",
+               R: &Return{},
+       }, "d1:rd2:id0:5:token0:e1:t2:\x8c%1:y1:re")
+       testMarshalUnmarshalMsg(t, Msg{
+               Y: "r",
+               T: "\x8c%",
+               R: &Return{
+                       Nodes: CompactIPv4NodeInfo{
+                               NodeInfo{
+                                       Addr: newDHTAddr(&net.UDPAddr{
+                                               IP:   net.IPv4(1, 2, 3, 4),
+                                               Port: 0x1234,
+                                       }),
+                               },
+                       },
+               },
+       }, "d1:rd2:id0:5:nodes26:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x1245:token0:e1:t2:\x8c%1:y1:re")
+       testMarshalUnmarshalMsg(t, Msg{
+               Y: "r",
+               T: "\x8c%",
+               R: &Return{
+                       Values: []util.CompactPeer{
+                               util.CompactPeer{
+                                       IP:   net.IPv4(1, 2, 3, 4).To4(),
+                                       Port: 0x5678,
+                               },
+                       },
+               },
+       }, "d1:rd2:id0:5:token0:6:valuesl6:\x01\x02\x03\x04\x56\x78ee1:t2:\x8c%1:y1:re")
+}
+
+func TestAnnounceTimeout(t *testing.T) {
+       s, err := NewServer(&ServerConfig{
+               BootstrapNodes: []string{"1.2.3.4:5"},
+       })
+       require.NoError(t, err)
+       a, err := s.Announce("12341234123412341234", 0, true)
+       <-a.Peers
+       a.Close()
+       s.Close()
+}
+
+func TestEqualPointers(t *testing.T) {
+       assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}})
+}
diff --git a/dht/krpcError.go b/dht/krpcError.go
new file mode 100644 (file)
index 0000000..c52023e
--- /dev/null
@@ -0,0 +1,45 @@
+package dht
+
+import (
+       "fmt"
+
+       "github.com/anacrolix/torrent/bencode"
+)
+
+// Represented as a string or list in bencode.
+type KRPCError struct {
+       Code int
+       Msg  string
+}
+
+var (
+       _ bencode.Unmarshaler = &KRPCError{}
+       _ bencode.Marshaler   = &KRPCError{}
+       _ error               = KRPCError{}
+)
+
+func (me *KRPCError) UnmarshalBencode(_b []byte) (err error) {
+       var _v interface{}
+       err = bencode.Unmarshal(_b, &_v)
+       if err != nil {
+               return
+       }
+       switch v := _v.(type) {
+       case []interface{}:
+               me.Code = int(v[0].(int64))
+               me.Msg = v[1].(string)
+       case string:
+               me.Msg = v
+       default:
+               err = fmt.Errorf(`KRPC error bencode value has unexpected type: %T`, _v)
+       }
+       return
+}
+
+func (me KRPCError) MarshalBencode() (ret []byte, err error) {
+       return bencode.Marshal([]interface{}{me.Code, me.Msg})
+}
+
+func (me KRPCError) Error() string {
+       return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
+}
diff --git a/dht/msg.go b/dht/msg.go
new file mode 100644 (file)
index 0000000..991a354
--- /dev/null
@@ -0,0 +1,52 @@
+package dht
+
+import (
+       "fmt"
+
+       "github.com/anacrolix/torrent/util"
+)
+
+// The unmarshalled KRPC dict message.
+type Msg struct {
+       Q string `bencode:"q,omitempty"`
+       A *struct {
+               ID       string `bencode:"id"`
+               InfoHash string `bencode:"info_hash"`
+               Target   string `bencode:"target"`
+       } `bencode:"a,omitempty"`
+       T string     `bencode:"t"`
+       Y string     `bencode:"y"`
+       R *Return    `bencode:"r,omitempty"`
+       E *KRPCError `bencode:"e,omitempty"`
+}
+
+type Return struct {
+       ID     string              `bencode:"id"`
+       Nodes  CompactIPv4NodeInfo `bencode:"nodes,omitempty"`
+       Token  string              `bencode:"token"`
+       Values []util.CompactPeer  `bencode:"values,omitempty"`
+}
+
+var _ fmt.Stringer = Msg{}
+
+func (m Msg) String() string {
+       return fmt.Sprintf("%#v", m)
+}
+
+// The node ID of the source of this Msg.
+func (m Msg) SenderID() string {
+       switch m.Y {
+       case "q":
+               return m.A.ID
+       case "r":
+               return m.R.ID
+       }
+       return ""
+}
+
+func (m Msg) Error() *KRPCError {
+       if m.Y != "e" {
+               return nil
+       }
+       return m.E
+}
index 942b9d73d07d750076b9da5654d0ecc47bc66a63..d0655f148f9763b0e37d38b0bb0420e5cf012266 100644 (file)
@@ -4,6 +4,7 @@ import (
        "encoding"
        "encoding/binary"
        "errors"
+       "fmt"
        "net"
 
        "github.com/bradfitz/iter"
@@ -46,6 +47,22 @@ type CompactPeer struct {
        Port int
 }
 
+var (
+       _ bencode.Marshaler   = &CompactPeer{}
+       _ bencode.Unmarshaler = &CompactPeer{}
+)
+
+func (me CompactPeer) MarshalBencode() (ret []byte, err error) {
+       ip := me.IP
+       if ip4 := ip.To4(); ip4 != nil {
+               ip = ip4
+       }
+       ret = make([]byte, len(ip)+2)
+       copy(ret, ip)
+       binary.BigEndian.PutUint16(ret[len(ip):], uint16(me.Port))
+       return bencode.Marshal(ret)
+}
+
 func (me *CompactPeer) UnmarshalBinary(b []byte) error {
        switch len(b) {
        case 18:
@@ -53,7 +70,7 @@ func (me *CompactPeer) UnmarshalBinary(b []byte) error {
        case 6:
                me.IP = make([]byte, 4)
        default:
-               return errors.New("bad length")
+               return fmt.Errorf("bad compact peer string: %q", b)
        }
        copy(me.IP, b)
        b = b[len(me.IP):]
@@ -61,6 +78,15 @@ func (me *CompactPeer) UnmarshalBinary(b []byte) error {
        return nil
 }
 
+func (me *CompactPeer) UnmarshalBencode(b []byte) (err error) {
+       var _b []byte
+       err = bencode.Unmarshal(b, &_b)
+       if err != nil {
+               return
+       }
+       return me.UnmarshalBinary(_b)
+}
+
 func UnmarshalIPv4CompactPeers(b []byte) (ret []CompactPeer, err error) {
        if len(b)%6 != 0 {
                err = errors.New("bad length")