]> Sergey Matveev's repositories - btrtrc.git/blobdiff - bencode/encode.go
Drop support for go 1.20
[btrtrc.git] / bencode / encode.go
index 196aa09539b944e8ba3502bc722bce4c2c58eae2..5e80cb16f5c60cc7c5c720cbd04e49479f3bcd33 100644 (file)
 package bencode
 
-import "bufio"
-import "reflect"
-import "runtime"
-import "strconv"
-import "sync"
-import "sort"
-
-func is_empty_value(v reflect.Value) bool {
-       switch v.Kind() {
-       case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
-               return v.Len() == 0
-       case reflect.Bool:
-               return !v.Bool()
-       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-               return v.Int() == 0
-       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-               return v.Uint() == 0
-       case reflect.Float32, reflect.Float64:
-               return v.Float() == 0
-       case reflect.Interface, reflect.Ptr:
-               return v.IsNil()
-       }
-       return false
+import (
+       "io"
+       "math/big"
+       "reflect"
+       "runtime"
+       "sort"
+       "strconv"
+       "sync"
+
+       "github.com/anacrolix/missinggo"
+)
+
+func isEmptyValue(v reflect.Value) bool {
+       return missinggo.IsEmptyValue(v)
 }
 
-type encoder struct {
-       *bufio.Writer
+type Encoder struct {
+       w       io.Writer
        scratch [64]byte
 }
 
-func (e *encoder) encode(v interface{}) (err error) {
+func (e *Encoder) Encode(v interface{}) (err error) {
+       if v == nil {
+               return
+       }
        defer func() {
                if e := recover(); e != nil {
                        if _, ok := e.(runtime.Error); ok {
                                panic(e)
                        }
-                       err = e.(error)
+                       var ok bool
+                       err, ok = e.(error)
+                       if !ok {
+                               panic(e)
+                       }
                }
        }()
-       e.reflect_value(reflect.ValueOf(v))
-       return e.Flush()
+       e.reflectValue(reflect.ValueOf(v))
+       return nil
 }
 
-type string_values []reflect.Value
+type stringValues []reflect.Value
 
-func (sv string_values) Len() int           { return len(sv) }
-func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
-func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
-func (sv string_values) get(i int) string   { return sv[i].String() }
+func (sv stringValues) Len() int           { return len(sv) }
+func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
+func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
+func (sv stringValues) get(i int) string   { return sv[i].String() }
 
-func (e *encoder) write(s []byte) {
-       _, err := e.Write(s)
+func (e *Encoder) write(s []byte) {
+       _, err := e.w.Write(s)
        if err != nil {
                panic(err)
        }
 }
 
-func (e *encoder) write_string(s string) {
-       _, err := e.WriteString(s)
-       if err != nil {
-               panic(err)
+func (e *Encoder) writeString(s string) {
+       for s != "" {
+               n := copy(e.scratch[:], s)
+               s = s[n:]
+               e.write(e.scratch[:n])
        }
 }
 
-func (e *encoder) reflect_string(s string) {
-       b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
-       e.write(b)
-       e.write_string(":")
-       e.write_string(s)
+func (e *Encoder) reflectString(s string) {
+       e.writeStringPrefix(int64(len(s)))
+       e.writeString(s)
 }
 
-func (e *encoder) reflect_byte_slice(s []byte) {
-       b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
+func (e *Encoder) writeStringPrefix(l int64) {
+       b := strconv.AppendInt(e.scratch[:0], l, 10)
        e.write(b)
-       e.write_string(":")
+       e.writeString(":")
+}
+
+func (e *Encoder) reflectByteSlice(s []byte) {
+       e.writeStringPrefix(int64(len(s)))
        e.write(s)
 }
 
-// returns true if the value implements Marshaler interface and marshaling was
-// done successfully
-func (e *encoder) reflect_marshaler(v reflect.Value) bool {
-       m, ok := v.Interface().(Marshaler)
-       if !ok {
-               // T doesn't work, try *T
-               if v.Kind() != reflect.Ptr && v.CanAddr() {
-                       m, ok = v.Addr().Interface().(Marshaler)
-                       if ok {
-                               v = v.Addr()
-                       }
+// Returns true if the value implements Marshaler interface and marshaling was
+// done successfully.
+func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
+       if !v.Type().Implements(marshalerType) {
+               if v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(marshalerType) {
+                       v = v.Addr()
+               } else {
+                       return false
                }
        }
-       if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
-               data, err := m.MarshalBencode()
-               if err != nil {
-                       panic(&MarshalerError{v.Type(), err})
-               }
-               e.write(data)
-               return true
+       m := v.Interface().(Marshaler)
+       data, err := m.MarshalBencode()
+       if err != nil {
+               panic(&MarshalerError{v.Type(), err})
        }
-
-       return false
+       e.write(data)
+       return true
 }
 
-func (e *encoder) reflect_value(v reflect.Value) {
-       if !v.IsValid() {
+var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
+
+func (e *Encoder) reflectValue(v reflect.Value) {
+       if e.reflectMarshaler(v) {
                return
        }
 
-       if e.reflect_marshaler(v) {
+       if v.Type() == bigIntType {
+               e.writeString("i")
+               bi := v.Interface().(big.Int)
+               e.writeString(bi.String())
+               e.writeString("e")
                return
        }
 
        switch v.Kind() {
        case reflect.Bool:
                if v.Bool() {
-                       e.write_string("i1e")
+                       e.writeString("i1e")
                } else {
-                       e.write_string("i0e")
+                       e.writeString("i0e")
                }
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               e.writeString("i")
                b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
-               e.write_string("i")
                e.write(b)
-               e.write_string("e")
+               e.writeString("e")
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+               e.writeString("i")
                b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
-               e.write_string("i")
                e.write(b)
-               e.write_string("e")
+               e.writeString("e")
        case reflect.String:
-               e.reflect_string(v.String())
+               e.reflectString(v.String())
        case reflect.Struct:
-               e.write_string("d")
-               for _, ef := range encode_fields(v.Type()) {
-                       field_value := v.Field(ef.i)
-                       if ef.omit_empty && is_empty_value(field_value) {
+               e.writeString("d")
+               for _, ef := range getEncodeFields(v.Type()) {
+                       fieldValue := ef.i(v)
+                       if !fieldValue.IsValid() {
                                continue
                        }
-
-                       e.reflect_string(ef.tag)
-                       e.reflect_value(field_value)
+                       if ef.omitEmpty && isEmptyValue(fieldValue) {
+                               continue
+                       }
+                       e.reflectString(ef.tag)
+                       e.reflectValue(fieldValue)
                }
-               e.write_string("e")
+               e.writeString("e")
        case reflect.Map:
                if v.Type().Key().Kind() != reflect.String {
                        panic(&MarshalTypeError{v.Type()})
                }
                if v.IsNil() {
-                       e.write_string("de")
+                       e.writeString("de")
                        break
                }
-               e.write_string("d")
-               sv := string_values(v.MapKeys())
+               e.writeString("d")
+               sv := stringValues(v.MapKeys())
                sort.Sort(sv)
                for _, key := range sv {
-                       e.reflect_string(key.String())
-                       e.reflect_value(v.MapIndex(key))
+                       e.reflectString(key.String())
+                       e.reflectValue(v.MapIndex(key))
                }
-               e.write_string("e")
-       case reflect.Slice:
+               e.writeString("e")
+       case reflect.Slice, reflect.Array:
+               e.reflectSequence(v)
+       case reflect.Interface:
+               e.reflectValue(v.Elem())
+       case reflect.Ptr:
                if v.IsNil() {
-                       e.write_string("le")
-                       break
-               }
-               if v.Type().Elem().Kind() == reflect.Uint8 {
-                       s := v.Bytes()
-                       e.reflect_byte_slice(s)
-                       break
-               }
-               fallthrough
-       case reflect.Array:
-               e.write_string("l")
-               for i, n := 0, v.Len(); i < n; i++ {
-                       e.reflect_value(v.Index(i))
-               }
-               e.write_string("e")
-       case reflect.Interface, reflect.Ptr:
-               if v.IsNil() {
-                       break
+                       v = reflect.Zero(v.Type().Elem())
+               } else {
+                       v = v.Elem()
                }
-               e.reflect_value(v.Elem())
+               e.reflectValue(v)
        default:
                panic(&MarshalTypeError{v.Type()})
        }
 }
 
-type encode_field struct {
-       i          int
-       tag        string
-       omit_empty bool
+func (e *Encoder) reflectSequence(v reflect.Value) {
+       // Use bencode string-type
+       if v.Type().Elem().Kind() == reflect.Uint8 {
+               if v.Kind() != reflect.Slice {
+                       // Can't use []byte optimization
+                       if !v.CanAddr() {
+                               e.writeStringPrefix(int64(v.Len()))
+                               for i := 0; i < v.Len(); i++ {
+                                       var b [1]byte
+                                       b[0] = byte(v.Index(i).Uint())
+                                       e.write(b[:])
+                               }
+                               return
+                       }
+                       v = v.Slice(0, v.Len())
+               }
+               s := v.Bytes()
+               e.reflectByteSlice(s)
+               return
+       }
+       if v.IsNil() {
+               e.writeString("le")
+               return
+       }
+       e.writeString("l")
+       for i, n := 0, v.Len(); i < n; i++ {
+               e.reflectValue(v.Index(i))
+       }
+       e.writeString("e")
+}
+
+type encodeField struct {
+       i         func(v reflect.Value) reflect.Value
+       tag       string
+       omitEmpty bool
 }
 
-type encode_fields_sort_type []encode_field
+type encodeFieldsSortType []encodeField
 
-func (ef encode_fields_sort_type) Len() int           { return len(ef) }
-func (ef encode_fields_sort_type) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
-func (ef encode_fields_sort_type) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
+func (ef encodeFieldsSortType) Len() int           { return len(ef) }
+func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
+func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
 
 var (
-       type_cache_lock     sync.RWMutex
-       encode_fields_cache = make(map[reflect.Type][]encode_field)
+       typeCacheLock     sync.RWMutex
+       encodeFieldsCache = make(map[reflect.Type][]encodeField)
 )
 
-func encode_fields(t reflect.Type) []encode_field {
-       type_cache_lock.RLock()
-       fs, ok := encode_fields_cache[t]
-       type_cache_lock.RUnlock()
-       if ok {
-               return fs
-       }
-
-       type_cache_lock.Lock()
-       defer type_cache_lock.Unlock()
-       fs, ok = encode_fields_cache[t]
+func getEncodeFields(t reflect.Type) []encodeField {
+       typeCacheLock.RLock()
+       fs, ok := encodeFieldsCache[t]
+       typeCacheLock.RUnlock()
        if ok {
                return fs
        }
+       fs = makeEncodeFields(t)
+       typeCacheLock.Lock()
+       defer typeCacheLock.Unlock()
+       encodeFieldsCache[t] = fs
+       return fs
+}
 
-       for i, n := 0, t.NumField(); i < n; i++ {
+func makeEncodeFields(t reflect.Type) (fs []encodeField) {
+       for _i, n := 0, t.NumField(); _i < n; _i++ {
+               i := _i
                f := t.Field(i)
                if f.PkgPath != "" {
                        continue
                }
                if f.Anonymous {
+                       t := f.Type
+                       if t.Kind() == reflect.Ptr {
+                               t = t.Elem()
+                       }
+                       anonEFs := makeEncodeFields(t)
+                       for aefi := range anonEFs {
+                               anonEF := anonEFs[aefi]
+                               bottomField := anonEF
+                               bottomField.i = func(v reflect.Value) reflect.Value {
+                                       v = v.Field(i)
+                                       if v.Kind() == reflect.Ptr {
+                                               if v.IsNil() {
+                                                       // This will skip serializing this value.
+                                                       return reflect.Value{}
+                                               }
+                                               v = v.Elem()
+                                       }
+                                       return anonEF.i(v)
+                               }
+                               fs = append(fs, bottomField)
+                       }
                        continue
                }
-               var ef encode_field
-               ef.i = i
+               var ef encodeField
+               ef.i = func(v reflect.Value) reflect.Value {
+                       return v.Field(i)
+               }
                ef.tag = f.Name
 
-               tv := f.Tag.Get("bencode")
-               if tv != "" {
-                       if tv == "-" {
-                               continue
-                       }
-                       name, opts := parse_tag(tv)
-                       ef.tag = name
-                       ef.omit_empty = opts.contains("omitempty")
+               tv := getTag(f.Tag)
+               if tv.Ignore() {
+                       continue
+               }
+               if tv.Key() != "" {
+                       ef.tag = tv.Key()
                }
+               ef.omitEmpty = tv.OmitEmpty()
                fs = append(fs, ef)
        }
-       fss := encode_fields_sort_type(fs)
+       fss := encodeFieldsSortType(fs)
        sort.Sort(fss)
-       encode_fields_cache[t] = fs
        return fs
 }