From c00e6f51ba4c03218f3d7ad83fce62c873e97fd5 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Fri, 1 Mar 2024 14:22:41 +1100 Subject: [PATCH] Implement decoding hash request, reject and hashes --- peer_protocol/decoder.go | 180 ++++++++++++++++++++-------- peer_protocol/fuzz_test.go | 4 +- peer_protocol/messagetype_string.go | 33 ++++- peer_protocol/protocol.go | 7 +- v2hashes.go | 1 + 5 files changed, 165 insertions(+), 60 deletions(-) create mode 100644 v2hashes.go diff --git a/peer_protocol/decoder.go b/peer_protocol/decoder.go index 9dfe125b..49eda436 100644 --- a/peer_protocol/decoder.go +++ b/peer_protocol/decoder.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/binary" "fmt" + g "github.com/anacrolix/generics" "io" "sync" @@ -19,41 +20,109 @@ type Decoder struct { MaxLength Integer // TODO: Should this include the length header or not? } -// io.EOF is returned if the source terminates cleanly on a message boundary. -func (d *Decoder) Decode(msg *Message) (err error) { - var length Integer - err = length.Read(d.R) - if err != nil { - return fmt.Errorf("reading message length: %w", err) +// This limits reads to the length of a message, returning io.EOF when the end of the message bytes +// are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader. +type decodeReader struct { + lr io.LimitedReader + br *bufio.Reader +} + +func (dr *decodeReader) Init(r *bufio.Reader, length int64) { + dr.lr.R = r + dr.lr.N = length + dr.br = r +} + +func (dr *decodeReader) ReadByte() (c byte, err error) { + if dr.lr.N <= 0 { + err = io.EOF + return } - if length > d.MaxLength { - return errors.New("message too long") + c, err = dr.br.ReadByte() + if err == nil { + dr.lr.N-- } - if length == 0 { - msg.Keepalive = true - return + return +} + +func (dr *decodeReader) Read(p []byte) (n int, err error) { + n, err = dr.lr.Read(p) + if dr.lr.N != 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return +} + +func (dr *decodeReader) UnreadLength() int64 { + return dr.lr.N +} + +// This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably +// not a good idea to pass this to functions that expect to read until the end of something, because +// they will probably expect io.EOF. +type expectReader struct { + dr *decodeReader +} + +func (er expectReader) ReadByte() (c byte, err error) { + c, err = er.dr.ReadByte() + if err == io.EOF { + err = io.ErrUnexpectedEOF } - r := d.R - readByte := func() (byte, error) { - length-- - return d.R.ReadByte() + return +} + +func (er expectReader) Read(p []byte) (n int, err error) { + n, err = er.dr.Read(p) + if err == io.EOF { + err = io.ErrUnexpectedEOF } - // From this point onwards, EOF is unexpected - defer func() { - if err == io.EOF { - err = io.ErrUnexpectedEOF + return +} + +func (er expectReader) UnreadLength() int64 { + return er.dr.UnreadLength() +} + +// io.EOF is returned if the source terminates cleanly on a message boundary. +func (d *Decoder) Decode(msg *Message) (err error) { + var dr decodeReader + { + var length Integer + err = length.Read(d.R) + if err != nil { + return fmt.Errorf("reading message length: %w", err) } - }() - c, err := readByte() + if length > d.MaxLength { + return errors.New("message too long") + } + if length == 0 { + msg.Keepalive = true + return + } + dr.Init(d.R, int64(length)) + } + r := expectReader{&dr} + c, err := r.ReadByte() if err != nil { return } msg.Type = MessageType(c) - // Can return directly in cases when err is not nil, or length is known to be zero. + err = readMessageAfterType(msg, &r, d.Pool) + if err != nil { + err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err) + return + } + if r.UnreadLength() != 0 { + err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type) + } + return +} + +func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) { switch msg.Type { case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: case Have, AllowedFast, Suggest: - length -= 4 err = msg.Index.Read(r) case Request, Cancel, Reject: for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} { @@ -62,67 +131,74 @@ func (d *Decoder) Decode(msg *Message) (err error) { break } } - length -= 12 case Bitfield: - b := make([]byte, length) + b := make([]byte, r.UnreadLength()) _, err = io.ReadFull(r, b) - length = 0 msg.Bitfield = unmarshalBitfield(b) - return case Piece: for _, pi := range []*Integer{&msg.Index, &msg.Begin} { - err := pi.Read(r) + err = pi.Read(r) if err != nil { - return err + return } } - length -= 8 - dataLen := int64(length) - if d.Pool == nil { + dataLen := r.UnreadLength() + if piecePool == nil { msg.Piece = make([]byte, dataLen) } else { - msg.Piece = *d.Pool.Get().(*[]byte) + msg.Piece = *piecePool.Get().(*[]byte) if int64(cap(msg.Piece)) < dataLen { return errors.New("piece data longer than expected") } msg.Piece = msg.Piece[:dataLen] } _, err = io.ReadFull(r, msg.Piece) - length = 0 - return case Extended: var b byte - b, err = readByte() + b, err = r.ReadByte() if err != nil { break } msg.ExtendedID = ExtensionNumber(b) - msg.ExtendedPayload = make([]byte, length) + msg.ExtendedPayload = make([]byte, r.UnreadLength()) _, err = io.ReadFull(r, msg.ExtendedPayload) - length = 0 - return case Port: err = binary.Read(r, binary.BigEndian, &msg.Port) - length -= 2 + case HashRequest, HashReject: + err = readHashRequest(r, msg) + case Hashes: + err = readHashRequest(r, msg) + numHashes := (r.UnreadLength() + 31) / 32 + g.MakeSliceWithCap(&msg.Hashes, numHashes) + for range numHashes { + var oneHash [32]byte + _, err = io.ReadFull(r, oneHash[:]) + if err != nil { + err = fmt.Errorf("error while reading hashes: %w", err) + return + } + msg.Hashes = append(msg.Hashes, oneHash) + } default: - err = fmt.Errorf("unknown message type %#v", c) - } - if err == nil && length != 0 { - err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type) + err = errors.New("unhandled message type") } return } -func readByte(r io.Reader) (b byte, err error) { - var arr [1]byte - n, err := r.Read(arr[:]) - b = arr[0] - if n == 1 { - err = nil +func readHashRequest(r io.Reader, msg *Message) (err error) { + _, err = io.ReadFull(r, msg.PiecesRoot[:]) + if err != nil { return } - if err == nil { - panic(err) + return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers) +} + +func readSeq(r io.Reader, data ...any) (err error) { + for _, d := range data { + err = binary.Read(r, binary.BigEndian, d) + if err != nil { + return + } } return } diff --git a/peer_protocol/fuzz_test.go b/peer_protocol/fuzz_test.go index 52415048..8ffdfd7b 100644 --- a/peer_protocol/fuzz_test.go +++ b/peer_protocol/fuzz_test.go @@ -6,7 +6,6 @@ package peer_protocol import ( "bufio" "bytes" - "errors" "io" "testing" @@ -30,7 +29,7 @@ func FuzzDecoder(f *testing.F) { var m Message err := d.Decode(&m) t.Log(err) - if errors.Is(err, io.EOF) { + if err == io.EOF { break } if err == nil { @@ -41,6 +40,7 @@ func FuzzDecoder(f *testing.F) { t.Skip(err) } } + t.Log(ms) var buf bytes.Buffer for _, m := range ms { buf.Write(m.MustMarshalBinary()) diff --git a/peer_protocol/messagetype_string.go b/peer_protocol/messagetype_string.go index 7be19f42..e1ad6a88 100644 --- a/peer_protocol/messagetype_string.go +++ b/peer_protocol/messagetype_string.go @@ -4,15 +4,41 @@ package peer_protocol import "strconv" +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Choke-0] + _ = x[Unchoke-1] + _ = x[Interested-2] + _ = x[NotInterested-3] + _ = x[Have-4] + _ = x[Bitfield-5] + _ = x[Request-6] + _ = x[Piece-7] + _ = x[Cancel-8] + _ = x[Port-9] + _ = x[Suggest-13] + _ = x[HaveAll-14] + _ = x[HaveNone-15] + _ = x[Reject-16] + _ = x[AllowedFast-17] + _ = x[Extended-20] + _ = x[HashRequest-21] + _ = x[Hashes-22] + _ = x[HashReject-23] +} + const ( _MessageType_name_0 = "ChokeUnchokeInterestedNotInterestedHaveBitfieldRequestPieceCancelPort" _MessageType_name_1 = "SuggestHaveAllHaveNoneRejectAllowedFast" - _MessageType_name_2 = "Extended" + _MessageType_name_2 = "ExtendedHashRequestHashesHashReject" ) var ( _MessageType_index_0 = [...]uint8{0, 5, 12, 22, 35, 39, 47, 54, 59, 65, 69} _MessageType_index_1 = [...]uint8{0, 7, 14, 22, 28, 39} + _MessageType_index_2 = [...]uint8{0, 8, 19, 25, 35} ) func (i MessageType) String() string { @@ -22,8 +48,9 @@ func (i MessageType) String() string { case 13 <= i && i <= 17: i -= 13 return _MessageType_name_1[_MessageType_index_1[i]:_MessageType_index_1[i+1]] - case i == 20: - return _MessageType_name_2 + case 20 <= i && i <= 23: + i -= 20 + return _MessageType_name_2[_MessageType_index_2[i]:_MessageType_index_2[i+1]] default: return "MessageType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index 2f92ab32..b1790d1c 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -6,6 +6,7 @@ const ( type MessageType byte +// golang.org/x/tools/cmd/stringer //go:generate stringer -type=MessageType func (mt MessageType) FastExtension() bool { @@ -43,9 +44,9 @@ const ( Extended MessageType = 20 // BEP 52 - HashRequest = 21 - Hashes = 22 - HashReject = 23 + HashRequest MessageType = 21 + Hashes MessageType = 22 + HashReject MessageType = 23 ) const ( diff --git a/v2hashes.go b/v2hashes.go new file mode 100644 index 00000000..10cbafc7 --- /dev/null +++ b/v2hashes.go @@ -0,0 +1 @@ +package torrent -- 2.44.0