]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
bencode: Encode arrays of bytes as strings
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "io"
5         "math/big"
6         "reflect"
7         "runtime"
8         "sort"
9         "strconv"
10         "sync"
11
12         "github.com/anacrolix/missinggo"
13 )
14
15 func isEmptyValue(v reflect.Value) bool {
16         return missinggo.IsEmptyValue(v)
17 }
18
19 type Encoder struct {
20         w       io.Writer
21         scratch [64]byte
22 }
23
24 func (e *Encoder) Encode(v interface{}) (err error) {
25         if v == nil {
26                 return
27         }
28         defer func() {
29                 if e := recover(); e != nil {
30                         if _, ok := e.(runtime.Error); ok {
31                                 panic(e)
32                         }
33                         var ok bool
34                         err, ok = e.(error)
35                         if !ok {
36                                 panic(e)
37                         }
38                 }
39         }()
40         e.reflectValue(reflect.ValueOf(v))
41         return nil
42 }
43
44 type stringValues []reflect.Value
45
46 func (sv stringValues) Len() int           { return len(sv) }
47 func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
48 func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
49 func (sv stringValues) get(i int) string   { return sv[i].String() }
50
51 func (e *Encoder) write(s []byte) {
52         _, err := e.w.Write(s)
53         if err != nil {
54                 panic(err)
55         }
56 }
57
58 func (e *Encoder) writeString(s string) {
59         for s != "" {
60                 n := copy(e.scratch[:], s)
61                 s = s[n:]
62                 e.write(e.scratch[:n])
63         }
64 }
65
66 func (e *Encoder) reflectString(s string) {
67         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
68         e.write(b)
69         e.writeString(":")
70         e.writeString(s)
71 }
72
73 func (e *Encoder) reflectByteSlice(s []byte) {
74         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
75         e.write(b)
76         e.writeString(":")
77         e.write(s)
78 }
79
80 // Returns true if the value implements Marshaler interface and marshaling was
81 // done successfully.
82 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
83         if !v.Type().Implements(marshalerType) {
84                 if v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(marshalerType) {
85                         v = v.Addr()
86                 } else {
87                         return false
88                 }
89         }
90         m := v.Interface().(Marshaler)
91         data, err := m.MarshalBencode()
92         if err != nil {
93                 panic(&MarshalerError{v.Type(), err})
94         }
95         e.write(data)
96         return true
97 }
98
99 var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
100
101 func (e *Encoder) reflectValue(v reflect.Value) {
102
103         if e.reflectMarshaler(v) {
104                 return
105         }
106
107         if v.Type() == bigIntType {
108                 e.writeString("i")
109                 bi := v.Interface().(big.Int)
110                 e.writeString(bi.String())
111                 e.writeString("e")
112                 return
113         }
114
115         switch v.Kind() {
116         case reflect.Bool:
117                 if v.Bool() {
118                         e.writeString("i1e")
119                 } else {
120                         e.writeString("i0e")
121                 }
122         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
123                 e.writeString("i")
124                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
125                 e.write(b)
126                 e.writeString("e")
127         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
128                 e.writeString("i")
129                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
130                 e.write(b)
131                 e.writeString("e")
132         case reflect.String:
133                 e.reflectString(v.String())
134         case reflect.Struct:
135                 e.writeString("d")
136                 for _, ef := range getEncodeFields(v.Type()) {
137                         fieldValue := ef.i(v)
138                         if !fieldValue.IsValid() {
139                                 continue
140                         }
141                         if ef.omitEmpty && isEmptyValue(fieldValue) {
142                                 continue
143                         }
144                         e.reflectString(ef.tag)
145                         e.reflectValue(fieldValue)
146                 }
147                 e.writeString("e")
148         case reflect.Map:
149                 if v.Type().Key().Kind() != reflect.String {
150                         panic(&MarshalTypeError{v.Type()})
151                 }
152                 if v.IsNil() {
153                         e.writeString("de")
154                         break
155                 }
156                 e.writeString("d")
157                 sv := stringValues(v.MapKeys())
158                 sort.Sort(sv)
159                 for _, key := range sv {
160                         e.reflectString(key.String())
161                         e.reflectValue(v.MapIndex(key))
162                 }
163                 e.writeString("e")
164         case reflect.Slice:
165                 e.reflectSlice(v)
166         case reflect.Array:
167                 e.reflectSlice(v.Slice(0, v.Len()))
168         case reflect.Interface:
169                 e.reflectValue(v.Elem())
170         case reflect.Ptr:
171                 if v.IsNil() {
172                         v = reflect.Zero(v.Type().Elem())
173                 } else {
174                         v = v.Elem()
175                 }
176                 e.reflectValue(v)
177         default:
178                 panic(&MarshalTypeError{v.Type()})
179         }
180 }
181
182 func (e *Encoder) reflectSlice(v reflect.Value) {
183         if v.Type().Elem().Kind() == reflect.Uint8 {
184                 // This can panic if v is not addressable, such as by passing an array of bytes by value. We
185                 // could copy them and make a slice to the copy, or the user could just avoid doing this. It
186                 // remains to be seen.
187                 s := v.Bytes()
188                 e.reflectByteSlice(s)
189                 return
190         }
191         if v.IsNil() {
192                 e.writeString("le")
193                 return
194         }
195         e.writeString("l")
196         for i, n := 0, v.Len(); i < n; i++ {
197                 e.reflectValue(v.Index(i))
198         }
199         e.writeString("e")
200 }
201
202 type encodeField struct {
203         i         func(v reflect.Value) reflect.Value
204         tag       string
205         omitEmpty bool
206 }
207
208 type encodeFieldsSortType []encodeField
209
210 func (ef encodeFieldsSortType) Len() int           { return len(ef) }
211 func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
212 func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
213
214 var (
215         typeCacheLock     sync.RWMutex
216         encodeFieldsCache = make(map[reflect.Type][]encodeField)
217 )
218
219 func getEncodeFields(t reflect.Type) []encodeField {
220         typeCacheLock.RLock()
221         fs, ok := encodeFieldsCache[t]
222         typeCacheLock.RUnlock()
223         if ok {
224                 return fs
225         }
226         fs = makeEncodeFields(t)
227         typeCacheLock.Lock()
228         defer typeCacheLock.Unlock()
229         encodeFieldsCache[t] = fs
230         return fs
231 }
232
233 func makeEncodeFields(t reflect.Type) (fs []encodeField) {
234         for _i, n := 0, t.NumField(); _i < n; _i++ {
235                 i := _i
236                 f := t.Field(i)
237                 if f.PkgPath != "" {
238                         continue
239                 }
240                 if f.Anonymous {
241                         t := f.Type
242                         if t.Kind() == reflect.Ptr {
243                                 t = t.Elem()
244                         }
245                         anonEFs := makeEncodeFields(t)
246                         for aefi := range anonEFs {
247                                 anonEF := anonEFs[aefi]
248                                 bottomField := anonEF
249                                 bottomField.i = func(v reflect.Value) reflect.Value {
250                                         v = v.Field(i)
251                                         if v.Kind() == reflect.Ptr {
252                                                 if v.IsNil() {
253                                                         // This will skip serializing this value.
254                                                         return reflect.Value{}
255                                                 }
256                                                 v = v.Elem()
257                                         }
258                                         return anonEF.i(v)
259                                 }
260                                 fs = append(fs, bottomField)
261                         }
262                         continue
263                 }
264                 var ef encodeField
265                 ef.i = func(v reflect.Value) reflect.Value {
266                         return v.Field(i)
267                 }
268                 ef.tag = f.Name
269
270                 tv := getTag(f.Tag)
271                 if tv.Ignore() {
272                         continue
273                 }
274                 if tv.Key() != "" {
275                         ef.tag = tv.Key()
276                 }
277                 ef.omitEmpty = tv.OmitEmpty()
278                 fs = append(fs, ef)
279         }
280         fss := encodeFieldsSortType(fs)
281         sort.Sort(fss)
282         return fs
283 }