]> Sergey Matveev's repositories - btrtrc.git/blob - handshake.go
Make peerID a public type
[btrtrc.git] / handshake.go
1 package torrent
2
3 import (
4         "bytes"
5         "encoding/hex"
6         "fmt"
7         "io"
8         "net"
9         "time"
10
11         "github.com/anacrolix/missinggo"
12
13         "github.com/anacrolix/torrent/metainfo"
14         "github.com/anacrolix/torrent/mse"
15         pp "github.com/anacrolix/torrent/peer_protocol"
16 )
17
18 type ExtensionBit uint
19
20 const (
21         ExtensionBitDHT      = 0  // http://www.bittorrent.org/beps/bep_0005.html
22         ExtensionBitExtended = 20 // http://www.bittorrent.org/beps/bep_0010.html
23         ExtensionBitFast     = 2  // http://www.bittorrent.org/beps/bep_0006.html
24 )
25
26 func handshakeWriter(w io.Writer, bb <-chan []byte, done chan<- error) {
27         var err error
28         for b := range bb {
29                 _, err = w.Write(b)
30                 if err != nil {
31                         break
32                 }
33         }
34         done <- err
35 }
36
37 type (
38         peerExtensionBytes [8]byte
39 )
40
41 func (pex peerExtensionBytes) SupportsExtended() bool {
42         return pex.GetBit(ExtensionBitExtended)
43 }
44
45 func (pex peerExtensionBytes) SupportsDHT() bool {
46         return pex.GetBit(ExtensionBitDHT)
47 }
48
49 func (pex peerExtensionBytes) SupportsFast() bool {
50         return pex.GetBit(ExtensionBitFast)
51 }
52
53 func (pex *peerExtensionBytes) SetBit(bit ExtensionBit) {
54         pex[7-bit/8] |= 1 << bit % 8
55 }
56
57 func (pex peerExtensionBytes) GetBit(bit ExtensionBit) bool {
58         return pex[7-bit/8]&(1<<(bit%8)) != 0
59 }
60
61 type handshakeResult struct {
62         peerExtensionBytes
63         PeerID
64         metainfo.Hash
65 }
66
67 // ih is nil if we expect the peer to declare the InfoHash, such as when the
68 // peer initiated the connection. Returns ok if the handshake was successful,
69 // and err if there was an unexpected condition other than the peer simply
70 // abandoning the handshake.
71 func handshake(sock io.ReadWriter, ih *metainfo.Hash, peerID [20]byte, extensions peerExtensionBytes) (res handshakeResult, ok bool, err error) {
72         // Bytes to be sent to the peer. Should never block the sender.
73         postCh := make(chan []byte, 4)
74         // A single error value sent when the writer completes.
75         writeDone := make(chan error, 1)
76         // Performs writes to the socket and ensures posts don't block.
77         go handshakeWriter(sock, postCh, writeDone)
78
79         defer func() {
80                 close(postCh) // Done writing.
81                 if !ok {
82                         return
83                 }
84                 if err != nil {
85                         panic(err)
86                 }
87                 // Wait until writes complete before returning from handshake.
88                 err = <-writeDone
89                 if err != nil {
90                         err = fmt.Errorf("error writing: %s", err)
91                 }
92         }()
93
94         post := func(bb []byte) {
95                 select {
96                 case postCh <- bb:
97                 default:
98                         panic("mustn't block while posting")
99                 }
100         }
101
102         post([]byte(pp.Protocol))
103         post(extensions[:])
104         if ih != nil { // We already know what we want.
105                 post(ih[:])
106                 post(peerID[:])
107         }
108         var b [68]byte
109         _, err = io.ReadFull(sock, b[:68])
110         if err != nil {
111                 err = nil
112                 return
113         }
114         if string(b[:20]) != pp.Protocol {
115                 return
116         }
117         missinggo.CopyExact(&res.peerExtensionBytes, b[20:28])
118         missinggo.CopyExact(&res.Hash, b[28:48])
119         missinggo.CopyExact(&res.PeerID, b[48:68])
120         peerExtensions.Add(hex.EncodeToString(res.peerExtensionBytes[:]), 1)
121
122         // TODO: Maybe we can just drop peers here if we're not interested. This
123         // could prevent them trying to reconnect, falsely believing there was
124         // just a problem.
125         if ih == nil { // We were waiting for the peer to tell us what they wanted.
126                 post(res.Hash[:])
127                 post(peerID[:])
128         }
129
130         ok = true
131         return
132 }
133
134 // Wraps a raw connection and provides the interface we want for using the
135 // connection in the message loop.
136 type deadlineReader struct {
137         nc net.Conn
138         r  io.Reader
139 }
140
141 func (r deadlineReader) Read(b []byte) (int, error) {
142         // Keep-alives should be received every 2 mins. Give a bit of gracetime.
143         err := r.nc.SetReadDeadline(time.Now().Add(150 * time.Second))
144         if err != nil {
145                 return 0, fmt.Errorf("error setting read deadline: %s", err)
146         }
147         return r.r.Read(b)
148 }
149
150 func handleEncryption(
151         rw io.ReadWriter,
152         skeys mse.SecretKeyIter,
153         policy EncryptionPolicy,
154 ) (
155         ret io.ReadWriter,
156         headerEncrypted bool,
157         cryptoMethod uint32,
158         err error,
159 ) {
160         if !policy.ForceEncryption {
161                 var protocol [len(pp.Protocol)]byte
162                 _, err = io.ReadFull(rw, protocol[:])
163                 if err != nil {
164                         return
165                 }
166                 rw = struct {
167                         io.Reader
168                         io.Writer
169                 }{
170                         io.MultiReader(bytes.NewReader(protocol[:]), rw),
171                         rw,
172                 }
173                 if string(protocol[:]) == pp.Protocol {
174                         ret = rw
175                         return
176                 }
177         }
178         headerEncrypted = true
179         ret, err = mse.ReceiveHandshake(rw, skeys, func(provides uint32) uint32 {
180                 cryptoMethod = func() uint32 {
181                         switch {
182                         case policy.ForceEncryption:
183                                 return mse.CryptoMethodRC4
184                         case policy.DisableEncryption:
185                                 return mse.CryptoMethodPlaintext
186                         case policy.PreferNoEncryption && provides&mse.CryptoMethodPlaintext != 0:
187                                 return mse.CryptoMethodPlaintext
188                         default:
189                                 return mse.DefaultCryptoSelector(provides)
190                         }
191                 }()
192                 return cryptoMethod
193         })
194         return
195 }