]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/handshake.go
Misc debug status, pex conn tracking improvements
[btrtrc.git] / peer_protocol / handshake.go
1 package peer_protocol
2
3 import (
4         "encoding/hex"
5         "errors"
6         "fmt"
7         "io"
8         "math/bits"
9         "strconv"
10         "strings"
11         "unsafe"
12
13         "github.com/anacrolix/torrent/metainfo"
14 )
15
16 type ExtensionBit uint
17
18 const (
19         ExtensionBitDHT      = 0  // http://www.bittorrent.org/beps/bep_0005.html
20         ExtensionBitExtended = 20 // http://www.bittorrent.org/beps/bep_0010.html
21         ExtensionBitFast     = 2  // http://www.bittorrent.org/beps/bep_0006.html
22 )
23
24 func handshakeWriter(w io.Writer, bb <-chan []byte, done chan<- error) {
25         var err error
26         for b := range bb {
27                 _, err = w.Write(b)
28                 if err != nil {
29                         break
30                 }
31         }
32         done <- err
33 }
34
35 type (
36         PeerExtensionBits [8]byte
37 )
38
39 var bitTags = []struct {
40         bit ExtensionBit
41         tag string
42 }{
43         // Ordered by their base protocol type values (PORT, fast.., EXTENDED)
44         {ExtensionBitDHT, "dht"},
45         {ExtensionBitFast, "fast"},
46         {ExtensionBitExtended, "ext"},
47 }
48
49 func (pex PeerExtensionBits) String() string {
50         pexHex := hex.EncodeToString(pex[:])
51         tags := make([]string, 0, len(bitTags)+1)
52         for _, bitTag := range bitTags {
53                 if pex.GetBit(bitTag.bit) {
54                         tags = append(tags, bitTag.tag)
55                         pex.SetBit(bitTag.bit, false)
56                 }
57         }
58         unknownCount := bits.OnesCount64(*(*uint64)((unsafe.Pointer(unsafe.SliceData(pex[:])))))
59         if unknownCount != 0 {
60                 tags = append(tags, fmt.Sprintf("%v unknown", unknownCount))
61         }
62         return fmt.Sprintf("%v (%s)", pexHex, strings.Join(tags, ", "))
63
64 }
65
66 func NewPeerExtensionBytes(bits ...ExtensionBit) (ret PeerExtensionBits) {
67         for _, b := range bits {
68                 ret.SetBit(b, true)
69         }
70         return
71 }
72
73 func (pex PeerExtensionBits) SupportsExtended() bool {
74         return pex.GetBit(ExtensionBitExtended)
75 }
76
77 func (pex PeerExtensionBits) SupportsDHT() bool {
78         return pex.GetBit(ExtensionBitDHT)
79 }
80
81 func (pex PeerExtensionBits) SupportsFast() bool {
82         return pex.GetBit(ExtensionBitFast)
83 }
84
85 func (pex *PeerExtensionBits) SetBit(bit ExtensionBit, on bool) {
86         if on {
87                 pex[7-bit/8] |= 1 << (bit % 8)
88         } else {
89                 pex[7-bit/8] &^= 1 << (bit % 8)
90         }
91 }
92
93 func (pex PeerExtensionBits) GetBit(bit ExtensionBit) bool {
94         return pex[7-bit/8]&(1<<(bit%8)) != 0
95 }
96
97 type HandshakeResult struct {
98         PeerExtensionBits
99         PeerID [20]byte
100         metainfo.Hash
101 }
102
103 // ih is nil if we expect the peer to declare the InfoHash, such as when the peer initiated the
104 // connection. Returns ok if the Handshake was successful, and err if there was an unexpected
105 // condition other than the peer simply abandoning the Handshake.
106 func Handshake(
107         sock io.ReadWriter, ih *metainfo.Hash, peerID [20]byte, extensions PeerExtensionBits,
108 ) (
109         res HandshakeResult, err error,
110 ) {
111         // Bytes to be sent to the peer. Should never block the sender.
112         postCh := make(chan []byte, 4)
113         // A single error value sent when the writer completes.
114         writeDone := make(chan error, 1)
115         // Performs writes to the socket and ensures posts don't block.
116         go handshakeWriter(sock, postCh, writeDone)
117
118         defer func() {
119                 close(postCh) // Done writing.
120                 if err != nil {
121                         return
122                 }
123                 // Wait until writes complete before returning from handshake.
124                 err = <-writeDone
125                 if err != nil {
126                         err = fmt.Errorf("error writing: %w", err)
127                 }
128         }()
129
130         post := func(bb []byte) {
131                 select {
132                 case postCh <- bb:
133                 default:
134                         panic("mustn't block while posting")
135                 }
136         }
137
138         post([]byte(Protocol))
139         post(extensions[:])
140         if ih != nil { // We already know what we want.
141                 post(ih[:])
142                 post(peerID[:])
143         }
144         var b [68]byte
145         _, err = io.ReadFull(sock, b[:68])
146         if err != nil {
147                 return res, fmt.Errorf("while reading: %w", err)
148         }
149         if string(b[:20]) != Protocol {
150                 return res, errors.New("unexpected protocol string")
151         }
152
153         copyExact := func(dst, src []byte) {
154                 if dstLen, srcLen := uint64(len(dst)), uint64(len(src)); dstLen != srcLen {
155                         panic("dst len " + strconv.FormatUint(dstLen, 10) + " != src len " + strconv.FormatUint(srcLen, 10))
156                 }
157                 copy(dst, src)
158         }
159         copyExact(res.PeerExtensionBits[:], b[20:28])
160         copyExact(res.Hash[:], b[28:48])
161         copyExact(res.PeerID[:], b[48:68])
162         // peerExtensions.Add(res.PeerExtensionBits.String(), 1)
163
164         // TODO: Maybe we can just drop peers here if we're not interested. This
165         // could prevent them trying to reconnect, falsely believing there was
166         // just a problem.
167         if ih == nil { // We were waiting for the peer to tell us what they wanted.
168                 post(res.Hash[:])
169                 post(peerID[:])
170         }
171
172         return
173 }