]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
Drop support for go 1.20
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "io"
5         "math/big"
6         "reflect"
7         "runtime"
8         "sort"
9         "strconv"
10         "sync"
11
12         "github.com/anacrolix/missinggo"
13 )
14
15 func isEmptyValue(v reflect.Value) bool {
16         return missinggo.IsEmptyValue(v)
17 }
18
19 type Encoder struct {
20         w       io.Writer
21         scratch [64]byte
22 }
23
24 func (e *Encoder) Encode(v interface{}) (err error) {
25         if v == nil {
26                 return
27         }
28         defer func() {
29                 if e := recover(); e != nil {
30                         if _, ok := e.(runtime.Error); ok {
31                                 panic(e)
32                         }
33                         var ok bool
34                         err, ok = e.(error)
35                         if !ok {
36                                 panic(e)
37                         }
38                 }
39         }()
40         e.reflectValue(reflect.ValueOf(v))
41         return nil
42 }
43
44 type stringValues []reflect.Value
45
46 func (sv stringValues) Len() int           { return len(sv) }
47 func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
48 func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
49 func (sv stringValues) get(i int) string   { return sv[i].String() }
50
51 func (e *Encoder) write(s []byte) {
52         _, err := e.w.Write(s)
53         if err != nil {
54                 panic(err)
55         }
56 }
57
58 func (e *Encoder) writeString(s string) {
59         for s != "" {
60                 n := copy(e.scratch[:], s)
61                 s = s[n:]
62                 e.write(e.scratch[:n])
63         }
64 }
65
66 func (e *Encoder) reflectString(s string) {
67         e.writeStringPrefix(int64(len(s)))
68         e.writeString(s)
69 }
70
71 func (e *Encoder) writeStringPrefix(l int64) {
72         b := strconv.AppendInt(e.scratch[:0], l, 10)
73         e.write(b)
74         e.writeString(":")
75 }
76
77 func (e *Encoder) reflectByteSlice(s []byte) {
78         e.writeStringPrefix(int64(len(s)))
79         e.write(s)
80 }
81
82 // Returns true if the value implements Marshaler interface and marshaling was
83 // done successfully.
84 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
85         if !v.Type().Implements(marshalerType) {
86                 if v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(marshalerType) {
87                         v = v.Addr()
88                 } else {
89                         return false
90                 }
91         }
92         m := v.Interface().(Marshaler)
93         data, err := m.MarshalBencode()
94         if err != nil {
95                 panic(&MarshalerError{v.Type(), err})
96         }
97         e.write(data)
98         return true
99 }
100
101 var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
102
103 func (e *Encoder) reflectValue(v reflect.Value) {
104         if e.reflectMarshaler(v) {
105                 return
106         }
107
108         if v.Type() == bigIntType {
109                 e.writeString("i")
110                 bi := v.Interface().(big.Int)
111                 e.writeString(bi.String())
112                 e.writeString("e")
113                 return
114         }
115
116         switch v.Kind() {
117         case reflect.Bool:
118                 if v.Bool() {
119                         e.writeString("i1e")
120                 } else {
121                         e.writeString("i0e")
122                 }
123         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
124                 e.writeString("i")
125                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
126                 e.write(b)
127                 e.writeString("e")
128         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
129                 e.writeString("i")
130                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
131                 e.write(b)
132                 e.writeString("e")
133         case reflect.String:
134                 e.reflectString(v.String())
135         case reflect.Struct:
136                 e.writeString("d")
137                 for _, ef := range getEncodeFields(v.Type()) {
138                         fieldValue := ef.i(v)
139                         if !fieldValue.IsValid() {
140                                 continue
141                         }
142                         if ef.omitEmpty && isEmptyValue(fieldValue) {
143                                 continue
144                         }
145                         e.reflectString(ef.tag)
146                         e.reflectValue(fieldValue)
147                 }
148                 e.writeString("e")
149         case reflect.Map:
150                 if v.Type().Key().Kind() != reflect.String {
151                         panic(&MarshalTypeError{v.Type()})
152                 }
153                 if v.IsNil() {
154                         e.writeString("de")
155                         break
156                 }
157                 e.writeString("d")
158                 sv := stringValues(v.MapKeys())
159                 sort.Sort(sv)
160                 for _, key := range sv {
161                         e.reflectString(key.String())
162                         e.reflectValue(v.MapIndex(key))
163                 }
164                 e.writeString("e")
165         case reflect.Slice, reflect.Array:
166                 e.reflectSequence(v)
167         case reflect.Interface:
168                 e.reflectValue(v.Elem())
169         case reflect.Ptr:
170                 if v.IsNil() {
171                         v = reflect.Zero(v.Type().Elem())
172                 } else {
173                         v = v.Elem()
174                 }
175                 e.reflectValue(v)
176         default:
177                 panic(&MarshalTypeError{v.Type()})
178         }
179 }
180
181 func (e *Encoder) reflectSequence(v reflect.Value) {
182         // Use bencode string-type
183         if v.Type().Elem().Kind() == reflect.Uint8 {
184                 if v.Kind() != reflect.Slice {
185                         // Can't use []byte optimization
186                         if !v.CanAddr() {
187                                 e.writeStringPrefix(int64(v.Len()))
188                                 for i := 0; i < v.Len(); i++ {
189                                         var b [1]byte
190                                         b[0] = byte(v.Index(i).Uint())
191                                         e.write(b[:])
192                                 }
193                                 return
194                         }
195                         v = v.Slice(0, v.Len())
196                 }
197                 s := v.Bytes()
198                 e.reflectByteSlice(s)
199                 return
200         }
201         if v.IsNil() {
202                 e.writeString("le")
203                 return
204         }
205         e.writeString("l")
206         for i, n := 0, v.Len(); i < n; i++ {
207                 e.reflectValue(v.Index(i))
208         }
209         e.writeString("e")
210 }
211
212 type encodeField struct {
213         i         func(v reflect.Value) reflect.Value
214         tag       string
215         omitEmpty bool
216 }
217
218 type encodeFieldsSortType []encodeField
219
220 func (ef encodeFieldsSortType) Len() int           { return len(ef) }
221 func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
222 func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
223
224 var (
225         typeCacheLock     sync.RWMutex
226         encodeFieldsCache = make(map[reflect.Type][]encodeField)
227 )
228
229 func getEncodeFields(t reflect.Type) []encodeField {
230         typeCacheLock.RLock()
231         fs, ok := encodeFieldsCache[t]
232         typeCacheLock.RUnlock()
233         if ok {
234                 return fs
235         }
236         fs = makeEncodeFields(t)
237         typeCacheLock.Lock()
238         defer typeCacheLock.Unlock()
239         encodeFieldsCache[t] = fs
240         return fs
241 }
242
243 func makeEncodeFields(t reflect.Type) (fs []encodeField) {
244         for _i, n := 0, t.NumField(); _i < n; _i++ {
245                 i := _i
246                 f := t.Field(i)
247                 if f.PkgPath != "" {
248                         continue
249                 }
250                 if f.Anonymous {
251                         t := f.Type
252                         if t.Kind() == reflect.Ptr {
253                                 t = t.Elem()
254                         }
255                         anonEFs := makeEncodeFields(t)
256                         for aefi := range anonEFs {
257                                 anonEF := anonEFs[aefi]
258                                 bottomField := anonEF
259                                 bottomField.i = func(v reflect.Value) reflect.Value {
260                                         v = v.Field(i)
261                                         if v.Kind() == reflect.Ptr {
262                                                 if v.IsNil() {
263                                                         // This will skip serializing this value.
264                                                         return reflect.Value{}
265                                                 }
266                                                 v = v.Elem()
267                                         }
268                                         return anonEF.i(v)
269                                 }
270                                 fs = append(fs, bottomField)
271                         }
272                         continue
273                 }
274                 var ef encodeField
275                 ef.i = func(v reflect.Value) reflect.Value {
276                         return v.Field(i)
277                 }
278                 ef.tag = f.Name
279
280                 tv := getTag(f.Tag)
281                 if tv.Ignore() {
282                         continue
283                 }
284                 if tv.Key() != "" {
285                         ef.tag = tv.Key()
286                 }
287                 ef.omitEmpty = tv.OmitEmpty()
288                 fs = append(fs, ef)
289         }
290         fss := encodeFieldsSortType(fs)
291         sort.Sort(fss)
292         return fs
293 }