]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
9900a033c3509b94471be2630076aa390fd85fe6
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 const (
23         Protocol = "\x13BitTorrent protocol"
24 )
25
26 const (
27         Choke         MessageType = iota
28         Unchoke                   // 1
29         Interested                // 2
30         NotInterested             // 3
31         Have                      // 4
32         Bitfield                  // 5
33         Request                   // 6
34         Piece                     // 7
35         Cancel                    // 8
36         Extended      = 20
37
38         HandshakeExtendedID = 0
39 )
40
41 type Message struct {
42         Keepalive            bool
43         Type                 MessageType
44         Index, Begin, Length Integer
45         Piece                []byte
46         Bitfield             []bool
47         ExtendedID           byte
48         ExtendedPayload      []byte
49 }
50
51 func (msg Message) MarshalBinary() (data []byte, err error) {
52         buf := &bytes.Buffer{}
53         if !msg.Keepalive {
54                 err = buf.WriteByte(byte(msg.Type))
55                 if err != nil {
56                         return
57                 }
58                 switch msg.Type {
59                 case Choke, Unchoke, Interested, NotInterested:
60                 case Have:
61                         err = binary.Write(buf, binary.BigEndian, msg.Index)
62                 case Request, Cancel:
63                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
64                                 err = binary.Write(buf, binary.BigEndian, i)
65                                 if err != nil {
66                                         break
67                                 }
68                         }
69                 case Bitfield:
70                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
71                 case Piece:
72                         for _, i := range []Integer{msg.Index, msg.Begin} {
73                                 err = binary.Write(buf, binary.BigEndian, i)
74                                 if err != nil {
75                                         return
76                                 }
77                         }
78                         n, err := buf.Write(msg.Piece)
79                         if err != nil {
80                                 break
81                         }
82                         if n != len(msg.Piece) {
83                                 panic(n)
84                         }
85                 case Extended:
86                         err = buf.WriteByte(msg.ExtendedID)
87                         if err != nil {
88                                 return
89                         }
90                         _, err = buf.Write(msg.ExtendedPayload)
91                 default:
92                         err = fmt.Errorf("unknown message type: %v", msg.Type)
93                 }
94         }
95         data = make([]byte, 4+buf.Len())
96         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
97         if buf.Len() != copy(data[4:], buf.Bytes()) {
98                 panic("bad copy")
99         }
100         return
101 }
102
103 type Decoder struct {
104         R         *bufio.Reader
105         MaxLength Integer // TODO: Should this include the length header or not?
106 }
107
108 func (d *Decoder) Decode(msg *Message) (err error) {
109         var length Integer
110         err = binary.Read(d.R, binary.BigEndian, &length)
111         if err != nil {
112                 return
113         }
114         if length > d.MaxLength {
115                 return errors.New("message too long")
116         }
117         if length == 0 {
118                 msg.Keepalive = true
119                 return
120         }
121         msg.Keepalive = false
122         b := make([]byte, length)
123         _, err = io.ReadFull(d.R, b)
124         if err == io.EOF {
125                 err = io.ErrUnexpectedEOF
126                 return
127         }
128         if err != nil {
129                 return
130         }
131         r := bytes.NewReader(b)
132         defer func() {
133                 written, _ := io.Copy(ioutil.Discard, r)
134                 if written != 0 && err == nil {
135                         err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
136                 } else if err == io.EOF {
137                         err = io.ErrUnexpectedEOF
138                 }
139         }()
140         msg.Keepalive = false
141         c, err := r.ReadByte()
142         if err != nil {
143                 return
144         }
145         msg.Type = MessageType(c)
146         switch msg.Type {
147         case Choke, Unchoke, Interested, NotInterested:
148                 return
149         case Have:
150                 err = msg.Index.Read(r)
151         case Request, Cancel:
152                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
153                         err = data.Read(r)
154                         if err != nil {
155                                 break
156                         }
157                 }
158         case Bitfield:
159                 b := make([]byte, length-1)
160                 _, err = io.ReadFull(r, b)
161                 msg.Bitfield = unmarshalBitfield(b)
162         case Piece:
163                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
164                         err = pi.Read(r)
165                         if err != nil {
166                                 break
167                         }
168                 }
169                 if err != nil {
170                         break
171                 }
172                 msg.Piece, err = ioutil.ReadAll(r)
173         case Extended:
174                 msg.ExtendedID, err = r.ReadByte()
175                 if err != nil {
176                         break
177                 }
178                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
179         default:
180                 err = fmt.Errorf("unknown message type %#v", c)
181         }
182         return
183 }
184
185 type Bytes []byte
186
187 func (b Bytes) MarshalBinary() ([]byte, error) {
188         return b, nil
189 }
190
191 func unmarshalBitfield(b []byte) (bf []bool) {
192         for _, c := range b {
193                 for i := 7; i >= 0; i-- {
194                         bf = append(bf, (c>>uint(i))&1 == 1)
195                 }
196         }
197         return
198 }
199
200 func marshalBitfield(bf []bool) (b []byte) {
201         b = make([]byte, (len(bf)+7)/8)
202         for i, have := range bf {
203                 if !have {
204                         continue
205                 }
206                 c := b[i/8]
207                 c |= 1 << uint(7-i%8)
208                 b[i/8] = c
209         }
210         return
211 }