]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/decoder.go
Implement decoding hash request, reject and hashes
[btrtrc.git] / peer_protocol / decoder.go
index 9dfe125b1a45e2960e8f746df1fc95ad17bb67d6..49eda4369959275889ce99a7c06681e6c103556d 100644 (file)
@@ -4,6 +4,7 @@ import (
        "bufio"
        "encoding/binary"
        "fmt"
+       g "github.com/anacrolix/generics"
        "io"
        "sync"
 
@@ -19,41 +20,109 @@ type Decoder struct {
        MaxLength Integer // TODO: Should this include the length header or not?
 }
 
-// io.EOF is returned if the source terminates cleanly on a message boundary.
-func (d *Decoder) Decode(msg *Message) (err error) {
-       var length Integer
-       err = length.Read(d.R)
-       if err != nil {
-               return fmt.Errorf("reading message length: %w", err)
+// This limits reads to the length of a message, returning io.EOF when the end of the message bytes
+// are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
+type decodeReader struct {
+       lr io.LimitedReader
+       br *bufio.Reader
+}
+
+func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
+       dr.lr.R = r
+       dr.lr.N = length
+       dr.br = r
+}
+
+func (dr *decodeReader) ReadByte() (c byte, err error) {
+       if dr.lr.N <= 0 {
+               err = io.EOF
+               return
        }
-       if length > d.MaxLength {
-               return errors.New("message too long")
+       c, err = dr.br.ReadByte()
+       if err == nil {
+               dr.lr.N--
        }
-       if length == 0 {
-               msg.Keepalive = true
-               return
+       return
+}
+
+func (dr *decodeReader) Read(p []byte) (n int, err error) {
+       n, err = dr.lr.Read(p)
+       if dr.lr.N != 0 && err == io.EOF {
+               err = io.ErrUnexpectedEOF
+       }
+       return
+}
+
+func (dr *decodeReader) UnreadLength() int64 {
+       return dr.lr.N
+}
+
+// This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
+// not a good idea to pass this to functions that expect to read until the end of something, because
+// they will probably expect io.EOF.
+type expectReader struct {
+       dr *decodeReader
+}
+
+func (er expectReader) ReadByte() (c byte, err error) {
+       c, err = er.dr.ReadByte()
+       if err == io.EOF {
+               err = io.ErrUnexpectedEOF
        }
-       r := d.R
-       readByte := func() (byte, error) {
-               length--
-               return d.R.ReadByte()
+       return
+}
+
+func (er expectReader) Read(p []byte) (n int, err error) {
+       n, err = er.dr.Read(p)
+       if err == io.EOF {
+               err = io.ErrUnexpectedEOF
        }
-       // From this point onwards, EOF is unexpected
-       defer func() {
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
+       return
+}
+
+func (er expectReader) UnreadLength() int64 {
+       return er.dr.UnreadLength()
+}
+
+// io.EOF is returned if the source terminates cleanly on a message boundary.
+func (d *Decoder) Decode(msg *Message) (err error) {
+       var dr decodeReader
+       {
+               var length Integer
+               err = length.Read(d.R)
+               if err != nil {
+                       return fmt.Errorf("reading message length: %w", err)
                }
-       }()
-       c, err := readByte()
+               if length > d.MaxLength {
+                       return errors.New("message too long")
+               }
+               if length == 0 {
+                       msg.Keepalive = true
+                       return
+               }
+               dr.Init(d.R, int64(length))
+       }
+       r := expectReader{&dr}
+       c, err := r.ReadByte()
        if err != nil {
                return
        }
        msg.Type = MessageType(c)
-       // Can return directly in cases when err is not nil, or length is known to be zero.
+       err = readMessageAfterType(msg, &r, d.Pool)
+       if err != nil {
+               err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
+               return
+       }
+       if r.UnreadLength() != 0 {
+               err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
+       }
+       return
+}
+
+func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
        case Have, AllowedFast, Suggest:
-               length -= 4
                err = msg.Index.Read(r)
        case Request, Cancel, Reject:
                for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
@@ -62,67 +131,74 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                                break
                        }
                }
-               length -= 12
        case Bitfield:
-               b := make([]byte, length)
+               b := make([]byte, r.UnreadLength())
                _, err = io.ReadFull(r, b)
-               length = 0
                msg.Bitfield = unmarshalBitfield(b)
-               return
        case Piece:
                for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
-                       err := pi.Read(r)
+                       err = pi.Read(r)
                        if err != nil {
-                               return err
+                               return
                        }
                }
-               length -= 8
-               dataLen := int64(length)
-               if d.Pool == nil {
+               dataLen := r.UnreadLength()
+               if piecePool == nil {
                        msg.Piece = make([]byte, dataLen)
                } else {
-                       msg.Piece = *d.Pool.Get().(*[]byte)
+                       msg.Piece = *piecePool.Get().(*[]byte)
                        if int64(cap(msg.Piece)) < dataLen {
                                return errors.New("piece data longer than expected")
                        }
                        msg.Piece = msg.Piece[:dataLen]
                }
                _, err = io.ReadFull(r, msg.Piece)
-               length = 0
-               return
        case Extended:
                var b byte
-               b, err = readByte()
+               b, err = r.ReadByte()
                if err != nil {
                        break
                }
                msg.ExtendedID = ExtensionNumber(b)
-               msg.ExtendedPayload = make([]byte, length)
+               msg.ExtendedPayload = make([]byte, r.UnreadLength())
                _, err = io.ReadFull(r, msg.ExtendedPayload)
-               length = 0
-               return
        case Port:
                err = binary.Read(r, binary.BigEndian, &msg.Port)
-               length -= 2
+       case HashRequest, HashReject:
+               err = readHashRequest(r, msg)
+       case Hashes:
+               err = readHashRequest(r, msg)
+               numHashes := (r.UnreadLength() + 31) / 32
+               g.MakeSliceWithCap(&msg.Hashes, numHashes)
+               for range numHashes {
+                       var oneHash [32]byte
+                       _, err = io.ReadFull(r, oneHash[:])
+                       if err != nil {
+                               err = fmt.Errorf("error while reading hashes: %w", err)
+                               return
+                       }
+                       msg.Hashes = append(msg.Hashes, oneHash)
+               }
        default:
-               err = fmt.Errorf("unknown message type %#v", c)
-       }
-       if err == nil && length != 0 {
-               err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
+               err = errors.New("unhandled message type")
        }
        return
 }
 
-func readByte(r io.Reader) (b byte, err error) {
-       var arr [1]byte
-       n, err := r.Read(arr[:])
-       b = arr[0]
-       if n == 1 {
-               err = nil
+func readHashRequest(r io.Reader, msg *Message) (err error) {
+       _, err = io.ReadFull(r, msg.PiecesRoot[:])
+       if err != nil {
                return
        }
-       if err == nil {
-               panic(err)
+       return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
+}
+
+func readSeq(r io.Reader, data ...any) (err error) {
+       for _, d := range data {
+               err = binary.Read(r, binary.BigEndian, d)
+               if err != nil {
+                       return
+               }
        }
        return
 }