]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
Implement the DHT Port message
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 const (
23         Protocol = "\x13BitTorrent protocol"
24 )
25
26 const (
27         Choke         MessageType = iota
28         Unchoke                   // 1
29         Interested                // 2
30         NotInterested             // 3
31         Have                      // 4
32         Bitfield                  // 5
33         Request                   // 6
34         Piece                     // 7
35         Cancel                    // 8
36         Port                      // 9
37         Extended      = 20
38
39         HandshakeExtendedID = 0
40
41         RequestMetadataExtensionMsgType = 0
42         DataMetadataExtensionMsgType    = 1
43         RejectMetadataExtensionMsgType  = 2
44 )
45
46 type Message struct {
47         Keepalive            bool
48         Type                 MessageType
49         Index, Begin, Length Integer
50         Piece                []byte
51         Bitfield             []bool
52         ExtendedID           byte
53         ExtendedPayload      []byte
54         Port                 uint16
55 }
56
57 func (msg Message) MarshalBinary() (data []byte, err error) {
58         buf := &bytes.Buffer{}
59         if !msg.Keepalive {
60                 err = buf.WriteByte(byte(msg.Type))
61                 if err != nil {
62                         return
63                 }
64                 switch msg.Type {
65                 case Choke, Unchoke, Interested, NotInterested:
66                 case Have:
67                         err = binary.Write(buf, binary.BigEndian, msg.Index)
68                 case Request, Cancel:
69                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
70                                 err = binary.Write(buf, binary.BigEndian, i)
71                                 if err != nil {
72                                         break
73                                 }
74                         }
75                 case Bitfield:
76                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
77                 case Piece:
78                         for _, i := range []Integer{msg.Index, msg.Begin} {
79                                 err = binary.Write(buf, binary.BigEndian, i)
80                                 if err != nil {
81                                         return
82                                 }
83                         }
84                         n, err := buf.Write(msg.Piece)
85                         if err != nil {
86                                 break
87                         }
88                         if n != len(msg.Piece) {
89                                 panic(n)
90                         }
91                 case Extended:
92                         err = buf.WriteByte(msg.ExtendedID)
93                         if err != nil {
94                                 return
95                         }
96                         _, err = buf.Write(msg.ExtendedPayload)
97                 case Port:
98                         err = binary.Write(buf, binary.BigEndian, msg.Port)
99                 default:
100                         err = fmt.Errorf("unknown message type: %v", msg.Type)
101                 }
102         }
103         data = make([]byte, 4+buf.Len())
104         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
105         if buf.Len() != copy(data[4:], buf.Bytes()) {
106                 panic("bad copy")
107         }
108         return
109 }
110
111 type Decoder struct {
112         R         *bufio.Reader
113         MaxLength Integer // TODO: Should this include the length header or not?
114 }
115
116 // io.EOF is returned if the source terminates cleanly on a message boundary.
117 func (d *Decoder) Decode(msg *Message) (err error) {
118         var length Integer
119         err = binary.Read(d.R, binary.BigEndian, &length)
120         if err != nil {
121                 if err != io.EOF {
122                         err = fmt.Errorf("error reading message length: %s", err)
123                 }
124                 return
125         }
126         if length > d.MaxLength {
127                 return errors.New("message too long")
128         }
129         if length == 0 {
130                 msg.Keepalive = true
131                 return
132         }
133         msg.Keepalive = false
134         b := make([]byte, length)
135         _, err = io.ReadFull(d.R, b)
136         if err != nil {
137                 if err == io.EOF {
138                         err = io.ErrUnexpectedEOF
139                 }
140                 if err != io.ErrUnexpectedEOF {
141                         err = fmt.Errorf("error reading message: %s", err)
142                 }
143                 return
144         }
145         r := bytes.NewReader(b)
146         // Check that all of r was utilized.
147         defer func() {
148                 if err != nil {
149                         return
150                 }
151                 if r.Len() != 0 {
152                         err = fmt.Errorf("%d bytes unused in message type %d", r.Len(), msg.Type)
153                 }
154         }()
155         msg.Keepalive = false
156         c, err := r.ReadByte()
157         if err != nil {
158                 return
159         }
160         msg.Type = MessageType(c)
161         switch msg.Type {
162         case Choke, Unchoke, Interested, NotInterested:
163                 return
164         case Have:
165                 err = msg.Index.Read(r)
166         case Request, Cancel:
167                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
168                         err = data.Read(r)
169                         if err != nil {
170                                 break
171                         }
172                 }
173         case Bitfield:
174                 b := make([]byte, length-1)
175                 _, err = io.ReadFull(r, b)
176                 msg.Bitfield = unmarshalBitfield(b)
177         case Piece:
178                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
179                         err = pi.Read(r)
180                         if err != nil {
181                                 break
182                         }
183                 }
184                 if err != nil {
185                         break
186                 }
187                 msg.Piece, err = ioutil.ReadAll(r)
188         case Extended:
189                 msg.ExtendedID, err = r.ReadByte()
190                 if err != nil {
191                         break
192                 }
193                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
194         case Port:
195                 err = binary.Read(r, binary.BigEndian, &msg.Port)
196         default:
197                 err = fmt.Errorf("unknown message type %#v", c)
198         }
199         return
200 }
201
202 type Bytes []byte
203
204 func (b Bytes) MarshalBinary() ([]byte, error) {
205         return b, nil
206 }
207
208 func unmarshalBitfield(b []byte) (bf []bool) {
209         for _, c := range b {
210                 for i := 7; i >= 0; i-- {
211                         bf = append(bf, (c>>uint(i))&1 == 1)
212                 }
213         }
214         return
215 }
216
217 func marshalBitfield(bf []bool) (b []byte) {
218         b = make([]byte, (len(bf)+7)/8)
219         for i, have := range bf {
220                 if !have {
221                         continue
222                 }
223                 c := b[i/8]
224                 c |= 1 << uint(7-i%8)
225                 b[i/8] = c
226         }
227         return
228 }