]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/encode.go
Initial commit.
[btrtrc.git] / bencode / encode.go
1 package bencode
2
3 import "bufio"
4 import "reflect"
5 import "runtime"
6 import "strconv"
7 import "sync"
8 import "sort"
9
10 func is_empty_value(v reflect.Value) bool {
11         switch v.Kind() {
12         case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
13                 return v.Len() == 0
14         case reflect.Bool:
15                 return !v.Bool()
16         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
17                 return v.Int() == 0
18         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
19                 return v.Uint() == 0
20         case reflect.Float32, reflect.Float64:
21                 return v.Float() == 0
22         case reflect.Interface, reflect.Ptr:
23                 return v.IsNil()
24         }
25         return false
26 }
27
28 type encoder struct {
29         *bufio.Writer
30         scratch [64]byte
31 }
32
33 func (e *encoder) encode(v interface{}) (err error) {
34         defer func() {
35                 if e := recover(); e != nil {
36                         if _, ok := e.(runtime.Error); ok {
37                                 panic(e)
38                         }
39                         err = e.(error)
40                 }
41         }()
42         e.reflect_value(reflect.ValueOf(v))
43         return nil
44 }
45
46 type string_values []reflect.Value
47
48 func (sv string_values) Len() int           { return len(sv) }
49 func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
50 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
51 func (sv string_values) get(i int) string   { return sv[i].String() }
52
53 func (e *encoder) reflect_string(s string) {
54         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
55         e.Write(b)
56         e.WriteString(":")
57         e.WriteString(s)
58 }
59
60 func (e *encoder) reflect_byte_slice(s []byte) {
61         b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
62         e.Write(b)
63         e.WriteString(":")
64         e.Write(s)
65 }
66
67 func (e *encoder) reflect_value(v reflect.Value) {
68         if !v.IsValid() {
69                 return
70         }
71
72         switch v.Kind() {
73         case reflect.Bool:
74                 if v.Bool() {
75                         e.WriteString("i1e")
76                 } else {
77                         e.WriteString("i0e")
78                 }
79         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
80                 b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
81                 e.WriteString("i")
82                 e.Write(b)
83                 e.WriteString("e")
84         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
85                 b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
86                 e.WriteString("i")
87                 e.Write(b)
88                 e.WriteString("e")
89         case reflect.String:
90                 e.reflect_string(v.String())
91         case reflect.Struct:
92                 e.WriteString("d")
93                 for _, ef := range encode_fields(v.Type()) {
94                         field_value := v.Field(ef.i)
95                         if ef.omit_empty && is_empty_value(field_value) {
96                                 continue
97                         }
98
99                         e.reflect_string(ef.tag)
100                         e.reflect_value(field_value)
101                 }
102                 e.WriteString("e")
103         case reflect.Map:
104                 if v.Type().Key().Kind() != reflect.String {
105                         panic(&MarshalTypeError{v.Type()})
106                 }
107                 if v.IsNil() {
108                         e.WriteString("de")
109                         break
110                 }
111                 e.WriteString("d")
112                 sv := string_values(v.MapKeys())
113                 sort.Sort(sv)
114                 for _, key := range sv {
115                         e.reflect_string(key.String())
116                         e.reflect_value(v.MapIndex(key))
117                 }
118                 e.WriteString("e")
119         case reflect.Slice:
120                 if v.IsNil() {
121                         e.WriteString("le")
122                         break
123                 }
124                 if v.Type().Elem().Kind() == reflect.Uint8 {
125                         s := v.Bytes()
126                         e.reflect_byte_slice(s)
127                         break
128                 }
129                 fallthrough
130         case reflect.Array:
131                 e.WriteString("l")
132                 for i, n := 0, v.Len(); i < n; i++ {
133                         e.reflect_value(v.Index(i))
134                 }
135                 e.WriteString("e")
136         case reflect.Interface, reflect.Ptr:
137                 if v.IsNil() {
138                         break
139                 }
140                 e.reflect_value(v.Elem())
141         default:
142                 panic(&MarshalTypeError{v.Type()})
143         }
144 }
145
146 type encode_field struct {
147         i          int
148         tag        string
149         omit_empty bool
150 }
151
152 type encode_fields_sort_type []encode_field
153
154 func (ef encode_fields_sort_type) Len() int           { return len(ef) }
155 func (ef encode_fields_sort_type) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
156 func (ef encode_fields_sort_type) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
157
158 var (
159         type_cache_lock     sync.RWMutex
160         encode_fields_cache = make(map[reflect.Type][]encode_field)
161 )
162
163 func encode_fields(t reflect.Type) []encode_field {
164         type_cache_lock.RLock()
165         fs, ok := encode_fields_cache[t]
166         type_cache_lock.RUnlock()
167         if ok {
168                 return fs
169         }
170
171         type_cache_lock.Lock()
172         defer type_cache_lock.Unlock()
173         fs, ok = encode_fields_cache[t]
174         if ok {
175                 return fs
176         }
177
178         for i, n := 0, t.NumField(); i < n; i++ {
179                 f := t.Field(i)
180                 if f.PkgPath != "" {
181                         continue
182                 }
183                 if f.Anonymous {
184                         continue
185                 }
186                 var ef encode_field
187                 ef.i = i
188                 ef.tag = f.Name
189
190                 tv := f.Tag.Get("bencode")
191                 if tv != "" {
192                         if tv == "-" {
193                                 continue
194                         }
195                         name, opts := parse_tag(tv)
196                         ef.tag = name
197                         ef.omit_empty = opts.contains("omitempty")
198                 }
199                 fs = append(fs, ef)
200         }
201         fss := encode_fields_sort_type(fs)
202         sort.Sort(fss)
203         encode_fields_cache[t] = fs
204         return fs
205 }