]> Sergey Matveev's repositories - btrtrc.git/blob - client.go
Squash some more debug messages, fix some other error handling
[btrtrc.git] / client.go
1 /*
2 Package torrent implements a torrent client.
3
4 Simple example:
5
6         c := &Client{}
7         c.Start()
8         defer c.Stop()
9         if err := c.AddTorrent(externalMetaInfoPackageSux); err != nil {
10                 return fmt.Errors("error adding torrent: %s", err)
11         }
12         c.WaitAll()
13         log.Print("erhmahgerd, torrent downloaded")
14
15 */
16 package torrent
17
18 import (
19         "bitbucket.org/anacrolix/go.torrent/dht"
20         "bitbucket.org/anacrolix/go.torrent/util"
21         "bufio"
22         "container/list"
23         "crypto/rand"
24         "crypto/sha1"
25         "errors"
26         "fmt"
27         "io"
28         "log"
29         mathRand "math/rand"
30         "net"
31         "os"
32         "sync"
33         "syscall"
34         "time"
35
36         "github.com/anacrolix/libtorgo/metainfo"
37         "github.com/nsf/libtorgo/bencode"
38
39         pp "bitbucket.org/anacrolix/go.torrent/peer_protocol"
40         "bitbucket.org/anacrolix/go.torrent/tracker"
41         _ "bitbucket.org/anacrolix/go.torrent/tracker/udp"
42 )
43
44 // Currently doesn't really queue, but should in the future.
45 func (cl *Client) queuePieceCheck(t *torrent, pieceIndex pp.Integer) {
46         piece := t.Pieces[pieceIndex]
47         if piece.QueuedForHash {
48                 return
49         }
50         piece.QueuedForHash = true
51         go cl.verifyPiece(t, pieceIndex)
52 }
53
54 // Queues the torrent data for the given region for download. The beginning of
55 // the region is given highest priority to allow a subsequent read at the same
56 // offset to return data ASAP.
57 func (me *Client) PrioritizeDataRegion(ih InfoHash, off, len_ int64) error {
58         me.mu.Lock()
59         defer me.mu.Unlock()
60         t := me.torrent(ih)
61         if t == nil {
62                 return errors.New("no such active torrent")
63         }
64         if !t.haveInfo() {
65                 return errors.New("missing metadata")
66         }
67         newPriorities := make([]request, 0, (len_+chunkSize-1)/chunkSize)
68         for len_ > 0 {
69                 req, ok := t.offsetRequest(off)
70                 if !ok {
71                         return errors.New("bad offset")
72                 }
73                 reqOff := t.requestOffset(req)
74                 // Gain the alignment adjustment.
75                 len_ += off - reqOff
76                 // Lose the length of this block.
77                 len_ -= int64(req.Length)
78                 off = reqOff + int64(req.Length)
79                 if !t.wantPiece(int(req.Index)) {
80                         continue
81                 }
82                 newPriorities = append(newPriorities, req)
83         }
84         if len(newPriorities) == 0 {
85                 return nil
86         }
87         t.Priorities.PushFront(newPriorities[0])
88         for _, req := range newPriorities[1:] {
89                 t.Priorities.PushBack(req)
90         }
91         for _, cn := range t.Conns {
92                 me.replenishConnRequests(t, cn)
93         }
94         return nil
95 }
96
97 type dataSpec struct {
98         InfoHash
99         request
100 }
101
102 type Client struct {
103         DataDir          string
104         HalfOpenLimit    int
105         PeerId           [20]byte
106         Listener         net.Listener
107         DisableTrackers  bool
108         DownloadStrategy DownloadStrategy
109         DHT              *dht.Server
110
111         mu    sync.Mutex
112         event sync.Cond
113         quit  chan struct{}
114
115         halfOpen   int
116         torrents   map[InfoHash]*torrent
117         dataWaiter chan struct{}
118 }
119
120 func (cl *Client) WriteStatus(w io.Writer) {
121         cl.mu.Lock()
122         defer cl.mu.Unlock()
123         fmt.Fprintf(w, "Half open: %d\n", cl.halfOpen)
124         fmt.Fprintf(w, "DHT nodes: %d\n", cl.DHT.NumNodes())
125         fmt.Fprintln(w)
126         for _, t := range cl.torrents {
127                 fmt.Fprintf(w, "%s: %f%%\n", t.Name(), func() float32 {
128                         if !t.haveInfo() {
129                                 return 0
130                         } else {
131                                 return 100 * (1 - float32(t.BytesLeft())/float32(t.Length()))
132                         }
133                 }())
134                 t.WriteStatus(w)
135                 fmt.Fprintln(w)
136         }
137 }
138
139 // Read torrent data at the given offset. Returns ErrDataNotReady if the data
140 // isn't available.
141 func (cl *Client) TorrentReadAt(ih InfoHash, off int64, p []byte) (n int, err error) {
142         cl.mu.Lock()
143         defer cl.mu.Unlock()
144         t := cl.torrent(ih)
145         if t == nil {
146                 err = errors.New("unknown torrent")
147                 return
148         }
149         index := pp.Integer(off / int64(t.UsualPieceSize()))
150         // Reading outside the bounds of a file is an error.
151         if index < 0 {
152                 err = os.ErrInvalid
153                 return
154         }
155         if int(index) >= len(t.Pieces) {
156                 err = io.EOF
157                 return
158         }
159         t.lastReadPiece = int(index)
160         piece := t.Pieces[index]
161         pieceOff := pp.Integer(off % int64(t.PieceLength(0)))
162         high := int(t.PieceLength(index) - pieceOff)
163         if high < len(p) {
164                 p = p[:high]
165         }
166         for cs, _ := range piece.PendingChunkSpecs {
167                 chunkOff := int64(pieceOff) - int64(cs.Begin)
168                 if chunkOff >= int64(t.PieceLength(index)) {
169                         panic(chunkOff)
170                 }
171                 if 0 <= chunkOff && chunkOff < int64(cs.Length) {
172                         // read begins in a pending chunk
173                         err = ErrDataNotReady
174                         return
175                 }
176                 // pending chunk caps available data
177                 if chunkOff < 0 && int64(len(p)) > -chunkOff {
178                         p = p[:-chunkOff]
179                 }
180         }
181         return t.Data.ReadAt(p, off)
182 }
183
184 // Starts the client. Defaults are applied. The client will begin accepting
185 // connections and tracking.
186 func (c *Client) Start() {
187         c.event.L = &c.mu
188         c.torrents = make(map[InfoHash]*torrent)
189         if c.HalfOpenLimit == 0 {
190                 c.HalfOpenLimit = 10
191         }
192         o := copy(c.PeerId[:], BEP20)
193         _, err := rand.Read(c.PeerId[o:])
194         if err != nil {
195                 panic("error generating peer id")
196         }
197         c.quit = make(chan struct{})
198         if c.DownloadStrategy == nil {
199                 c.DownloadStrategy = &DefaultDownloadStrategy{}
200         }
201         if c.Listener != nil {
202                 go c.acceptConnections()
203         }
204 }
205
206 func (cl *Client) stopped() bool {
207         select {
208         case <-cl.quit:
209                 return true
210         default:
211                 return false
212         }
213 }
214
215 // Stops the client. All connections to peers are closed and all activity will
216 // come to a halt.
217 func (me *Client) Stop() {
218         me.mu.Lock()
219         close(me.quit)
220         me.event.Broadcast()
221         for _, t := range me.torrents {
222                 for _, c := range t.Conns {
223                         c.Close()
224                 }
225         }
226         me.mu.Unlock()
227 }
228
229 func (cl *Client) acceptConnections() {
230         for {
231                 conn, err := cl.Listener.Accept()
232                 select {
233                 case <-cl.quit:
234                         if conn != nil {
235                                 conn.Close()
236                         }
237                         return
238                 default:
239                 }
240                 if err != nil {
241                         log.Print(err)
242                         return
243                 }
244                 // log.Printf("accepted connection from %s", conn.RemoteAddr())
245                 go func() {
246                         if err := cl.runConnection(conn, nil, peerSourceIncoming); err != nil {
247                                 log.Print(err)
248                         }
249                 }()
250         }
251 }
252
253 func (me *Client) torrent(ih InfoHash) *torrent {
254         for _, t := range me.torrents {
255                 if t.InfoHash == ih {
256                         return t
257                 }
258         }
259         return nil
260 }
261
262 func (me *Client) initiateConn(peer Peer, torrent *torrent) {
263         if peer.Id == me.PeerId {
264                 return
265         }
266         me.halfOpen++
267         go func() {
268                 addr := &net.TCPAddr{
269                         IP:   peer.IP,
270                         Port: peer.Port,
271                 }
272                 conn, err := net.DialTimeout(addr.Network(), addr.String(), dialTimeout)
273
274                 go func() {
275                         me.mu.Lock()
276                         defer me.mu.Unlock()
277                         if me.halfOpen == 0 {
278                                 panic("assert")
279                         }
280                         me.halfOpen--
281                         me.openNewConns()
282                 }()
283
284                 if netOpErr, ok := err.(*net.OpError); ok {
285                         if netOpErr.Timeout() {
286                                 return
287                         }
288                         switch netOpErr.Err {
289                         case syscall.ECONNREFUSED, syscall.EHOSTUNREACH:
290                                 return
291                         }
292                 }
293                 if err != nil {
294                         log.Printf("error connecting to peer: %s %#v", err, err)
295                         return
296                 }
297                 // log.Printf("connected to %s", conn.RemoteAddr())
298                 err = me.runConnection(conn, torrent, peer.Source)
299                 if err != nil {
300                         log.Print(err)
301                 }
302         }()
303 }
304
305 func (cl *Client) incomingPeerPort() int {
306         if cl.Listener == nil {
307                 return 0
308         }
309         _, p, err := net.SplitHostPort(cl.Listener.Addr().String())
310         if err != nil {
311                 panic(err)
312         }
313         var i int
314         _, err = fmt.Sscanf(p, "%d", &i)
315         if err != nil {
316                 panic(err)
317         }
318         return i
319 }
320
321 func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) {
322         conn := &connection{
323                 Discovery:       discovery,
324                 Socket:          sock,
325                 Choked:          true,
326                 PeerChoked:      true,
327                 write:           make(chan []byte),
328                 post:            make(chan pp.Message),
329                 PeerMaxRequests: 250, // Default in libtorrent is 250.
330         }
331         defer func() {
332                 // There's a lock and deferred unlock later in this function. The
333                 // client will not be locked when this deferred is invoked.
334                 me.mu.Lock()
335                 defer me.mu.Unlock()
336                 conn.Close()
337         }()
338         go conn.writer()
339         // go conn.writeOptimizer()
340         conn.write <- pp.Bytes(pp.Protocol)
341         conn.write <- pp.Bytes("\x00\x00\x00\x00\x00\x10\x00\x00")
342         if torrent != nil {
343                 conn.write <- pp.Bytes(torrent.InfoHash[:])
344                 conn.write <- pp.Bytes(me.PeerId[:])
345         }
346         var b [28]byte
347         _, err = io.ReadFull(conn.Socket, b[:])
348         if err == io.EOF {
349                 return nil
350         }
351         if err != nil {
352                 err = fmt.Errorf("when reading protocol and extensions: %s", err)
353                 return
354         }
355         if string(b[:20]) != pp.Protocol {
356                 // err = fmt.Errorf("wrong protocol: %#v", string(b[:20]))
357                 return
358         }
359         if 8 != copy(conn.PeerExtensions[:], b[20:]) {
360                 panic("wtf")
361         }
362         // log.Printf("peer extensions: %#v", string(conn.PeerExtensions[:]))
363         var infoHash [20]byte
364         _, err = io.ReadFull(conn.Socket, infoHash[:])
365         if err != nil {
366                 return fmt.Errorf("reading peer info hash: %s", err)
367         }
368         _, err = io.ReadFull(conn.Socket, conn.PeerId[:])
369         if err != nil {
370                 return fmt.Errorf("reading peer id: %s", err)
371         }
372         if torrent == nil {
373                 torrent = me.torrent(infoHash)
374                 if torrent == nil {
375                         return
376                 }
377                 conn.write <- pp.Bytes(torrent.InfoHash[:])
378                 conn.write <- pp.Bytes(me.PeerId[:])
379         }
380         me.mu.Lock()
381         defer me.mu.Unlock()
382         if !me.addConnection(torrent, conn) {
383                 return
384         }
385         go conn.writeOptimizer(time.Minute)
386         if conn.PeerExtensions[5]&0x10 != 0 {
387                 conn.Post(pp.Message{
388                         Type:       pp.Extended,
389                         ExtendedID: pp.HandshakeExtendedID,
390                         ExtendedPayload: func() []byte {
391                                 d := map[string]interface{}{
392                                         "m": map[string]int{
393                                                 "ut_metadata": 1,
394                                                 "ut_pex":      2,
395                                         },
396                                         "v": "go.torrent dev",
397                                 }
398                                 if torrent.metadataSizeKnown() {
399                                         d["metadata_size"] = torrent.metadataSize()
400                                 }
401                                 if p := me.incomingPeerPort(); p != 0 {
402                                         d["p"] = p
403                                 }
404                                 b, err := bencode.Marshal(d)
405                                 if err != nil {
406                                         panic(err)
407                                 }
408                                 return b
409                         }(),
410                 })
411         }
412         if torrent.haveAnyPieces() {
413                 conn.Post(pp.Message{
414                         Type:     pp.Bitfield,
415                         Bitfield: torrent.bitfield(),
416                 })
417         }
418         err = me.connectionLoop(torrent, conn)
419         if err != nil {
420                 err = fmt.Errorf("during Connection loop: %s", err)
421         }
422         me.dropConnection(torrent, conn)
423         return
424 }
425
426 func (me *Client) peerGotPiece(t *torrent, c *connection, piece int) {
427         for piece >= len(c.PeerPieces) {
428                 c.PeerPieces = append(c.PeerPieces, false)
429         }
430         c.PeerPieces[piece] = true
431         if t.wantPiece(piece) {
432                 me.replenishConnRequests(t, c)
433         }
434 }
435
436 func (me *Client) peerUnchoked(torrent *torrent, conn *connection) {
437         me.replenishConnRequests(torrent, conn)
438 }
439
440 func (cl *Client) connCancel(t *torrent, cn *connection, r request) (ok bool) {
441         ok = cn.Cancel(r)
442         if ok {
443                 cl.DownloadStrategy.DeleteRequest(t, r)
444         }
445         return
446 }
447
448 func (cl *Client) connDeleteRequest(t *torrent, cn *connection, r request) {
449         if !cn.RequestPending(r) {
450                 return
451         }
452         cl.DownloadStrategy.DeleteRequest(t, r)
453         delete(cn.Requests, r)
454 }
455
456 func (cl *Client) requestPendingMetadata(t *torrent, c *connection) {
457         if t.haveInfo() {
458                 return
459         }
460         var pending []int
461         for index := 0; index < t.MetadataPieceCount(); index++ {
462                 if !t.HaveMetadataPiece(index) {
463                         pending = append(pending, index)
464                 }
465         }
466         for _, i := range mathRand.Perm(len(pending)) {
467                 c.Post(pp.Message{
468                         Type:       pp.Extended,
469                         ExtendedID: byte(c.PeerExtensionIDs["ut_metadata"]),
470                         ExtendedPayload: func() []byte {
471                                 b, err := bencode.Marshal(map[string]int{
472                                         "msg_type": 0,
473                                         "piece":    pending[i],
474                                 })
475                                 if err != nil {
476                                         panic(err)
477                                 }
478                                 return b
479                         }(),
480                 })
481         }
482 }
483
484 func (cl *Client) completedMetadata(t *torrent) {
485         h := sha1.New()
486         h.Write(t.MetaData)
487         var ih InfoHash
488         copy(ih[:], h.Sum(nil)[:])
489         if ih != t.InfoHash {
490                 log.Print("bad metadata")
491                 t.InvalidateMetadata()
492                 return
493         }
494         var info metainfo.Info
495         err := bencode.Unmarshal(t.MetaData, &info)
496         if err != nil {
497                 log.Printf("error unmarshalling metadata: %s", err)
498                 t.InvalidateMetadata()
499                 return
500         }
501         // TODO(anacrolix): If this fails, I think something harsher should be
502         // done.
503         err = cl.setMetaData(t, info, t.MetaData)
504         if err != nil {
505                 log.Printf("error setting metadata: %s", err)
506                 t.InvalidateMetadata()
507                 return
508         }
509         log.Printf("%s: got metadata from peers", t)
510 }
511
512 func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connection) (err error) {
513         var d map[string]int
514         err = bencode.Unmarshal(payload, &d)
515         if err != nil {
516                 err = fmt.Errorf("error unmarshalling payload: %s: %q", err, payload)
517                 return
518         }
519         msgType, ok := d["msg_type"]
520         if !ok {
521                 err = errors.New("missing msg_type field")
522                 return
523         }
524         piece := d["piece"]
525         switch msgType {
526         case pp.DataMetadataExtensionMsgType:
527                 if t.haveInfo() {
528                         break
529                 }
530                 t.SaveMetadataPiece(piece, payload[len(payload)-metadataPieceSize(d["total_size"], piece):])
531                 if !t.HaveAllMetadataPieces() {
532                         break
533                 }
534                 cl.completedMetadata(t)
535         case pp.RequestMetadataExtensionMsgType:
536                 if !t.HaveMetadataPiece(piece) {
537                         c.Post(t.NewMetadataExtensionMessage(c, pp.RejectMetadataExtensionMsgType, d["piece"], nil))
538                         break
539                 }
540                 c.Post(t.NewMetadataExtensionMessage(c, pp.DataMetadataExtensionMsgType, piece, t.MetaData[(1<<14)*piece:(1<<14)*piece+t.metadataPieceSize(piece)]))
541         case pp.RejectMetadataExtensionMsgType:
542         default:
543                 err = errors.New("unknown msg_type value")
544         }
545         return
546 }
547
548 type peerExchangeMessage struct {
549         Added      util.CompactPeers `bencode:"added"`
550         AddedFlags []byte            `bencode:"added.f"`
551         Dropped    []tracker.Peer    `bencode:"dropped"`
552 }
553
554 func (me *Client) connectionLoop(t *torrent, c *connection) error {
555         decoder := pp.Decoder{
556                 R:         bufio.NewReader(c.Socket),
557                 MaxLength: 256 * 1024,
558         }
559         for {
560                 me.mu.Unlock()
561                 var msg pp.Message
562                 err := decoder.Decode(&msg)
563                 me.mu.Lock()
564                 if c.closed {
565                         return nil
566                 }
567                 if err != nil {
568                         if me.stopped() || err == io.EOF {
569                                 return nil
570                         }
571                         return err
572                 }
573                 if msg.Keepalive {
574                         continue
575                 }
576                 switch msg.Type {
577                 case pp.Choke:
578                         c.PeerChoked = true
579                         for r := range c.Requests {
580                                 me.connDeleteRequest(t, c, r)
581                         }
582                 case pp.Unchoke:
583                         c.PeerChoked = false
584                         me.peerUnchoked(t, c)
585                 case pp.Interested:
586                         c.PeerInterested = true
587                         // TODO: This should be done from a dedicated unchoking routine.
588                         c.Unchoke()
589                 case pp.NotInterested:
590                         c.PeerInterested = false
591                         c.Choke()
592                 case pp.Have:
593                         me.peerGotPiece(t, c, int(msg.Index))
594                 case pp.Request:
595                         if c.PeerRequests == nil {
596                                 c.PeerRequests = make(map[request]struct{}, maxRequests)
597                         }
598                         request := newRequest(msg.Index, msg.Begin, msg.Length)
599                         // TODO: Requests should be satisfied from a dedicated upload routine.
600                         // c.PeerRequests[request] = struct{}{}
601                         p := make([]byte, msg.Length)
602                         n, err := t.Data.ReadAt(p, int64(t.PieceLength(0))*int64(msg.Index)+int64(msg.Begin))
603                         if err != nil {
604                                 return fmt.Errorf("reading t data to serve request %q: %s", request, err)
605                         }
606                         if n != int(msg.Length) {
607                                 return fmt.Errorf("bad request: %v", msg)
608                         }
609                         c.Post(pp.Message{
610                                 Type:  pp.Piece,
611                                 Index: msg.Index,
612                                 Begin: msg.Begin,
613                                 Piece: p,
614                         })
615                 case pp.Cancel:
616                         req := newRequest(msg.Index, msg.Begin, msg.Length)
617                         if !c.PeerCancel(req) {
618                                 log.Printf("received unexpected cancel: %v", req)
619                         }
620                 case pp.Bitfield:
621                         if c.PeerPieces != nil {
622                                 err = errors.New("received unexpected bitfield")
623                                 break
624                         }
625                         if t.haveInfo() {
626                                 if len(msg.Bitfield) < t.NumPieces() {
627                                         err = errors.New("received invalid bitfield")
628                                         break
629                                 }
630                                 msg.Bitfield = msg.Bitfield[:t.NumPieces()]
631                         }
632                         c.PeerPieces = msg.Bitfield
633                         for index, has := range c.PeerPieces {
634                                 if has {
635                                         me.peerGotPiece(t, c, index)
636                                 }
637                         }
638                 case pp.Piece:
639                         err = me.downloadedChunk(t, c, &msg)
640                 case pp.Extended:
641                         switch msg.ExtendedID {
642                         case pp.HandshakeExtendedID:
643                                 // TODO: Create a bencode struct for this.
644                                 var d map[string]interface{}
645                                 err = bencode.Unmarshal(msg.ExtendedPayload, &d)
646                                 if err != nil {
647                                         err = fmt.Errorf("error decoding extended message payload: %s", err)
648                                         break
649                                 }
650                                 if reqq, ok := d["reqq"]; ok {
651                                         if i, ok := reqq.(int64); ok {
652                                                 c.PeerMaxRequests = int(i)
653                                         }
654                                 }
655                                 if v, ok := d["v"]; ok {
656                                         c.PeerClientName = v.(string)
657                                 }
658                                 m, ok := d["m"]
659                                 if !ok {
660                                         err = errors.New("handshake missing m item")
661                                         break
662                                 }
663                                 mTyped, ok := m.(map[string]interface{})
664                                 if !ok {
665                                         err = errors.New("handshake m value is not dict")
666                                         break
667                                 }
668                                 if c.PeerExtensionIDs == nil {
669                                         c.PeerExtensionIDs = make(map[string]int64, len(mTyped))
670                                 }
671                                 for name, v := range mTyped {
672                                         id, ok := v.(int64)
673                                         if !ok {
674                                                 log.Printf("bad handshake m item extension ID type: %T", v)
675                                                 continue
676                                         }
677                                         if id == 0 {
678                                                 delete(c.PeerExtensionIDs, name)
679                                         } else {
680                                                 c.PeerExtensionIDs[name] = id
681                                         }
682                                 }
683                                 metadata_sizeUntyped, ok := d["metadata_size"]
684                                 if ok {
685                                         metadata_size, ok := metadata_sizeUntyped.(int64)
686                                         if !ok {
687                                                 log.Printf("bad metadata_size type: %T", metadata_sizeUntyped)
688                                         } else {
689                                                 t.SetMetadataSize(metadata_size)
690                                         }
691                                 }
692                                 if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok {
693                                         me.requestPendingMetadata(t, c)
694                                 }
695                         case 1:
696                                 err = me.gotMetadataExtensionMsg(msg.ExtendedPayload, t, c)
697                                 if err != nil {
698                                         err = fmt.Errorf("error handling metadata extension message: %s", err)
699                                 }
700                         case 2:
701                                 var pexMsg peerExchangeMessage
702                                 err := bencode.Unmarshal(msg.ExtendedPayload, &pexMsg)
703                                 if err != nil {
704                                         err = fmt.Errorf("error unmarshalling PEX message: %s", err)
705                                         break
706                                 }
707                                 go func() {
708                                         err := me.AddPeers(t.InfoHash, func() (ret []Peer) {
709                                                 for _, cp := range pexMsg.Added {
710                                                         p := Peer{
711                                                                 IP:     make([]byte, 4),
712                                                                 Port:   int(cp.Port),
713                                                                 Source: peerSourcePEX,
714                                                         }
715                                                         if n := copy(p.IP, cp.IP[:]); n != 4 {
716                                                                 panic(n)
717                                                         }
718                                                         ret = append(ret, p)
719                                                 }
720                                                 return
721                                         }())
722                                         if err != nil {
723                                                 log.Printf("error adding PEX peers: %s", err)
724                                                 return
725                                         }
726                                         log.Printf("added %d peers from PEX", len(pexMsg.Added))
727                                 }()
728                         default:
729                                 err = fmt.Errorf("unexpected extended message ID: %v", msg.ExtendedID)
730                         }
731                 default:
732                         err = fmt.Errorf("received unknown message type: %#v", msg.Type)
733                 }
734                 if err != nil {
735                         return err
736                 }
737         }
738 }
739
740 func (me *Client) dropConnection(torrent *torrent, conn *connection) {
741         conn.Socket.Close()
742         for r := range conn.Requests {
743                 me.connDeleteRequest(torrent, conn, r)
744         }
745         for i0, c := range torrent.Conns {
746                 if c != conn {
747                         continue
748                 }
749                 i1 := len(torrent.Conns) - 1
750                 if i0 != i1 {
751                         torrent.Conns[i0] = torrent.Conns[i1]
752                 }
753                 torrent.Conns = torrent.Conns[:i1]
754                 return
755         }
756         panic("connection not found")
757 }
758
759 func (me *Client) addConnection(t *torrent, c *connection) bool {
760         if me.stopped() {
761                 return false
762         }
763         for _, c0 := range t.Conns {
764                 if c.PeerId == c0.PeerId {
765                         // Already connected to a client with that ID.
766                         return false
767                 }
768         }
769         t.Conns = append(t.Conns, c)
770         return true
771 }
772
773 func (me *Client) openNewConns() {
774         for _, t := range me.torrents {
775                 for len(t.Peers) != 0 {
776                         if me.halfOpen >= me.HalfOpenLimit {
777                                 return
778                         }
779                         p := t.Peers[0]
780                         t.Peers = t.Peers[1:]
781                         me.initiateConn(p, t)
782                 }
783         }
784 }
785
786 // Adds peers to the swarm for the torrent corresponding to infoHash.
787 func (me *Client) AddPeers(infoHash InfoHash, peers []Peer) error {
788         me.mu.Lock()
789         t := me.torrent(infoHash)
790         if t == nil {
791                 return errors.New("no such torrent")
792         }
793         t.Peers = append(t.Peers, peers...)
794         me.openNewConns()
795         me.mu.Unlock()
796         return nil
797 }
798
799 func (cl *Client) setMetaData(t *torrent, md metainfo.Info, bytes []byte) (err error) {
800         err = t.setMetadata(md, cl.DataDir, bytes)
801         if err != nil {
802                 return
803         }
804         // Queue all pieces for hashing. This is done sequentially to avoid
805         // spamming goroutines.
806         for _, p := range t.Pieces {
807                 p.QueuedForHash = true
808         }
809         go func() {
810                 for i := range t.Pieces {
811                         cl.verifyPiece(t, pp.Integer(i))
812                 }
813         }()
814
815         cl.DownloadStrategy.TorrentStarted(t)
816         return
817 }
818
819 // Prepare a Torrent without any attachment to a Client. That means we can
820 // initialize fields all fields that don't require the Client without locking
821 // it.
822 func newTorrent(ih InfoHash, announceList [][]string) (t *torrent, err error) {
823         t = &torrent{
824                 InfoHash: ih,
825         }
826         t.Trackers = make([][]tracker.Client, len(announceList))
827         for tierIndex := range announceList {
828                 tier := t.Trackers[tierIndex]
829                 for _, url := range announceList[tierIndex] {
830                         tr, err := tracker.New(url)
831                         if err != nil {
832                                 log.Print(err)
833                                 continue
834                         }
835                         tier = append(tier, tr)
836                 }
837                 // The trackers within each tier must be shuffled before use.
838                 // http://stackoverflow.com/a/12267471/149482
839                 // http://www.bittorrent.org/beps/bep_0012.html#order-of-processing
840                 for i := range tier {
841                         j := mathRand.Intn(i + 1)
842                         tier[i], tier[j] = tier[j], tier[i]
843                 }
844                 t.Trackers[tierIndex] = tier
845         }
846         return
847 }
848
849 func (cl *Client) AddMagnet(uri string) (err error) {
850         m, err := ParseMagnetURI(uri)
851         if err != nil {
852                 return
853         }
854         t, err := newTorrent(m.InfoHash, [][]string{m.Trackers})
855         if err != nil {
856                 return
857         }
858         t.DisplayName = m.DisplayName
859         cl.mu.Lock()
860         defer cl.mu.Unlock()
861         err = cl.addTorrent(t)
862         if err != nil {
863                 t.Close()
864         }
865         return
866 }
867
868 func (me *Client) addTorrent(t *torrent) (err error) {
869         if _, ok := me.torrents[t.InfoHash]; ok {
870                 err = fmt.Errorf("torrent infohash collision")
871                 return
872         }
873         me.torrents[t.InfoHash] = t
874         if !me.DisableTrackers {
875                 go me.announceTorrent(t)
876         }
877         if me.DHT != nil {
878                 go me.announceTorrentDHT(t)
879         }
880         return
881 }
882
883 // Adds the torrent to the client.
884 func (me *Client) AddTorrent(metaInfo *metainfo.MetaInfo) (err error) {
885         t, err := newTorrent(BytesInfoHash(metaInfo.Info.Hash), metaInfo.AnnounceList)
886         if err != nil {
887                 return
888         }
889         me.mu.Lock()
890         defer me.mu.Unlock()
891         err = me.addTorrent(t)
892         if err != nil {
893                 return
894         }
895         err = me.setMetaData(t, metaInfo.Info.Info, metaInfo.Info.Bytes)
896         if err != nil {
897                 return
898         }
899         return
900 }
901
902 func (cl *Client) listenerAnnouncePort() (port int16) {
903         l := cl.Listener
904         if l == nil {
905                 return
906         }
907         addr := l.Addr()
908         switch data := addr.(type) {
909         case *net.TCPAddr:
910                 return int16(data.Port)
911         case *net.UDPAddr:
912                 return int16(data.Port)
913         default:
914                 log.Printf("unknown listener addr type: %T", addr)
915         }
916         return
917 }
918
919 func (cl *Client) announceTorrentDHT(t *torrent) {
920         for {
921                 ps, err := cl.DHT.GetPeers(string(t.InfoHash[:]))
922                 if err != nil {
923                         log.Printf("error getting peers from dht: %s", err)
924                         return
925                 }
926                 nextScrape := time.After(1 * time.Minute)
927         getPeers:
928                 for {
929                         select {
930                         case <-nextScrape:
931                                 break getPeers
932                         case cps, ok := <-ps.Values:
933                                 if !ok {
934                                         break getPeers
935                                 }
936                                 err = cl.AddPeers(t.InfoHash, func() (ret []Peer) {
937                                         for _, cp := range cps {
938                                                 ret = append(ret, Peer{
939                                                         IP:     cp.IP[:],
940                                                         Port:   int(cp.Port),
941                                                         Source: peerSourceDHT,
942                                                 })
943                                                 // log.Printf("peer from dht: %s", &net.UDPAddr{
944                                                 //      IP:   cp.IP[:],
945                                                 //      Port: int(cp.Port),
946                                                 // })
947                                         }
948                                         return
949                                 }())
950                                 if err != nil {
951                                         log.Printf("error adding peers from dht for torrent %q: %s", t, err)
952                                         break getPeers
953                                 }
954                                 // log.Printf("got %d peers from dht for torrent %q", len(cps), t)
955                         }
956                 }
957                 ps.Close()
958         }
959 }
960
961 func (cl *Client) announceTorrent(t *torrent) {
962         req := tracker.AnnounceRequest{
963                 Event:    tracker.Started,
964                 NumWant:  -1,
965                 Port:     cl.listenerAnnouncePort(),
966                 PeerId:   cl.PeerId,
967                 InfoHash: t.InfoHash,
968         }
969 newAnnounce:
970         for {
971                 cl.mu.Lock()
972                 req.Left = t.BytesLeft()
973                 cl.mu.Unlock()
974                 for _, tier := range t.Trackers {
975                         for trIndex, tr := range tier {
976                                 if err := tr.Connect(); err != nil {
977                                         log.Print(err)
978                                         continue
979                                 }
980                                 resp, err := tr.Announce(&req)
981                                 if err != nil {
982                                         log.Print(err)
983                                         continue
984                                 }
985                                 var peers []Peer
986                                 for _, peer := range resp.Peers {
987                                         peers = append(peers, Peer{
988                                                 IP:   peer.IP,
989                                                 Port: peer.Port,
990                                         })
991                                 }
992                                 err = cl.AddPeers(t.InfoHash, peers)
993                                 if err != nil {
994                                         log.Print(err)
995                                 } else {
996                                         log.Printf("%s: %d new peers from %s", t, len(peers), tr)
997                                 }
998                                 tier[0], tier[trIndex] = tier[trIndex], tier[0]
999                                 time.Sleep(time.Second * time.Duration(resp.Interval))
1000                                 req.Event = tracker.None
1001                                 continue newAnnounce
1002                         }
1003                 }
1004                 time.Sleep(5 * time.Second)
1005         }
1006 }
1007
1008 func (cl *Client) allTorrentsCompleted() bool {
1009         for _, t := range cl.torrents {
1010                 if !t.haveAllPieces() {
1011                         return false
1012                 }
1013         }
1014         return true
1015 }
1016
1017 // Returns true when all torrents are completely downloaded and false if the
1018 // client is stopped before that.
1019 func (me *Client) WaitAll() bool {
1020         me.mu.Lock()
1021         defer me.mu.Unlock()
1022         for !me.allTorrentsCompleted() {
1023                 if me.stopped() {
1024                         return false
1025                 }
1026                 me.event.Wait()
1027         }
1028         return true
1029 }
1030
1031 func (cl *Client) assertRequestHeat() {
1032         dds, ok := cl.DownloadStrategy.(*DefaultDownloadStrategy)
1033         if !ok {
1034                 return
1035         }
1036         for _, t := range cl.torrents {
1037                 m := make(map[request]int, 3000)
1038                 for _, cn := range t.Conns {
1039                         for r := range cn.Requests {
1040                                 m[r]++
1041                         }
1042                 }
1043                 for r, h := range dds.heat[t] {
1044                         if m[r] != h {
1045                                 panic(fmt.Sprintln(m[r], h))
1046                         }
1047                 }
1048         }
1049 }
1050
1051 func (me *Client) replenishConnRequests(t *torrent, c *connection) {
1052         if !t.haveInfo() {
1053                 return
1054         }
1055         me.DownloadStrategy.FillRequests(t, c)
1056         //me.assertRequestHeat()
1057         if len(c.Requests) == 0 && !c.PeerChoked {
1058                 c.SetInterested(false)
1059         }
1060 }
1061
1062 func (me *Client) downloadedChunk(t *torrent, c *connection, msg *pp.Message) error {
1063         req := newRequest(msg.Index, msg.Begin, pp.Integer(len(msg.Piece)))
1064
1065         // Request has been satisfied.
1066         me.connDeleteRequest(t, c, req)
1067
1068         defer me.replenishConnRequests(t, c)
1069
1070         // Do we actually want this chunk?
1071         if _, ok := t.Pieces[req.Index].PendingChunkSpecs[req.chunkSpec]; !ok {
1072                 log.Printf("got unnecessary chunk from %v: %q", req, string(c.PeerId[:]))
1073                 return nil
1074         }
1075
1076         // Write the chunk out.
1077         err := t.WriteChunk(int(msg.Index), int64(msg.Begin), msg.Piece)
1078         if err != nil {
1079                 return err
1080         }
1081
1082         // Record that we have the chunk.
1083         delete(t.Pieces[req.Index].PendingChunkSpecs, req.chunkSpec)
1084         t.PiecesByBytesLeft.ValueChanged(t.Pieces[req.Index].bytesLeftElement)
1085         if len(t.Pieces[req.Index].PendingChunkSpecs) == 0 {
1086                 me.queuePieceCheck(t, req.Index)
1087         }
1088
1089         // Unprioritize the chunk.
1090         var next *list.Element
1091         for e := t.Priorities.Front(); e != nil; e = next {
1092                 next = e.Next()
1093                 if e.Value.(request) == req {
1094                         t.Priorities.Remove(e)
1095                 }
1096         }
1097
1098         // Cancel pending requests for this chunk.
1099         cancelled := false
1100         for _, c := range t.Conns {
1101                 if me.connCancel(t, c, req) {
1102                         cancelled = true
1103                         me.replenishConnRequests(t, c)
1104                 }
1105         }
1106         if cancelled {
1107                 log.Printf("cancelled concurrent requests for %v", req)
1108         }
1109
1110         me.dataReady(dataSpec{t.InfoHash, req})
1111         return nil
1112 }
1113
1114 func (cl *Client) dataReady(ds dataSpec) {
1115         if cl.dataWaiter != nil {
1116                 close(cl.dataWaiter)
1117         }
1118         cl.dataWaiter = nil
1119 }
1120
1121 // Returns a channel that is closed when new data has become available in the
1122 // client.
1123 func (me *Client) DataWaiter() <-chan struct{} {
1124         me.mu.Lock()
1125         defer me.mu.Unlock()
1126         if me.dataWaiter == nil {
1127                 me.dataWaiter = make(chan struct{})
1128         }
1129         return me.dataWaiter
1130 }
1131
1132 func (me *Client) pieceHashed(t *torrent, piece pp.Integer, correct bool) {
1133         p := t.Pieces[piece]
1134         p.EverHashed = true
1135         if correct {
1136                 p.PendingChunkSpecs = nil
1137                 // log.Printf("%s: got piece %d, (%d/%d)", t, piece, t.NumPiecesCompleted(), t.NumPieces())
1138                 var next *list.Element
1139                 for e := t.Priorities.Front(); e != nil; e = next {
1140                         next = e.Next()
1141                         if e.Value.(request).Index == piece {
1142                                 t.Priorities.Remove(e)
1143                         }
1144                 }
1145                 me.dataReady(dataSpec{
1146                         t.InfoHash,
1147                         request{
1148                                 pp.Integer(piece),
1149                                 chunkSpec{0, pp.Integer(t.PieceLength(piece))},
1150                         },
1151                 })
1152         } else {
1153                 if len(p.PendingChunkSpecs) == 0 {
1154                         t.pendAllChunkSpecs(piece)
1155                 }
1156         }
1157         for _, conn := range t.Conns {
1158                 if correct {
1159                         conn.Post(pp.Message{
1160                                 Type:  pp.Have,
1161                                 Index: pp.Integer(piece),
1162                         })
1163                         // TODO: Cancel requests for this piece.
1164                 } else {
1165                         if conn.PeerHasPiece(piece) {
1166                                 me.replenishConnRequests(t, conn)
1167                         }
1168                 }
1169         }
1170         me.event.Broadcast()
1171 }
1172
1173 func (cl *Client) verifyPiece(t *torrent, index pp.Integer) {
1174         cl.mu.Lock()
1175         p := t.Pieces[index]
1176         for p.Hashing {
1177                 cl.event.Wait()
1178         }
1179         p.Hashing = true
1180         p.QueuedForHash = false
1181         cl.mu.Unlock()
1182         sum := t.HashPiece(index)
1183         cl.mu.Lock()
1184         p.Hashing = false
1185         cl.pieceHashed(t, index, sum == p.Hash)
1186         cl.mu.Unlock()
1187 }
1188
1189 func (me *Client) Torrents() (ret []*torrent) {
1190         me.mu.Lock()
1191         for _, t := range me.torrents {
1192                 ret = append(ret, t)
1193         }
1194         me.mu.Unlock()
1195         return
1196 }