]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
Keepalives weren't marshalled correctly
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 const (
23         Protocol = "\x13BitTorrent protocol"
24 )
25
26 const (
27         Choke         MessageType = iota
28         Unchoke                   // 1
29         Interested                // 2
30         NotInterested             // 3
31         Have                      // 4
32         Bitfield                  // 5
33         Request                   // 6
34         Piece                     // 7
35         Cancel                    // 8
36 )
37
38 type Message struct {
39         Keepalive            bool
40         Type                 MessageType
41         Index, Begin, Length Integer
42         Piece                []byte
43         Bitfield             []bool
44 }
45
46 func (msg Message) MarshalBinary() (data []byte, err error) {
47         buf := &bytes.Buffer{}
48         if !msg.Keepalive {
49                 err = buf.WriteByte(byte(msg.Type))
50                 if err != nil {
51                         return
52                 }
53                 switch msg.Type {
54                 case Choke, Unchoke, Interested, NotInterested:
55                 case Have:
56                         err = binary.Write(buf, binary.BigEndian, msg.Index)
57                 case Request, Cancel:
58                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
59                                 err = binary.Write(buf, binary.BigEndian, i)
60                                 if err != nil {
61                                         break
62                                 }
63                         }
64                 case Bitfield:
65                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
66                 case Piece:
67                         for _, i := range []Integer{msg.Index, msg.Begin} {
68                                 err = binary.Write(buf, binary.BigEndian, i)
69                                 if err != nil {
70                                         return
71                                 }
72                         }
73                         n, err := buf.Write(msg.Piece)
74                         if err != nil {
75                                 break
76                         }
77                         if n != len(msg.Piece) {
78                                 panic(n)
79                         }
80                 default:
81                         err = fmt.Errorf("unknown message type: %s", msg.Type)
82                 }
83         }
84         data = make([]byte, 4+buf.Len())
85         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
86         if buf.Len() != copy(data[4:], buf.Bytes()) {
87                 panic("bad copy")
88         }
89         return
90 }
91
92 type Decoder struct {
93         R         *bufio.Reader
94         MaxLength Integer // TODO: Should this include the length header or not?
95 }
96
97 func (d *Decoder) Decode(msg *Message) (err error) {
98         var length Integer
99         err = binary.Read(d.R, binary.BigEndian, &length)
100         if err != nil {
101                 return
102         }
103         if length > d.MaxLength {
104                 return errors.New("message too long")
105         }
106         if length == 0 {
107                 msg.Keepalive = true
108                 return
109         }
110         msg.Keepalive = false
111         b := make([]byte, length)
112         _, err = io.ReadFull(d.R, b)
113         if err == io.EOF {
114                 err = io.ErrUnexpectedEOF
115                 return
116         }
117         if err != nil {
118                 return
119         }
120         r := bytes.NewReader(b)
121         defer func() {
122                 written, _ := io.Copy(ioutil.Discard, r)
123                 if written != 0 && err == nil {
124                         err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
125                 } else if err == io.EOF {
126                         err = io.ErrUnexpectedEOF
127                 }
128         }()
129         msg.Keepalive = false
130         c, err := r.ReadByte()
131         if err != nil {
132                 return
133         }
134         msg.Type = MessageType(c)
135         switch msg.Type {
136         case Choke, Unchoke, Interested, NotInterested:
137                 return
138         case Have:
139                 err = msg.Index.Read(r)
140         case Request, Cancel:
141                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
142                         err = data.Read(r)
143                         if err != nil {
144                                 break
145                         }
146                 }
147         case Bitfield:
148                 b := make([]byte, length-1)
149                 _, err = io.ReadFull(r, b)
150                 msg.Bitfield = unmarshalBitfield(b)
151         case Piece:
152                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
153                         err = pi.Read(r)
154                         if err != nil {
155                                 break
156                         }
157                 }
158                 if err != nil {
159                         break
160                 }
161                 msg.Piece, err = ioutil.ReadAll(r)
162         default:
163                 err = fmt.Errorf("unknown message type %#v", c)
164         }
165         return
166 }
167
168 type Bytes []byte
169
170 func (b Bytes) MarshalBinary() ([]byte, error) {
171         return b, nil
172 }
173
174 func unmarshalBitfield(b []byte) (bf []bool) {
175         for _, c := range b {
176                 for i := 7; i >= 0; i-- {
177                         bf = append(bf, (c>>uint(i))&1 == 1)
178                 }
179         }
180         return
181 }
182
183 func marshalBitfield(bf []bool) (b []byte) {
184         b = make([]byte, (len(bf)+7)/8)
185         for i, have := range bf {
186                 if !have {
187                         continue
188                 }
189                 c := b[i/8]
190                 c |= 1 << uint(7-i%8)
191                 b[i/8] = c
192         }
193         return
194 }