]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Fix bugs, implement missing bits. Made a mess, can't describe properly.
authornsf <no.smile.face@gmail.com>
Wed, 27 Jun 2012 20:21:26 +0000 (02:21 +0600)
committernsf <no.smile.face@gmail.com>
Wed, 27 Jun 2012 20:21:26 +0000 (02:21 +0600)
bencode/both_test.go
bencode/decode.go
bencode/decode_test.go
bencode/encode.go

index 837e611573254eb82f8128ebb845f3daff04be82..a000f35b78deb10448f6f859213be0a56fb4d47d 100644 (file)
@@ -3,7 +3,6 @@ package bencode
 import "testing"
 import "bytes"
 import "io/ioutil"
-import "time"
 
 func load_file(name string, t *testing.T) []byte {
        data, err := ioutil.ReadFile(name)
@@ -13,8 +12,8 @@ func load_file(name string, t *testing.T) []byte {
        return data
 }
 
-func TestBothInterface(t *testing.T) {
-       data1 := load_file("_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent", t)
+func test_file_interface(t *testing.T, filename string) {
+       data1 := load_file(filename, t)
        var iface interface{}
 
        err := Unmarshal(data1, &iface)
@@ -30,6 +29,12 @@ func TestBothInterface(t *testing.T) {
        if !bytes.Equal(data1, data2) {
                t.Fatalf("equality expected\n")
        }
+
+}
+
+func TestBothInterface(t *testing.T) {
+       test_file_interface(t, "_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent")
+       test_file_interface(t, "_testdata/continuum.torrent")
 }
 
 type torrent_file struct {
@@ -50,8 +55,8 @@ type torrent_file struct {
        URLList      interface{} `bencode:"url-list,omitempty"`
 }
 
-func TestBoth(t *testing.T) {
-       data1 := load_file("_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent", t)
+func test_file(t *testing.T, filename string) {
+       data1 := load_file(filename, t)
        var f torrent_file
 
        err := Unmarshal(data1, &f)
@@ -59,19 +64,17 @@ func TestBoth(t *testing.T) {
                t.Fatal(err)
        }
 
-       t.Logf("Name: %s\n", f.Info.Name)
-       t.Logf("Length: %v bytes\n", f.Info.Length)
-       t.Logf("Announce: %s\n", f.Announce)
-       t.Logf("CreationDate: %s\n", time.Unix(f.CreationDate, 0).String())
-       t.Logf("CreatedBy: %s\n", f.CreatedBy)
-       t.Logf("Comment: %s\n", f.Comment)
-
        data2, err := Marshal(&f)
        if err != nil {
                t.Fatal(err)
        }
 
        if !bytes.Equal(data1, data2) {
+               println(string(data2))
                t.Fatalf("equality expected")
        }
 }
+
+func TestBoth(t *testing.T) {
+       test_file(t, "_testdata/archlinux-2011.08.19-netinstall-i686.iso.torrent")
+}
index 903dfe4a477abc2bb6dbbbc8ea6b3967fb7a4275..72251e50910bf4a76bc6d6717e4bbf6b851a615b 100644 (file)
@@ -319,20 +319,105 @@ func (d *decoder) parse_list(v reflect.Value) {
        }
 }
 
+func (d *decoder) read_one_value() bool {
+       b, err := d.ReadByte()
+       if err != nil {
+               panic(err)
+       }
+       if b == 'e' {
+               d.UnreadByte()
+               return false
+       } else {
+               d.offset++
+               d.buf.WriteByte(b)
+       }
+
+       switch b {
+       case 'd', 'l':
+               // read until there is nothing to read
+               for d.read_one_value() {}
+               // consume 'e' as well
+               b = d.read_byte()
+               d.buf.WriteByte(b)
+       case 'i':
+               d.read_until('e')
+               d.buf.WriteString("e")
+       default:
+               if b >= '0' && b <= '9' {
+                       start := d.buf.Len() - 1
+                       d.read_until(':')
+                       length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
+                       check_for_int_parse_error(err, d.offset - 1)
+
+                       d.buf.WriteString(":")
+                       n, err := io.CopyN(&d.buf, d, length)
+                       d.offset += n
+                       if err != nil {
+                               check_for_unexpected_eof(err, d.offset)
+                               panic(&SyntaxError{
+                                       Offset: d.offset,
+                                       what:   "unexpected I/O error: " + err.Error(),
+                               })
+                       }
+                       break
+               }
+
+               // unknown value
+               panic(&SyntaxError{
+                       Offset: d.offset - 1,
+                       what:   "unknown value type (invalid bencode?)",
+               })
+       }
+
+       return true
+
+}
+
+func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
+       m, ok := v.Interface().(Unmarshaler)
+       if !ok {
+               // T doesn't work, try *T
+               if v.Kind() != reflect.Ptr && v.CanAddr() {
+                       m, ok = v.Addr().Interface().(Unmarshaler)
+                       if ok {
+                               v = v.Addr()
+                       }
+               }
+       }
+       if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
+               if d.read_one_value() {
+                       err := m.UnmarshalBencode(d.buf.Bytes())
+                       d.buf.Reset()
+                       if err != nil {
+                               panic(err)
+                       }
+                       return true
+               }
+               d.buf.Reset()
+       }
+
+       return false
+}
+
 // returns true if there was a value and it's now stored in 'v', otherwise there
 // was an end symbol ("e") and no value was stored
 func (d *decoder) parse_value(v reflect.Value) bool {
-       if pv := v; pv.Kind() == reflect.Ptr {
+       // we support one level of indirection at the moment
+       if v.Kind() == reflect.Ptr {
                // if the pointer is nil, allocate a new element of the type it
                // points to
-               if pv.IsNil() {
-                       pv.Set(reflect.New(pv.Type().Elem()))
+               if v.IsNil() {
+                       v.Set(reflect.New(v.Type().Elem()))
                }
-               v = pv.Elem()
+               v = v.Elem()
+       }
+
+       if d.parse_unmarshaler(v) {
+               return true
        }
 
-       // common case
-       if v.Kind() == reflect.Interface {
+       // common case: interface{}
+       if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
                iface, _ := d.parse_value_interface()
                v.Set(reflect.ValueOf(iface))
                return true
index 4f06f391a5f1070da83e73057071b55e087a39c8..b0714a746ed22b69e3b8f1e9352349cec4c0d24d 100644 (file)
@@ -34,3 +34,44 @@ func TestRandomDecode(t *testing.T) {
                }
        }
 }
+
+func check_error(t *testing.T, err error) {
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func assert_equal(t *testing.T, x, y interface{}) {
+       if !reflect.DeepEqual(x, y) {
+               t.Errorf("got: %v (%T), expected: %v (%T)\n", x, x, y, y)
+       }
+}
+
+type unmarshaler_int struct {
+       x int
+}
+
+func (this *unmarshaler_int) UnmarshalBencode(data []byte) error {
+       return Unmarshal(data, &this.x)
+}
+
+type unmarshaler_string struct {
+       x string
+}
+
+func (this *unmarshaler_string) UnmarshalBencode(data []byte) error {
+       this.x = string(data)
+       return nil
+}
+
+func TestUnmarshalerBencode(t *testing.T) {
+       var i unmarshaler_int
+       var ss []unmarshaler_string
+       check_error(t, Unmarshal([]byte("i71e"), &i))
+       assert_equal(t, i.x, 71)
+       check_error(t, Unmarshal([]byte("l5:hello5:fruit3:waye"), &ss))
+       assert_equal(t, ss[0].x, "5:hello")
+       assert_equal(t, ss[1].x, "5:fruit")
+       assert_equal(t, ss[2].x, "3:way")
+
+}
index 4a6eea171df537aecb4c87927d44d2dc5a87444a..196aa09539b944e8ba3502bc722bce4c2c58eae2 100644 (file)
@@ -78,13 +78,12 @@ func (e *encoder) reflect_byte_slice(s []byte) {
        e.write(s)
 }
 
-func (e *encoder) reflect_value(v reflect.Value) {
-       if !v.IsValid() {
-               return
-       }
-
+// returns true if the value implements Marshaler interface and marshaling was
+// done successfully
+func (e *encoder) reflect_marshaler(v reflect.Value) bool {
        m, ok := v.Interface().(Marshaler)
        if !ok {
+               // T doesn't work, try *T
                if v.Kind() != reflect.Ptr && v.CanAddr() {
                        m, ok = v.Addr().Interface().(Marshaler)
                        if ok {
@@ -98,6 +97,18 @@ func (e *encoder) reflect_value(v reflect.Value) {
                        panic(&MarshalerError{v.Type(), err})
                }
                e.write(data)
+               return true
+       }
+
+       return false
+}
+
+func (e *encoder) reflect_value(v reflect.Value) {
+       if !v.IsValid() {
+               return
+       }
+
+       if e.reflect_marshaler(v) {
                return
        }