]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
Fix bugs, implement missing bits. Made a mess, can't describe properly.
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import "bufio"
4 import "reflect"
5 import "runtime"
6 import "strconv"
7 import "sync"
8 import "sort"
9
10 func is_empty_value(v reflect.Value) bool {
11         switch v.Kind() {
12         case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
13                 return v.Len() == 0
14         case reflect.Bool:
15                 return !v.Bool()
16         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
17                 return v.Int() == 0
18         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
19                 return v.Uint() == 0
20         case reflect.Float32, reflect.Float64:
21                 return v.Float() == 0
22         case reflect.Interface, reflect.Ptr:
23                 return v.IsNil()
24         }
25         return false
26 }
27
28 type encoder struct {
29         *bufio.Writer
30         scratch [64]byte
31 }
32
33 func (e *encoder) encode(v interface{}) (err error) {
34         defer func() {
35                 if e := recover(); e != nil {
36                         if _, ok := e.(runtime.Error); ok {
37                                 panic(e)
38                         }
39                         err = e.(error)
40                 }
41         }()
42         e.reflect_value(reflect.ValueOf(v))
43         return e.Flush()
44 }
45
46 type string_values []reflect.Value
47
48 func (sv string_values) Len() int           { return len(sv) }
49 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
50 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
51 func (sv string_values) get(i int) string   { return sv[i].String() }
52
53 func (e *encoder) write(s []byte) {
54         _, err := e.Write(s)
55         if err != nil {
56                 panic(err)
57         }
58 }
59
60 func (e *encoder) write_string(s string) {
61         _, err := e.WriteString(s)
62         if err != nil {
63                 panic(err)
64         }
65 }
66
67 func (e *encoder) reflect_string(s string) {
68         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
69         e.write(b)
70         e.write_string(":")
71         e.write_string(s)
72 }
73
74 func (e *encoder) reflect_byte_slice(s []byte) {
75         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
76         e.write(b)
77         e.write_string(":")
78         e.write(s)
79 }
80
81 // returns true if the value implements Marshaler interface and marshaling was
82 // done successfully
83 func (e *encoder) reflect_marshaler(v reflect.Value) bool {
84         m, ok := v.Interface().(Marshaler)
85         if !ok {
86                 // T doesn't work, try *T
87                 if v.Kind() != reflect.Ptr && v.CanAddr() {
88                         m, ok = v.Addr().Interface().(Marshaler)
89                         if ok {
90                                 v = v.Addr()
91                         }
92                 }
93         }
94         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
95                 data, err := m.MarshalBencode()
96                 if err != nil {
97                         panic(&MarshalerError{v.Type(), err})
98                 }
99                 e.write(data)
100                 return true
101         }
102
103         return false
104 }
105
106 func (e *encoder) reflect_value(v reflect.Value) {
107         if !v.IsValid() {
108                 return
109         }
110
111         if e.reflect_marshaler(v) {
112                 return
113         }
114
115         switch v.Kind() {
116         case reflect.Bool:
117                 if v.Bool() {
118                         e.write_string("i1e")
119                 } else {
120                         e.write_string("i0e")
121                 }
122         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
123                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
124                 e.write_string("i")
125                 e.write(b)
126                 e.write_string("e")
127         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
128                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
129                 e.write_string("i")
130                 e.write(b)
131                 e.write_string("e")
132         case reflect.String:
133                 e.reflect_string(v.String())
134         case reflect.Struct:
135                 e.write_string("d")
136                 for _, ef := range encode_fields(v.Type()) {
137                         field_value := v.Field(ef.i)
138                         if ef.omit_empty && is_empty_value(field_value) {
139                                 continue
140                         }
141
142                         e.reflect_string(ef.tag)
143                         e.reflect_value(field_value)
144                 }
145                 e.write_string("e")
146         case reflect.Map:
147                 if v.Type().Key().Kind() != reflect.String {
148                         panic(&MarshalTypeError{v.Type()})
149                 }
150                 if v.IsNil() {
151                         e.write_string("de")
152                         break
153                 }
154                 e.write_string("d")
155                 sv := string_values(v.MapKeys())
156                 sort.Sort(sv)
157                 for _, key := range sv {
158                         e.reflect_string(key.String())
159                         e.reflect_value(v.MapIndex(key))
160                 }
161                 e.write_string("e")
162         case reflect.Slice:
163                 if v.IsNil() {
164                         e.write_string("le")
165                         break
166                 }
167                 if v.Type().Elem().Kind() == reflect.Uint8 {
168                         s := v.Bytes()
169                         e.reflect_byte_slice(s)
170                         break
171                 }
172                 fallthrough
173         case reflect.Array:
174                 e.write_string("l")
175                 for i, n := 0, v.Len(); i < n; i++ {
176                         e.reflect_value(v.Index(i))
177                 }
178                 e.write_string("e")
179         case reflect.Interface, reflect.Ptr:
180                 if v.IsNil() {
181                         break
182                 }
183                 e.reflect_value(v.Elem())
184         default:
185                 panic(&MarshalTypeError{v.Type()})
186         }
187 }
188
189 type encode_field struct {
190         i          int
191         tag        string
192         omit_empty bool
193 }
194
195 type encode_fields_sort_type []encode_field
196
197 func (ef encode_fields_sort_type) Len() int           { return len(ef) }
198 func (ef encode_fields_sort_type) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
199 func (ef encode_fields_sort_type) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
200
201 var (
202         type_cache_lock     sync.RWMutex
203         encode_fields_cache = make(map[reflect.Type][]encode_field)
204 )
205
206 func encode_fields(t reflect.Type) []encode_field {
207         type_cache_lock.RLock()
208         fs, ok := encode_fields_cache[t]
209         type_cache_lock.RUnlock()
210         if ok {
211                 return fs
212         }
213
214         type_cache_lock.Lock()
215         defer type_cache_lock.Unlock()
216         fs, ok = encode_fields_cache[t]
217         if ok {
218                 return fs
219         }
220
221         for i, n := 0, t.NumField(); i < n; i++ {
222                 f := t.Field(i)
223                 if f.PkgPath != "" {
224                         continue
225                 }
226                 if f.Anonymous {
227                         continue
228                 }
229                 var ef encode_field
230                 ef.i = i
231                 ef.tag = f.Name
232
233                 tv := f.Tag.Get("bencode")
234                 if tv != "" {
235                         if tv == "-" {
236                                 continue
237                         }
238                         name, opts := parse_tag(tv)
239                         ef.tag = name
240                         ef.omit_empty = opts.contains("omitempty")
241                 }
242                 fs = append(fs, ef)
243         }
244         fss := encode_fields_sort_type(fs)
245         sort.Sort(fss)
246         encode_fields_cache[t] = fs
247         return fs
248 }