]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Add bufio.Writer error handling, move Flush to the internal code.
authornsf <no.smile.face@gmail.com>
Sun, 24 Jun 2012 10:59:04 +0000 (16:59 +0600)
committernsf <no.smile.face@gmail.com>
Sun, 24 Jun 2012 10:59:04 +0000 (16:59 +0600)
bencode/api.go
bencode/encode.go

index 5f21f36348b4480a8746a919922b716873a13aa1..28828d5ecaa197be293ab99a33abcfe346da018d 100644 (file)
@@ -72,6 +72,15 @@ func (e *SyntaxError) Error() string {
                "): " + e.what
 }
 
+type MarshalerError struct {
+       Type reflect.Type
+       Err  error
+}
+
+func (e *MarshalerError) Error() string {
+       return "bencode: error calling MarshalBencode for type " + e.Type.String() + ": " + e.Err.Error()
+}
+
 //----------------------------------------------------------------------------
 // Interfaces
 //----------------------------------------------------------------------------
@@ -97,8 +106,7 @@ func Marshal(v interface{}) ([]byte, error) {
        if err != nil {
                return nil, err
        }
-       err = e.Flush()
-       return buf.Bytes(), err
+       return buf.Bytes(), nil
 }
 
 func Unmarshal(data []byte, v interface{}) error {
@@ -135,5 +143,5 @@ func (e *Encoder) Encode(v interface{}) error {
        if err != nil {
                return err
        }
-       return e.e.Flush()
+       return nil
 }
index 82f5b4c2d982c2e168c3429e475e8e24eeac733b..eee55818899c9d6332417a2872720e93f6f0dbcf 100644 (file)
@@ -40,7 +40,7 @@ func (e *encoder) encode(v interface{}) (err error) {
                }
        }()
        e.reflect_value(reflect.ValueOf(v))
-       return nil
+       return e.Flush()
 }
 
 type string_values []reflect.Value
@@ -50,18 +50,32 @@ func (sv string_values) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
 func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
 func (sv string_values) get(i int) string   { return sv[i].String() }
 
+func (e *encoder) write(s []byte) {
+       _, err := e.Write(s)
+       if err != nil {
+               panic(err)
+       }
+}
+
+func (e *encoder) write_string(s string) {
+       _, err := e.WriteString(s)
+       if err != nil {
+               panic(err)
+       }
+}
+
 func (e *encoder) reflect_string(s string) {
        b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
-       e.Write(b)
-       e.WriteString(":")
-       e.WriteString(s)
+       e.write(b)
+       e.write_string(":")
+       e.write_string(s)
 }
 
 func (e *encoder) reflect_byte_slice(s []byte) {
        b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10)
-       e.Write(b)
-       e.WriteString(":")
-       e.Write(s)
+       e.write(b)
+       e.write_string(":")
+       e.write(s)
 }
 
 func (e *encoder) reflect_value(v reflect.Value) {
@@ -69,27 +83,44 @@ func (e *encoder) reflect_value(v reflect.Value) {
                return
        }
 
+       m, ok := v.Interface().(Marshaler)
+       if !ok {
+               if v.Kind() != reflect.Ptr && v.CanAddr() {
+                       m, ok = v.Addr().Interface().(Marshaler)
+                       if ok {
+                               v = v.Addr()
+                       }
+               }
+       }
+       if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
+               data, err := m.MarshalBencode()
+               if err != nil {
+                       panic(&MarshalerError{v.Type(), err})
+               }
+               e.write(data)
+       }
+
        switch v.Kind() {
        case reflect.Bool:
                if v.Bool() {
-                       e.WriteString("i1e")
+                       e.write_string("i1e")
                } else {
-                       e.WriteString("i0e")
+                       e.write_string("i0e")
                }
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
-               e.WriteString("i")
-               e.Write(b)
-               e.WriteString("e")
+               e.write_string("i")
+               e.write(b)
+               e.write_string("e")
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
                b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
-               e.WriteString("i")
-               e.Write(b)
-               e.WriteString("e")
+               e.write_string("i")
+               e.write(b)
+               e.write_string("e")
        case reflect.String:
                e.reflect_string(v.String())
        case reflect.Struct:
-               e.WriteString("d")
+               e.write_string("d")
                for _, ef := range encode_fields(v.Type()) {
                        field_value := v.Field(ef.i)
                        if ef.omit_empty && is_empty_value(field_value) {
@@ -99,26 +130,26 @@ func (e *encoder) reflect_value(v reflect.Value) {
                        e.reflect_string(ef.tag)
                        e.reflect_value(field_value)
                }
-               e.WriteString("e")
+               e.write_string("e")
        case reflect.Map:
                if v.Type().Key().Kind() != reflect.String {
                        panic(&MarshalTypeError{v.Type()})
                }
                if v.IsNil() {
-                       e.WriteString("de")
+                       e.write_string("de")
                        break
                }
-               e.WriteString("d")
+               e.write_string("d")
                sv := string_values(v.MapKeys())
                sort.Sort(sv)
                for _, key := range sv {
                        e.reflect_string(key.String())
                        e.reflect_value(v.MapIndex(key))
                }
-               e.WriteString("e")
+               e.write_string("e")
        case reflect.Slice:
                if v.IsNil() {
-                       e.WriteString("le")
+                       e.write_string("le")
                        break
                }
                if v.Type().Elem().Kind() == reflect.Uint8 {
@@ -128,11 +159,11 @@ func (e *encoder) reflect_value(v reflect.Value) {
                }
                fallthrough
        case reflect.Array:
-               e.WriteString("l")
+               e.write_string("l")
                for i, n := 0, v.Len(); i < n; i++ {
                        e.reflect_value(v.Index(i))
                }
-               e.WriteString("e")
+               e.write_string("e")
        case reflect.Interface, reflect.Ptr:
                if v.IsNil() {
                        break