]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/msg.go
Send hash requests for missing v2 hashes
[btrtrc.git] / peer_protocol / msg.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding"
7         "encoding/binary"
8         "fmt"
9 )
10
11 // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
12 // I didn't choose to use type-assertions. Fields are ordered to minimize struct size and padding.
13 type Message struct {
14         PiecesRoot           [32]byte
15         Piece                []byte
16         Bitfield             []bool
17         ExtendedPayload      []byte
18         Hashes               [][32]byte
19         Index, Begin, Length Integer
20         BaseLayer            Integer
21         ProofLayers          Integer
22         Port                 uint16
23         Type                 MessageType
24         ExtendedID           ExtensionNumber
25         Keepalive            bool
26 }
27
28 var _ interface {
29         encoding.BinaryUnmarshaler
30         encoding.BinaryMarshaler
31 } = (*Message)(nil)
32
33 func MakeCancelMessage(piece, offset, length Integer) Message {
34         return Message{
35                 Type:   Cancel,
36                 Index:  piece,
37                 Begin:  offset,
38                 Length: length,
39         }
40 }
41
42 func (msg Message) RequestSpec() (ret RequestSpec) {
43         return RequestSpec{
44                 msg.Index,
45                 msg.Begin,
46                 func() Integer {
47                         if msg.Type == Piece {
48                                 return Integer(len(msg.Piece))
49                         } else {
50                                 return msg.Length
51                         }
52                 }(),
53         }
54 }
55
56 func (msg Message) MustMarshalBinary() []byte {
57         b, err := msg.MarshalBinary()
58         if err != nil {
59                 panic(err)
60         }
61         return b
62 }
63
64 func (msg Message) MarshalBinary() (data []byte, err error) {
65         // It might look like you could have a pool of buffers and preallocate the message length
66         // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
67         // will need a benchmark.
68         var buf bytes.Buffer
69         mustWrite := func(data any) {
70                 err := binary.Write(&buf, binary.BigEndian, data)
71                 if err != nil {
72                         panic(err)
73                 }
74         }
75         writeConsecutive := func(data ...any) {
76                 for _, d := range data {
77                         mustWrite(d)
78                 }
79         }
80         if !msg.Keepalive {
81                 err = buf.WriteByte(byte(msg.Type))
82                 if err != nil {
83                         return
84                 }
85                 switch msg.Type {
86                 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
87                 case Have, AllowedFast, Suggest:
88                         err = binary.Write(&buf, binary.BigEndian, msg.Index)
89                 case Request, Cancel, Reject:
90                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
91                                 err = binary.Write(&buf, binary.BigEndian, i)
92                                 if err != nil {
93                                         break
94                                 }
95                         }
96                 case Bitfield:
97                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
98                 case Piece:
99                         for _, i := range []Integer{msg.Index, msg.Begin} {
100                                 err = binary.Write(&buf, binary.BigEndian, i)
101                                 if err != nil {
102                                         return
103                                 }
104                         }
105                         n, err := buf.Write(msg.Piece)
106                         if err != nil {
107                                 break
108                         }
109                         if n != len(msg.Piece) {
110                                 panic(n)
111                         }
112                 case Extended:
113                         err = buf.WriteByte(byte(msg.ExtendedID))
114                         if err != nil {
115                                 return
116                         }
117                         _, err = buf.Write(msg.ExtendedPayload)
118                 case Port:
119                         err = binary.Write(&buf, binary.BigEndian, msg.Port)
120                 case HashRequest:
121                         buf.Write(msg.PiecesRoot[:])
122                         writeConsecutive(msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers)
123                 default:
124                         err = fmt.Errorf("unknown message type: %v", msg.Type)
125                 }
126         }
127         data = make([]byte, 4+buf.Len())
128         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
129         if buf.Len() != copy(data[4:], buf.Bytes()) {
130                 panic("bad copy")
131         }
132         return
133 }
134
135 func marshalBitfield(bf []bool) (b []byte) {
136         b = make([]byte, (len(bf)+7)/8)
137         for i, have := range bf {
138                 if !have {
139                         continue
140                 }
141                 c := b[i/8]
142                 c |= 1 << uint(7-i%8)
143                 b[i/8] = c
144         }
145         return
146 }
147
148 func (me *Message) UnmarshalBinary(b []byte) error {
149         d := Decoder{
150                 R: bufio.NewReader(bytes.NewReader(b)),
151         }
152         err := d.Decode(me)
153         if err != nil {
154                 return err
155         }
156         if d.R.Buffered() != 0 {
157                 return fmt.Errorf("%d trailing bytes", d.R.Buffered())
158         }
159         return nil
160 }