From 8db0c29da909354e010cfcd67ac2ea16ffdc57df Mon Sep 17 00:00:00 2001 From: nsf Date: Sun, 24 Jun 2012 16:59:04 +0600 Subject: [PATCH] Add bufio.Writer error handling, move Flush to the internal code. --- bencode/api.go | 14 +++++++-- bencode/encode.go | 77 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 65 insertions(+), 26 deletions(-) diff --git a/bencode/api.go b/bencode/api.go index 5f21f363..28828d5e 100644 --- a/bencode/api.go +++ b/bencode/api.go @@ -72,6 +72,15 @@ func (e *SyntaxError) Error() string { "): " + e.what } +type MarshalerError struct { + Type reflect.Type + Err error +} + +func (e *MarshalerError) Error() string { + return "bencode: error calling MarshalBencode for type " + e.Type.String() + ": " + e.Err.Error() +} + //---------------------------------------------------------------------------- // Interfaces //---------------------------------------------------------------------------- @@ -97,8 +106,7 @@ func Marshal(v interface{}) ([]byte, error) { if err != nil { return nil, err } - err = e.Flush() - return buf.Bytes(), err + return buf.Bytes(), nil } func Unmarshal(data []byte, v interface{}) error { @@ -135,5 +143,5 @@ func (e *Encoder) Encode(v interface{}) error { if err != nil { return err } - return e.e.Flush() + return nil } diff --git a/bencode/encode.go b/bencode/encode.go index 82f5b4c2..eee55818 100644 --- a/bencode/encode.go +++ b/bencode/encode.go @@ -40,7 +40,7 @@ func (e *encoder) encode(v interface{}) (err error) { } }() e.reflect_value(reflect.ValueOf(v)) - return nil + return e.Flush() } type string_values []reflect.Value @@ -50,18 +50,32 @@ func (sv string_values) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } func (sv string_values) Less(i, j int) bool { return sv.get(i) < sv.get(j) } func (sv string_values) get(i int) string { return sv[i].String() } +func (e *encoder) write(s []byte) { + _, err := e.Write(s) + if err != nil { + panic(err) + } +} + +func (e *encoder) write_string(s string) { + _, err := e.WriteString(s) + if err != nil { + panic(err) + } +} + func (e *encoder) reflect_string(s string) { b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10) - e.Write(b) - e.WriteString(":") - e.WriteString(s) + e.write(b) + e.write_string(":") + e.write_string(s) } func (e *encoder) reflect_byte_slice(s []byte) { b := strconv.AppendInt(e.scratch[:0], int64(len(s)), 10) - e.Write(b) - e.WriteString(":") - e.Write(s) + e.write(b) + e.write_string(":") + e.write(s) } func (e *encoder) reflect_value(v reflect.Value) { @@ -69,27 +83,44 @@ func (e *encoder) reflect_value(v reflect.Value) { return } + m, ok := v.Interface().(Marshaler) + if !ok { + if v.Kind() != reflect.Ptr && v.CanAddr() { + m, ok = v.Addr().Interface().(Marshaler) + if ok { + v = v.Addr() + } + } + } + if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { + data, err := m.MarshalBencode() + if err != nil { + panic(&MarshalerError{v.Type(), err}) + } + e.write(data) + } + switch v.Kind() { case reflect.Bool: if v.Bool() { - e.WriteString("i1e") + e.write_string("i1e") } else { - e.WriteString("i0e") + e.write_string("i0e") } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: b := strconv.AppendInt(e.scratch[:0], v.Int(), 10) - e.WriteString("i") - e.Write(b) - e.WriteString("e") + e.write_string("i") + e.write(b) + e.write_string("e") case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10) - e.WriteString("i") - e.Write(b) - e.WriteString("e") + e.write_string("i") + e.write(b) + e.write_string("e") case reflect.String: e.reflect_string(v.String()) case reflect.Struct: - e.WriteString("d") + e.write_string("d") for _, ef := range encode_fields(v.Type()) { field_value := v.Field(ef.i) if ef.omit_empty && is_empty_value(field_value) { @@ -99,26 +130,26 @@ func (e *encoder) reflect_value(v reflect.Value) { e.reflect_string(ef.tag) e.reflect_value(field_value) } - e.WriteString("e") + e.write_string("e") case reflect.Map: if v.Type().Key().Kind() != reflect.String { panic(&MarshalTypeError{v.Type()}) } if v.IsNil() { - e.WriteString("de") + e.write_string("de") break } - e.WriteString("d") + e.write_string("d") sv := string_values(v.MapKeys()) sort.Sort(sv) for _, key := range sv { e.reflect_string(key.String()) e.reflect_value(v.MapIndex(key)) } - e.WriteString("e") + e.write_string("e") case reflect.Slice: if v.IsNil() { - e.WriteString("le") + e.write_string("le") break } if v.Type().Elem().Kind() == reflect.Uint8 { @@ -128,11 +159,11 @@ func (e *encoder) reflect_value(v reflect.Value) { } fallthrough case reflect.Array: - e.WriteString("l") + e.write_string("l") for i, n := 0, v.Len(); i < n; i++ { e.reflect_value(v.Index(i)) } - e.WriteString("e") + e.write_string("e") case reflect.Interface, reflect.Ptr: if v.IsNil() { break -- 2.48.1