]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
Fix Marshaler case.
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import "bufio"
4 import "reflect"
5 import "runtime"
6 import "strconv"
7 import "sync"
8 import "sort"
9
10 func is_empty_value(v reflect.Value) bool {
11         switch v.Kind() {
12         case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
13                 return v.Len() == 0
14         case reflect.Bool:
15                 return !v.Bool()
16         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
17                 return v.Int() == 0
18         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
19                 return v.Uint() == 0
20         case reflect.Float32, reflect.Float64:
21                 return v.Float() == 0
22         case reflect.Interface, reflect.Ptr:
23                 return v.IsNil()
24         }
25         return false
26 }
27
28 type encoder struct {
29         *bufio.Writer
30         scratch [64]byte
31 }
32
33 func (e *encoder) encode(v interface{}) (err error) {
34         defer func() {
35                 if e := recover(); e != nil {
36                         if _, ok := e.(runtime.Error); ok {
37                                 panic(e)
38                         }
39                         err = e.(error)
40                 }
41         }()
42         e.reflect_value(reflect.ValueOf(v))
43         return e.Flush()
44 }
45
46 type string_values []reflect.Value
47
48 func (sv string_values) Len() int           { return len(sv) }
49 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
50 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
51 func (sv string_values) get(i int) string   { return sv[i].String() }
52
53 func (e *encoder) write(s []byte) {
54         _, err := e.Write(s)
55         if err != nil {
56                 panic(err)
57         }
58 }
59
60 func (e *encoder) write_string(s string) {
61         _, err := e.WriteString(s)
62         if err != nil {
63                 panic(err)
64         }
65 }
66
67 func (e *encoder) reflect_string(s string) {
68         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
69         e.write(b)
70         e.write_string(":")
71         e.write_string(s)
72 }
73
74 func (e *encoder) reflect_byte_slice(s []byte) {
75         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
76         e.write(b)
77         e.write_string(":")
78         e.write(s)
79 }
80
81 func (e *encoder) reflect_value(v reflect.Value) {
82         if !v.IsValid() {
83                 return
84         }
85
86         m, ok := v.Interface().(Marshaler)
87         if !ok {
88                 if v.Kind() != reflect.Ptr && v.CanAddr() {
89                         m, ok = v.Addr().Interface().(Marshaler)
90                         if ok {
91                                 v = v.Addr()
92                         }
93                 }
94         }
95         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
96                 data, err := m.MarshalBencode()
97                 if err != nil {
98                         panic(&MarshalerError{v.Type(), err})
99                 }
100                 e.write(data)
101                 return
102         }
103
104         switch v.Kind() {
105         case reflect.Bool:
106                 if v.Bool() {
107                         e.write_string("i1e")
108                 } else {
109                         e.write_string("i0e")
110                 }
111         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
112                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
113                 e.write_string("i")
114                 e.write(b)
115                 e.write_string("e")
116         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
117                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
118                 e.write_string("i")
119                 e.write(b)
120                 e.write_string("e")
121         case reflect.String:
122                 e.reflect_string(v.String())
123         case reflect.Struct:
124                 e.write_string("d")
125                 for _, ef := range encode_fields(v.Type()) {
126                         field_value := v.Field(ef.i)
127                         if ef.omit_empty && is_empty_value(field_value) {
128                                 continue
129                         }
130
131                         e.reflect_string(ef.tag)
132                         e.reflect_value(field_value)
133                 }
134                 e.write_string("e")
135         case reflect.Map:
136                 if v.Type().Key().Kind() != reflect.String {
137                         panic(&MarshalTypeError{v.Type()})
138                 }
139                 if v.IsNil() {
140                         e.write_string("de")
141                         break
142                 }
143                 e.write_string("d")
144                 sv := string_values(v.MapKeys())
145                 sort.Sort(sv)
146                 for _, key := range sv {
147                         e.reflect_string(key.String())
148                         e.reflect_value(v.MapIndex(key))
149                 }
150                 e.write_string("e")
151         case reflect.Slice:
152                 if v.IsNil() {
153                         e.write_string("le")
154                         break
155                 }
156                 if v.Type().Elem().Kind() == reflect.Uint8 {
157                         s := v.Bytes()
158                         e.reflect_byte_slice(s)
159                         break
160                 }
161                 fallthrough
162         case reflect.Array:
163                 e.write_string("l")
164                 for i, n := 0, v.Len(); i < n; i++ {
165                         e.reflect_value(v.Index(i))
166                 }
167                 e.write_string("e")
168         case reflect.Interface, reflect.Ptr:
169                 if v.IsNil() {
170                         break
171                 }
172                 e.reflect_value(v.Elem())
173         default:
174                 panic(&MarshalTypeError{v.Type()})
175         }
176 }
177
178 type encode_field struct {
179         i          int
180         tag        string
181         omit_empty bool
182 }
183
184 type encode_fields_sort_type []encode_field
185
186 func (ef encode_fields_sort_type) Len() int           { return len(ef) }
187 func (ef encode_fields_sort_type) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
188 func (ef encode_fields_sort_type) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
189
190 var (
191         type_cache_lock     sync.RWMutex
192         encode_fields_cache = make(map[reflect.Type][]encode_field)
193 )
194
195 func encode_fields(t reflect.Type) []encode_field {
196         type_cache_lock.RLock()
197         fs, ok := encode_fields_cache[t]
198         type_cache_lock.RUnlock()
199         if ok {
200                 return fs
201         }
202
203         type_cache_lock.Lock()
204         defer type_cache_lock.Unlock()
205         fs, ok = encode_fields_cache[t]
206         if ok {
207                 return fs
208         }
209
210         for i, n := 0, t.NumField(); i < n; i++ {
211                 f := t.Field(i)
212                 if f.PkgPath != "" {
213                         continue
214                 }
215                 if f.Anonymous {
216                         continue
217                 }
218                 var ef encode_field
219                 ef.i = i
220                 ef.tag = f.Name
221
222                 tv := f.Tag.Get("bencode")
223                 if tv != "" {
224                         if tv == "-" {
225                                 continue
226                         }
227                         name, opts := parse_tag(tv)
228                         ef.tag = name
229                         ef.omit_empty = opts.contains("omitempty")
230                 }
231                 fs = append(fs, ef)
232         }
233         fss := encode_fields_sort_type(fs)
234         sort.Sort(fss)
235         encode_fields_cache[t] = fs
236         return fs
237 }