From 259356ccd6d1f921f50400b41bfbd477a8757df8 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Fri, 21 May 2021 23:50:29 +1000
Subject: [PATCH] Rework bencode decoding so it might support embedded structs

---
 bencode/decode.go      | 103 +++++++++++++++++++++--------------------
 bencode/decode_test.go |   3 +-
 bencode/tags.go        |   3 ++
 3 files changed, 58 insertions(+), 51 deletions(-)

diff --git a/bencode/decode.go b/bencode/decode.go
index 8b22fa73..51804614 100644
--- a/bencode/decode.go
+++ b/bencode/decode.go
@@ -205,50 +205,40 @@ func (d *Decoder) parseString(v reflect.Value) error {
 
 // Info for parsing a dict value.
 type dictField struct {
-	Value reflect.Value // Storage for the parsed value.
-	// True if field value should be parsed into Value. If false, the value
-	// should be parsed and discarded.
-	Ok                       bool
-	Set                      func() // Call this after parsing into Value.
-	IgnoreUnmarshalTypeError bool
+	Type reflect.Type
+	Get  func(value reflect.Value) func(reflect.Value)
+	Tags tag
 }
 
 // Returns specifics for parsing a dict field value.
-func getDictField(dict reflect.Value, key string) dictField {
+func getDictField(dict reflect.Type, key string) dictField {
 	// get valuev as a map value or as a struct field
 	switch dict.Kind() {
 	case reflect.Map:
-		value := reflect.New(dict.Type().Elem()).Elem()
 		return dictField{
-			Value: value,
-			Ok:    true,
-			Set: func() {
-				if dict.IsNil() {
-					dict.Set(reflect.MakeMap(dict.Type()))
+			Type: dict.Elem(),
+			Get: func(mapValue reflect.Value) func(reflect.Value) {
+				return func(value reflect.Value) {
+					if mapValue.IsNil() {
+						mapValue.Set(reflect.MakeMap(dict))
+					}
+					// Assigns the value into the map.
+					//log.Printf("map type: %v", mapValue.Type())
+					mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value)
 				}
-				// Assigns the value into the map.
-				dict.SetMapIndex(reflect.ValueOf(key).Convert(dict.Type().Key()), value)
 			},
 		}
 	case reflect.Struct:
-		sf, ok := getStructFieldForKey(dict.Type(), key)
-		if !ok {
-			return dictField{}
-		}
-		if sf.r.PkgPath != "" {
-			panic(&UnmarshalFieldError{
-				Key:   key,
-				Type:  dict.Type(),
-				Field: sf.r,
-			})
-		}
-		return dictField{
-			Value:                    dict.FieldByIndex(sf.r.Index),
-			Ok:                       true,
-			Set:                      func() {},
-			IgnoreUnmarshalTypeError: sf.tag.IgnoreUnmarshalTypeError(),
-		}
+		return getStructFieldForKey(dict, key)
+		//if sf.r.PkgPath != "" {
+		//	panic(&UnmarshalFieldError{
+		//		Key:   key,
+		//		Type:  dict.Type(),
+		//		Field: sf.r,
+		//	})
+		//}
 	default:
+		panic("unimplemented")
 		return dictField{}
 	}
 }
@@ -260,11 +250,12 @@ type structField struct {
 
 var (
 	structFieldsMu sync.Mutex
-	structFields   = map[reflect.Type]map[string]structField{}
+	structFields   = map[reflect.Type]map[string]dictField{}
 )
 
-func parseStructFields(struct_ reflect.Type, each func(string, structField)) {
-	for i, n := 0, struct_.NumField(); i < n; i++ {
+func parseStructFields(struct_ reflect.Type, each func(string, dictField)) {
+	for _i, n := 0, struct_.NumField(); _i < n; _i++ {
+		i := _i
 		f := struct_.Field(i)
 		if f.Anonymous {
 			continue
@@ -278,25 +269,35 @@ func parseStructFields(struct_ reflect.Type, each func(string, structField)) {
 		if key == "" {
 			key = f.Name
 		}
-		each(key, structField{f, tag})
+		each(key, dictField{f.Type, func(value reflect.Value) func(reflect.Value) {
+			return value.Field(i).Set
+		}, tag})
 	}
 }
 
 func saveStructFields(struct_ reflect.Type) {
-	m := make(map[string]structField)
-	parseStructFields(struct_, func(key string, sf structField) {
+	m := make(map[string]dictField)
+	parseStructFields(struct_, func(key string, sf dictField) {
 		m[key] = sf
 	})
 	structFields[struct_] = m
 }
 
-func getStructFieldForKey(struct_ reflect.Type, key string) (f structField, ok bool) {
+func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) {
 	structFieldsMu.Lock()
 	if _, ok := structFields[struct_]; !ok {
 		saveStructFields(struct_)
 	}
-	f, ok = structFields[struct_][key]
+	f, ok := structFields[struct_][key]
 	structFieldsMu.Unlock()
+	if !ok {
+		var discard interface{}
+		return dictField{
+			Type: reflect.TypeOf(discard),
+			Get:  func(reflect.Value) func(reflect.Value) { return func(reflect.Value) {} },
+			Tags: nil,
+		}
+	}
 	return
 }
 
@@ -314,31 +315,33 @@ func (d *Decoder) parseDict(v reflect.Value) error {
 			return nil
 		}
 
-		df := getDictField(v, keyStr)
+		df := getDictField(v.Type(), keyStr)
 
 		// now we need to actually parse it
-		if df.Ok {
-			// log.Printf("parsing ok struct field for key %q", keyStr)
-			ok, err = d.parseValue(df.Value)
-		} else {
+		if df.Type == nil {
 			// Discard the value, there's nowhere to put it.
 			var if_ interface{}
 			if_, ok = d.parseValueInterface()
 			if if_ == nil {
-				err = fmt.Errorf("error parsing value for key %q", keyStr)
+				return fmt.Errorf("error parsing value for key %q", keyStr)
 			}
+			if !ok {
+				return fmt.Errorf("missing value for key %q", keyStr)
+			}
+			continue
 		}
+		setValue := reflect.New(df.Type).Elem()
+		//log.Printf("parsing into %v", setValue.Type())
+		ok, err = d.parseValue(setValue)
 		if err != nil {
-			if _, ok := err.(*UnmarshalTypeError); !ok || !df.IgnoreUnmarshalTypeError {
+			if _, ok := err.(*UnmarshalTypeError); !ok || !df.Tags.IgnoreUnmarshalTypeError() {
 				return fmt.Errorf("parsing value for key %q: %s", keyStr, err)
 			}
 		}
 		if !ok {
 			return fmt.Errorf("missing value for key %q", keyStr)
 		}
-		if df.Ok {
-			df.Set()
-		}
+		df.Get(v)(setValue)
 	}
 }
 
diff --git a/bencode/decode_test.go b/bencode/decode_test.go
index 4b72edbb..056a399a 100644
--- a/bencode/decode_test.go
+++ b/bencode/decode_test.go
@@ -7,6 +7,7 @@ import (
 	"reflect"
 	"testing"
 
+	qt "github.com/frankban/quicktest"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 )
@@ -144,7 +145,7 @@ func TestIgnoreUnmarshalTypeError(t *testing.T) {
 	}{}
 	require.Error(t, Unmarshal([]byte("d6:Normal5:helloe"), &s))
 	assert.NoError(t, Unmarshal([]byte("d6:Ignore5:helloe"), &s))
-	require.Nil(t, Unmarshal([]byte("d6:Ignorei42ee"), &s))
+	qt.Assert(t, Unmarshal([]byte("d6:Ignorei42ee"), &s), qt.IsNil)
 	assert.EqualValues(t, 42, s.Ignore)
 }
 
diff --git a/bencode/tags.go b/bencode/tags.go
index 50bdc72b..d4adeb24 100644
--- a/bencode/tags.go
+++ b/bencode/tags.go
@@ -24,6 +24,9 @@ func (me tag) Key() string {
 }
 
 func (me tag) HasOpt(opt string) bool {
+	if len(me) < 1 {
+		return false
+	}
 	for _, s := range me[1:] {
 		if s == opt {
 			return true
-- 
2.51.0