]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
75e33cab453abd5db59bbf9a0dd1fef07cc9f4c8
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 const (
23         Protocol = "\x13BitTorrent protocol"
24 )
25
26 const (
27         Choke         MessageType = iota
28         Unchoke                   // 1
29         Interested                // 2
30         NotInterested             // 3
31         Have                      // 4
32         Bitfield                  // 5
33         Request                   // 6
34         Piece                     // 7
35         Cancel                    // 8
36 )
37
38 type Message struct {
39         Keepalive            bool
40         Type                 MessageType
41         Index, Begin, Length Integer
42         Piece                []byte
43         Bitfield             []bool
44 }
45
46 func (msg Message) MarshalBinary() (data []byte, err error) {
47         buf := &bytes.Buffer{}
48         if msg.Keepalive {
49                 data = buf.Bytes()
50                 return
51         }
52         err = buf.WriteByte(byte(msg.Type))
53         if err != nil {
54                 return
55         }
56         switch msg.Type {
57         case Choke, Unchoke, Interested, NotInterested:
58         case Have:
59                 err = binary.Write(buf, binary.BigEndian, msg.Index)
60         case Request, Cancel:
61                 for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
62                         err = binary.Write(buf, binary.BigEndian, i)
63                         if err != nil {
64                                 break
65                         }
66                 }
67         case Bitfield:
68                 _, err = buf.Write(marshalBitfield(msg.Bitfield))
69         case Piece:
70                 for _, i := range []Integer{msg.Index, msg.Begin} {
71                         err = binary.Write(buf, binary.BigEndian, i)
72                         if err != nil {
73                                 return
74                         }
75                 }
76                 n, err := buf.Write(msg.Piece)
77                 if err != nil {
78                         break
79                 }
80                 if n != len(msg.Piece) {
81                         panic(n)
82                 }
83         default:
84                 err = fmt.Errorf("unknown message type: %s", msg.Type)
85         }
86         data = make([]byte, 4+buf.Len())
87         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
88         if buf.Len() != copy(data[4:], buf.Bytes()) {
89                 panic("bad copy")
90         }
91         return
92 }
93
94 type Decoder struct {
95         R         *bufio.Reader
96         MaxLength Integer // TODO: Should this include the length header or not?
97 }
98
99 func (d *Decoder) Decode(msg *Message) (err error) {
100         var length Integer
101         err = binary.Read(d.R, binary.BigEndian, &length)
102         if err != nil {
103                 return
104         }
105         if length > d.MaxLength {
106                 return errors.New("message too long")
107         }
108         if length == 0 {
109                 msg.Keepalive = true
110                 return
111         }
112         msg.Keepalive = false
113         b := make([]byte, length)
114         _, err = io.ReadFull(d.R, b)
115         if err == io.EOF {
116                 err = io.ErrUnexpectedEOF
117                 return
118         }
119         if err != nil {
120                 return
121         }
122         r := bytes.NewReader(b)
123         defer func() {
124                 written, _ := io.Copy(ioutil.Discard, r)
125                 if written != 0 && err == nil {
126                         err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
127                 } else if err == io.EOF {
128                         err = io.ErrUnexpectedEOF
129                 }
130         }()
131         msg.Keepalive = false
132         c, err := r.ReadByte()
133         if err != nil {
134                 return
135         }
136         msg.Type = MessageType(c)
137         switch msg.Type {
138         case Choke, Unchoke, Interested, NotInterested:
139                 return
140         case Have:
141                 err = msg.Index.Read(r)
142         case Request, Cancel:
143                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
144                         err = data.Read(r)
145                         if err != nil {
146                                 break
147                         }
148                 }
149         case Bitfield:
150                 b := make([]byte, length-1)
151                 _, err = io.ReadFull(r, b)
152                 msg.Bitfield = unmarshalBitfield(b)
153         case Piece:
154                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
155                         err = pi.Read(r)
156                         if err != nil {
157                                 break
158                         }
159                 }
160                 if err != nil {
161                         break
162                 }
163                 msg.Piece, err = ioutil.ReadAll(r)
164         default:
165                 err = fmt.Errorf("unknown message type %#v", c)
166         }
167         return
168 }
169
170 type Bytes []byte
171
172 func (b Bytes) MarshalBinary() ([]byte, error) {
173         return b, nil
174 }
175
176 func unmarshalBitfield(b []byte) (bf []bool) {
177         for _, c := range b {
178                 for i := 7; i >= 0; i-- {
179                         bf = append(bf, (c>>uint(i))&1 == 1)
180                 }
181         }
182         return
183 }
184
185 func marshalBitfield(bf []bool) (b []byte) {
186         b = make([]byte, (len(bf)+7)/8)
187         for i, have := range bf {
188                 if !have {
189                         continue
190                 }
191                 c := b[i/8]
192                 c |= 1 << uint(7-i%8)
193                 b[i/8] = c
194         }
195         return
196 }