]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
bencode: Fix marshalling of unaddressable array of bytes
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "io"
5         "math/big"
6         "reflect"
7         "runtime"
8         "sort"
9         "strconv"
10         "sync"
11
12         "github.com/anacrolix/missinggo"
13 )
14
15 func isEmptyValue(v reflect.Value) bool {
16         return missinggo.IsEmptyValue(v)
17 }
18
19 type Encoder struct {
20         w       io.Writer
21         scratch [64]byte
22 }
23
24 func (e *Encoder) Encode(v interface{}) (err error) {
25         if v == nil {
26                 return
27         }
28         defer func() {
29                 if e := recover(); e != nil {
30                         if _, ok := e.(runtime.Error); ok {
31                                 panic(e)
32                         }
33                         var ok bool
34                         err, ok = e.(error)
35                         if !ok {
36                                 panic(e)
37                         }
38                 }
39         }()
40         e.reflectValue(reflect.ValueOf(v))
41         return nil
42 }
43
44 type stringValues []reflect.Value
45
46 func (sv stringValues) Len() int           { return len(sv) }
47 func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
48 func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
49 func (sv stringValues) get(i int) string   { return sv[i].String() }
50
51 func (e *Encoder) write(s []byte) {
52         _, err := e.w.Write(s)
53         if err != nil {
54                 panic(err)
55         }
56 }
57
58 func (e *Encoder) writeString(s string) {
59         for s != "" {
60                 n := copy(e.scratch[:], s)
61                 s = s[n:]
62                 e.write(e.scratch[:n])
63         }
64 }
65
66 func (e *Encoder) reflectString(s string) {
67         e.writeStringPrefix(int64(len(s)))
68         e.writeString(s)
69 }
70
71 func (e *Encoder) writeStringPrefix(l int64) {
72         b := strconv.AppendInt(e.scratch[:0], l, 10)
73         e.write(b)
74         e.writeString(":")
75 }
76
77 func (e *Encoder) reflectByteSlice(s []byte) {
78         e.writeStringPrefix(int64(len(s)))
79         e.write(s)
80 }
81
82 // Returns true if the value implements Marshaler interface and marshaling was
83 // done successfully.
84 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
85         if !v.Type().Implements(marshalerType) {
86                 if v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(marshalerType) {
87                         v = v.Addr()
88                 } else {
89                         return false
90                 }
91         }
92         m := v.Interface().(Marshaler)
93         data, err := m.MarshalBencode()
94         if err != nil {
95                 panic(&MarshalerError{v.Type(), err})
96         }
97         e.write(data)
98         return true
99 }
100
101 var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
102
103 func (e *Encoder) reflectValue(v reflect.Value) {
104
105         if e.reflectMarshaler(v) {
106                 return
107         }
108
109         if v.Type() == bigIntType {
110                 e.writeString("i")
111                 bi := v.Interface().(big.Int)
112                 e.writeString(bi.String())
113                 e.writeString("e")
114                 return
115         }
116
117         switch v.Kind() {
118         case reflect.Bool:
119                 if v.Bool() {
120                         e.writeString("i1e")
121                 } else {
122                         e.writeString("i0e")
123                 }
124         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
125                 e.writeString("i")
126                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
127                 e.write(b)
128                 e.writeString("e")
129         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
130                 e.writeString("i")
131                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
132                 e.write(b)
133                 e.writeString("e")
134         case reflect.String:
135                 e.reflectString(v.String())
136         case reflect.Struct:
137                 e.writeString("d")
138                 for _, ef := range getEncodeFields(v.Type()) {
139                         fieldValue := ef.i(v)
140                         if !fieldValue.IsValid() {
141                                 continue
142                         }
143                         if ef.omitEmpty && isEmptyValue(fieldValue) {
144                                 continue
145                         }
146                         e.reflectString(ef.tag)
147                         e.reflectValue(fieldValue)
148                 }
149                 e.writeString("e")
150         case reflect.Map:
151                 if v.Type().Key().Kind() != reflect.String {
152                         panic(&MarshalTypeError{v.Type()})
153                 }
154                 if v.IsNil() {
155                         e.writeString("de")
156                         break
157                 }
158                 e.writeString("d")
159                 sv := stringValues(v.MapKeys())
160                 sort.Sort(sv)
161                 for _, key := range sv {
162                         e.reflectString(key.String())
163                         e.reflectValue(v.MapIndex(key))
164                 }
165                 e.writeString("e")
166         case reflect.Slice, reflect.Array:
167                 e.reflectSequence(v)
168         case reflect.Interface:
169                 e.reflectValue(v.Elem())
170         case reflect.Ptr:
171                 if v.IsNil() {
172                         v = reflect.Zero(v.Type().Elem())
173                 } else {
174                         v = v.Elem()
175                 }
176                 e.reflectValue(v)
177         default:
178                 panic(&MarshalTypeError{v.Type()})
179         }
180 }
181
182 func (e *Encoder) reflectSequence(v reflect.Value) {
183         // Use bencode string-type
184         if v.Type().Elem().Kind() == reflect.Uint8 {
185                 if v.Kind() != reflect.Slice {
186                         // Can't use []byte optimization
187                         if !v.CanAddr() {
188                                 e.writeStringPrefix(int64(v.Len()))
189                                 for i := 0; i < v.Len(); i++ {
190                                         var b [1]byte
191                                         b[0] = byte(v.Index(i).Uint())
192                                         e.write(b[:])
193                                 }
194                                 return
195                         }
196                         v = v.Slice(0, v.Len())
197                 }
198                 s := v.Bytes()
199                 e.reflectByteSlice(s)
200                 return
201         }
202         if v.IsNil() {
203                 e.writeString("le")
204                 return
205         }
206         e.writeString("l")
207         for i, n := 0, v.Len(); i < n; i++ {
208                 e.reflectValue(v.Index(i))
209         }
210         e.writeString("e")
211 }
212
213 type encodeField struct {
214         i         func(v reflect.Value) reflect.Value
215         tag       string
216         omitEmpty bool
217 }
218
219 type encodeFieldsSortType []encodeField
220
221 func (ef encodeFieldsSortType) Len() int           { return len(ef) }
222 func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
223 func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
224
225 var (
226         typeCacheLock     sync.RWMutex
227         encodeFieldsCache = make(map[reflect.Type][]encodeField)
228 )
229
230 func getEncodeFields(t reflect.Type) []encodeField {
231         typeCacheLock.RLock()
232         fs, ok := encodeFieldsCache[t]
233         typeCacheLock.RUnlock()
234         if ok {
235                 return fs
236         }
237         fs = makeEncodeFields(t)
238         typeCacheLock.Lock()
239         defer typeCacheLock.Unlock()
240         encodeFieldsCache[t] = fs
241         return fs
242 }
243
244 func makeEncodeFields(t reflect.Type) (fs []encodeField) {
245         for _i, n := 0, t.NumField(); _i < n; _i++ {
246                 i := _i
247                 f := t.Field(i)
248                 if f.PkgPath != "" {
249                         continue
250                 }
251                 if f.Anonymous {
252                         t := f.Type
253                         if t.Kind() == reflect.Ptr {
254                                 t = t.Elem()
255                         }
256                         anonEFs := makeEncodeFields(t)
257                         for aefi := range anonEFs {
258                                 anonEF := anonEFs[aefi]
259                                 bottomField := anonEF
260                                 bottomField.i = func(v reflect.Value) reflect.Value {
261                                         v = v.Field(i)
262                                         if v.Kind() == reflect.Ptr {
263                                                 if v.IsNil() {
264                                                         // This will skip serializing this value.
265                                                         return reflect.Value{}
266                                                 }
267                                                 v = v.Elem()
268                                         }
269                                         return anonEF.i(v)
270                                 }
271                                 fs = append(fs, bottomField)
272                         }
273                         continue
274                 }
275                 var ef encodeField
276                 ef.i = func(v reflect.Value) reflect.Value {
277                         return v.Field(i)
278                 }
279                 ef.tag = f.Name
280
281                 tv := getTag(f.Tag)
282                 if tv.Ignore() {
283                         continue
284                 }
285                 if tv.Key() != "" {
286                         ef.tag = tv.Key()
287                 }
288                 ef.omitEmpty = tv.OmitEmpty()
289                 fs = append(fs, ef)
290         }
291         fss := encodeFieldsSortType(fs)
292         sort.Sort(fss)
293         return fs
294 }