]> Sergey Matveev's repositories - btrtrc.git/blob - client.go
Fix request/chunk confusion, missing outgoing message prefix, protocol tests; improve...
[btrtrc.git] / client.go
1 package torrent
2
3 import (
4         "bitbucket.org/anacrolix/go.torrent/peer_protocol"
5         "bufio"
6         "container/list"
7         "crypto"
8         "crypto/rand"
9         "encoding"
10         "errors"
11         metainfo "github.com/nsf/libtorgo/torrent"
12         "io"
13         "launchpad.net/gommap"
14         "log"
15         "net"
16         "os"
17         "path/filepath"
18 )
19
20 const (
21         PieceHash   = crypto.SHA1
22         maxRequests = 10
23         chunkSize   = 0x4000 // 16KiB
24 )
25
26 type InfoHash [20]byte
27
28 type pieceSum [20]byte
29
30 func copyHashSum(dst, src []byte) {
31         if len(dst) != len(src) || copy(dst, src) != len(dst) {
32                 panic("hash sum sizes differ")
33         }
34 }
35
36 func BytesInfoHash(b []byte) (ih InfoHash) {
37         if len(b) != len(ih) || copy(ih[:], b) != len(ih) {
38                 panic("bad infohash bytes")
39         }
40         return
41 }
42
43 type pieceState uint8
44
45 const (
46         pieceStateUnknown = iota
47         pieceStateComplete
48         pieceStateIncomplete
49 )
50
51 type piece struct {
52         State             pieceState
53         Hash              pieceSum
54         PendingChunkSpecs map[chunkSpec]struct{}
55 }
56
57 type chunkSpec struct {
58         Begin, Length peer_protocol.Integer
59 }
60
61 type request struct {
62         Index peer_protocol.Integer
63         chunkSpec
64 }
65
66 type connection struct {
67         Socket net.Conn
68         post   chan encoding.BinaryMarshaler
69         write  chan []byte
70
71         Interested bool
72         Choked     bool
73         Requests   map[request]struct{}
74
75         PeerId         [20]byte
76         PeerInterested bool
77         PeerChoked     bool
78         PeerRequests   map[request]struct{}
79         PeerExtensions [8]byte
80         PeerPieces     []bool
81 }
82
83 func (c *connection) PeerHasPiece(index int) bool {
84         if c.PeerPieces == nil {
85                 return false
86         }
87         return c.PeerPieces[index]
88 }
89
90 func (c *connection) Post(msg encoding.BinaryMarshaler) {
91         c.post <- msg
92 }
93
94 func (c *connection) Request(chunk request) bool {
95         if len(c.Requests) >= maxRequests {
96                 return false
97         }
98         if _, ok := c.Requests[chunk]; !ok {
99                 c.Post(peer_protocol.Message{
100                         Type:   peer_protocol.Request,
101                         Index:  chunk.Index,
102                         Begin:  chunk.Begin,
103                         Length: chunk.Length,
104                 })
105         }
106         if c.Requests == nil {
107                 c.Requests = make(map[request]struct{}, maxRequests)
108         }
109         c.Requests[chunk] = struct{}{}
110         return true
111 }
112
113 func (c *connection) SetInterested(interested bool) {
114         if c.Interested == interested {
115                 return
116         }
117         c.Post(peer_protocol.Message{
118                 Type: func() peer_protocol.MessageType {
119                         if interested {
120                                 return peer_protocol.Interested
121                         } else {
122                                 return peer_protocol.NotInterested
123                         }
124                 }(),
125         })
126         c.Interested = interested
127 }
128
129 func (conn *connection) writer() {
130         for {
131                 b := <-conn.write
132                 n, err := conn.Socket.Write(b)
133                 if err != nil {
134                         log.Print(err)
135                         close(conn.write)
136                         break
137                 }
138                 if n != len(b) {
139                         panic("didn't write all bytes")
140                 }
141                 log.Printf("wrote %#v", string(b))
142         }
143 }
144
145 func (conn *connection) writeOptimizer() {
146         pending := list.New()
147         var nextWrite []byte
148         for {
149                 write := conn.write
150                 if pending.Len() == 0 {
151                         write = nil
152                 } else {
153                         var err error
154                         nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
155                         if err != nil {
156                                 panic(err)
157                         }
158                 }
159                 select {
160                 case msg := <-conn.post:
161                         pending.PushBack(msg)
162                 case write <- nextWrite:
163                         pending.Remove(pending.Front())
164                 }
165         }
166 }
167
168 type torrent struct {
169         InfoHash InfoHash
170         Pieces   []piece
171         Data     MMapSpan
172         MetaInfo *metainfo.MetaInfo
173         Conns    []*connection
174         Peers    []Peer
175 }
176
177 func (t *torrent) bitfield() (bf []bool) {
178         for _, p := range t.Pieces {
179                 bf = append(bf, p.State == pieceStateComplete)
180         }
181         return
182 }
183
184 func (t *torrent) pieceChunkSpecs(index int) (cs map[chunkSpec]struct{}) {
185         cs = make(map[chunkSpec]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
186         c := chunkSpec{
187                 Begin: 0,
188         }
189         for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length {
190                 c.Length = left
191                 if c.Length > chunkSize {
192                         c.Length = chunkSize
193                 }
194                 cs[c] = struct{}{}
195                 c.Begin += c.Length
196         }
197         return
198 }
199
200 func (t *torrent) requestHeat() (ret map[request]int) {
201         ret = make(map[request]int)
202         for _, conn := range t.Conns {
203                 for req, _ := range conn.Requests {
204                         ret[req]++
205                 }
206         }
207         return
208 }
209
210 type Peer struct {
211         Id   [20]byte
212         IP   net.IP
213         Port int
214 }
215
216 func (t *torrent) PieceSize(piece int) (size int64) {
217         if piece == len(t.Pieces)-1 {
218                 size = t.Data.Size() % t.MetaInfo.PieceLength
219         }
220         if size == 0 {
221                 size = t.MetaInfo.PieceLength
222         }
223         return
224 }
225
226 func (t *torrent) PieceReader(piece int) io.Reader {
227         return io.NewSectionReader(t.Data, int64(piece)*t.MetaInfo.PieceLength, t.MetaInfo.PieceLength)
228 }
229
230 func (t *torrent) HashPiece(piece int) (ps pieceSum) {
231         hash := PieceHash.New()
232         n, err := io.Copy(hash, t.PieceReader(piece))
233         if err != nil {
234                 panic(err)
235         }
236         if n != t.PieceSize(piece) {
237                 panic("hashed wrong number of bytes")
238         }
239         copyHashSum(ps[:], hash.Sum(nil))
240         return
241 }
242
243 // func (t *torrent) bitfield
244
245 type client struct {
246         DataDir       string
247         HalfOpenLimit int
248         PeerId        [20]byte
249
250         halfOpen int
251         torrents map[InfoHash]*torrent
252
253         noTorrents      chan struct{}
254         addTorrent      chan *torrent
255         torrentFinished chan InfoHash
256         actorTask       chan func()
257 }
258
259 func NewClient(dataDir string) *client {
260         c := &client{
261                 DataDir:       dataDir,
262                 HalfOpenLimit: 10,
263
264                 torrents: make(map[InfoHash]*torrent),
265
266                 noTorrents:      make(chan struct{}),
267                 addTorrent:      make(chan *torrent),
268                 torrentFinished: make(chan InfoHash),
269                 actorTask:       make(chan func()),
270         }
271         _, err := rand.Read(c.PeerId[:])
272         if err != nil {
273                 panic("error generating peer id")
274         }
275         go c.run()
276         return c
277 }
278
279 func mmapTorrentData(metaInfo *metainfo.MetaInfo, location string) (mms MMapSpan, err error) {
280         defer func() {
281                 if err != nil {
282                         mms.Close()
283                         mms = nil
284                 }
285         }()
286         for _, miFile := range metaInfo.Files {
287                 fileName := filepath.Join(append([]string{location, metaInfo.Name}, miFile.Path...)...)
288                 err = os.MkdirAll(filepath.Dir(fileName), 0666)
289                 if err != nil {
290                         return
291                 }
292                 var file *os.File
293                 file, err = os.OpenFile(fileName, os.O_CREATE|os.O_RDWR, 0666)
294                 if err != nil {
295                         return
296                 }
297                 func() {
298                         defer file.Close()
299                         var fi os.FileInfo
300                         fi, err = file.Stat()
301                         if err != nil {
302                                 return
303                         }
304                         if fi.Size() < miFile.Length {
305                                 err = file.Truncate(miFile.Length)
306                                 if err != nil {
307                                         return
308                                 }
309                         }
310                         var mMap gommap.MMap
311                         mMap, err = gommap.MapRegion(file.Fd(), 0, miFile.Length, gommap.PROT_READ|gommap.PROT_WRITE, gommap.MAP_SHARED)
312                         if err != nil {
313                                 return
314                         }
315                         if int64(len(mMap)) != miFile.Length {
316                                 panic("mmap has wrong length")
317                         }
318                         mms = append(mms, MMap{mMap})
319                 }()
320                 if err != nil {
321                         return
322                 }
323         }
324         return
325 }
326
327 func (me *client) torrent(ih InfoHash) *torrent {
328         for _, t := range me.torrents {
329                 if t.InfoHash == ih {
330                         return t
331                 }
332         }
333         return nil
334 }
335
336 func (me *client) initiateConn(peer Peer, torrent *torrent) {
337         if peer.Id == me.PeerId {
338                 return
339         }
340         me.halfOpen++
341         go func() {
342                 conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
343                         IP:   peer.IP,
344                         Port: peer.Port,
345                 })
346                 me.withContext(func() {
347                         me.halfOpen--
348                         me.openNewConns()
349                 })
350                 if err != nil {
351                         log.Printf("error connecting to peer: %s", err)
352                         return
353                 }
354                 log.Printf("connected to %s", conn.RemoteAddr())
355                 me.handshake(conn, torrent, peer.Id)
356         }()
357 }
358
359 func (me *torrent) haveAnyPieces() bool {
360         for _, piece := range me.Pieces {
361                 if piece.State == pieceStateComplete {
362                         return true
363                 }
364         }
365         return false
366 }
367
368 func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
369         conn := &connection{
370                 Socket:     sock,
371                 Choked:     true,
372                 PeerChoked: true,
373                 write:      make(chan []byte),
374                 post:       make(chan encoding.BinaryMarshaler),
375         }
376         go conn.writer()
377         go conn.writeOptimizer()
378         conn.post <- peer_protocol.Bytes(peer_protocol.Protocol)
379         conn.post <- peer_protocol.Bytes("\x00\x00\x00\x00\x00\x00\x00\x00")
380         if torrent != nil {
381                 conn.post <- peer_protocol.Bytes(torrent.InfoHash[:])
382                 conn.post <- peer_protocol.Bytes(me.PeerId[:])
383         }
384         var b [28]byte
385         _, err := io.ReadFull(conn.Socket, b[:])
386         if err != nil {
387                 log.Fatal(err)
388         }
389         if string(b[:20]) != peer_protocol.Protocol {
390                 log.Printf("wrong protocol: %#v", string(b[:20]))
391                 return
392         }
393         if 8 != copy(conn.PeerExtensions[:], b[20:]) {
394                 panic("wtf")
395         }
396         log.Printf("peer extensions: %#v", string(conn.PeerExtensions[:]))
397         var infoHash [20]byte
398         _, err = io.ReadFull(conn.Socket, infoHash[:])
399         if err != nil {
400                 return
401         }
402         _, err = io.ReadFull(conn.Socket, conn.PeerId[:])
403         if err != nil {
404                 return
405         }
406         if torrent == nil {
407                 torrent = me.torrent(infoHash)
408                 if torrent == nil {
409                         return
410                 }
411                 conn.post <- peer_protocol.Bytes(torrent.InfoHash[:])
412                 conn.post <- peer_protocol.Bytes(me.PeerId[:])
413         }
414         me.withContext(func() {
415                 me.addConnection(torrent, conn)
416                 if torrent.haveAnyPieces() {
417                         conn.Post(peer_protocol.Message{
418                                 Type:     peer_protocol.Bitfield,
419                                 Bitfield: torrent.bitfield(),
420                         })
421                 }
422                 go func() {
423                         defer me.withContext(func() {
424                                 me.dropConnection(torrent, conn)
425                         })
426                         err := me.runConnection(torrent, conn)
427                         if err != nil {
428                                 log.Print(err)
429                         }
430                 }()
431         })
432 }
433
434 func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) {
435         if conn.PeerPieces == nil {
436                 conn.PeerPieces = make([]bool, len(torrent.Pieces))
437         }
438         conn.PeerPieces[piece] = true
439         if torrent.wantPiece(piece) {
440                 conn.SetInterested(true)
441                 me.replenishConnRequests(torrent, conn)
442         }
443 }
444
445 func (t *torrent) wantPiece(index int) bool {
446         return t.Pieces[index].State == pieceStateIncomplete
447 }
448
449 func (me *client) peerUnchoked(torrent *torrent, conn *connection) {
450         me.replenishConnRequests(torrent, conn)
451 }
452
453 func (me *client) runConnection(torrent *torrent, conn *connection) error {
454         decoder := peer_protocol.Decoder{
455                 R:         bufio.NewReader(conn.Socket),
456                 MaxLength: 256 * 1024,
457         }
458         for {
459                 msg := new(peer_protocol.Message)
460                 err := decoder.Decode(msg)
461                 if err != nil {
462                         return err
463                 }
464                 if msg.Keepalive {
465                         continue
466                 }
467                 go me.withContext(func() {
468                         log.Print(msg)
469                         var err error
470                         switch msg.Type {
471                         case peer_protocol.Choke:
472                                 conn.PeerChoked = true
473                         case peer_protocol.Unchoke:
474                                 conn.PeerChoked = false
475                                 me.peerUnchoked(torrent, conn)
476                         case peer_protocol.Interested:
477                                 conn.PeerInterested = true
478                         case peer_protocol.NotInterested:
479                                 conn.PeerInterested = false
480                         case peer_protocol.Have:
481                                 me.peerGotPiece(torrent, conn, int(msg.Index))
482                         case peer_protocol.Request:
483                                 conn.PeerRequests[request{
484                                         Index:     msg.Index,
485                                         chunkSpec: chunkSpec{msg.Begin, msg.Length},
486                                 }] = struct{}{}
487                         case peer_protocol.Bitfield:
488                                 if len(msg.Bitfield) < len(torrent.Pieces) {
489                                         err = errors.New("received invalid bitfield")
490                                         break
491                                 }
492                                 if conn.PeerPieces != nil {
493                                         err = errors.New("received unexpected bitfield")
494                                         break
495                                 }
496                                 conn.PeerPieces = msg.Bitfield[:len(torrent.Pieces)]
497                                 for index, has := range conn.PeerPieces {
498                                         if has {
499                                                 me.peerGotPiece(torrent, conn, index)
500                                         }
501                                 }
502                         default:
503                                 log.Printf("received unknown message type: %#v", msg.Type)
504                         }
505                         if err != nil {
506                                 log.Print(err)
507                                 me.dropConnection(torrent, conn)
508                         }
509                 })
510         }
511 }
512
513 func (me *client) dropConnection(torrent *torrent, conn *connection) {
514         conn.Socket.Close()
515         for i0, c := range torrent.Conns {
516                 if c != conn {
517                         continue
518                 }
519                 i1 := len(torrent.Conns) - 1
520                 if i0 != i1 {
521                         torrent.Conns[i0] = torrent.Conns[i1]
522                 }
523                 torrent.Conns = torrent.Conns[:i1]
524                 return
525         }
526         panic("no such connection")
527 }
528
529 func (me *client) addConnection(t *torrent, c *connection) bool {
530         for _, c := range t.Conns {
531                 if c.PeerId == c.PeerId {
532                         return false
533                 }
534         }
535         t.Conns = append(t.Conns, c)
536         return true
537 }
538
539 func (me *client) openNewConns() {
540         for _, t := range me.torrents {
541                 for len(t.Peers) != 0 {
542                         if me.halfOpen >= me.HalfOpenLimit {
543                                 return
544                         }
545                         p := t.Peers[0]
546                         t.Peers = t.Peers[1:]
547                         me.initiateConn(p, t)
548                 }
549         }
550 }
551
552 func (me *client) AddPeers(infoHash InfoHash, peers []Peer) (err error) {
553         me.withContext(func() {
554                 t := me.torrent(infoHash)
555                 if t == nil {
556                         err = errors.New("no such torrent")
557                         return
558                 }
559                 t.Peers = append(t.Peers, peers...)
560                 me.openNewConns()
561         })
562         return
563 }
564
565 func (me *client) AddTorrent(metaInfo *metainfo.MetaInfo) error {
566         torrent := &torrent{
567                 InfoHash: BytesInfoHash(metaInfo.InfoHash),
568         }
569         for offset := 0; offset < len(metaInfo.Pieces); offset += PieceHash.Size() {
570                 hash := metaInfo.Pieces[offset : offset+PieceHash.Size()]
571                 if len(hash) != PieceHash.Size() {
572                         return errors.New("bad piece hash in metainfo")
573                 }
574                 piece := piece{}
575                 copyHashSum(piece.Hash[:], hash)
576                 torrent.Pieces = append(torrent.Pieces, piece)
577         }
578         var err error
579         torrent.Data, err = mmapTorrentData(metaInfo, me.DataDir)
580         if err != nil {
581                 return err
582         }
583         torrent.MetaInfo = metaInfo
584         me.addTorrent <- torrent
585         return nil
586 }
587
588 func (me *client) WaitAll() {
589         <-me.noTorrents
590 }
591
592 func (me *client) Close() {
593 }
594
595 func (me *client) withContext(f func()) {
596         me.actorTask <- f
597 }
598
599 func (me *client) replenishConnRequests(torrent *torrent, conn *connection) {
600         if len(conn.Requests) >= maxRequests {
601                 return
602         }
603         if conn.PeerChoked {
604                 return
605         }
606         requestHeatMap := torrent.requestHeat()
607         for index, has := range conn.PeerPieces {
608                 if !has {
609                         continue
610                 }
611                 for chunkSpec, _ := range torrent.Pieces[index].PendingChunkSpecs {
612                         request := request{peer_protocol.Integer(index), chunkSpec}
613                         if heat := requestHeatMap[request]; heat > 0 {
614                                 continue
615                         }
616                         conn.SetInterested(true)
617                         if !conn.Request(request) {
618                                 return
619                         }
620                 }
621         }
622         //conn.SetInterested(false)
623
624 }
625
626 func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
627         torrent := me.torrents[ih]
628         newState := func() pieceState {
629                 if correct {
630                         return pieceStateComplete
631                 } else {
632                         return pieceStateIncomplete
633                 }
634         }()
635         oldState := torrent.Pieces[piece].State
636         if newState == oldState {
637                 return
638         }
639         torrent.Pieces[piece].State = newState
640         if newState == pieceStateIncomplete {
641                 torrent.Pieces[piece].PendingChunkSpecs = torrent.pieceChunkSpecs(piece)
642         }
643         for _, conn := range torrent.Conns {
644                 if correct {
645                         conn.Post(peer_protocol.Message{
646                                 Type:  peer_protocol.Have,
647                                 Index: peer_protocol.Integer(piece),
648                         })
649                 } else {
650                         if conn.PeerHasPiece(piece) {
651                                 me.replenishConnRequests(torrent, conn)
652                         }
653                 }
654         }
655
656 }
657
658 func (me *client) run() {
659         for {
660                 noTorrents := me.noTorrents
661                 if len(me.torrents) != 0 {
662                         noTorrents = nil
663                 }
664                 select {
665                 case noTorrents <- struct{}{}:
666                 case torrent := <-me.addTorrent:
667                         if _, ok := me.torrents[torrent.InfoHash]; ok {
668                                 break
669                         }
670                         me.torrents[torrent.InfoHash] = torrent
671                         go func() {
672                                 for _piece := range torrent.Pieces {
673                                         piece := _piece
674                                         sum := torrent.HashPiece(piece)
675                                         me.withContext(func() {
676                                                 me.pieceHashed(torrent.InfoHash, piece, sum == torrent.Pieces[piece].Hash)
677                                         })
678                                 }
679                         }()
680                 case infoHash := <-me.torrentFinished:
681                         delete(me.torrents, infoHash)
682                 case task := <-me.actorTask:
683                         task()
684                 }
685         }
686 }