]> Sergey Matveev's repositories - btrtrc.git/commitdiff
bencode: Support anonymous embedded struct pointers
authorMatt Joiner <anacrolix@gmail.com>
Sat, 22 May 2021 01:02:39 +0000 (11:02 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 7 Jun 2021 03:01:40 +0000 (13:01 +1000)
More to come if this line of improvement is retained.

bencode/decode.go
bencode/encode.go

index 51804614d6a6fbc97366975110657a091c489db6..e891f5be0dc8b868d136d3003721a45729e365e8 100644 (file)
@@ -243,21 +243,27 @@ func getDictField(dict reflect.Type, key string) dictField {
        }
 }
 
-type structField struct {
-       r   reflect.StructField
-       tag tag
-}
-
 var (
        structFieldsMu sync.Mutex
        structFields   = map[reflect.Type]map[string]dictField{}
 )
 
-func parseStructFields(struct_ reflect.Type, each func(string, dictField)) {
+func parseStructFields(struct_ reflect.Type, each func(key string, df dictField)) {
        for _i, n := 0, struct_.NumField(); _i < n; _i++ {
                i := _i
                f := struct_.Field(i)
                if f.Anonymous {
+                       parseStructFields(f.Type.Elem(), func(key string, df dictField) {
+                               innerGet := df.Get
+                               df.Get = func(value reflect.Value) func(reflect.Value) {
+                                       anonPtr := value.Field(i)
+                                       if anonPtr.IsNil() {
+                                               anonPtr.Set(reflect.New(f.Type.Elem()))
+                                       }
+                                       return innerGet(anonPtr.Elem())
+                               }
+                               each(key, df)
+                       })
                        continue
                }
                tagStr := f.Tag.Get("bencode")
index 443c11e7faf8b3b488d6dd1590297a8159ed7ccd..f25cfef82bdf4ed70df4728f5f3bf119a9a1c462 100644 (file)
@@ -133,13 +133,16 @@ func (e *Encoder) reflectValue(v reflect.Value) {
                e.reflectString(v.String())
        case reflect.Struct:
                e.writeString("d")
-               for _, ef := range encodeFields(v.Type()) {
-                       field_value := v.Field(ef.i)
-                       if ef.omit_empty && isEmptyValue(field_value) {
+               for _, ef := range getEncodeFields(v.Type()) {
+                       fieldValue := ef.i(v)
+                       if !fieldValue.IsValid() {
+                               continue
+                       }
+                       if ef.omitEmpty && isEmptyValue(fieldValue) {
                                continue
                        }
                        e.reflectString(ef.tag)
-                       e.reflectValue(field_value)
+                       e.reflectValue(fieldValue)
                }
                e.writeString("e")
        case reflect.Map:
@@ -190,9 +193,9 @@ func (e *Encoder) reflectValue(v reflect.Value) {
 }
 
 type encodeField struct {
-       i          int
-       tag        string
-       omit_empty bool
+       i         func(v reflect.Value) reflect.Value
+       tag       string
+       omitEmpty bool
 }
 
 type encodeFieldsSortType []encodeField
@@ -206,31 +209,47 @@ var (
        encodeFieldsCache = make(map[reflect.Type][]encodeField)
 )
 
-func encodeFields(t reflect.Type) []encodeField {
+func getEncodeFields(t reflect.Type) []encodeField {
        typeCacheLock.RLock()
        fs, ok := encodeFieldsCache[t]
        typeCacheLock.RUnlock()
        if ok {
                return fs
        }
-
+       fs = makeEncodeFields(t)
        typeCacheLock.Lock()
        defer typeCacheLock.Unlock()
-       fs, ok = encodeFieldsCache[t]
-       if ok {
-               return fs
-       }
+       encodeFieldsCache[t] = fs
+       return fs
+}
 
-       for i, n := 0, t.NumField(); i < n; i++ {
+func makeEncodeFields(t reflect.Type) (fs []encodeField) {
+       for _i, n := 0, t.NumField(); _i < n; _i++ {
+               i := _i
                f := t.Field(i)
                if f.PkgPath != "" {
                        continue
                }
                if f.Anonymous {
+                       anonEFs := makeEncodeFields(f.Type.Elem())
+                       for aefi := range anonEFs {
+                               anonEF := anonEFs[aefi]
+                               bottomField := anonEF
+                               bottomField.i = func(v reflect.Value) reflect.Value {
+                                       v = v.Field(i)
+                                       if v.IsNil() {
+                                               return reflect.Value{}
+                                       }
+                                       return anonEF.i(v.Elem())
+                               }
+                               fs = append(fs, bottomField)
+                       }
                        continue
                }
                var ef encodeField
-               ef.i = i
+               ef.i = func(v reflect.Value) reflect.Value {
+                       return v.Field(i)
+               }
                ef.tag = f.Name
 
                tv := getTag(f.Tag)
@@ -240,11 +259,10 @@ func encodeFields(t reflect.Type) []encodeField {
                if tv.Key() != "" {
                        ef.tag = tv.Key()
                }
-               ef.omit_empty = tv.OmitEmpty()
+               ef.omitEmpty = tv.OmitEmpty()
                fs = append(fs, ef)
        }
        fss := encodeFieldsSortType(fs)
        sort.Sort(fss)
-       encodeFieldsCache[t] = fs
        return fs
 }