]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Rework bencode decoding so it might support embedded structs
authorMatt Joiner <anacrolix@gmail.com>
Fri, 21 May 2021 13:50:29 +0000 (23:50 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 7 Jun 2021 03:01:40 +0000 (13:01 +1000)
bencode/decode.go
bencode/decode_test.go
bencode/tags.go

index 8b22fa734f6f4f19fc26c6d1260d3210165b18e5..51804614d6a6fbc97366975110657a091c489db6 100644 (file)
@@ -205,50 +205,40 @@ func (d *Decoder) parseString(v reflect.Value) error {
 
 // Info for parsing a dict value.
 type dictField struct {
-       Value reflect.Value // Storage for the parsed value.
-       // True if field value should be parsed into Value. If false, the value
-       // should be parsed and discarded.
-       Ok                       bool
-       Set                      func() // Call this after parsing into Value.
-       IgnoreUnmarshalTypeError bool
+       Type reflect.Type
+       Get  func(value reflect.Value) func(reflect.Value)
+       Tags tag
 }
 
 // Returns specifics for parsing a dict field value.
-func getDictField(dict reflect.Value, key string) dictField {
+func getDictField(dict reflect.Type, key string) dictField {
        // get valuev as a map value or as a struct field
        switch dict.Kind() {
        case reflect.Map:
-               value := reflect.New(dict.Type().Elem()).Elem()
                return dictField{
-                       Value: value,
-                       Ok:    true,
-                       Set: func() {
-                               if dict.IsNil() {
-                                       dict.Set(reflect.MakeMap(dict.Type()))
+                       Type: dict.Elem(),
+                       Get: func(mapValue reflect.Value) func(reflect.Value) {
+                               return func(value reflect.Value) {
+                                       if mapValue.IsNil() {
+                                               mapValue.Set(reflect.MakeMap(dict))
+                                       }
+                                       // Assigns the value into the map.
+                                       //log.Printf("map type: %v", mapValue.Type())
+                                       mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value)
                                }
-                               // Assigns the value into the map.
-                               dict.SetMapIndex(reflect.ValueOf(key).Convert(dict.Type().Key()), value)
                        },
                }
        case reflect.Struct:
-               sf, ok := getStructFieldForKey(dict.Type(), key)
-               if !ok {
-                       return dictField{}
-               }
-               if sf.r.PkgPath != "" {
-                       panic(&UnmarshalFieldError{
-                               Key:   key,
-                               Type:  dict.Type(),
-                               Field: sf.r,
-                       })
-               }
-               return dictField{
-                       Value:                    dict.FieldByIndex(sf.r.Index),
-                       Ok:                       true,
-                       Set:                      func() {},
-                       IgnoreUnmarshalTypeError: sf.tag.IgnoreUnmarshalTypeError(),
-               }
+               return getStructFieldForKey(dict, key)
+               //if sf.r.PkgPath != "" {
+               //      panic(&UnmarshalFieldError{
+               //              Key:   key,
+               //              Type:  dict.Type(),
+               //              Field: sf.r,
+               //      })
+               //}
        default:
+               panic("unimplemented")
                return dictField{}
        }
 }
@@ -260,11 +250,12 @@ type structField struct {
 
 var (
        structFieldsMu sync.Mutex
-       structFields   = map[reflect.Type]map[string]structField{}
+       structFields   = map[reflect.Type]map[string]dictField{}
 )
 
-func parseStructFields(struct_ reflect.Type, each func(string, structField)) {
-       for i, n := 0, struct_.NumField(); i < n; i++ {
+func parseStructFields(struct_ reflect.Type, each func(string, dictField)) {
+       for _i, n := 0, struct_.NumField(); _i < n; _i++ {
+               i := _i
                f := struct_.Field(i)
                if f.Anonymous {
                        continue
@@ -278,25 +269,35 @@ func parseStructFields(struct_ reflect.Type, each func(string, structField)) {
                if key == "" {
                        key = f.Name
                }
-               each(key, structField{f, tag})
+               each(key, dictField{f.Type, func(value reflect.Value) func(reflect.Value) {
+                       return value.Field(i).Set
+               }, tag})
        }
 }
 
 func saveStructFields(struct_ reflect.Type) {
-       m := make(map[string]structField)
-       parseStructFields(struct_, func(key string, sf structField) {
+       m := make(map[string]dictField)
+       parseStructFields(struct_, func(key string, sf dictField) {
                m[key] = sf
        })
        structFields[struct_] = m
 }
 
-func getStructFieldForKey(struct_ reflect.Type, key string) (f structField, ok bool) {
+func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) {
        structFieldsMu.Lock()
        if _, ok := structFields[struct_]; !ok {
                saveStructFields(struct_)
        }
-       f, ok = structFields[struct_][key]
+       f, ok := structFields[struct_][key]
        structFieldsMu.Unlock()
+       if !ok {
+               var discard interface{}
+               return dictField{
+                       Type: reflect.TypeOf(discard),
+                       Get:  func(reflect.Value) func(reflect.Value) { return func(reflect.Value) {} },
+                       Tags: nil,
+               }
+       }
        return
 }
 
@@ -314,31 +315,33 @@ func (d *Decoder) parseDict(v reflect.Value) error {
                        return nil
                }
 
-               df := getDictField(v, keyStr)
+               df := getDictField(v.Type(), keyStr)
 
                // now we need to actually parse it
-               if df.Ok {
-                       // log.Printf("parsing ok struct field for key %q", keyStr)
-                       ok, err = d.parseValue(df.Value)
-               } else {
+               if df.Type == nil {
                        // Discard the value, there's nowhere to put it.
                        var if_ interface{}
                        if_, ok = d.parseValueInterface()
                        if if_ == nil {
-                               err = fmt.Errorf("error parsing value for key %q", keyStr)
+                               return fmt.Errorf("error parsing value for key %q", keyStr)
                        }
+                       if !ok {
+                               return fmt.Errorf("missing value for key %q", keyStr)
+                       }
+                       continue
                }
+               setValue := reflect.New(df.Type).Elem()
+               //log.Printf("parsing into %v", setValue.Type())
+               ok, err = d.parseValue(setValue)
                if err != nil {
-                       if _, ok := err.(*UnmarshalTypeError); !ok || !df.IgnoreUnmarshalTypeError {
+                       if _, ok := err.(*UnmarshalTypeError); !ok || !df.Tags.IgnoreUnmarshalTypeError() {
                                return fmt.Errorf("parsing value for key %q: %s", keyStr, err)
                        }
                }
                if !ok {
                        return fmt.Errorf("missing value for key %q", keyStr)
                }
-               if df.Ok {
-                       df.Set()
-               }
+               df.Get(v)(setValue)
        }
 }
 
index 4b72edbb802376fc4389b0aefc29eaf12e581664..056a399a43ed1211151f3695f58e79d66e780ba1 100644 (file)
@@ -7,6 +7,7 @@ import (
        "reflect"
        "testing"
 
+       qt "github.com/frankban/quicktest"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
 )
@@ -144,7 +145,7 @@ func TestIgnoreUnmarshalTypeError(t *testing.T) {
        }{}
        require.Error(t, Unmarshal([]byte("d6:Normal5:helloe"), &s))
        assert.NoError(t, Unmarshal([]byte("d6:Ignore5:helloe"), &s))
-       require.Nil(t, Unmarshal([]byte("d6:Ignorei42ee"), &s))
+       qt.Assert(t, Unmarshal([]byte("d6:Ignorei42ee"), &s), qt.IsNil)
        assert.EqualValues(t, 42, s.Ignore)
 }
 
index 50bdc72b74e1e1cd24480fd5af849066c3ff6570..d4adeb24275ce96fc7ffa9773caba56d5d9d5ca2 100644 (file)
@@ -24,6 +24,9 @@ func (me tag) Key() string {
 }
 
 func (me tag) HasOpt(opt string) bool {
+       if len(me) < 1 {
+               return false
+       }
        for _, s := range me[1:] {
                if s == opt {
                        return true