]> Sergey Matveev's repositories - btrtrc.git/blobdiff - bencode/encode.go
Drop support for go 1.20
[btrtrc.git] / bencode / encode.go
index 8542d430ce33b7da28a4352d7b42065d680a6c41..5e80cb16f5c60cc7c5c720cbd04e49479f3bcd33 100644 (file)
@@ -41,12 +41,12 @@ func (e *Encoder) Encode(v interface{}) (err error) {
        return nil
 }
 
-type string_values []reflect.Value
+type stringValues []reflect.Value
 
-func (sv string_values) Len() int           { return len(sv) }
-func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
-func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
-func (sv string_values) get(i int) string   { return sv[i].String() }
+func (sv stringValues) Len() int           { return len(sv) }
+func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
+func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
+func (sv stringValues) get(i int) string   { return sv[i].String() }
 
 func (e *Encoder) write(s []byte) {
        _, err := e.w.Write(s)
@@ -64,16 +64,18 @@ func (e *Encoder) writeString(s string) {
 }
 
 func (e *Encoder) reflectString(s string) {
-       b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
-       e.write(b)
-       e.writeString(":")
+       e.writeStringPrefix(int64(len(s)))
        e.writeString(s)
 }
 
-func (e *Encoder) reflectByteSlice(s []byte) {
-       b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
+func (e *Encoder) writeStringPrefix(l int64) {
+       b := strconv.AppendInt(e.scratch[:0], l, 10)
        e.write(b)
        e.writeString(":")
+}
+
+func (e *Encoder) reflectByteSlice(s []byte) {
+       e.writeStringPrefix(int64(len(s)))
        e.write(s)
 }
 
@@ -96,10 +98,9 @@ func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
        return true
 }
 
-var bigIntType = reflect.TypeOf(big.Int{})
+var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
 
 func (e *Encoder) reflectValue(v reflect.Value) {
-
        if e.reflectMarshaler(v) {
                return
        }
@@ -133,13 +134,16 @@ func (e *Encoder) reflectValue(v reflect.Value) {
                e.reflectString(v.String())
        case reflect.Struct:
                e.writeString("d")
-               for _, ef := range encodeFields(v.Type()) {
-                       field_value := v.Field(ef.i)
-                       if ef.omit_empty && isEmptyValue(field_value) {
+               for _, ef := range getEncodeFields(v.Type()) {
+                       fieldValue := ef.i(v)
+                       if !fieldValue.IsValid() {
+                               continue
+                       }
+                       if ef.omitEmpty && isEmptyValue(fieldValue) {
                                continue
                        }
                        e.reflectString(ef.tag)
-                       e.reflectValue(field_value)
+                       e.reflectValue(fieldValue)
                }
                e.writeString("e")
        case reflect.Map:
@@ -151,30 +155,15 @@ func (e *Encoder) reflectValue(v reflect.Value) {
                        break
                }
                e.writeString("d")
-               sv := string_values(v.MapKeys())
+               sv := stringValues(v.MapKeys())
                sort.Sort(sv)
                for _, key := range sv {
                        e.reflectString(key.String())
                        e.reflectValue(v.MapIndex(key))
                }
                e.writeString("e")
-       case reflect.Slice:
-               if v.IsNil() {
-                       e.writeString("le")
-                       break
-               }
-               if v.Type().Elem().Kind() == reflect.Uint8 {
-                       s := v.Bytes()
-                       e.reflectByteSlice(s)
-                       break
-               }
-               fallthrough
-       case reflect.Array:
-               e.writeString("l")
-               for i, n := 0, v.Len(); i < n; i++ {
-                       e.reflectValue(v.Index(i))
-               }
-               e.writeString("e")
+       case reflect.Slice, reflect.Array:
+               e.reflectSequence(v)
        case reflect.Interface:
                e.reflectValue(v.Elem())
        case reflect.Ptr:
@@ -189,10 +178,41 @@ func (e *Encoder) reflectValue(v reflect.Value) {
        }
 }
 
+func (e *Encoder) reflectSequence(v reflect.Value) {
+       // Use bencode string-type
+       if v.Type().Elem().Kind() == reflect.Uint8 {
+               if v.Kind() != reflect.Slice {
+                       // Can't use []byte optimization
+                       if !v.CanAddr() {
+                               e.writeStringPrefix(int64(v.Len()))
+                               for i := 0; i < v.Len(); i++ {
+                                       var b [1]byte
+                                       b[0] = byte(v.Index(i).Uint())
+                                       e.write(b[:])
+                               }
+                               return
+                       }
+                       v = v.Slice(0, v.Len())
+               }
+               s := v.Bytes()
+               e.reflectByteSlice(s)
+               return
+       }
+       if v.IsNil() {
+               e.writeString("le")
+               return
+       }
+       e.writeString("l")
+       for i, n := 0, v.Len(); i < n; i++ {
+               e.reflectValue(v.Index(i))
+       }
+       e.writeString("e")
+}
+
 type encodeField struct {
-       i          int
-       tag        string
-       omit_empty bool
+       i         func(v reflect.Value) reflect.Value
+       tag       string
+       omitEmpty bool
 }
 
 type encodeFieldsSortType []encodeField
@@ -206,31 +226,55 @@ var (
        encodeFieldsCache = make(map[reflect.Type][]encodeField)
 )
 
-func encodeFields(t reflect.Type) []encodeField {
+func getEncodeFields(t reflect.Type) []encodeField {
        typeCacheLock.RLock()
        fs, ok := encodeFieldsCache[t]
        typeCacheLock.RUnlock()
        if ok {
                return fs
        }
-
+       fs = makeEncodeFields(t)
        typeCacheLock.Lock()
        defer typeCacheLock.Unlock()
-       fs, ok = encodeFieldsCache[t]
-       if ok {
-               return fs
-       }
+       encodeFieldsCache[t] = fs
+       return fs
+}
 
-       for i, n := 0, t.NumField(); i < n; i++ {
+func makeEncodeFields(t reflect.Type) (fs []encodeField) {
+       for _i, n := 0, t.NumField(); _i < n; _i++ {
+               i := _i
                f := t.Field(i)
                if f.PkgPath != "" {
                        continue
                }
                if f.Anonymous {
+                       t := f.Type
+                       if t.Kind() == reflect.Ptr {
+                               t = t.Elem()
+                       }
+                       anonEFs := makeEncodeFields(t)
+                       for aefi := range anonEFs {
+                               anonEF := anonEFs[aefi]
+                               bottomField := anonEF
+                               bottomField.i = func(v reflect.Value) reflect.Value {
+                                       v = v.Field(i)
+                                       if v.Kind() == reflect.Ptr {
+                                               if v.IsNil() {
+                                                       // This will skip serializing this value.
+                                                       return reflect.Value{}
+                                               }
+                                               v = v.Elem()
+                                       }
+                                       return anonEF.i(v)
+                               }
+                               fs = append(fs, bottomField)
+                       }
                        continue
                }
                var ef encodeField
-               ef.i = i
+               ef.i = func(v reflect.Value) reflect.Value {
+                       return v.Field(i)
+               }
                ef.tag = f.Name
 
                tv := getTag(f.Tag)
@@ -240,11 +284,10 @@ func encodeFields(t reflect.Type) []encodeField {
                if tv.Key() != "" {
                        ef.tag = tv.Key()
                }
-               ef.omit_empty = tv.OmitEmpty()
+               ef.omitEmpty = tv.OmitEmpty()
                fs = append(fs, ef)
        }
        fss := encodeFieldsSortType(fs)
        sort.Sort(fss)
-       encodeFieldsCache[t] = fs
        return fs
 }