]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
bencode: Add ignore_unmarshal_type_error tag
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "io"
5         "math/big"
6         "reflect"
7         "runtime"
8         "sort"
9         "strconv"
10         "sync"
11
12         "github.com/anacrolix/missinggo"
13 )
14
15 func isEmptyValue(v reflect.Value) bool {
16         return missinggo.IsEmptyValue(v)
17 }
18
19 type Encoder struct {
20         w interface {
21                 Flush() error
22                 io.Writer
23                 WriteString(string) (int, error)
24         }
25         scratch [64]byte
26 }
27
28 func (e *Encoder) Encode(v interface{}) (err error) {
29         if v == nil {
30                 return
31         }
32         defer func() {
33                 if e := recover(); e != nil {
34                         if _, ok := e.(runtime.Error); ok {
35                                 panic(e)
36                         }
37                         var ok bool
38                         err, ok = e.(error)
39                         if !ok {
40                                 panic(e)
41                         }
42                 }
43         }()
44         e.reflectValue(reflect.ValueOf(v))
45         return e.w.Flush()
46 }
47
48 type string_values []reflect.Value
49
50 func (sv string_values) Len() int           { return len(sv) }
51 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
52 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
53 func (sv string_values) get(i int) string   { return sv[i].String() }
54
55 func (e *Encoder) write(s []byte) {
56         _, err := e.w.Write(s)
57         if err != nil {
58                 panic(err)
59         }
60 }
61
62 func (e *Encoder) writeString(s string) {
63         _, err := e.w.WriteString(s)
64         if err != nil {
65                 panic(err)
66         }
67 }
68
69 func (e *Encoder) reflectString(s string) {
70         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
71         e.write(b)
72         e.writeString(":")
73         e.writeString(s)
74 }
75
76 func (e *Encoder) reflectByteSlice(s []byte) {
77         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
78         e.write(b)
79         e.writeString(":")
80         e.write(s)
81 }
82
83 // returns true if the value implements Marshaler interface and marshaling was
84 // done successfully
85 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
86         m, ok := v.Interface().(Marshaler)
87         if !ok {
88                 // T doesn't work, try *T
89                 if v.Kind() != reflect.Ptr && v.CanAddr() {
90                         m, ok = v.Addr().Interface().(Marshaler)
91                         if ok {
92                                 v = v.Addr()
93                         }
94                 }
95         }
96         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
97                 data, err := m.MarshalBencode()
98                 if err != nil {
99                         panic(&MarshalerError{v.Type(), err})
100                 }
101                 e.write(data)
102                 return true
103         }
104
105         return false
106 }
107
108 func (e *Encoder) reflectValue(v reflect.Value) {
109
110         if e.reflectMarshaler(v) {
111                 return
112         }
113
114         switch t := v.Interface().(type) {
115         case big.Int:
116                 e.writeString("i")
117                 e.writeString(t.String())
118                 e.writeString("e")
119                 return
120         }
121
122         switch v.Kind() {
123         case reflect.Bool:
124                 if v.Bool() {
125                         e.writeString("i1e")
126                 } else {
127                         e.writeString("i0e")
128                 }
129         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
130                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
131                 e.writeString("i")
132                 e.write(b)
133                 e.writeString("e")
134         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
135                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
136                 e.writeString("i")
137                 e.write(b)
138                 e.writeString("e")
139         case reflect.String:
140                 e.reflectString(v.String())
141         case reflect.Struct:
142                 e.writeString("d")
143                 for _, ef := range encodeFields(v.Type()) {
144                         field_value := v.Field(ef.i)
145                         if ef.omit_empty && isEmptyValue(field_value) {
146                                 continue
147                         }
148                         e.reflectString(ef.tag)
149                         e.reflectValue(field_value)
150                 }
151                 e.writeString("e")
152         case reflect.Map:
153                 if v.Type().Key().Kind() != reflect.String {
154                         panic(&MarshalTypeError{v.Type()})
155                 }
156                 if v.IsNil() {
157                         e.writeString("de")
158                         break
159                 }
160                 e.writeString("d")
161                 sv := string_values(v.MapKeys())
162                 sort.Sort(sv)
163                 for _, key := range sv {
164                         e.reflectString(key.String())
165                         e.reflectValue(v.MapIndex(key))
166                 }
167                 e.writeString("e")
168         case reflect.Slice:
169                 if v.IsNil() {
170                         e.writeString("le")
171                         break
172                 }
173                 if v.Type().Elem().Kind() == reflect.Uint8 {
174                         s := v.Bytes()
175                         e.reflectByteSlice(s)
176                         break
177                 }
178                 fallthrough
179         case reflect.Array:
180                 e.writeString("l")
181                 for i, n := 0, v.Len(); i < n; i++ {
182                         e.reflectValue(v.Index(i))
183                 }
184                 e.writeString("e")
185         case reflect.Interface:
186                 e.reflectValue(v.Elem())
187         case reflect.Ptr:
188                 if v.IsNil() {
189                         v = reflect.Zero(v.Type().Elem())
190                 } else {
191                         v = v.Elem()
192                 }
193                 e.reflectValue(v)
194         default:
195                 panic(&MarshalTypeError{v.Type()})
196         }
197 }
198
199 type encodeField struct {
200         i          int
201         tag        string
202         omit_empty bool
203 }
204
205 type encodeFieldsSortType []encodeField
206
207 func (ef encodeFieldsSortType) Len() int           { return len(ef) }
208 func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
209 func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
210
211 var (
212         typeCacheLock     sync.RWMutex
213         encodeFieldsCache = make(map[reflect.Type][]encodeField)
214 )
215
216 func encodeFields(t reflect.Type) []encodeField {
217         typeCacheLock.RLock()
218         fs, ok := encodeFieldsCache[t]
219         typeCacheLock.RUnlock()
220         if ok {
221                 return fs
222         }
223
224         typeCacheLock.Lock()
225         defer typeCacheLock.Unlock()
226         fs, ok = encodeFieldsCache[t]
227         if ok {
228                 return fs
229         }
230
231         for i, n := 0, t.NumField(); i < n; i++ {
232                 f := t.Field(i)
233                 if f.PkgPath != "" {
234                         continue
235                 }
236                 if f.Anonymous {
237                         continue
238                 }
239                 var ef encodeField
240                 ef.i = i
241                 ef.tag = f.Name
242
243                 tv := getTag(f.Tag)
244                 if tv.Ignore() {
245                         continue
246                 }
247                 if tv.Key() != "" {
248                         ef.tag = tv.Key()
249                 }
250                 ef.omit_empty = tv.OmitEmpty()
251                 fs = append(fs, ef)
252         }
253         fss := encodeFieldsSortType(fs)
254         sort.Sort(fss)
255         encodeFieldsCache[t] = fs
256         return fs
257 }