]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
Add peer_protocol.Integer.Int()
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 // It's perfectly fine to cast these to an int.
23 func (i Integer) Int() int {
24         return int(i)
25 }
26
27 const (
28         Protocol = "\x13BitTorrent protocol"
29 )
30
31 const (
32         Choke         MessageType = iota
33         Unchoke                   // 1
34         Interested                // 2
35         NotInterested             // 3
36         Have                      // 4
37         Bitfield                  // 5
38         Request                   // 6
39         Piece                     // 7
40         Cancel                    // 8
41         Port                      // 9
42
43         // BEP 6
44         Suggest     = 0xd  // 13
45         HaveAll     = 0xe  // 14
46         HaveNone    = 0xf  // 15
47         Reject      = 0x10 // 16
48         AllowedFast = 0x11 // 17
49
50         Extended = 20
51
52         HandshakeExtendedID = 0
53
54         RequestMetadataExtensionMsgType = 0
55         DataMetadataExtensionMsgType    = 1
56         RejectMetadataExtensionMsgType  = 2
57 )
58
59 type Message struct {
60         Keepalive            bool
61         Type                 MessageType
62         Index, Begin, Length Integer
63         Piece                []byte
64         Bitfield             []bool
65         ExtendedID           byte
66         ExtendedPayload      []byte
67         Port                 uint16
68 }
69
70 func (msg Message) MarshalBinary() (data []byte, err error) {
71         buf := &bytes.Buffer{}
72         if !msg.Keepalive {
73                 err = buf.WriteByte(byte(msg.Type))
74                 if err != nil {
75                         return
76                 }
77                 switch msg.Type {
78                 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
79                 case Have:
80                         err = binary.Write(buf, binary.BigEndian, msg.Index)
81                 case Request, Cancel, Reject:
82                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
83                                 err = binary.Write(buf, binary.BigEndian, i)
84                                 if err != nil {
85                                         break
86                                 }
87                         }
88                 case Bitfield:
89                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
90                 case Piece:
91                         for _, i := range []Integer{msg.Index, msg.Begin} {
92                                 err = binary.Write(buf, binary.BigEndian, i)
93                                 if err != nil {
94                                         return
95                                 }
96                         }
97                         n, err := buf.Write(msg.Piece)
98                         if err != nil {
99                                 break
100                         }
101                         if n != len(msg.Piece) {
102                                 panic(n)
103                         }
104                 case Extended:
105                         err = buf.WriteByte(msg.ExtendedID)
106                         if err != nil {
107                                 return
108                         }
109                         _, err = buf.Write(msg.ExtendedPayload)
110                 case Port:
111                         err = binary.Write(buf, binary.BigEndian, msg.Port)
112                 default:
113                         err = fmt.Errorf("unknown message type: %v", msg.Type)
114                 }
115         }
116         data = make([]byte, 4+buf.Len())
117         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
118         if buf.Len() != copy(data[4:], buf.Bytes()) {
119                 panic("bad copy")
120         }
121         return
122 }
123
124 type Decoder struct {
125         R         *bufio.Reader
126         MaxLength Integer // TODO: Should this include the length header or not?
127 }
128
129 // io.EOF is returned if the source terminates cleanly on a message boundary.
130 func (d *Decoder) Decode(msg *Message) (err error) {
131         var length Integer
132         err = binary.Read(d.R, binary.BigEndian, &length)
133         if err != nil {
134                 if err != io.EOF {
135                         err = fmt.Errorf("error reading message length: %s", err)
136                 }
137                 return
138         }
139         if length > d.MaxLength {
140                 return errors.New("message too long")
141         }
142         if length == 0 {
143                 msg.Keepalive = true
144                 return
145         }
146         msg.Keepalive = false
147         b := make([]byte, length)
148         _, err = io.ReadFull(d.R, b)
149         if err != nil {
150                 if err == io.EOF {
151                         err = io.ErrUnexpectedEOF
152                 }
153                 if err != io.ErrUnexpectedEOF {
154                         err = fmt.Errorf("error reading message: %s", err)
155                 }
156                 return
157         }
158         r := bytes.NewReader(b)
159         // Check that all of r was utilized.
160         defer func() {
161                 if err != nil {
162                         return
163                 }
164                 if r.Len() != 0 {
165                         err = fmt.Errorf("%d bytes unused in message type %d", r.Len(), msg.Type)
166                 }
167         }()
168         msg.Keepalive = false
169         c, err := r.ReadByte()
170         if err != nil {
171                 return
172         }
173         msg.Type = MessageType(c)
174         switch msg.Type {
175         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
176                 return
177         case Have:
178                 err = msg.Index.Read(r)
179         case Request, Cancel, Reject:
180                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
181                         err = data.Read(r)
182                         if err != nil {
183                                 break
184                         }
185                 }
186         case Bitfield:
187                 b := make([]byte, length-1)
188                 _, err = io.ReadFull(r, b)
189                 msg.Bitfield = unmarshalBitfield(b)
190         case Piece:
191                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
192                         err = pi.Read(r)
193                         if err != nil {
194                                 break
195                         }
196                 }
197                 if err != nil {
198                         break
199                 }
200                 msg.Piece, err = ioutil.ReadAll(r)
201         case Extended:
202                 msg.ExtendedID, err = r.ReadByte()
203                 if err != nil {
204                         break
205                 }
206                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
207         case Port:
208                 err = binary.Read(r, binary.BigEndian, &msg.Port)
209         default:
210                 err = fmt.Errorf("unknown message type %#v", c)
211         }
212         return
213 }
214
215 type Bytes []byte
216
217 func (b Bytes) MarshalBinary() ([]byte, error) {
218         return b, nil
219 }
220
221 func unmarshalBitfield(b []byte) (bf []bool) {
222         for _, c := range b {
223                 for i := 7; i >= 0; i-- {
224                         bf = append(bf, (c>>uint(i))&1 == 1)
225                 }
226         }
227         return
228 }
229
230 func marshalBitfield(bf []bool) (b []byte) {
231         b = make([]byte, (len(bf)+7)/8)
232         for i, have := range bf {
233                 if !have {
234                         continue
235                 }
236                 c := b[i/8]
237                 c |= 1 << uint(7-i%8)
238                 b[i/8] = c
239         }
240         return
241 }