]> Sergey Matveev's repositories - btrtrc.git/commitdiff
metainfo: Add Magnet.Params for more open handling
authorMatt Joiner <anacrolix@gmail.com>
Tue, 24 Sep 2019 05:52:18 +0000 (15:52 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 24 Sep 2019 05:52:18 +0000 (15:52 +1000)
Addresses #310.

metainfo/magnet.go

index 8da585be4129ecf95cf9b010a4f89e2aa6b46225..c9221a074b720db39e1c227fd579b06b90bd5597 100644 (file)
@@ -3,6 +3,7 @@ package metainfo
 import (
        "encoding/base32"
        "encoding/hex"
+       "errors"
        "fmt"
        "net/url"
        "strings"
@@ -13,65 +14,98 @@ type Magnet struct {
        InfoHash    Hash
        Trackers    []string
        DisplayName string
+       Params      url.Values
 }
 
 const xtPrefix = "urn:btih:"
 
 func (m Magnet) String() string {
-       // net.URL likes to assume //, and encodes ':' on us, so we do most of
-       // this manually.
-       ret := "magnet:?xt="
-       ret += xtPrefix + hex.EncodeToString(m.InfoHash[:])
-       if m.DisplayName != "" {
-               ret += "&dn=" + url.QueryEscape(m.DisplayName)
+       // Deep-copy m.Params
+       vs := make(url.Values, len(m.Params)+len(m.Trackers)+2)
+       for k, v := range m.Params {
+               vs[k] = append([]string(nil), v...)
        }
+
+       vs.Add("xt", xtPrefix+m.InfoHash.HexString())
        for _, tr := range m.Trackers {
-               ret += "&tr=" + url.QueryEscape(tr)
+               vs.Add("tr", tr)
+       }
+       if m.DisplayName != "" {
+               vs.Add("dn", m.DisplayName)
        }
-       return ret
+
+       return (&url.URL{
+               Scheme:   "magnet",
+               RawQuery: vs.Encode(),
+       }).String()
 }
 
 // ParseMagnetURI parses Magnet-formatted URIs into a Magnet instance
 func ParseMagnetURI(uri string) (m Magnet, err error) {
        u, err := url.Parse(uri)
        if err != nil {
-               err = fmt.Errorf("error parsing uri: %s", err)
+               err = fmt.Errorf("error parsing uri: %w", err)
                return
        }
        if u.Scheme != "magnet" {
-               err = fmt.Errorf("unexpected scheme: %q", u.Scheme)
+               err = fmt.Errorf("unexpected scheme %q", u.Scheme)
                return
        }
-       xt := u.Query().Get("xt")
-       if !strings.HasPrefix(xt, xtPrefix) {
-               err = fmt.Errorf("bad xt parameter")
+       q := u.Query()
+       xt := q.Get("xt")
+       m.InfoHash, err = parseInfohash(q.Get("xt"))
+       if err != nil {
+               err = fmt.Errorf("error parsing infohash %q: %w", xt, err)
                return
        }
-       infoHash := xt[len(xtPrefix):]
-
-       // BTIH hash can be in HEX or BASE32 encoding
-       // will assign appropriate func judging from symbol length
-       var decode func(dst, src []byte) (int, error)
-       switch len(infoHash) {
-       case 40:
-               decode = hex.Decode
-       case 32:
-               decode = base32.StdEncoding.Decode
+       dropFirst(q, "xt")
+       m.DisplayName = q.Get("dn")
+       dropFirst(q, "dn")
+       m.Trackers = q["tr"]
+       delete(q, "tr")
+       if len(q) == 0 {
+               q = nil
        }
+       m.Params = q
+       return
+}
 
+func parseInfohash(xt string) (ih Hash, err error) {
+       if !strings.HasPrefix(xt, xtPrefix) {
+               err = errors.New("bad xt parameter prefix")
+               return
+       }
+       encoded := xt[len(xtPrefix):]
+       decode := func() func(dst, src []byte) (int, error) {
+               switch len(encoded) {
+               case 40:
+                       return hex.Decode
+               case 32:
+                       return base32.StdEncoding.Decode
+               }
+               return nil
+       }()
        if decode == nil {
-               err = fmt.Errorf("unhandled xt parameter encoding: encoded length %d", len(infoHash))
+               err = fmt.Errorf("unhandled xt parameter encoding (encoded length %d)", len(encoded))
                return
        }
-       n, err := decode(m.InfoHash[:], []byte(infoHash))
+       n, err := decode(ih[:], []byte(encoded))
        if err != nil {
-               err = fmt.Errorf("error decoding xt: %s", err)
+               err = fmt.Errorf("error decoding xt: %w", err)
                return
        }
        if n != 20 {
                panic(n)
        }
-       m.DisplayName = u.Query().Get("dn")
-       m.Trackers = u.Query()["tr"]
        return
 }
+
+func dropFirst(vs url.Values, key string) {
+       sl := vs[key]
+       switch len(sl) {
+       case 0, 1:
+               vs.Del(key)
+       default:
+               vs[key] = sl[1:]
+       }
+}