]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
db32ac59ce99f31e86dd281b59acda40ba9535ab
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 const (
23         Protocol = "\x13BitTorrent protocol"
24 )
25
26 const (
27         Choke         MessageType = iota
28         Unchoke                   // 1
29         Interested                // 2
30         NotInterested             // 3
31         Have                      // 4
32         Bitfield                  // 5
33         Request                   // 6
34         Piece                     // 7
35         Cancel                    // 8
36         Extended      = 20
37
38         HandshakeExtendedID = 0
39
40         RequestMetadataExtensionMsgType = 0
41         DataMetadataExtensionMsgType    = 1
42         RejectMetadataExtensionMsgType  = 2
43 )
44
45 type Message struct {
46         Keepalive            bool
47         Type                 MessageType
48         Index, Begin, Length Integer
49         Piece                []byte
50         Bitfield             []bool
51         ExtendedID           byte
52         ExtendedPayload      []byte
53 }
54
55 func (msg Message) MarshalBinary() (data []byte, err error) {
56         buf := &bytes.Buffer{}
57         if !msg.Keepalive {
58                 err = buf.WriteByte(byte(msg.Type))
59                 if err != nil {
60                         return
61                 }
62                 switch msg.Type {
63                 case Choke, Unchoke, Interested, NotInterested:
64                 case Have:
65                         err = binary.Write(buf, binary.BigEndian, msg.Index)
66                 case Request, Cancel:
67                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
68                                 err = binary.Write(buf, binary.BigEndian, i)
69                                 if err != nil {
70                                         break
71                                 }
72                         }
73                 case Bitfield:
74                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
75                 case Piece:
76                         for _, i := range []Integer{msg.Index, msg.Begin} {
77                                 err = binary.Write(buf, binary.BigEndian, i)
78                                 if err != nil {
79                                         return
80                                 }
81                         }
82                         n, err := buf.Write(msg.Piece)
83                         if err != nil {
84                                 break
85                         }
86                         if n != len(msg.Piece) {
87                                 panic(n)
88                         }
89                 case Extended:
90                         err = buf.WriteByte(msg.ExtendedID)
91                         if err != nil {
92                                 return
93                         }
94                         _, err = buf.Write(msg.ExtendedPayload)
95                 default:
96                         err = fmt.Errorf("unknown message type: %v", msg.Type)
97                 }
98         }
99         data = make([]byte, 4+buf.Len())
100         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
101         if buf.Len() != copy(data[4:], buf.Bytes()) {
102                 panic("bad copy")
103         }
104         return
105 }
106
107 type Decoder struct {
108         R         *bufio.Reader
109         MaxLength Integer // TODO: Should this include the length header or not?
110 }
111
112 // io.EOF is returned if the source terminates cleanly on a message boundary.
113 func (d *Decoder) Decode(msg *Message) (err error) {
114         var length Integer
115         err = binary.Read(d.R, binary.BigEndian, &length)
116         if err != nil {
117                 if err != io.EOF {
118                         err = fmt.Errorf("error reading message length: %s", err)
119                 }
120                 return
121         }
122         if length > d.MaxLength {
123                 return errors.New("message too long")
124         }
125         if length == 0 {
126                 msg.Keepalive = true
127                 return
128         }
129         msg.Keepalive = false
130         b := make([]byte, length)
131         _, err = io.ReadFull(d.R, b)
132         if err != nil {
133                 if err == io.EOF {
134                         err = io.ErrUnexpectedEOF
135                 }
136                 if err != io.ErrUnexpectedEOF {
137                         err = fmt.Errorf("error reading message: %s", err)
138                 }
139                 return
140         }
141         r := bytes.NewReader(b)
142         // Check that all of r was utilized.
143         defer func() {
144                 if err != nil {
145                         return
146                 }
147                 if r.Len() != 0 {
148                         err = fmt.Errorf("%d bytes unused in message type %d", r.Len(), msg.Type)
149                 }
150         }()
151         msg.Keepalive = false
152         c, err := r.ReadByte()
153         if err != nil {
154                 return
155         }
156         msg.Type = MessageType(c)
157         switch msg.Type {
158         case Choke, Unchoke, Interested, NotInterested:
159                 return
160         case Have:
161                 err = msg.Index.Read(r)
162         case Request, Cancel:
163                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
164                         err = data.Read(r)
165                         if err != nil {
166                                 break
167                         }
168                 }
169         case Bitfield:
170                 b := make([]byte, length-1)
171                 _, err = io.ReadFull(r, b)
172                 msg.Bitfield = unmarshalBitfield(b)
173         case Piece:
174                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
175                         err = pi.Read(r)
176                         if err != nil {
177                                 break
178                         }
179                 }
180                 if err != nil {
181                         break
182                 }
183                 msg.Piece, err = ioutil.ReadAll(r)
184         case Extended:
185                 msg.ExtendedID, err = r.ReadByte()
186                 if err != nil {
187                         break
188                 }
189                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
190         default:
191                 err = fmt.Errorf("unknown message type %#v", c)
192         }
193         return
194 }
195
196 type Bytes []byte
197
198 func (b Bytes) MarshalBinary() ([]byte, error) {
199         return b, nil
200 }
201
202 func unmarshalBitfield(b []byte) (bf []bool) {
203         for _, c := range b {
204                 for i := 7; i >= 0; i-- {
205                         bf = append(bf, (c>>uint(i))&1 == 1)
206                 }
207         }
208         return
209 }
210
211 func marshalBitfield(bf []bool) (b []byte) {
212         b = make([]byte, (len(bf)+7)/8)
213         for i, have := range bf {
214                 if !have {
215                         continue
216                 }
217                 c := b[i/8]
218                 c |= 1 << uint(7-i%8)
219                 b[i/8] = c
220         }
221         return
222 }