QUIC Key Update support

Handles key updates initiated remotely and also adds a QuicConnection method to initiate a key update, but this method is currently only called in tests.

Protected by FLAGS_quic_reloadable_flag_quic_key_update_supported.

PiperOrigin-RevId: 336385088
Change-Id: If74d032e1d34e5392312f4b619d28c9f93a95265
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index c57fc6e..6145a79 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -4821,6 +4821,176 @@
       0u);
 }
 
+TEST_P(EndToEndTest, KeyUpdateInitiatedByClient) {
+  SetQuicReloadableFlag(quic_key_update_supported, true);
+
+  if (!version_.UsesTls()) {
+    // Key Update is only supported in TLS handshake.
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+
+  ASSERT_TRUE(Initialize());
+
+  SendSynchronousFooRequestAndCheckResponse();
+  QuicConnection* client_connection = GetClientConnection();
+  ASSERT_TRUE(client_connection);
+  EXPECT_EQ(0u, client_connection->GetStats().key_update_count);
+
+  EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+  EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(2u, client_connection->GetStats().key_update_count);
+
+  server_thread_->Pause();
+  QuicConnection* server_connection = GetServerConnection();
+  if (server_connection) {
+    QuicConnectionStats server_stats = server_connection->GetStats();
+    EXPECT_EQ(2u, server_stats.key_update_count);
+  } else {
+    ADD_FAILURE() << "Missing server connection";
+  }
+  server_thread_->Resume();
+}
+
+TEST_P(EndToEndTest, KeyUpdateInitiatedByServer) {
+  SetQuicReloadableFlag(quic_key_update_supported, true);
+
+  if (!version_.UsesTls()) {
+    // Key Update is only supported in TLS handshake.
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+
+  ASSERT_TRUE(Initialize());
+
+  SendSynchronousFooRequestAndCheckResponse();
+  QuicConnection* client_connection = GetClientConnection();
+  ASSERT_TRUE(client_connection);
+  EXPECT_EQ(0u, client_connection->GetStats().key_update_count);
+
+  // Use WaitUntil to ensure the server had executed the key update predicate
+  // before sending the Foo request, otherwise the test can be flaky if it
+  // receives the Foo request before executing the key update.
+  server_thread_->WaitUntil(
+      [this]() {
+        QuicConnection* server_connection = GetServerConnection();
+        if (server_connection != nullptr) {
+          EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+        } else {
+          ADD_FAILURE() << "Missing server connection";
+        }
+        return true;
+      },
+      QuicTime::Delta::FromSeconds(5));
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+  server_thread_->WaitUntil(
+      [this]() {
+        QuicConnection* server_connection = GetServerConnection();
+        if (server_connection != nullptr) {
+          EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+        } else {
+          ADD_FAILURE() << "Missing server connection";
+        }
+        return true;
+      },
+      QuicTime::Delta::FromSeconds(5));
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(2u, client_connection->GetStats().key_update_count);
+
+  server_thread_->Pause();
+  QuicConnection* server_connection = GetServerConnection();
+  if (server_connection) {
+    QuicConnectionStats server_stats = server_connection->GetStats();
+    EXPECT_EQ(2u, server_stats.key_update_count);
+  } else {
+    ADD_FAILURE() << "Missing server connection";
+  }
+  server_thread_->Resume();
+}
+
+TEST_P(EndToEndTest, KeyUpdateInitiatedByBoth) {
+  SetQuicReloadableFlag(quic_key_update_supported, true);
+
+  if (!version_.UsesTls()) {
+    // Key Update is only supported in TLS handshake.
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+
+  ASSERT_TRUE(Initialize());
+
+  SendSynchronousFooRequestAndCheckResponse();
+
+  // Use WaitUntil to ensure the server had executed the key update predicate
+  // before the client sends the Foo request, otherwise the Foo request from
+  // the client could trigger the server key update before the server can
+  // initiate the key update locally. That would mean the test is no longer
+  // hitting the intended test state of both sides locally initiating a key
+  // update before receiving a packet in the new key phase from the other side.
+  // Additionally the test would fail since InitiateKeyUpdate() would not allow
+  // to do another key update yet and return false.
+  server_thread_->WaitUntil(
+      [this]() {
+        QuicConnection* server_connection = GetServerConnection();
+        if (server_connection != nullptr) {
+          EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+        } else {
+          ADD_FAILURE() << "Missing server connection";
+        }
+        return true;
+      },
+      QuicTime::Delta::FromSeconds(5));
+  QuicConnection* client_connection = GetClientConnection();
+  ASSERT_TRUE(client_connection);
+  EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+  server_thread_->WaitUntil(
+      [this]() {
+        QuicConnection* server_connection = GetServerConnection();
+        if (server_connection != nullptr) {
+          EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+        } else {
+          ADD_FAILURE() << "Missing server connection";
+        }
+        return true;
+      },
+      QuicTime::Delta::FromSeconds(5));
+  EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+
+  SendSynchronousFooRequestAndCheckResponse();
+  EXPECT_EQ(2u, client_connection->GetStats().key_update_count);
+
+  server_thread_->Pause();
+  QuicConnection* server_connection = GetServerConnection();
+  if (server_connection) {
+    QuicConnectionStats server_stats = server_connection->GetStats();
+    EXPECT_EQ(2u, server_stats.key_update_count);
+  } else {
+    ADD_FAILURE() << "Missing server connection";
+  }
+  server_thread_->Resume();
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 8af0087..bb081b3 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -145,6 +145,14 @@
   }
   void SetServerApplicationStateForResumption(
       std::unique_ptr<ApplicationState> /*application_state*/) override {}
+  bool KeyUpdateSupportedLocally() const override { return false; }
+  std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+      override {
+    return nullptr;
+  }
+  std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+    return nullptr;
+  }
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override {
     return *params_;
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc
index 0abec1d..4adbaab 100644
--- a/quic/core/http/quic_spdy_stream_test.cc
+++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -133,6 +133,14 @@
   }
   void SetServerApplicationStateForResumption(
       std::unique_ptr<ApplicationState> /*application_state*/) override {}
+  bool KeyUpdateSupportedLocally() const override { return true; }
+  std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+      override {
+    return nullptr;
+  }
+  std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+    return nullptr;
+  }
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override {
     return *params_;