]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
bencode: Remove private types encoder and decoder
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import (
4         "bufio"
5         "reflect"
6         "runtime"
7         "sort"
8         "strconv"
9         "sync"
10
11         "github.com/anacrolix/missinggo"
12 )
13
14 func is_empty_value(v reflect.Value) bool {
15         return missinggo.IsEmptyValue(v)
16 }
17
18 type Encoder struct {
19         *bufio.Writer
20         scratch [64]byte
21 }
22
23 func (e *Encoder) Encode(v interface{}) (err error) {
24         if v == nil {
25                 return
26         }
27         defer func() {
28                 if e := recover(); e != nil {
29                         if _, ok := e.(runtime.Error); ok {
30                                 panic(e)
31                         }
32                         var ok bool
33                         err, ok = e.(error)
34                         if !ok {
35                                 panic(e)
36                         }
37                 }
38         }()
39         e.reflect_value(reflect.ValueOf(v))
40         return e.Flush()
41 }
42
43 type string_values []reflect.Value
44
45 func (sv string_values) Len() int           { return len(sv) }
46 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
47 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
48 func (sv string_values) get(i int) string   { return sv[i].String() }
49
50 func (e *Encoder) write(s []byte) {
51         _, err := e.Write(s)
52         if err != nil {
53                 panic(err)
54         }
55 }
56
57 func (e *Encoder) write_string(s string) {
58         _, err := e.WriteString(s)
59         if err != nil {
60                 panic(err)
61         }
62 }
63
64 func (e *Encoder) reflect_string(s string) {
65         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
66         e.write(b)
67         e.write_string(":")
68         e.write_string(s)
69 }
70
71 func (e *Encoder) reflect_byte_slice(s []byte) {
72         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
73         e.write(b)
74         e.write_string(":")
75         e.write(s)
76 }
77
78 // returns true if the value implements Marshaler interface and marshaling was
79 // done successfully
80 func (e *Encoder) reflect_marshaler(v reflect.Value) bool {
81         m, ok := v.Interface().(Marshaler)
82         if !ok {
83                 // T doesn't work, try *T
84                 if v.Kind() != reflect.Ptr && v.CanAddr() {
85                         m, ok = v.Addr().Interface().(Marshaler)
86                         if ok {
87                                 v = v.Addr()
88                         }
89                 }
90         }
91         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
92                 data, err := m.MarshalBencode()
93                 if err != nil {
94                         panic(&MarshalerError{v.Type(), err})
95                 }
96                 e.write(data)
97                 return true
98         }
99
100         return false
101 }
102
103 func (e *Encoder) reflect_value(v reflect.Value) {
104
105         if e.reflect_marshaler(v) {
106                 return
107         }
108
109         switch v.Kind() {
110         case reflect.Bool:
111                 if v.Bool() {
112                         e.write_string("i1e")
113                 } else {
114                         e.write_string("i0e")
115                 }
116         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
117                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
118                 e.write_string("i")
119                 e.write(b)
120                 e.write_string("e")
121         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
122                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
123                 e.write_string("i")
124                 e.write(b)
125                 e.write_string("e")
126         case reflect.String:
127                 e.reflect_string(v.String())
128         case reflect.Struct:
129                 e.write_string("d")
130                 for _, ef := range encode_fields(v.Type()) {
131                         field_value := v.Field(ef.i)
132                         if ef.omit_empty && is_empty_value(field_value) {
133                                 continue
134                         }
135                         e.reflect_string(ef.tag)
136                         e.reflect_value(field_value)
137                 }
138                 e.write_string("e")
139         case reflect.Map:
140                 if v.Type().Key().Kind() != reflect.String {
141                         panic(&MarshalTypeError{v.Type()})
142                 }
143                 if v.IsNil() {
144                         e.write_string("de")
145                         break
146                 }
147                 e.write_string("d")
148                 sv := string_values(v.MapKeys())
149                 sort.Sort(sv)
150                 for _, key := range sv {
151                         e.reflect_string(key.String())
152                         e.reflect_value(v.MapIndex(key))
153                 }
154                 e.write_string("e")
155         case reflect.Slice:
156                 if v.IsNil() {
157                         e.write_string("le")
158                         break
159                 }
160                 if v.Type().Elem().Kind() == reflect.Uint8 {
161                         s := v.Bytes()
162                         e.reflect_byte_slice(s)
163                         break
164                 }
165                 fallthrough
166         case reflect.Array:
167                 e.write_string("l")
168                 for i, n := 0, v.Len(); i < n; i++ {
169                         e.reflect_value(v.Index(i))
170                 }
171                 e.write_string("e")
172         case reflect.Interface:
173                 e.reflect_value(v.Elem())
174         case reflect.Ptr:
175                 if v.IsNil() {
176                         v = reflect.Zero(v.Type().Elem())
177                 } else {
178                         v = v.Elem()
179                 }
180                 e.reflect_value(v)
181         default:
182                 panic(&MarshalTypeError{v.Type()})
183         }
184 }
185
186 type encode_field struct {
187         i          int
188         tag        string
189         omit_empty bool
190 }
191
192 type encode_fields_sort_type []encode_field
193
194 func (ef encode_fields_sort_type) Len() int           { return len(ef) }
195 func (ef encode_fields_sort_type) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
196 func (ef encode_fields_sort_type) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
197
198 var (
199         type_cache_lock     sync.RWMutex
200         encode_fields_cache = make(map[reflect.Type][]encode_field)
201 )
202
203 func encode_fields(t reflect.Type) []encode_field {
204         type_cache_lock.RLock()
205         fs, ok := encode_fields_cache[t]
206         type_cache_lock.RUnlock()
207         if ok {
208                 return fs
209         }
210
211         type_cache_lock.Lock()
212         defer type_cache_lock.Unlock()
213         fs, ok = encode_fields_cache[t]
214         if ok {
215                 return fs
216         }
217
218         for i, n := 0, t.NumField(); i < n; i++ {
219                 f := t.Field(i)
220                 if f.PkgPath != "" {
221                         continue
222                 }
223                 if f.Anonymous {
224                         continue
225                 }
226                 var ef encode_field
227                 ef.i = i
228                 ef.tag = f.Name
229
230                 tv := f.Tag.Get("bencode")
231                 if tv != "" {
232                         if tv == "-" {
233                                 continue
234                         }
235                         name, opts := parse_tag(tv)
236                         if name != "" {
237                                 ef.tag = name
238                         }
239                         ef.omit_empty = opts.contains("omitempty")
240                 }
241                 fs = append(fs, ef)
242         }
243         fss := encode_fields_sort_type(fs)
244         sort.Sort(fss)
245         encode_fields_cache[t] = fs
246         return fs
247 }