From 9b1a769bef25fffc027550c854574709015e801a Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sat, 3 Feb 2018 14:06:10 +1100 Subject: [PATCH] Move extended message handling into its own method --- connection.go | 215 ++++++++++++++++++++++++++------------------------ 1 file changed, 111 insertions(+), 104 deletions(-) diff --git a/connection.go b/connection.go index 3ad80b5b..9e025497 100644 --- a/connection.go +++ b/connection.go @@ -944,110 +944,7 @@ func (c *connection) mainReadLoop() error { t.chunkPool.Put(&msg.Piece) } case pp.Extended: - switch msg.ExtendedID { - case pp.HandshakeExtendedID: - // TODO: Create a bencode struct for this. - var d map[string]interface{} - err = bencode.Unmarshal(msg.ExtendedPayload, &d) - if err != nil { - err = fmt.Errorf("error decoding extended message payload: %s", err) - break - } - // log.Printf("got handshake from %q: %#v", c.Socket.RemoteAddr().String(), d) - if reqq, ok := d["reqq"]; ok { - if i, ok := reqq.(int64); ok { - c.PeerMaxRequests = int(i) - } - } - if v, ok := d["v"]; ok { - c.PeerClientName = v.(string) - } - if m, ok := d["m"]; ok { - mTyped, ok := m.(map[string]interface{}) - if !ok { - err = errors.New("handshake m value is not dict") - break - } - if c.PeerExtensionIDs == nil { - c.PeerExtensionIDs = make(map[string]byte, len(mTyped)) - } - for name, v := range mTyped { - id, ok := v.(int64) - if !ok { - log.Printf("bad handshake m item extension ID type: %T", v) - continue - } - if id == 0 { - delete(c.PeerExtensionIDs, name) - } else { - if c.PeerExtensionIDs[name] == 0 { - supportedExtensionMessages.Add(name, 1) - } - c.PeerExtensionIDs[name] = byte(id) - } - } - } - metadata_sizeUntyped, ok := d["metadata_size"] - if ok { - metadata_size, ok := metadata_sizeUntyped.(int64) - if !ok { - log.Printf("bad metadata_size type: %T", metadata_sizeUntyped) - } else { - err = t.setMetadataSize(metadata_size) - if err != nil { - err = fmt.Errorf("error setting metadata size to %d", metadata_size) - break - } - } - } - if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok { - c.requestPendingMetadata() - } - case metadataExtendedId: - err = cl.gotMetadataExtensionMsg(msg.ExtendedPayload, t, c) - if err != nil { - err = fmt.Errorf("error handling metadata extension message: %s", err) - } - case pexExtendedId: - if cl.config.DisablePEX { - break - } - var pexMsg peerExchangeMessage - err = bencode.Unmarshal(msg.ExtendedPayload, &pexMsg) - if err != nil { - err = fmt.Errorf("error unmarshalling PEX message: %s", err) - break - } - go func() { - cl.mu.Lock() - t.addPeers(func() (ret []Peer) { - for i, cp := range pexMsg.Added { - p := Peer{ - IP: make([]byte, 4), - Port: cp.Port, - Source: peerSourcePEX, - } - if i < len(pexMsg.AddedFlags) && pexMsg.AddedFlags[i]&0x01 != 0 { - p.SupportsEncryption = true - } - missinggo.CopyExact(p.IP, cp.IP[:]) - ret = append(ret, p) - } - return - }()) - cl.mu.Unlock() - }() - default: - err = fmt.Errorf("unexpected extended message ID: %v", msg.ExtendedID) - } - if err != nil { - // That client uses its own extension IDs for outgoing message - // types, which is incorrect. - if bytes.HasPrefix(c.PeerID[:], []byte("-SD0100-")) || - strings.HasPrefix(string(c.PeerID[:]), "-XL0012-") { - return nil - } - } + err = c.onReadExtendedMsg(msg.ExtendedID, msg.ExtendedPayload) case pp.Port: if cl.dHT == nil { break @@ -1069,6 +966,116 @@ func (c *connection) mainReadLoop() error { } } +func (c *connection) onReadExtendedMsg(id byte, payload []byte) (err error) { + defer func() { + // TODO: Should we still do this? + if err != nil { + // These clients use their own extension IDs for outgoing message + // types, which is incorrect. + if bytes.HasPrefix(c.PeerID[:], []byte("-SD0100-")) || strings.HasPrefix(string(c.PeerID[:]), "-XL0012-") { + err = nil + } + } + }() + t := c.t + cl := t.cl + switch id { + case pp.HandshakeExtendedID: + // TODO: Create a bencode struct for this. + var d map[string]interface{} + err := bencode.Unmarshal(payload, &d) + if err != nil { + return fmt.Errorf("error decoding extended message payload: %s", err) + } + // log.Printf("got handshake from %q: %#v", c.Socket.RemoteAddr().String(), d) + if reqq, ok := d["reqq"]; ok { + if i, ok := reqq.(int64); ok { + c.PeerMaxRequests = int(i) + } + } + if v, ok := d["v"]; ok { + c.PeerClientName = v.(string) + } + if m, ok := d["m"]; ok { + mTyped, ok := m.(map[string]interface{}) + if !ok { + return errors.New("handshake m value is not dict") + } + if c.PeerExtensionIDs == nil { + c.PeerExtensionIDs = make(map[string]byte, len(mTyped)) + } + for name, v := range mTyped { + id, ok := v.(int64) + if !ok { + log.Printf("bad handshake m item extension ID type: %T", v) + continue + } + if id == 0 { + delete(c.PeerExtensionIDs, name) + } else { + if c.PeerExtensionIDs[name] == 0 { + supportedExtensionMessages.Add(name, 1) + } + c.PeerExtensionIDs[name] = byte(id) + } + } + } + metadata_sizeUntyped, ok := d["metadata_size"] + if ok { + metadata_size, ok := metadata_sizeUntyped.(int64) + if !ok { + log.Printf("bad metadata_size type: %T", metadata_sizeUntyped) + } else { + err = t.setMetadataSize(metadata_size) + if err != nil { + return fmt.Errorf("error setting metadata size to %d", metadata_size) + } + } + } + if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok { + c.requestPendingMetadata() + } + return nil + case metadataExtendedId: + err := cl.gotMetadataExtensionMsg(payload, t, c) + if err != nil { + return fmt.Errorf("error handling metadata extension message: %s", err) + } + return nil + case pexExtendedId: + if cl.config.DisablePEX { + return nil + } + var pexMsg peerExchangeMessage + err := bencode.Unmarshal(payload, &pexMsg) + if err != nil { + return fmt.Errorf("error unmarshalling PEX message: %s", err) + } + go func() { + cl.mu.Lock() + t.addPeers(func() (ret []Peer) { + for i, cp := range pexMsg.Added { + p := Peer{ + IP: make([]byte, 4), + Port: cp.Port, + Source: peerSourcePEX, + } + if i < len(pexMsg.AddedFlags) && pexMsg.AddedFlags[i]&0x01 != 0 { + p.SupportsEncryption = true + } + missinggo.CopyExact(p.IP, cp.IP[:]) + ret = append(ret, p) + } + return + }()) + cl.mu.Unlock() + }() + return nil + default: + return fmt.Errorf("unexpected extended message ID: %v", id) + } +} + // Set both the Reader and Writer for the connection from a single ReadWriter. func (cn *connection) setRW(rw io.ReadWriter) { cn.r = rw -- 2.50.0