]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Move extended message handling into its own method
authorMatt Joiner <anacrolix@gmail.com>
Sat, 3 Feb 2018 03:06:10 +0000 (14:06 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 3 Feb 2018 03:06:10 +0000 (14:06 +1100)
connection.go

index 3ad80b5b244f11527cd57d656ee9f78159d95e96..9e025497643cdb28e735753eef7d39f47a33a85a 100644 (file)
@@ -944,110 +944,7 @@ func (c *connection) mainReadLoop() error {
                                t.chunkPool.Put(&msg.Piece)
                        }
                case pp.Extended:
-                       switch msg.ExtendedID {
-                       case pp.HandshakeExtendedID:
-                               // TODO: Create a bencode struct for this.
-                               var d map[string]interface{}
-                               err = bencode.Unmarshal(msg.ExtendedPayload, &d)
-                               if err != nil {
-                                       err = fmt.Errorf("error decoding extended message payload: %s", err)
-                                       break
-                               }
-                               // log.Printf("got handshake from %q: %#v", c.Socket.RemoteAddr().String(), d)
-                               if reqq, ok := d["reqq"]; ok {
-                                       if i, ok := reqq.(int64); ok {
-                                               c.PeerMaxRequests = int(i)
-                                       }
-                               }
-                               if v, ok := d["v"]; ok {
-                                       c.PeerClientName = v.(string)
-                               }
-                               if m, ok := d["m"]; ok {
-                                       mTyped, ok := m.(map[string]interface{})
-                                       if !ok {
-                                               err = errors.New("handshake m value is not dict")
-                                               break
-                                       }
-                                       if c.PeerExtensionIDs == nil {
-                                               c.PeerExtensionIDs = make(map[string]byte, len(mTyped))
-                                       }
-                                       for name, v := range mTyped {
-                                               id, ok := v.(int64)
-                                               if !ok {
-                                                       log.Printf("bad handshake m item extension ID type: %T", v)
-                                                       continue
-                                               }
-                                               if id == 0 {
-                                                       delete(c.PeerExtensionIDs, name)
-                                               } else {
-                                                       if c.PeerExtensionIDs[name] == 0 {
-                                                               supportedExtensionMessages.Add(name, 1)
-                                                       }
-                                                       c.PeerExtensionIDs[name] = byte(id)
-                                               }
-                                       }
-                               }
-                               metadata_sizeUntyped, ok := d["metadata_size"]
-                               if ok {
-                                       metadata_size, ok := metadata_sizeUntyped.(int64)
-                                       if !ok {
-                                               log.Printf("bad metadata_size type: %T", metadata_sizeUntyped)
-                                       } else {
-                                               err = t.setMetadataSize(metadata_size)
-                                               if err != nil {
-                                                       err = fmt.Errorf("error setting metadata size to %d", metadata_size)
-                                                       break
-                                               }
-                                       }
-                               }
-                               if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok {
-                                       c.requestPendingMetadata()
-                               }
-                       case metadataExtendedId:
-                               err = cl.gotMetadataExtensionMsg(msg.ExtendedPayload, t, c)
-                               if err != nil {
-                                       err = fmt.Errorf("error handling metadata extension message: %s", err)
-                               }
-                       case pexExtendedId:
-                               if cl.config.DisablePEX {
-                                       break
-                               }
-                               var pexMsg peerExchangeMessage
-                               err = bencode.Unmarshal(msg.ExtendedPayload, &pexMsg)
-                               if err != nil {
-                                       err = fmt.Errorf("error unmarshalling PEX message: %s", err)
-                                       break
-                               }
-                               go func() {
-                                       cl.mu.Lock()
-                                       t.addPeers(func() (ret []Peer) {
-                                               for i, cp := range pexMsg.Added {
-                                                       p := Peer{
-                                                               IP:     make([]byte, 4),
-                                                               Port:   cp.Port,
-                                                               Source: peerSourcePEX,
-                                                       }
-                                                       if i < len(pexMsg.AddedFlags) && pexMsg.AddedFlags[i]&0x01 != 0 {
-                                                               p.SupportsEncryption = true
-                                                       }
-                                                       missinggo.CopyExact(p.IP, cp.IP[:])
-                                                       ret = append(ret, p)
-                                               }
-                                               return
-                                       }())
-                                       cl.mu.Unlock()
-                               }()
-                       default:
-                               err = fmt.Errorf("unexpected extended message ID: %v", msg.ExtendedID)
-                       }
-                       if err != nil {
-                               // That client uses its own extension IDs for outgoing message
-                               // types, which is incorrect.
-                               if bytes.HasPrefix(c.PeerID[:], []byte("-SD0100-")) ||
-                                       strings.HasPrefix(string(c.PeerID[:]), "-XL0012-") {
-                                       return nil
-                               }
-                       }
+                       err = c.onReadExtendedMsg(msg.ExtendedID, msg.ExtendedPayload)
                case pp.Port:
                        if cl.dHT == nil {
                                break
@@ -1069,6 +966,116 @@ func (c *connection) mainReadLoop() error {
        }
 }
 
+func (c *connection) onReadExtendedMsg(id byte, payload []byte) (err error) {
+       defer func() {
+               // TODO: Should we still do this?
+               if err != nil {
+                       // These clients use their own extension IDs for outgoing message
+                       // types, which is incorrect.
+                       if bytes.HasPrefix(c.PeerID[:], []byte("-SD0100-")) || strings.HasPrefix(string(c.PeerID[:]), "-XL0012-") {
+                               err = nil
+                       }
+               }
+       }()
+       t := c.t
+       cl := t.cl
+       switch id {
+       case pp.HandshakeExtendedID:
+               // TODO: Create a bencode struct for this.
+               var d map[string]interface{}
+               err := bencode.Unmarshal(payload, &d)
+               if err != nil {
+                       return fmt.Errorf("error decoding extended message payload: %s", err)
+               }
+               // log.Printf("got handshake from %q: %#v", c.Socket.RemoteAddr().String(), d)
+               if reqq, ok := d["reqq"]; ok {
+                       if i, ok := reqq.(int64); ok {
+                               c.PeerMaxRequests = int(i)
+                       }
+               }
+               if v, ok := d["v"]; ok {
+                       c.PeerClientName = v.(string)
+               }
+               if m, ok := d["m"]; ok {
+                       mTyped, ok := m.(map[string]interface{})
+                       if !ok {
+                               return errors.New("handshake m value is not dict")
+                       }
+                       if c.PeerExtensionIDs == nil {
+                               c.PeerExtensionIDs = make(map[string]byte, len(mTyped))
+                       }
+                       for name, v := range mTyped {
+                               id, ok := v.(int64)
+                               if !ok {
+                                       log.Printf("bad handshake m item extension ID type: %T", v)
+                                       continue
+                               }
+                               if id == 0 {
+                                       delete(c.PeerExtensionIDs, name)
+                               } else {
+                                       if c.PeerExtensionIDs[name] == 0 {
+                                               supportedExtensionMessages.Add(name, 1)
+                                       }
+                                       c.PeerExtensionIDs[name] = byte(id)
+                               }
+                       }
+               }
+               metadata_sizeUntyped, ok := d["metadata_size"]
+               if ok {
+                       metadata_size, ok := metadata_sizeUntyped.(int64)
+                       if !ok {
+                               log.Printf("bad metadata_size type: %T", metadata_sizeUntyped)
+                       } else {
+                               err = t.setMetadataSize(metadata_size)
+                               if err != nil {
+                                       return fmt.Errorf("error setting metadata size to %d", metadata_size)
+                               }
+                       }
+               }
+               if _, ok := c.PeerExtensionIDs["ut_metadata"]; ok {
+                       c.requestPendingMetadata()
+               }
+               return nil
+       case metadataExtendedId:
+               err := cl.gotMetadataExtensionMsg(payload, t, c)
+               if err != nil {
+                       return fmt.Errorf("error handling metadata extension message: %s", err)
+               }
+               return nil
+       case pexExtendedId:
+               if cl.config.DisablePEX {
+                       return nil
+               }
+               var pexMsg peerExchangeMessage
+               err := bencode.Unmarshal(payload, &pexMsg)
+               if err != nil {
+                       return fmt.Errorf("error unmarshalling PEX message: %s", err)
+               }
+               go func() {
+                       cl.mu.Lock()
+                       t.addPeers(func() (ret []Peer) {
+                               for i, cp := range pexMsg.Added {
+                                       p := Peer{
+                                               IP:     make([]byte, 4),
+                                               Port:   cp.Port,
+                                               Source: peerSourcePEX,
+                                       }
+                                       if i < len(pexMsg.AddedFlags) && pexMsg.AddedFlags[i]&0x01 != 0 {
+                                               p.SupportsEncryption = true
+                                       }
+                                       missinggo.CopyExact(p.IP, cp.IP[:])
+                                       ret = append(ret, p)
+                               }
+                               return
+                       }())
+                       cl.mu.Unlock()
+               }()
+               return nil
+       default:
+               return fmt.Errorf("unexpected extended message ID: %v", id)
+       }
+}
+
 // Set both the Reader and Writer for the connection from a single ReadWriter.
 func (cn *connection) setRW(rw io.ReadWriter) {
        cn.r = rw