src/crypto/tls/conn.go | 10 +++++++--- src/crypto/tls/handshake_client_tls13.go | 4 ++-- src/crypto/tls/handshake_server_tls13.go | 4 ++-- src/crypto/tls/handshake_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index fa5b87407f9da5ad1cfb111eef9c3eff49082b68..c6e6ec703026ffe4504d19af694ab3775922dae1 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -1367,7 +1367,7 @@ c.setWriteTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) } newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret); err != nil { + if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret, keyUpdate.updateRequested); err != nil { return err } @@ -1701,12 +1701,16 @@ // setReadTrafficSecret sets the read traffic secret for the given encryption level. If // being called at the same time as setWriteTrafficSecret, the caller must ensure the call // to setWriteTrafficSecret happens first so any alerts are sent at the write level. -func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) error { +func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte, locked bool) error { // Ensure that there are no buffered handshake messages before changing the // read keys, since that can cause messages to be parsed that were encrypted // using old keys which are no longer appropriate. if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) + if locked { + c.sendAlertLocked(alertUnexpectedMessage) + } else { + c.sendAlert(alertUnexpectedMessage) + } return errors.New("tls: handshake buffer not empty before setting read traffic secret") } c.in.setTrafficSecret(suite, level, secret) diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index cf793a476b6491c4a76302718ccfdb775fae008d..b0d489f7fe8429cd6e55928aa745f3d39a6d7680 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -517,7 +517,7 @@ clientSecret := handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) serverSecret := handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript) - if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret); err != nil { + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret, false); err != nil { return err } @@ -736,7 +736,7 @@ // Derive secrets that take context through the server Finished. hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript) serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript) - if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret); err != nil { + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret, false); err != nil { return err } diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index a4b205908ed9348859e444944e192c94c7aa7697..318833ecbeef4f5c49d87f36321270721fb81af2 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -785,7 +785,7 @@ serverSecret := hs.handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript) c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) clientSecret := hs.handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) - if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret); err != nil { + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret, false); err != nil { return err } @@ -1169,7 +1169,7 @@ c.sendAlert(alertDecryptError) return errors.New("tls: invalid client finished hash") } - if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret); err != nil { + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret, false); err != nil { return err } diff --git a/src/crypto/tls/handshake_test.go b/src/crypto/tls/handshake_test.go index c70b2aad443b2a67a2fb0e03c18b2122e9ac82ae..631712dbf92cb4ef69a278277710ffd9e0cc9e81 100644 --- a/src/crypto/tls/handshake_test.go +++ b/src/crypto/tls/handshake_test.go @@ -772,3 +772,51 @@ outBuf[4] = byte(m) outBuf = append(outBuf, marshalled...) return outBuf, nil } + +func TestMultipleKeyUpdate(t *testing.T) { + for _, requestUpdate := range []bool{true, false} { + t.Run(fmt.Sprintf("requestUpdate=%t", requestUpdate), func(t *testing.T) { + + c, s := localPipe(t) + cfg := testConfig.Clone() + cfg.MinVersion = VersionTLS13 + cfg.MaxVersion = VersionTLS13 + client := Client(c, testConfig) + server := Server(s, testConfig) + + clientHandshakeDone := make(chan struct{}) + go func() { + if err := client.Handshake(); err != nil { + } + close(clientHandshakeDone) + io.Copy(io.Discard, server) + }() + + if err := server.Handshake(); err != nil { + t.Fatalf("server handshake failed: %v\n", err) + } + <-clientHandshakeDone + + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + s.SetReadDeadline(time.Now().Add(1 * time.Second)) + + kuMsg, err := (&keyUpdateMsg{updateRequested: requestUpdate}).marshal() + if err != nil { + t.Fatalf("failed to marshal key update message: %v", err) + } + + client.out.Lock() + if _, err := client.writeRecordLocked(recordTypeHandshake, append(kuMsg, kuMsg...)); err != nil { + t.Fatalf("failed to write key update messages: %v", err) + } + client.out.Unlock() + + _, err = io.Copy(io.Discard, client) + if err == nil { + t.Fatal("expected multiple key update messages to cause an error, got nil") + } else if !strings.HasSuffix(err.Error(), "tls: unexpected message") { + t.Fatalf("unexpected error: %v", err) + } + }) + } +}