]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/handshake.go
acdc3da58faa2af1cb3e1bce906488751edaf148
[btrtrc.git] / peer_protocol / handshake.go
1 package peer_protocol
2
3 import (
4         "encoding/hex"
5         "errors"
6         "fmt"
7         "io"
8         "strconv"
9
10         "github.com/anacrolix/torrent/metainfo"
11 )
12
13 type ExtensionBit uint
14
15 const (
16         ExtensionBitDHT      = 0  // http://www.bittorrent.org/beps/bep_0005.html
17         ExtensionBitExtended = 20 // http://www.bittorrent.org/beps/bep_0010.html
18         ExtensionBitFast     = 2  // http://www.bittorrent.org/beps/bep_0006.html
19 )
20
21 func handshakeWriter(w io.Writer, bb <-chan []byte, done chan<- error) {
22         var err error
23         for b := range bb {
24                 _, err = w.Write(b)
25                 if err != nil {
26                         break
27                 }
28         }
29         done <- err
30 }
31
32 type (
33         PeerExtensionBits [8]byte
34 )
35
36 func (pex PeerExtensionBits) String() string {
37         return hex.EncodeToString(pex[:])
38 }
39
40 func NewPeerExtensionBytes(bits ...ExtensionBit) (ret PeerExtensionBits) {
41         for _, b := range bits {
42                 ret.SetBit(b, true)
43         }
44         return
45 }
46
47 func (pex PeerExtensionBits) SupportsExtended() bool {
48         return pex.GetBit(ExtensionBitExtended)
49 }
50
51 func (pex PeerExtensionBits) SupportsDHT() bool {
52         return pex.GetBit(ExtensionBitDHT)
53 }
54
55 func (pex PeerExtensionBits) SupportsFast() bool {
56         return pex.GetBit(ExtensionBitFast)
57 }
58
59 func (pex *PeerExtensionBits) SetBit(bit ExtensionBit, on bool) {
60         if on {
61                 pex[7-bit/8] |= 1 << (bit % 8)
62         } else {
63                 pex[7-bit/8] &^= 1 << (bit % 8)
64         }
65 }
66
67 func (pex PeerExtensionBits) GetBit(bit ExtensionBit) bool {
68         return pex[7-bit/8]&(1<<(bit%8)) != 0
69 }
70
71 type HandshakeResult struct {
72         PeerExtensionBits
73         PeerID [20]byte
74         metainfo.Hash
75 }
76
77 // ih is nil if we expect the peer to declare the InfoHash, such as when the peer initiated the
78 // connection. Returns ok if the Handshake was successful, and err if there was an unexpected
79 // condition other than the peer simply abandoning the Handshake.
80 func Handshake(
81         sock io.ReadWriter, ih *metainfo.Hash, peerID [20]byte, extensions PeerExtensionBits,
82 ) (
83         res HandshakeResult, err error,
84 ) {
85         // Bytes to be sent to the peer. Should never block the sender.
86         postCh := make(chan []byte, 4)
87         // A single error value sent when the writer completes.
88         writeDone := make(chan error, 1)
89         // Performs writes to the socket and ensures posts don't block.
90         go handshakeWriter(sock, postCh, writeDone)
91
92         defer func() {
93                 close(postCh) // Done writing.
94                 if err != nil {
95                         return
96                 }
97                 // Wait until writes complete before returning from handshake.
98                 err = <-writeDone
99                 if err != nil {
100                         err = fmt.Errorf("error writing: %w", err)
101                 }
102         }()
103
104         post := func(bb []byte) {
105                 select {
106                 case postCh <- bb:
107                 default:
108                         panic("mustn't block while posting")
109                 }
110         }
111
112         post([]byte(Protocol))
113         post(extensions[:])
114         if ih != nil { // We already know what we want.
115                 post(ih[:])
116                 post(peerID[:])
117         }
118         var b [68]byte
119         _, err = io.ReadFull(sock, b[:68])
120         if err != nil {
121                 return res, fmt.Errorf("while reading: %w", err)
122         }
123         if string(b[:20]) != Protocol {
124                 return res, errors.New("unexpected protocol string")
125         }
126
127         copyExact := func(dst, src []byte) {
128                 if dstLen, srcLen := uint64(len(dst)), uint64(len(src)); dstLen != srcLen {
129                         panic("dst len " + strconv.FormatUint(dstLen, 10) + " != src len " + strconv.FormatUint(srcLen, 10))
130                 }
131                 copy(dst, src)
132         }
133         copyExact(res.PeerExtensionBits[:], b[20:28])
134         copyExact(res.Hash[:], b[28:48])
135         copyExact(res.PeerID[:], b[48:68])
136         // peerExtensions.Add(res.PeerExtensionBits.String(), 1)
137
138         // TODO: Maybe we can just drop peers here if we're not interested. This
139         // could prevent them trying to reconnect, falsely believing there was
140         // just a problem.
141         if ih == nil { // We were waiting for the peer to tell us what they wanted.
142                 post(res.Hash[:])
143                 post(peerID[:])
144         }
145
146         return
147 }