]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
Add bufio.Writer error handling, move Flush to the internal code.
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import "bufio"
4 import "reflect"
5 import "runtime"
6 import "strconv"
7 import "sync"
8 import "sort"
9
10 func is_empty_value(v reflect.Value) bool {
11         switch v.Kind() {
12         case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
13                 return v.Len() == 0
14         case reflect.Bool:
15                 return !v.Bool()
16         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
17                 return v.Int() == 0
18         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
19                 return v.Uint() == 0
20         case reflect.Float32, reflect.Float64:
21                 return v.Float() == 0
22         case reflect.Interface, reflect.Ptr:
23                 return v.IsNil()
24         }
25         return false
26 }
27
28 type encoder struct {
29         *bufio.Writer
30         scratch [64]byte
31 }
32
33 func (e *encoder) encode(v interface{}) (err error) {
34         defer func() {
35                 if e := recover(); e != nil {
36                         if _, ok := e.(runtime.Error); ok {
37                                 panic(e)
38                         }
39                         err = e.(error)
40                 }
41         }()
42         e.reflect_value(reflect.ValueOf(v))
43         return e.Flush()
44 }
45
46 type string_values []reflect.Value
47
48 func (sv string_values) Len() int           { return len(sv) }
49 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
50 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
51 func (sv string_values) get(i int) string   { return sv[i].String() }
52
53 func (e *encoder) write(s []byte) {
54         _, err := e.Write(s)
55         if err != nil {
56                 panic(err)
57         }
58 }
59
60 func (e *encoder) write_string(s string) {
61         _, err := e.WriteString(s)
62         if err != nil {
63                 panic(err)
64         }
65 }
66
67 func (e *encoder) reflect_string(s string) {
68         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
69         e.write(b)
70         e.write_string(":")
71         e.write_string(s)
72 }
73
74 func (e *encoder) reflect_byte_slice(s []byte) {
75         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
76         e.write(b)
77         e.write_string(":")
78         e.write(s)
79 }
80
81 func (e *encoder) reflect_value(v reflect.Value) {
82         if !v.IsValid() {
83                 return
84         }
85
86         m, ok := v.Interface().(Marshaler)
87         if !ok {
88                 if v.Kind() != reflect.Ptr && v.CanAddr() {
89                         m, ok = v.Addr().Interface().(Marshaler)
90                         if ok {
91                                 v = v.Addr()
92                         }
93                 }
94         }
95         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
96                 data, err := m.MarshalBencode()
97                 if err != nil {
98                         panic(&MarshalerError{v.Type(), err})
99                 }
100                 e.write(data)
101         }
102
103         switch v.Kind() {
104         case reflect.Bool:
105                 if v.Bool() {
106                         e.write_string("i1e")
107                 } else {
108                         e.write_string("i0e")
109                 }
110         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
111                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
112                 e.write_string("i")
113                 e.write(b)
114                 e.write_string("e")
115         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
116                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
117                 e.write_string("i")
118                 e.write(b)
119                 e.write_string("e")
120         case reflect.String:
121                 e.reflect_string(v.String())
122         case reflect.Struct:
123                 e.write_string("d")
124                 for _, ef := range encode_fields(v.Type()) {
125                         field_value := v.Field(ef.i)
126                         if ef.omit_empty && is_empty_value(field_value) {
127                                 continue
128                         }
129
130                         e.reflect_string(ef.tag)
131                         e.reflect_value(field_value)
132                 }
133                 e.write_string("e")
134         case reflect.Map:
135                 if v.Type().Key().Kind() != reflect.String {
136                         panic(&MarshalTypeError{v.Type()})
137                 }
138                 if v.IsNil() {
139                         e.write_string("de")
140                         break
141                 }
142                 e.write_string("d")
143                 sv := string_values(v.MapKeys())
144                 sort.Sort(sv)
145                 for _, key := range sv {
146                         e.reflect_string(key.String())
147                         e.reflect_value(v.MapIndex(key))
148                 }
149                 e.write_string("e")
150         case reflect.Slice:
151                 if v.IsNil() {
152                         e.write_string("le")
153                         break
154                 }
155                 if v.Type().Elem().Kind() == reflect.Uint8 {
156                         s := v.Bytes()
157                         e.reflect_byte_slice(s)
158                         break
159                 }
160                 fallthrough
161         case reflect.Array:
162                 e.write_string("l")
163                 for i, n := 0, v.Len(); i < n; i++ {
164                         e.reflect_value(v.Index(i))
165                 }
166                 e.write_string("e")
167         case reflect.Interface, reflect.Ptr:
168                 if v.IsNil() {
169                         break
170                 }
171                 e.reflect_value(v.Elem())
172         default:
173                 panic(&MarshalTypeError{v.Type()})
174         }
175 }
176
177 type encode_field struct {
178         i          int
179         tag        string
180         omit_empty bool
181 }
182
183 type encode_fields_sort_type []encode_field
184
185 func (ef encode_fields_sort_type) Len() int           { return len(ef) }
186 func (ef encode_fields_sort_type) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
187 func (ef encode_fields_sort_type) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
188
189 var (
190         type_cache_lock     sync.RWMutex
191         encode_fields_cache = make(map[reflect.Type][]encode_field)
192 )
193
194 func encode_fields(t reflect.Type) []encode_field {
195         type_cache_lock.RLock()
196         fs, ok := encode_fields_cache[t]
197         type_cache_lock.RUnlock()
198         if ok {
199                 return fs
200         }
201
202         type_cache_lock.Lock()
203         defer type_cache_lock.Unlock()
204         fs, ok = encode_fields_cache[t]
205         if ok {
206                 return fs
207         }
208
209         for i, n := 0, t.NumField(); i < n; i++ {
210                 f := t.Field(i)
211                 if f.PkgPath != "" {
212                         continue
213                 }
214                 if f.Anonymous {
215                         continue
216                 }
217                 var ef encode_field
218                 ef.i = i
219                 ef.tag = f.Name
220
221                 tv := f.Tag.Get("bencode")
222                 if tv != "" {
223                         if tv == "-" {
224                                 continue
225                         }
226                         name, opts := parse_tag(tv)
227                         ef.tag = name
228                         ef.omit_empty = opts.contains("omitempty")
229                 }
230                 fs = append(fs, ef)
231         }
232         fss := encode_fields_sort_type(fs)
233         sort.Sort(fss)
234         encode_fields_cache[t] = fs
235         return fs
236 }