From: Matt Joiner <anacrolix@gmail.com>
Date: Sat, 3 Feb 2018 03:06:10 +0000 (+1100)
Subject: Move extended message handling into its own method
X-Git-Tag: v1.0.0~214
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=9b1a769bef25fffc027550c854574709015e801a;p=btrtrc.git

Move extended message handling into its own method
---

diff --git a/connection.go b/connection.go
index 3ad80b5b..9e025497 100644
--- a/connection.go
+++ b/connection.go
@@ -944,110 +944,7 @@ func (c *connection) mainReadLoop() error {
 				t.chunkPool.Put(&msg.Piece)
 			}
 		case pp.Extended:
-			switch msg.ExtendedID {
-			case pp.HandshakeExtendedID:
-				// TODO: Create a bencode struct for this.
-				var d map[string]interface{}
-				err = bencode.Unmarshal(msg.ExtendedPayload, &d)
-				if err != nil {
-					err = fmt.Errorf("error decoding extended message payload: %s", err)
-					break
-				}
-				// log.Printf("got handshake from %q: %#v", c.Socket.RemoteAddr().String(), d)
-				if reqq, ok := d["reqq"]; ok {
-					if i, ok := reqq.(int64); ok {
-						c.PeerMaxRequests = int(i)
-					}
-				}
-				if v, ok := d["v"]; ok {
-					c.PeerClientName = v.(string)
-				}
-				if m, ok := d["m"]; ok {
-					mTyped, ok := m.(map[string]interface{})
-					if !ok {
-						err = errors.New("handshake m value is not dict")
-						break
-					}
-					if c.PeerExtensionIDs == nil {
-						c.PeerExtensionIDs = make(map[string]byte, len(mTyped))
-					}
-					for name, v := range mTyped {
-						id, ok := v.(int64)
-						if !ok {
-							log.Printf("bad handshake m item extension ID type: %T", v)
-							continue
-						}
-						if id == 0 {
-							delete(c.PeerExtensionIDs, name)
-						} else {
-							if c.PeerExtensionIDs[name] == 0 {
-								supportedExtensionMessages.Add(name, 1)
-							}
-							c.PeerExtensionIDs[name] = byte(id)
-						}
-					}
-				}
-				metadata_sizeUntyped, ok := d["metadata_size"]
-				if ok {
-					metadata_size, ok := metadata_sizeUntyped.(int64)
-					if !ok {
-						log.Printf("bad metadata_size type: %T", metadata_sizeUntyped)
-					} else {
-						err = t.setMetadataSize(metadata_size)
-						if err != nil {
-							err = fmt.Errorf("error setting metadata size to %d", metadata_size)
-							break
-						}
-					}
-				}
-				if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok {
-					c.requestPendingMetadata()
-				}
-			case metadataExtendedId:
-				err = cl.gotMetadataExtensionMsg(msg.ExtendedPayload, t, c)
-				if err != nil {
-					err = fmt.Errorf("error handling metadata extension message: %s", err)
-				}
-			case pexExtendedId:
-				if cl.config.DisablePEX {
-					break
-				}
-				var pexMsg peerExchangeMessage
-				err = bencode.Unmarshal(msg.ExtendedPayload, &pexMsg)
-				if err != nil {
-					err = fmt.Errorf("error unmarshalling PEX message: %s", err)
-					break
-				}
-				go func() {
-					cl.mu.Lock()
-					t.addPeers(func() (ret []Peer) {
-						for i, cp := range pexMsg.Added {
-							p := Peer{
-								IP:     make([]byte, 4),
-								Port:   cp.Port,
-								Source: peerSourcePEX,
-							}
-							if i < len(pexMsg.AddedFlags) && pexMsg.AddedFlags[i]&0x01 != 0 {
-								p.SupportsEncryption = true
-							}
-							missinggo.CopyExact(p.IP, cp.IP[:])
-							ret = append(ret, p)
-						}
-						return
-					}())
-					cl.mu.Unlock()
-				}()
-			default:
-				err = fmt.Errorf("unexpected extended message ID: %v", msg.ExtendedID)
-			}
-			if err != nil {
-				// That client uses its own extension IDs for outgoing message
-				// types, which is incorrect.
-				if bytes.HasPrefix(c.PeerID[:], []byte("-SD0100-")) ||
-					strings.HasPrefix(string(c.PeerID[:]), "-XL0012-") {
-					return nil
-				}
-			}
+			err = c.onReadExtendedMsg(msg.ExtendedID, msg.ExtendedPayload)
 		case pp.Port:
 			if cl.dHT == nil {
 				break
@@ -1069,6 +966,116 @@ func (c *connection) mainReadLoop() error {
 	}
 }
 
+func (c *connection) onReadExtendedMsg(id byte, payload []byte) (err error) {
+	defer func() {
+		// TODO: Should we still do this?
+		if err != nil {
+			// These clients use their own extension IDs for outgoing message
+			// types, which is incorrect.
+			if bytes.HasPrefix(c.PeerID[:], []byte("-SD0100-")) || strings.HasPrefix(string(c.PeerID[:]), "-XL0012-") {
+				err = nil
+			}
+		}
+	}()
+	t := c.t
+	cl := t.cl
+	switch id {
+	case pp.HandshakeExtendedID:
+		// TODO: Create a bencode struct for this.
+		var d map[string]interface{}
+		err := bencode.Unmarshal(payload, &d)
+		if err != nil {
+			return fmt.Errorf("error decoding extended message payload: %s", err)
+		}
+		// log.Printf("got handshake from %q: %#v", c.Socket.RemoteAddr().String(), d)
+		if reqq, ok := d["reqq"]; ok {
+			if i, ok := reqq.(int64); ok {
+				c.PeerMaxRequests = int(i)
+			}
+		}
+		if v, ok := d["v"]; ok {
+			c.PeerClientName = v.(string)
+		}
+		if m, ok := d["m"]; ok {
+			mTyped, ok := m.(map[string]interface{})
+			if !ok {
+				return errors.New("handshake m value is not dict")
+			}
+			if c.PeerExtensionIDs == nil {
+				c.PeerExtensionIDs = make(map[string]byte, len(mTyped))
+			}
+			for name, v := range mTyped {
+				id, ok := v.(int64)
+				if !ok {
+					log.Printf("bad handshake m item extension ID type: %T", v)
+					continue
+				}
+				if id == 0 {
+					delete(c.PeerExtensionIDs, name)
+				} else {
+					if c.PeerExtensionIDs[name] == 0 {
+						supportedExtensionMessages.Add(name, 1)
+					}
+					c.PeerExtensionIDs[name] = byte(id)
+				}
+			}
+		}
+		metadata_sizeUntyped, ok := d["metadata_size"]
+		if ok {
+			metadata_size, ok := metadata_sizeUntyped.(int64)
+			if !ok {
+				log.Printf("bad metadata_size type: %T", metadata_sizeUntyped)
+			} else {
+				err = t.setMetadataSize(metadata_size)
+				if err != nil {
+					return fmt.Errorf("error setting metadata size to %d", metadata_size)
+				}
+			}
+		}
+		if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok {
+			c.requestPendingMetadata()
+		}
+		return nil
+	case metadataExtendedId:
+		err := cl.gotMetadataExtensionMsg(payload, t, c)
+		if err != nil {
+			return fmt.Errorf("error handling metadata extension message: %s", err)
+		}
+		return nil
+	case pexExtendedId:
+		if cl.config.DisablePEX {
+			return nil
+		}
+		var pexMsg peerExchangeMessage
+		err := bencode.Unmarshal(payload, &pexMsg)
+		if err != nil {
+			return fmt.Errorf("error unmarshalling PEX message: %s", err)
+		}
+		go func() {
+			cl.mu.Lock()
+			t.addPeers(func() (ret []Peer) {
+				for i, cp := range pexMsg.Added {
+					p := Peer{
+						IP:     make([]byte, 4),
+						Port:   cp.Port,
+						Source: peerSourcePEX,
+					}
+					if i < len(pexMsg.AddedFlags) && pexMsg.AddedFlags[i]&0x01 != 0 {
+						p.SupportsEncryption = true
+					}
+					missinggo.CopyExact(p.IP, cp.IP[:])
+					ret = append(ret, p)
+				}
+				return
+			}())
+			cl.mu.Unlock()
+		}()
+		return nil
+	default:
+		return fmt.Errorf("unexpected extended message ID: %v", id)
+	}
+}
+
 // Set both the Reader and Writer for the connection from a single ReadWriter.
 func (cn *connection) setRW(rw io.ReadWriter) {
 	cn.r = rw