]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
peer_protocol.Decoder.Decode: Avoid allocating another intermediate reader
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11         "sync"
12 )
13
14 type (
15         MessageType byte
16         Integer     uint32
17 )
18
19 func (i *Integer) Read(r io.Reader) error {
20         return binary.Read(r, binary.BigEndian, i)
21 }
22
23 // It's perfectly fine to cast these to an int.
24 func (i Integer) Int() int {
25         return int(i)
26 }
27
28 const (
29         Protocol = "\x13BitTorrent protocol"
30 )
31
32 const (
33         Choke         MessageType = iota
34         Unchoke                   // 1
35         Interested                // 2
36         NotInterested             // 3
37         Have                      // 4
38         Bitfield                  // 5
39         Request                   // 6
40         Piece                     // 7
41         Cancel                    // 8
42         Port                      // 9
43
44         // BEP 6
45         Suggest     = 0xd  // 13
46         HaveAll     = 0xe  // 14
47         HaveNone    = 0xf  // 15
48         Reject      = 0x10 // 16
49         AllowedFast = 0x11 // 17
50
51         Extended = 20
52
53         HandshakeExtendedID = 0
54
55         RequestMetadataExtensionMsgType = 0
56         DataMetadataExtensionMsgType    = 1
57         RejectMetadataExtensionMsgType  = 2
58 )
59
60 type Message struct {
61         Keepalive            bool
62         Type                 MessageType
63         Index, Begin, Length Integer
64         Piece                []byte
65         Bitfield             []bool
66         ExtendedID           byte
67         ExtendedPayload      []byte
68         Port                 uint16
69 }
70
71 func (msg Message) MarshalBinary() (data []byte, err error) {
72         buf := &bytes.Buffer{}
73         if !msg.Keepalive {
74                 err = buf.WriteByte(byte(msg.Type))
75                 if err != nil {
76                         return
77                 }
78                 switch msg.Type {
79                 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
80                 case Have:
81                         err = binary.Write(buf, binary.BigEndian, msg.Index)
82                 case Request, Cancel, Reject:
83                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
84                                 err = binary.Write(buf, binary.BigEndian, i)
85                                 if err != nil {
86                                         break
87                                 }
88                         }
89                 case Bitfield:
90                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
91                 case Piece:
92                         for _, i := range []Integer{msg.Index, msg.Begin} {
93                                 err = binary.Write(buf, binary.BigEndian, i)
94                                 if err != nil {
95                                         return
96                                 }
97                         }
98                         n, err := buf.Write(msg.Piece)
99                         if err != nil {
100                                 break
101                         }
102                         if n != len(msg.Piece) {
103                                 panic(n)
104                         }
105                 case Extended:
106                         err = buf.WriteByte(msg.ExtendedID)
107                         if err != nil {
108                                 return
109                         }
110                         _, err = buf.Write(msg.ExtendedPayload)
111                 case Port:
112                         err = binary.Write(buf, binary.BigEndian, msg.Port)
113                 default:
114                         err = fmt.Errorf("unknown message type: %v", msg.Type)
115                 }
116         }
117         data = make([]byte, 4+buf.Len())
118         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
119         if buf.Len() != copy(data[4:], buf.Bytes()) {
120                 panic("bad copy")
121         }
122         return
123 }
124
125 type Decoder struct {
126         R         *bufio.Reader
127         Pool      *sync.Pool
128         MaxLength Integer // TODO: Should this include the length header or not?
129 }
130
131 func readByte(r io.Reader) (b byte, err error) {
132         var arr [1]byte
133         n, err := r.Read(arr[:])
134         b = arr[0]
135         if n == 1 {
136                 err = nil
137                 return
138         }
139         if err == nil {
140                 panic(err)
141         }
142         return
143 }
144
145 // io.EOF is returned if the source terminates cleanly on a message boundary.
146 func (d *Decoder) Decode(msg *Message) (err error) {
147         var length Integer
148         err = binary.Read(d.R, binary.BigEndian, &length)
149         if err != nil {
150                 if err != io.EOF {
151                         err = fmt.Errorf("error reading message length: %s", err)
152                 }
153                 return
154         }
155         if length > d.MaxLength {
156                 return errors.New("message too long")
157         }
158         if length == 0 {
159                 msg.Keepalive = true
160                 return
161         }
162         msg.Keepalive = false
163         r := &io.LimitedReader{d.R, int64(length)}
164         // Check that all of r was utilized.
165         defer func() {
166                 if err != nil {
167                         return
168                 }
169                 if r.N != 0 {
170                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
171                 }
172         }()
173         msg.Keepalive = false
174         c, err := readByte(r)
175         if err != nil {
176                 return
177         }
178         msg.Type = MessageType(c)
179         switch msg.Type {
180         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
181                 return
182         case Have:
183                 err = msg.Index.Read(r)
184         case Request, Cancel, Reject:
185                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
186                         err = data.Read(r)
187                         if err != nil {
188                                 break
189                         }
190                 }
191         case Bitfield:
192                 b := make([]byte, length-1)
193                 _, err = io.ReadFull(r, b)
194                 msg.Bitfield = unmarshalBitfield(b)
195         case Piece:
196                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
197                         err = pi.Read(r)
198                         if err != nil {
199                                 break
200                         }
201                 }
202                 if err != nil {
203                         break
204                 }
205                 //msg.Piece, err = ioutil.ReadAll(r)
206                 b := d.Pool.Get().([]byte)
207                 n, err := io.ReadFull(r, b)
208                 if err != nil {
209                         if err != io.ErrUnexpectedEOF || n != int(length-9) {
210                                 return err
211                         }
212                         b = b[0:n]
213                 }
214                 msg.Piece = b
215         case Extended:
216                 msg.ExtendedID, err = readByte(r)
217                 if err != nil {
218                         break
219                 }
220                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
221         case Port:
222                 err = binary.Read(r, binary.BigEndian, &msg.Port)
223         default:
224                 err = fmt.Errorf("unknown message type %#v", c)
225         }
226         return
227 }
228
229 type Bytes []byte
230
231 func (b Bytes) MarshalBinary() ([]byte, error) {
232         return b, nil
233 }
234
235 func unmarshalBitfield(b []byte) (bf []bool) {
236         for _, c := range b {
237                 for i := 7; i >= 0; i-- {
238                         bf = append(bf, (c>>uint(i))&1 == 1)
239                 }
240         }
241         return
242 }
243
244 func marshalBitfield(bf []bool) (b []byte) {
245         b = make([]byte, (len(bf)+7)/8)
246         for i, have := range bf {
247                 if !have {
248                         continue
249                 }
250                 c := b[i/8]
251                 c |= 1 << uint(7-i%8)
252                 b[i/8] = c
253         }
254         return
255 }