]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Implement decoding hash request, reject and hashes
authorMatt Joiner <anacrolix@gmail.com>
Fri, 1 Mar 2024 03:22:41 +0000 (14:22 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 2 Mar 2024 02:02:55 +0000 (13:02 +1100)
peer_protocol/decoder.go
peer_protocol/fuzz_test.go
peer_protocol/messagetype_string.go
peer_protocol/protocol.go
v2hashes.go [new file with mode: 0644]

index 9dfe125b1a45e2960e8f746df1fc95ad17bb67d6..49eda4369959275889ce99a7c06681e6c103556d 100644 (file)
@@ -4,6 +4,7 @@ import (
        "bufio"
        "encoding/binary"
        "fmt"
+       g "github.com/anacrolix/generics"
        "io"
        "sync"
 
@@ -19,41 +20,109 @@ type Decoder struct {
        MaxLength Integer // TODO: Should this include the length header or not?
 }
 
-// io.EOF is returned if the source terminates cleanly on a message boundary.
-func (d *Decoder) Decode(msg *Message) (err error) {
-       var length Integer
-       err = length.Read(d.R)
-       if err != nil {
-               return fmt.Errorf("reading message length: %w", err)
+// This limits reads to the length of a message, returning io.EOF when the end of the message bytes
+// are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
+type decodeReader struct {
+       lr io.LimitedReader
+       br *bufio.Reader
+}
+
+func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
+       dr.lr.R = r
+       dr.lr.N = length
+       dr.br = r
+}
+
+func (dr *decodeReader) ReadByte() (c byte, err error) {
+       if dr.lr.N <= 0 {
+               err = io.EOF
+               return
        }
-       if length > d.MaxLength {
-               return errors.New("message too long")
+       c, err = dr.br.ReadByte()
+       if err == nil {
+               dr.lr.N--
        }
-       if length == 0 {
-               msg.Keepalive = true
-               return
+       return
+}
+
+func (dr *decodeReader) Read(p []byte) (n int, err error) {
+       n, err = dr.lr.Read(p)
+       if dr.lr.N != 0 && err == io.EOF {
+               err = io.ErrUnexpectedEOF
+       }
+       return
+}
+
+func (dr *decodeReader) UnreadLength() int64 {
+       return dr.lr.N
+}
+
+// This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
+// not a good idea to pass this to functions that expect to read until the end of something, because
+// they will probably expect io.EOF.
+type expectReader struct {
+       dr *decodeReader
+}
+
+func (er expectReader) ReadByte() (c byte, err error) {
+       c, err = er.dr.ReadByte()
+       if err == io.EOF {
+               err = io.ErrUnexpectedEOF
        }
-       r := d.R
-       readByte := func() (byte, error) {
-               length--
-               return d.R.ReadByte()
+       return
+}
+
+func (er expectReader) Read(p []byte) (n int, err error) {
+       n, err = er.dr.Read(p)
+       if err == io.EOF {
+               err = io.ErrUnexpectedEOF
        }
-       // From this point onwards, EOF is unexpected
-       defer func() {
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
+       return
+}
+
+func (er expectReader) UnreadLength() int64 {
+       return er.dr.UnreadLength()
+}
+
+// io.EOF is returned if the source terminates cleanly on a message boundary.
+func (d *Decoder) Decode(msg *Message) (err error) {
+       var dr decodeReader
+       {
+               var length Integer
+               err = length.Read(d.R)
+               if err != nil {
+                       return fmt.Errorf("reading message length: %w", err)
                }
-       }()
-       c, err := readByte()
+               if length > d.MaxLength {
+                       return errors.New("message too long")
+               }
+               if length == 0 {
+                       msg.Keepalive = true
+                       return
+               }
+               dr.Init(d.R, int64(length))
+       }
+       r := expectReader{&dr}
+       c, err := r.ReadByte()
        if err != nil {
                return
        }
        msg.Type = MessageType(c)
-       // Can return directly in cases when err is not nil, or length is known to be zero.
+       err = readMessageAfterType(msg, &r, d.Pool)
+       if err != nil {
+               err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
+               return
+       }
+       if r.UnreadLength() != 0 {
+               err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
+       }
+       return
+}
+
+func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
        case Have, AllowedFast, Suggest:
-               length -= 4
                err = msg.Index.Read(r)
        case Request, Cancel, Reject:
                for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
@@ -62,67 +131,74 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                                break
                        }
                }
-               length -= 12
        case Bitfield:
-               b := make([]byte, length)
+               b := make([]byte, r.UnreadLength())
                _, err = io.ReadFull(r, b)
-               length = 0
                msg.Bitfield = unmarshalBitfield(b)
-               return
        case Piece:
                for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
-                       err := pi.Read(r)
+                       err = pi.Read(r)
                        if err != nil {
-                               return err
+                               return
                        }
                }
-               length -= 8
-               dataLen := int64(length)
-               if d.Pool == nil {
+               dataLen := r.UnreadLength()
+               if piecePool == nil {
                        msg.Piece = make([]byte, dataLen)
                } else {
-                       msg.Piece = *d.Pool.Get().(*[]byte)
+                       msg.Piece = *piecePool.Get().(*[]byte)
                        if int64(cap(msg.Piece)) < dataLen {
                                return errors.New("piece data longer than expected")
                        }
                        msg.Piece = msg.Piece[:dataLen]
                }
                _, err = io.ReadFull(r, msg.Piece)
-               length = 0
-               return
        case Extended:
                var b byte
-               b, err = readByte()
+               b, err = r.ReadByte()
                if err != nil {
                        break
                }
                msg.ExtendedID = ExtensionNumber(b)
-               msg.ExtendedPayload = make([]byte, length)
+               msg.ExtendedPayload = make([]byte, r.UnreadLength())
                _, err = io.ReadFull(r, msg.ExtendedPayload)
-               length = 0
-               return
        case Port:
                err = binary.Read(r, binary.BigEndian, &msg.Port)
-               length -= 2
+       case HashRequest, HashReject:
+               err = readHashRequest(r, msg)
+       case Hashes:
+               err = readHashRequest(r, msg)
+               numHashes := (r.UnreadLength() + 31) / 32
+               g.MakeSliceWithCap(&msg.Hashes, numHashes)
+               for range numHashes {
+                       var oneHash [32]byte
+                       _, err = io.ReadFull(r, oneHash[:])
+                       if err != nil {
+                               err = fmt.Errorf("error while reading hashes: %w", err)
+                               return
+                       }
+                       msg.Hashes = append(msg.Hashes, oneHash)
+               }
        default:
-               err = fmt.Errorf("unknown message type %#v", c)
-       }
-       if err == nil && length != 0 {
-               err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
+               err = errors.New("unhandled message type")
        }
        return
 }
 
-func readByte(r io.Reader) (b byte, err error) {
-       var arr [1]byte
-       n, err := r.Read(arr[:])
-       b = arr[0]
-       if n == 1 {
-               err = nil
+func readHashRequest(r io.Reader, msg *Message) (err error) {
+       _, err = io.ReadFull(r, msg.PiecesRoot[:])
+       if err != nil {
                return
        }
-       if err == nil {
-               panic(err)
+       return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
+}
+
+func readSeq(r io.Reader, data ...any) (err error) {
+       for _, d := range data {
+               err = binary.Read(r, binary.BigEndian, d)
+               if err != nil {
+                       return
+               }
        }
        return
 }
index 5241504853dfe90f4b4d904395ff943d87de704b..8ffdfd7b47dc55b66ff8725f9d24d7abac55fa01 100644 (file)
@@ -6,7 +6,6 @@ package peer_protocol
 import (
        "bufio"
        "bytes"
-       "errors"
        "io"
        "testing"
 
@@ -30,7 +29,7 @@ func FuzzDecoder(f *testing.F) {
                        var m Message
                        err := d.Decode(&m)
                        t.Log(err)
-                       if errors.Is(err, io.EOF) {
+                       if err == io.EOF {
                                break
                        }
                        if err == nil {
@@ -41,6 +40,7 @@ func FuzzDecoder(f *testing.F) {
                                t.Skip(err)
                        }
                }
+               t.Log(ms)
                var buf bytes.Buffer
                for _, m := range ms {
                        buf.Write(m.MustMarshalBinary())
index 7be19f4275b8a286084bf33564ebb4f44db73d92..e1ad6a88a83cc71825262bdd7adbf9bc03e074ad 100644 (file)
@@ -4,15 +4,41 @@ package peer_protocol
 
 import "strconv"
 
+func _() {
+       // An "invalid array index" compiler error signifies that the constant values have changed.
+       // Re-run the stringer command to generate them again.
+       var x [1]struct{}
+       _ = x[Choke-0]
+       _ = x[Unchoke-1]
+       _ = x[Interested-2]
+       _ = x[NotInterested-3]
+       _ = x[Have-4]
+       _ = x[Bitfield-5]
+       _ = x[Request-6]
+       _ = x[Piece-7]
+       _ = x[Cancel-8]
+       _ = x[Port-9]
+       _ = x[Suggest-13]
+       _ = x[HaveAll-14]
+       _ = x[HaveNone-15]
+       _ = x[Reject-16]
+       _ = x[AllowedFast-17]
+       _ = x[Extended-20]
+       _ = x[HashRequest-21]
+       _ = x[Hashes-22]
+       _ = x[HashReject-23]
+}
+
 const (
        _MessageType_name_0 = "ChokeUnchokeInterestedNotInterestedHaveBitfieldRequestPieceCancelPort"
        _MessageType_name_1 = "SuggestHaveAllHaveNoneRejectAllowedFast"
-       _MessageType_name_2 = "Extended"
+       _MessageType_name_2 = "ExtendedHashRequestHashesHashReject"
 )
 
 var (
        _MessageType_index_0 = [...]uint8{0, 5, 12, 22, 35, 39, 47, 54, 59, 65, 69}
        _MessageType_index_1 = [...]uint8{0, 7, 14, 22, 28, 39}
+       _MessageType_index_2 = [...]uint8{0, 8, 19, 25, 35}
 )
 
 func (i MessageType) String() string {
@@ -22,8 +48,9 @@ func (i MessageType) String() string {
        case 13 <= i && i <= 17:
                i -= 13
                return _MessageType_name_1[_MessageType_index_1[i]:_MessageType_index_1[i+1]]
-       case i == 20:
-               return _MessageType_name_2
+       case 20 <= i && i <= 23:
+               i -= 20
+               return _MessageType_name_2[_MessageType_index_2[i]:_MessageType_index_2[i+1]]
        default:
                return "MessageType(" + strconv.FormatInt(int64(i), 10) + ")"
        }
index 2f92ab324d2315abd7b8d483767a9b0c1080ab70..b1790d1c1038ee59258f7b09f0ac333549174100 100644 (file)
@@ -6,6 +6,7 @@ const (
 
 type MessageType byte
 
+// golang.org/x/tools/cmd/stringer
 //go:generate stringer -type=MessageType
 
 func (mt MessageType) FastExtension() bool {
@@ -43,9 +44,9 @@ const (
        Extended MessageType = 20
 
        // BEP 52
-       HashRequest = 21
-       Hashes      = 22
-       HashReject  = 23
+       HashRequest MessageType = 21
+       Hashes      MessageType = 22
+       HashReject  MessageType = 23
 )
 
 const (
diff --git a/v2hashes.go b/v2hashes.go
new file mode 100644 (file)
index 0000000..10cbafc
--- /dev/null
@@ -0,0 +1 @@
+package torrent