// send queries, and respond to the ones from the network.
// Each node has a globally unique identifier known as the "node ID."
// Node IDs are chosen at random from the same 160-bit space
-// as BitTorrent infohashes [2] and define the behaviour of the node.
+// as BitTorrent infohashes and define the behaviour of the node.
// Zero valued Server does not have a valid ID and thus
// is unable to function properly. Use `NewServer(nil)`
// to initialize a default node.
}
return true
}
-
-// Transaction keeps track of a message exchange between nodes,
-// such as a query message and a response message
-type Transaction struct {
- mu sync.Mutex
- remoteAddr dHTAddr
- t string
- response chan Msg
- onResponse func(Msg) // Called with the server locked.
- done chan struct{}
- queryPacket []byte
- timer *time.Timer
- s *Server
- retries int
- lastSend time.Time
- userOnResponse func(Msg)
-}
-
-// Set a function to be called with the response.
-func (t *Transaction) SetResponseHandler(f func(Msg)) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.userOnResponse = f
- t.tryHandleResponse()
-}
-
-func (t *Transaction) tryHandleResponse() {
- if t.userOnResponse == nil {
- return
- }
- select {
- case r := <-t.response:
- t.userOnResponse(r)
- // Shouldn't be called more than once.
- t.userOnResponse = nil
- default:
- }
-}
-
-func (t *Transaction) key() transactionKey {
- return transactionKey{
- t.remoteAddr.String(),
- t.t,
- }
-}
-
func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
}
-func (t *Transaction) startTimer() {
- t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
-}
-
-func (t *Transaction) timerCallback() {
- t.mu.Lock()
- defer t.mu.Unlock()
- select {
- case <-t.done:
- return
- default:
- }
- if t.retries == 2 {
- t.timeout()
- return
- }
- t.retries++
- t.sendQuery()
- if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
- panic("timer should have fired to get here")
- }
-}
-
-func (t *Transaction) sendQuery() error {
- err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
- if err != nil {
- return err
- }
- t.lastSend = time.Now()
- return nil
-}
-
-func (t *Transaction) timeout() {
- go func() {
- t.s.mu.Lock()
- defer t.s.mu.Unlock()
- t.s.nodeTimedOut(t.remoteAddr)
- }()
- t.close()
-}
-
-func (t *Transaction) close() {
- if t.closing() {
- return
- }
- t.queryPacket = nil
- close(t.response)
- t.tryHandleResponse()
- close(t.done)
- t.timer.Stop()
- go func() {
- t.s.mu.Lock()
- defer t.s.mu.Unlock()
- t.s.deleteTransaction(t)
- }()
-}
-
-func (t *Transaction) closing() bool {
- select {
- case <-t.done:
- return true
- default:
- return false
- }
-}
-
-// Abandon the transaction.
-func (t *Transaction) Close() {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.close()
-}
-
-func (t *Transaction) handleResponse(m Msg) {
- t.mu.Lock()
- if t.closing() {
- t.mu.Unlock()
- return
- }
- close(t.done)
- t.mu.Unlock()
- if t.onResponse != nil {
- t.s.mu.Lock()
- t.onResponse(m)
- t.s.mu.Unlock()
- }
- t.queryPacket = nil
- select {
- case t.response <- m:
- default:
- panic("blocked handling response")
- }
- close(t.response)
- t.tryHandleResponse()
-}
-
func maskForIP(ip net.IP) []byte {
switch {
case ip.To4() != nil:
return
}
-// The size in bytes of a NodeInfo in its compact binary representation.
-const CompactIPv4NodeInfoLen = 26
-
-type NodeInfo struct {
- ID [20]byte
- Addr dHTAddr
-}
-
-// Writes the node info to its compact binary representation in b. See
-// CompactNodeInfoLen.
-func (ni *NodeInfo) PutCompact(b []byte) error {
- if n := copy(b[:], ni.ID[:]); n != 20 {
- panic(n)
- }
- ip := missinggo.AddrIP(ni.Addr).To4()
- if len(ip) != 4 {
- return errors.New("expected ipv4 address")
- }
- if n := copy(b[20:], ip); n != 4 {
- panic(n)
- }
- binary.BigEndian.PutUint16(b[24:], uint16(missinggo.AddrPort(ni.Addr)))
- return nil
-}
-
-func (cni *NodeInfo) UnmarshalCompactIPv4(b []byte) error {
- if len(b) != 26 {
- return errors.New("expected 26 bytes")
- }
- missinggo.CopyExact(cni.ID[:], b[:20])
- cni.Addr = newDHTAddr(&net.UDPAddr{
- IP: net.IPv4(b[20], b[21], b[22], b[23]),
- Port: int(binary.BigEndian.Uint16(b[24:26])),
- })
- return nil
-}
-
// Sends a ping query to the address given.
func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) {
s.mu.Lock()
--- /dev/null
+package dht
+
+import (
+ "encoding/binary"
+ "errors"
+ "net"
+
+ "github.com/anacrolix/missinggo"
+)
+
+// The size in bytes of a NodeInfo in its compact binary representation.
+const CompactIPv4NodeInfoLen = 26
+
+type NodeInfo struct {
+ ID [20]byte
+ Addr dHTAddr
+}
+
+// Writes the node info to its compact binary representation in b. See
+// CompactNodeInfoLen.
+func (ni *NodeInfo) PutCompact(b []byte) error {
+ if n := copy(b[:], ni.ID[:]); n != 20 {
+ panic(n)
+ }
+ ip := missinggo.AddrIP(ni.Addr).To4()
+ if len(ip) != 4 {
+ return errors.New("expected ipv4 address")
+ }
+ if n := copy(b[20:], ip); n != 4 {
+ panic(n)
+ }
+ binary.BigEndian.PutUint16(b[24:], uint16(missinggo.AddrPort(ni.Addr)))
+ return nil
+}
+
+func (cni *NodeInfo) UnmarshalCompactIPv4(b []byte) error {
+ if len(b) != CompactIPv4NodeInfoLen {
+ return errors.New("expected 26 bytes")
+ }
+ missinggo.CopyExact(cni.ID[:], b[:20])
+ cni.Addr = newDHTAddr(&net.UDPAddr{
+ IP: net.IPv4(b[20], b[21], b[22], b[23]),
+ Port: int(binary.BigEndian.Uint16(b[24:26])),
+ })
+ return nil
+}
--- /dev/null
+package dht
+
+import (
+ "sync"
+ "time"
+)
+
+// Transaction keeps track of a message exchange between nodes,
+// such as a query message and a response message
+type Transaction struct {
+ mu sync.Mutex
+ remoteAddr dHTAddr
+ t string
+ response chan Msg
+ onResponse func(Msg) // Called with the server locked.
+ done chan struct{}
+ queryPacket []byte
+ timer *time.Timer
+ s *Server
+ retries int
+ lastSend time.Time
+ userOnResponse func(Msg)
+}
+
+// SetResponseHandler sets up a function to be called when query response
+// arrives
+func (t *Transaction) SetResponseHandler(f func(Msg)) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.userOnResponse = f
+ t.tryHandleResponse()
+}
+
+func (t *Transaction) tryHandleResponse() {
+ if t.userOnResponse == nil {
+ return
+ }
+ select {
+ case r := <-t.response:
+ t.userOnResponse(r)
+ // Shouldn't be called more than once.
+ t.userOnResponse = nil
+ default:
+ }
+}
+
+func (t *Transaction) key() transactionKey {
+ return transactionKey{
+ t.remoteAddr.String(),
+ t.t,
+ }
+}
+
+func (t *Transaction) startTimer() {
+ t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
+}
+
+func (t *Transaction) timerCallback() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ select {
+ case <-t.done:
+ return
+ default:
+ }
+ if t.retries == 2 {
+ t.timeout()
+ return
+ }
+ t.retries++
+ t.sendQuery()
+ if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
+ panic("timer should have fired to get here")
+ }
+}
+
+func (t *Transaction) sendQuery() error {
+ err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
+ if err != nil {
+ return err
+ }
+ t.lastSend = time.Now()
+ return nil
+}
+
+func (t *Transaction) timeout() {
+ go func() {
+ t.s.mu.Lock()
+ defer t.s.mu.Unlock()
+ t.s.nodeTimedOut(t.remoteAddr)
+ }()
+ t.close()
+}
+
+func (t *Transaction) close() {
+ if t.closing() {
+ return
+ }
+ t.queryPacket = nil
+ close(t.response)
+ t.tryHandleResponse()
+ close(t.done)
+ t.timer.Stop()
+ go func() {
+ t.s.mu.Lock()
+ defer t.s.mu.Unlock()
+ t.s.deleteTransaction(t)
+ }()
+}
+
+func (t *Transaction) closing() bool {
+ select {
+ case <-t.done:
+ return true
+ default:
+ return false
+ }
+}
+
+// Close (abandon) the transaction.
+func (t *Transaction) Close() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.close()
+}
+
+func (t *Transaction) handleResponse(m Msg) {
+ t.mu.Lock()
+ if t.closing() {
+ t.mu.Unlock()
+ return
+ }
+ close(t.done)
+ t.mu.Unlock()
+ if t.onResponse != nil {
+ t.s.mu.Lock()
+ t.onResponse(m)
+ t.s.mu.Unlock()
+ }
+ t.queryPacket = nil
+ select {
+ case t.response <- m:
+ default:
+ panic("blocked handling response")
+ }
+ close(t.response)
+ t.tryHandleResponse()
+}