]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
57f0b80513189e957bfff416d18d93fce4b5f63e
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "io"
5         "reflect"
6         "runtime"
7         "sort"
8         "strconv"
9         "sync"
10
11         "github.com/anacrolix/missinggo"
12 )
13
14 func isEmptyValue(v reflect.Value) bool {
15         return missinggo.IsEmptyValue(v)
16 }
17
18 type Encoder struct {
19         w interface {
20                 Flush() error
21                 io.Writer
22                 WriteString(string) (int, error)
23         }
24         scratch [64]byte
25 }
26
27 func (e *Encoder) Encode(v interface{}) (err error) {
28         if v == nil {
29                 return
30         }
31         defer func() {
32                 if e := recover(); e != nil {
33                         if _, ok := e.(runtime.Error); ok {
34                                 panic(e)
35                         }
36                         var ok bool
37                         err, ok = e.(error)
38                         if !ok {
39                                 panic(e)
40                         }
41                 }
42         }()
43         e.reflectValue(reflect.ValueOf(v))
44         return e.w.Flush()
45 }
46
47 type string_values []reflect.Value
48
49 func (sv string_values) Len() int           { return len(sv) }
50 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
51 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
52 func (sv string_values) get(i int) string   { return sv[i].String() }
53
54 func (e *Encoder) write(s []byte) {
55         _, err := e.w.Write(s)
56         if err != nil {
57                 panic(err)
58         }
59 }
60
61 func (e *Encoder) writeString(s string) {
62         _, err := e.w.WriteString(s)
63         if err != nil {
64                 panic(err)
65         }
66 }
67
68 func (e *Encoder) reflectString(s string) {
69         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
70         e.write(b)
71         e.writeString(":")
72         e.writeString(s)
73 }
74
75 func (e *Encoder) reflectByteSlice(s []byte) {
76         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
77         e.write(b)
78         e.writeString(":")
79         e.write(s)
80 }
81
82 // returns true if the value implements Marshaler interface and marshaling was
83 // done successfully
84 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
85         m, ok := v.Interface().(Marshaler)
86         if !ok {
87                 // T doesn't work, try *T
88                 if v.Kind() != reflect.Ptr && v.CanAddr() {
89                         m, ok = v.Addr().Interface().(Marshaler)
90                         if ok {
91                                 v = v.Addr()
92                         }
93                 }
94         }
95         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
96                 data, err := m.MarshalBencode()
97                 if err != nil {
98                         panic(&MarshalerError{v.Type(), err})
99                 }
100                 e.write(data)
101                 return true
102         }
103
104         return false
105 }
106
107 func (e *Encoder) reflectValue(v reflect.Value) {
108
109         if e.reflectMarshaler(v) {
110                 return
111         }
112
113         switch v.Kind() {
114         case reflect.Bool:
115                 if v.Bool() {
116                         e.writeString("i1e")
117                 } else {
118                         e.writeString("i0e")
119                 }
120         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
121                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
122                 e.writeString("i")
123                 e.write(b)
124                 e.writeString("e")
125         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
126                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
127                 e.writeString("i")
128                 e.write(b)
129                 e.writeString("e")
130         case reflect.String:
131                 e.reflectString(v.String())
132         case reflect.Struct:
133                 e.writeString("d")
134                 for _, ef := range encodeFields(v.Type()) {
135                         field_value := v.Field(ef.i)
136                         if ef.omit_empty && isEmptyValue(field_value) {
137                                 continue
138                         }
139                         e.reflectString(ef.tag)
140                         e.reflectValue(field_value)
141                 }
142                 e.writeString("e")
143         case reflect.Map:
144                 if v.Type().Key().Kind() != reflect.String {
145                         panic(&MarshalTypeError{v.Type()})
146                 }
147                 if v.IsNil() {
148                         e.writeString("de")
149                         break
150                 }
151                 e.writeString("d")
152                 sv := string_values(v.MapKeys())
153                 sort.Sort(sv)
154                 for _, key := range sv {
155                         e.reflectString(key.String())
156                         e.reflectValue(v.MapIndex(key))
157                 }
158                 e.writeString("e")
159         case reflect.Slice:
160                 if v.IsNil() {
161                         e.writeString("le")
162                         break
163                 }
164                 if v.Type().Elem().Kind() == reflect.Uint8 {
165                         s := v.Bytes()
166                         e.reflectByteSlice(s)
167                         break
168                 }
169                 fallthrough
170         case reflect.Array:
171                 e.writeString("l")
172                 for i, n := 0, v.Len(); i < n; i++ {
173                         e.reflectValue(v.Index(i))
174                 }
175                 e.writeString("e")
176         case reflect.Interface:
177                 e.reflectValue(v.Elem())
178         case reflect.Ptr:
179                 if v.IsNil() {
180                         v = reflect.Zero(v.Type().Elem())
181                 } else {
182                         v = v.Elem()
183                 }
184                 e.reflectValue(v)
185         default:
186                 panic(&MarshalTypeError{v.Type()})
187         }
188 }
189
190 type encodeField struct {
191         i          int
192         tag        string
193         omit_empty bool
194 }
195
196 type encodeFieldsSortType []encodeField
197
198 func (ef encodeFieldsSortType) Len() int           { return len(ef) }
199 func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
200 func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
201
202 var (
203         typeCacheLock     sync.RWMutex
204         encodeFieldsCache = make(map[reflect.Type][]encodeField)
205 )
206
207 func encodeFields(t reflect.Type) []encodeField {
208         typeCacheLock.RLock()
209         fs, ok := encodeFieldsCache[t]
210         typeCacheLock.RUnlock()
211         if ok {
212                 return fs
213         }
214
215         typeCacheLock.Lock()
216         defer typeCacheLock.Unlock()
217         fs, ok = encodeFieldsCache[t]
218         if ok {
219                 return fs
220         }
221
222         for i, n := 0, t.NumField(); i < n; i++ {
223                 f := t.Field(i)
224                 if f.PkgPath != "" {
225                         continue
226                 }
227                 if f.Anonymous {
228                         continue
229                 }
230                 var ef encodeField
231                 ef.i = i
232                 ef.tag = f.Name
233
234                 tv := f.Tag.Get("bencode")
235                 if tv != "" {
236                         if tv == "-" {
237                                 continue
238                         }
239                         name, opts := parseTag(tv)
240                         if name != "" {
241                                 ef.tag = name
242                         }
243                         ef.omit_empty = opts.contains("omitempty")
244                 }
245                 fs = append(fs, ef)
246         }
247         fss := encodeFieldsSortType(fs)
248         sort.Sort(fss)
249         encodeFieldsCache[t] = fs
250         return fs
251 }