]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
8c309bcbfcc525327026d41e16350aa77ad35af9
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11         "sync"
12 )
13
14 type (
15         MessageType byte
16         Integer     uint32
17 )
18
19 func (i *Integer) Read(r io.Reader) error {
20         return binary.Read(r, binary.BigEndian, i)
21 }
22
23 // It's perfectly fine to cast these to an int.
24 func (i Integer) Int() int {
25         return int(i)
26 }
27
28 const (
29         Protocol = "\x13BitTorrent protocol"
30 )
31
32 const (
33         Choke         MessageType = iota
34         Unchoke                   // 1
35         Interested                // 2
36         NotInterested             // 3
37         Have                      // 4
38         Bitfield                  // 5
39         Request                   // 6
40         Piece                     // 7
41         Cancel                    // 8
42         Port                      // 9
43
44         // BEP 6
45         Suggest     = 0xd  // 13
46         HaveAll     = 0xe  // 14
47         HaveNone    = 0xf  // 15
48         Reject      = 0x10 // 16
49         AllowedFast = 0x11 // 17
50
51         Extended = 20
52
53         HandshakeExtendedID = 0
54
55         RequestMetadataExtensionMsgType = 0
56         DataMetadataExtensionMsgType    = 1
57         RejectMetadataExtensionMsgType  = 2
58 )
59
60 type Message struct {
61         Keepalive            bool
62         Type                 MessageType
63         Index, Begin, Length Integer
64         Piece                []byte
65         Bitfield             []bool
66         ExtendedID           byte
67         ExtendedPayload      []byte
68         Port                 uint16
69 }
70
71 func (msg Message) MustMarshalBinary() []byte {
72         b, err := msg.MarshalBinary()
73         if err != nil {
74                 panic(err)
75         }
76         return b
77 }
78
79 func (msg Message) MarshalBinary() (data []byte, err error) {
80         buf := &bytes.Buffer{}
81         if !msg.Keepalive {
82                 err = buf.WriteByte(byte(msg.Type))
83                 if err != nil {
84                         return
85                 }
86                 switch msg.Type {
87                 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
88                 case Have:
89                         err = binary.Write(buf, binary.BigEndian, msg.Index)
90                 case Request, Cancel, Reject:
91                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
92                                 err = binary.Write(buf, binary.BigEndian, i)
93                                 if err != nil {
94                                         break
95                                 }
96                         }
97                 case Bitfield:
98                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
99                 case Piece:
100                         for _, i := range []Integer{msg.Index, msg.Begin} {
101                                 err = binary.Write(buf, binary.BigEndian, i)
102                                 if err != nil {
103                                         return
104                                 }
105                         }
106                         n, err := buf.Write(msg.Piece)
107                         if err != nil {
108                                 break
109                         }
110                         if n != len(msg.Piece) {
111                                 panic(n)
112                         }
113                 case Extended:
114                         err = buf.WriteByte(msg.ExtendedID)
115                         if err != nil {
116                                 return
117                         }
118                         _, err = buf.Write(msg.ExtendedPayload)
119                 case Port:
120                         err = binary.Write(buf, binary.BigEndian, msg.Port)
121                 default:
122                         err = fmt.Errorf("unknown message type: %v", msg.Type)
123                 }
124         }
125         data = make([]byte, 4+buf.Len())
126         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
127         if buf.Len() != copy(data[4:], buf.Bytes()) {
128                 panic("bad copy")
129         }
130         return
131 }
132
133 type Decoder struct {
134         R         *bufio.Reader
135         Pool      *sync.Pool
136         MaxLength Integer // TODO: Should this include the length header or not?
137 }
138
139 func readByte(r io.Reader) (b byte, err error) {
140         var arr [1]byte
141         n, err := r.Read(arr[:])
142         b = arr[0]
143         if n == 1 {
144                 err = nil
145                 return
146         }
147         if err == nil {
148                 panic(err)
149         }
150         return
151 }
152
153 // io.EOF is returned if the source terminates cleanly on a message boundary.
154 func (d *Decoder) Decode(msg *Message) (err error) {
155         var length Integer
156         err = binary.Read(d.R, binary.BigEndian, &length)
157         if err != nil {
158                 if err != io.EOF {
159                         err = fmt.Errorf("error reading message length: %s", err)
160                 }
161                 return
162         }
163         if length > d.MaxLength {
164                 return errors.New("message too long")
165         }
166         if length == 0 {
167                 msg.Keepalive = true
168                 return
169         }
170         msg.Keepalive = false
171         r := &io.LimitedReader{d.R, int64(length)}
172         // Check that all of r was utilized.
173         defer func() {
174                 if err != nil {
175                         return
176                 }
177                 if r.N != 0 {
178                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
179                 }
180         }()
181         msg.Keepalive = false
182         c, err := readByte(r)
183         if err != nil {
184                 return
185         }
186         msg.Type = MessageType(c)
187         switch msg.Type {
188         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
189                 return
190         case Have:
191                 err = msg.Index.Read(r)
192         case Request, Cancel, Reject:
193                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
194                         err = data.Read(r)
195                         if err != nil {
196                                 break
197                         }
198                 }
199         case Bitfield:
200                 b := make([]byte, length-1)
201                 _, err = io.ReadFull(r, b)
202                 msg.Bitfield = unmarshalBitfield(b)
203         case Piece:
204                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
205                         err = pi.Read(r)
206                         if err != nil {
207                                 break
208                         }
209                 }
210                 if err != nil {
211                         break
212                 }
213                 //msg.Piece, err = ioutil.ReadAll(r)
214                 b := d.Pool.Get().([]byte)
215                 n, err := io.ReadFull(r, b)
216                 if err != nil {
217                         if err != io.ErrUnexpectedEOF || n != int(length-9) {
218                                 return err
219                         }
220                         b = b[0:n]
221                 }
222                 msg.Piece = b
223         case Extended:
224                 msg.ExtendedID, err = readByte(r)
225                 if err != nil {
226                         break
227                 }
228                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
229         case Port:
230                 err = binary.Read(r, binary.BigEndian, &msg.Port)
231         default:
232                 err = fmt.Errorf("unknown message type %#v", c)
233         }
234         return
235 }
236
237 type Bytes []byte
238
239 func (b Bytes) MarshalBinary() ([]byte, error) {
240         return b, nil
241 }
242
243 func unmarshalBitfield(b []byte) (bf []bool) {
244         for _, c := range b {
245                 for i := 7; i >= 0; i-- {
246                         bf = append(bf, (c>>uint(i))&1 == 1)
247                 }
248         }
249         return
250 }
251
252 func marshalBitfield(bf []bool) (b []byte) {
253         b = make([]byte, (len(bf)+7)/8)
254         for i, have := range bf {
255                 if !have {
256                         continue
257                 }
258                 c := b[i/8]
259                 c |= 1 << uint(7-i%8)
260                 b[i/8] = c
261         }
262         return
263 }
264
265 func MakeCancelMessage(piece, offset, length Integer) Message {
266         return Message{
267                 Type:   Cancel,
268                 Index:  piece,
269                 Begin:  offset,
270                 Length: length,
271         }
272 }