]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
Implementing bitfields and connection message handling
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10 )
11
12 type (
13         MessageType byte
14         Integer     uint32
15 )
16
17 func (i *Integer) Read(r io.Reader) error {
18         return binary.Read(r, binary.BigEndian, i)
19 }
20
21 const (
22         Protocol = "\x13BitTorrent protocol"
23 )
24
25 const (
26         Choke MessageType = iota
27         Unchoke
28         Interested
29         NotInterested
30         Have
31         Bitfield
32         Request
33         Piece
34         Cancel
35 )
36
37 type Message struct {
38         Keepalive            bool
39         Type                 MessageType
40         Index, Begin, Length Integer
41         Piece                []byte
42         Bitfield             []bool
43 }
44
45 func (msg Message) MarshalBinary() (data []byte, err error) {
46         buf := &bytes.Buffer{}
47         if msg.Keepalive {
48                 data = buf.Bytes()
49                 return
50         }
51         err = buf.WriteByte(byte(msg.Type))
52         if err != nil {
53                 return
54         }
55         switch msg.Type {
56         case Choke, Unchoke, Interested, NotInterested:
57         case Have:
58                 err = binary.Write(buf, binary.BigEndian, msg.Index)
59         case Request, Cancel:
60                 for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
61                         err = binary.Write(buf, binary.BigEndian, i)
62                         if err != nil {
63                                 break
64                         }
65                 }
66         case Bitfield:
67                 _, err = buf.Write(marshalBitfield(msg.Bitfield))
68         default:
69                 err = errors.New("unknown message type")
70         }
71         data = buf.Bytes()
72         return
73 }
74
75 type Decoder struct {
76         R         *bufio.Reader
77         MaxLength Integer
78 }
79
80 func (d *Decoder) Decode(msg *Message) (err error) {
81         var length Integer
82         err = binary.Read(d.R, binary.BigEndian, &length)
83         if err != nil {
84                 return
85         }
86         if length > d.MaxLength {
87                 return errors.New("message too long")
88         }
89         if length == 0 {
90                 msg.Keepalive = true
91                 return
92         }
93         msg.Keepalive = false
94         c, err := d.R.ReadByte()
95         if err != nil {
96                 return
97         }
98         msg.Type = MessageType(c)
99         switch msg.Type {
100         case Choke, Unchoke, Interested, NotInterested:
101                 return
102         case Have:
103                 err = msg.Index.Read(d.R)
104         case Request, Cancel:
105                 err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
106         case Bitfield:
107                 b := make([]byte, length-1)
108                 _, err = io.ReadFull(d.R, b)
109                 msg.Bitfield = unmarshalBitfield(b)
110         default:
111                 err = fmt.Errorf("unknown message type %#v", c)
112         }
113         return
114 }
115
116 func encodeMessage(type_ MessageType, data interface{}) []byte {
117         w := &bytes.Buffer{}
118         w.WriteByte(byte(type_))
119         err := binary.Write(w, binary.BigEndian, data)
120         if err != nil {
121                 panic(err)
122         }
123         return w.Bytes()
124 }
125
126 type Bytes []byte
127
128 func (b Bytes) MarshalBinary() ([]byte, error) {
129         return b, nil
130 }
131
132 func unmarshalBitfield(b []byte) (bf []bool) {
133         for _, c := range b {
134                 for i := 7; i >= 0; i-- {
135                         bf = append(bf, (c>>uint(i))&1 == 1)
136                 }
137         }
138         return
139 }
140
141 func marshalBitfield(bf []bool) (b []byte) {
142         b = make([]byte, (len(bf)+7)/8)
143         for i, have := range bf {
144                 if !have {
145                         continue
146                 }
147                 c := b[i/8]
148                 c |= 1 << uint(7-i%8)
149                 b[i/8] = c
150         }
151         return
152 }