]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
bencode.Marshal: Get rid of the intermediate buffer
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "io"
5         "math/big"
6         "reflect"
7         "runtime"
8         "sort"
9         "strconv"
10         "sync"
11
12         "github.com/anacrolix/missinggo"
13 )
14
15 func isEmptyValue(v reflect.Value) bool {
16         return missinggo.IsEmptyValue(v)
17 }
18
19 type Encoder struct {
20         w       io.Writer
21         scratch [64]byte
22 }
23
24 func (e *Encoder) Encode(v interface{}) (err error) {
25         if v == nil {
26                 return
27         }
28         defer func() {
29                 if e := recover(); e != nil {
30                         if _, ok := e.(runtime.Error); ok {
31                                 panic(e)
32                         }
33                         var ok bool
34                         err, ok = e.(error)
35                         if !ok {
36                                 panic(e)
37                         }
38                 }
39         }()
40         e.reflectValue(reflect.ValueOf(v))
41         return nil
42 }
43
44 type string_values []reflect.Value
45
46 func (sv string_values) Len() int           { return len(sv) }
47 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
48 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
49 func (sv string_values) get(i int) string   { return sv[i].String() }
50
51 func (e *Encoder) write(s []byte) {
52         _, err := e.w.Write(s)
53         if err != nil {
54                 panic(err)
55         }
56 }
57
58 func (e *Encoder) writeString(s string) {
59         for s != "" {
60                 n := copy(e.scratch[:], s)
61                 s = s[n:]
62                 e.write(e.scratch[:n])
63         }
64 }
65
66 func (e *Encoder) reflectString(s string) {
67         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
68         e.write(b)
69         e.writeString(":")
70         e.writeString(s)
71 }
72
73 func (e *Encoder) reflectByteSlice(s []byte) {
74         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
75         e.write(b)
76         e.writeString(":")
77         e.write(s)
78 }
79
80 // returns true if the value implements Marshaler interface and marshaling was
81 // done successfully
82 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
83         m, ok := v.Interface().(Marshaler)
84         if !ok {
85                 // T doesn't work, try *T
86                 if v.Kind() != reflect.Ptr && v.CanAddr() {
87                         m, ok = v.Addr().Interface().(Marshaler)
88                         if ok {
89                                 v = v.Addr()
90                         }
91                 }
92         }
93         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
94                 data, err := m.MarshalBencode()
95                 if err != nil {
96                         panic(&MarshalerError{v.Type(), err})
97                 }
98                 e.write(data)
99                 return true
100         }
101
102         return false
103 }
104
105 var bigIntType = reflect.TypeOf(big.Int{})
106
107 func (e *Encoder) reflectValue(v reflect.Value) {
108
109         if e.reflectMarshaler(v) {
110                 return
111         }
112
113         if v.Type() == bigIntType {
114                 e.writeString("i")
115                 bi := v.Interface().(big.Int)
116                 e.writeString(bi.String())
117                 e.writeString("e")
118                 return
119         }
120
121         switch v.Kind() {
122         case reflect.Bool:
123                 if v.Bool() {
124                         e.writeString("i1e")
125                 } else {
126                         e.writeString("i0e")
127                 }
128         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
129                 e.writeString("i")
130                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
131                 e.write(b)
132                 e.writeString("e")
133         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
134                 e.writeString("i")
135                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
136                 e.write(b)
137                 e.writeString("e")
138         case reflect.String:
139                 e.reflectString(v.String())
140         case reflect.Struct:
141                 e.writeString("d")
142                 for _, ef := range encodeFields(v.Type()) {
143                         field_value := v.Field(ef.i)
144                         if ef.omit_empty && isEmptyValue(field_value) {
145                                 continue
146                         }
147                         e.reflectString(ef.tag)
148                         e.reflectValue(field_value)
149                 }
150                 e.writeString("e")
151         case reflect.Map:
152                 if v.Type().Key().Kind() != reflect.String {
153                         panic(&MarshalTypeError{v.Type()})
154                 }
155                 if v.IsNil() {
156                         e.writeString("de")
157                         break
158                 }
159                 e.writeString("d")
160                 sv := string_values(v.MapKeys())
161                 sort.Sort(sv)
162                 for _, key := range sv {
163                         e.reflectString(key.String())
164                         e.reflectValue(v.MapIndex(key))
165                 }
166                 e.writeString("e")
167         case reflect.Slice:
168                 if v.IsNil() {
169                         e.writeString("le")
170                         break
171                 }
172                 if v.Type().Elem().Kind() == reflect.Uint8 {
173                         s := v.Bytes()
174                         e.reflectByteSlice(s)
175                         break
176                 }
177                 fallthrough
178         case reflect.Array:
179                 e.writeString("l")
180                 for i, n := 0, v.Len(); i < n; i++ {
181                         e.reflectValue(v.Index(i))
182                 }
183                 e.writeString("e")
184         case reflect.Interface:
185                 e.reflectValue(v.Elem())
186         case reflect.Ptr:
187                 if v.IsNil() {
188                         v = reflect.Zero(v.Type().Elem())
189                 } else {
190                         v = v.Elem()
191                 }
192                 e.reflectValue(v)
193         default:
194                 panic(&MarshalTypeError{v.Type()})
195         }
196 }
197
198 type encodeField struct {
199         i          int
200         tag        string
201         omit_empty bool
202 }
203
204 type encodeFieldsSortType []encodeField
205
206 func (ef encodeFieldsSortType) Len() int           { return len(ef) }
207 func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
208 func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
209
210 var (
211         typeCacheLock     sync.RWMutex
212         encodeFieldsCache = make(map[reflect.Type][]encodeField)
213 )
214
215 func encodeFields(t reflect.Type) []encodeField {
216         typeCacheLock.RLock()
217         fs, ok := encodeFieldsCache[t]
218         typeCacheLock.RUnlock()
219         if ok {
220                 return fs
221         }
222
223         typeCacheLock.Lock()
224         defer typeCacheLock.Unlock()
225         fs, ok = encodeFieldsCache[t]
226         if ok {
227                 return fs
228         }
229
230         for i, n := 0, t.NumField(); i < n; i++ {
231                 f := t.Field(i)
232                 if f.PkgPath != "" {
233                         continue
234                 }
235                 if f.Anonymous {
236                         continue
237                 }
238                 var ef encodeField
239                 ef.i = i
240                 ef.tag = f.Name
241
242                 tv := getTag(f.Tag)
243                 if tv.Ignore() {
244                         continue
245                 }
246                 if tv.Key() != "" {
247                         ef.tag = tv.Key()
248                 }
249                 ef.omit_empty = tv.OmitEmpty()
250                 fs = append(fs, ef)
251         }
252         fss := encodeFieldsSortType(fs)
253         sort.Sort(fss)
254         encodeFieldsCache[t] = fs
255         return fs
256 }