]> Sergey Matveev's repositories - btrtrc.git/blob - webseed/client.go
webseed: Close unused part responses after error
[btrtrc.git] / webseed / client.go
1 package webseed
2
3 import (
4         "bytes"
5         "context"
6         "fmt"
7         "io"
8         "net/http"
9         "strings"
10
11         "github.com/RoaringBitmap/roaring"
12         "github.com/anacrolix/torrent/common"
13         "github.com/anacrolix/torrent/metainfo"
14         "github.com/anacrolix/torrent/segments"
15 )
16
17 type RequestSpec = segments.Extent
18
19 type requestPartResult struct {
20         resp *http.Response
21         err  error
22 }
23
24 type requestPart struct {
25         req    *http.Request
26         e      segments.Extent
27         result chan requestPartResult
28 }
29
30 type Request struct {
31         cancel func()
32         Result chan RequestResult
33 }
34
35 func (r Request) Cancel() {
36         r.cancel()
37 }
38
39 type Client struct {
40         HttpClient *http.Client
41         Url        string
42         fileIndex  segments.Index
43         info       *metainfo.Info
44         // The pieces we can request with the Url. We're more likely to ban/block at the file-level
45         // given that's how requests are mapped to webseeds, but the torrent.Client works at the piece
46         // level. We can map our file-level adjustments to the pieces here.
47         Pieces roaring.Bitmap
48 }
49
50 func (me *Client) SetInfo(info *metainfo.Info) {
51         if !strings.HasSuffix(me.Url, "/") && info.IsDir() {
52                 // In my experience, this is a non-conforming webseed. For example the
53                 // http://ia600500.us.archive.org/1/items URLs in archive.org torrents.
54                 return
55         }
56         me.fileIndex = segments.NewIndex(common.LengthIterFromUpvertedFiles(info.UpvertedFiles()))
57         me.info = info
58         me.Pieces.AddRange(0, uint64(info.NumPieces()))
59 }
60
61 type RequestResult struct {
62         Bytes []byte
63         Err   error
64 }
65
66 func (ws *Client) NewRequest(r RequestSpec) Request {
67         ctx, cancel := context.WithCancel(context.Background())
68         var requestParts []requestPart
69         if !ws.fileIndex.Locate(r, func(i int, e segments.Extent) bool {
70                 req, err := NewRequest(ws.Url, i, ws.info, e.Start, e.Length)
71                 if err != nil {
72                         panic(err)
73                 }
74                 req = req.WithContext(ctx)
75                 part := requestPart{
76                         req:    req,
77                         result: make(chan requestPartResult, 1),
78                         e:      e,
79                 }
80                 go func() {
81                         resp, err := ws.HttpClient.Do(req)
82                         part.result <- requestPartResult{
83                                 resp: resp,
84                                 err:  err,
85                         }
86                 }()
87                 requestParts = append(requestParts, part)
88                 return true
89         }) {
90                 panic("request out of file bounds")
91         }
92         req := Request{
93                 cancel: cancel,
94                 Result: make(chan RequestResult, 1),
95         }
96         go func() {
97                 b, err := readRequestPartResponses(ctx, requestParts)
98                 req.Result <- RequestResult{
99                         Bytes: b,
100                         Err:   err,
101                 }
102         }()
103         return req
104 }
105
106 type ErrBadResponse struct {
107         Msg      string
108         Response *http.Response
109 }
110
111 func (me ErrBadResponse) Error() string {
112         return me.Msg
113 }
114
115 func recvPartResult(ctx context.Context, buf io.Writer, part requestPart) error {
116         result := <-part.result
117         if result.err != nil {
118                 return result.err
119         }
120         defer result.resp.Body.Close()
121         if ctx.Err() != nil {
122                 return ctx.Err()
123         }
124         switch result.resp.StatusCode {
125         case http.StatusPartialContent:
126         case http.StatusOK:
127                 if part.e.Start != 0 {
128                         return ErrBadResponse{"got status ok but request was at offset", result.resp}
129                 }
130         default:
131                 return ErrBadResponse{
132                         fmt.Sprintf("unhandled response status code (%v)", result.resp.StatusCode),
133                         result.resp,
134                 }
135         }
136         copied, err := io.Copy(buf, result.resp.Body)
137         if err != nil {
138                 return err
139         }
140         if copied != part.e.Length {
141                 return fmt.Errorf("got %v bytes, expected %v", copied, part.e.Length)
142         }
143         return nil
144 }
145
146 func readRequestPartResponses(ctx context.Context, parts []requestPart) ([]byte, error) {
147         ctx, cancel := context.WithCancel(ctx)
148         defer cancel()
149         var buf bytes.Buffer
150         firstErr := make(chan error, 1)
151         go func() {
152                 for _, part := range parts {
153                         err := recvPartResult(ctx, &buf, part)
154                         if err != nil {
155                                 // Ensure no further unnecessary response reads occur.
156                                 cancel()
157                                 select {
158                                 case firstErr <- fmt.Errorf("reading %q at %q: %w", part.req.URL, part.req.Header.Get("Range"), err):
159                                 default:
160                                 }
161                         }
162                 }
163                 select {
164                 case firstErr <- nil:
165                 default:
166                 }
167         }()
168         // This can't be merged into the return statement, because buf.Bytes is called first!
169         err := <-firstErr
170         return buf.Bytes(), err
171 }