From 1b66994c0a64d2d376a64aeb9e08f5cd5a4c6d1e Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Thu, 30 Sep 2021 09:01:10 +1000
Subject: [PATCH] Add some fuzzing in peer_protocol

---
 peer_protocol/decoder_fuzz_test.go | 23 +++++++++++++++++++++++
 peer_protocol/extended.go          |  5 +++++
 peer_protocol/int.go               |  5 +++++
 peer_protocol/msg.go               | 21 +++++++++++++++++++++
 peer_protocol/msg_fuzz_test.go     | 18 ++++++++++++++++++
 peer_protocol/protocol.go          |  5 +++++
 6 files changed, 77 insertions(+)
 create mode 100644 peer_protocol/decoder_fuzz_test.go
 create mode 100644 peer_protocol/msg_fuzz_test.go

diff --git a/peer_protocol/decoder_fuzz_test.go b/peer_protocol/decoder_fuzz_test.go
new file mode 100644
index 00000000..0fed779c
--- /dev/null
+++ b/peer_protocol/decoder_fuzz_test.go
@@ -0,0 +1,23 @@
+package peer_protocol
+
+import (
+	"bufio"
+	"bytes"
+	"testing"
+)
+
+func FuzzDecoder(f *testing.F) {
+	f.Add([]byte("\x00\x00\x00\x00"))
+	f.Add([]byte("\x00\x00\x00\x01\x00"))
+	f.Add([]byte("\x00\x00\x00\x03\x14\x00"))
+	f.Fuzz(func(t *testing.T, b []byte) {
+		d := Decoder{
+			R: bufio.NewReader(bytes.NewReader(b)),
+		}
+		var m Message
+		err := d.Decode(&m)
+		if err != nil {
+			t.Skip(err)
+		}
+	})
+}
diff --git a/peer_protocol/extended.go b/peer_protocol/extended.go
index cbefee6a..e6d935f7 100644
--- a/peer_protocol/extended.go
+++ b/peer_protocol/extended.go
@@ -31,3 +31,8 @@ const (
 
 	ExtensionDeleteNumber ExtensionNumber = 0
 )
+
+func (me *ExtensionNumber) UnmarshalBinary(b []byte) error {
+	*me = ExtensionNumber(b[0])
+	return nil
+}
diff --git a/peer_protocol/int.go b/peer_protocol/int.go
index a0d7cf89..f0203162 100644
--- a/peer_protocol/int.go
+++ b/peer_protocol/int.go
@@ -1,6 +1,7 @@
 package peer_protocol
 
 import (
+	"bytes"
 	"encoding/binary"
 	"io"
 )
@@ -11,6 +12,10 @@ func (i *Integer) Read(r io.Reader) error {
 	return binary.Read(r, binary.BigEndian, i)
 }
 
+func (i *Integer) UnmarshalBinary(b []byte) error {
+	return i.Read(bytes.NewReader(b))
+}
+
 // It's perfectly fine to cast these to an int. TODO: Or is it?
 func (i Integer) Int() int {
 	return int(i)
diff --git a/peer_protocol/msg.go b/peer_protocol/msg.go
index c0d94e37..bc53afbb 100644
--- a/peer_protocol/msg.go
+++ b/peer_protocol/msg.go
@@ -1,7 +1,9 @@
 package peer_protocol
 
 import (
+	"bufio"
 	"bytes"
+	"encoding"
 	"encoding/binary"
 	"fmt"
 )
@@ -19,6 +21,11 @@ type Message struct {
 	Port                 uint16
 }
 
+var _ interface {
+	encoding.BinaryUnmarshaler
+	encoding.BinaryMarshaler
+} = (*Message)(nil)
+
 func MakeCancelMessage(piece, offset, length Integer) Message {
 	return Message{
 		Type:   Cancel,
@@ -116,3 +123,17 @@ func marshalBitfield(bf []bool) (b []byte) {
 	}
 	return
 }
+
+func (me *Message) UnmarshalBinary(b []byte) error {
+	d := Decoder{
+		R: bufio.NewReader(bytes.NewReader(b)),
+	}
+	err := d.Decode(me)
+	if err != nil {
+		return err
+	}
+	if d.R.Buffered() != 0 {
+		return fmt.Errorf("%d trailing bytes", d.R.Buffered())
+	}
+	return nil
+}
diff --git a/peer_protocol/msg_fuzz_test.go b/peer_protocol/msg_fuzz_test.go
new file mode 100644
index 00000000..9e214b30
--- /dev/null
+++ b/peer_protocol/msg_fuzz_test.go
@@ -0,0 +1,18 @@
+package peer_protocol
+
+import (
+	"testing"
+
+	qt "github.com/frankban/quicktest"
+)
+
+func FuzzMessageMarshalBinary(f *testing.F) {
+	f.Fuzz(func(t *testing.T, b []byte) {
+		var m Message
+		if err := m.UnmarshalBinary(b); err != nil {
+			t.Skip(err)
+		}
+		b0 := m.MustMarshalBinary()
+		qt.Assert(t, b0, qt.DeepEquals, b)
+	})
+}
diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go
index 05c6657a..bfeb6a04 100644
--- a/peer_protocol/protocol.go
+++ b/peer_protocol/protocol.go
@@ -12,6 +12,11 @@ func (mt MessageType) FastExtension() bool {
 	return mt >= Suggest && mt <= AllowedFast
 }
 
+func (mt *MessageType) UnmarshalBinary(b []byte) error {
+	*mt = MessageType(b[0])
+	return nil
+}
+
 const (
 	// BEP 3
 	Choke         MessageType = 0
-- 
2.51.0