QUIC+TLS server should temporarily retain 0-RTT keys so that re-ordered packets can be decoded.

Protected by FLAGS_quic_restart_flag_quic_server_temporarily_retain_tls_zero_rtt_keys.

PiperOrigin-RevId: 345557425
Change-Id: I27f0a13c98d7362ed702de26c33215252e686d2b
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index d7abc1a..825a0e9 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -191,6 +191,25 @@
   QuicConnection* connection_;
 };
 
+class DiscardZeroRttDecryptionKeysAlarmDelegate : public QuicAlarm::Delegate {
+ public:
+  explicit DiscardZeroRttDecryptionKeysAlarmDelegate(QuicConnection* connection)
+      : connection_(connection) {}
+  DiscardZeroRttDecryptionKeysAlarmDelegate(
+      const DiscardZeroRttDecryptionKeysAlarmDelegate&) = delete;
+  DiscardZeroRttDecryptionKeysAlarmDelegate& operator=(
+      const DiscardZeroRttDecryptionKeysAlarmDelegate&) = delete;
+
+  void OnAlarm() override {
+    DCHECK(connection_->connected());
+    QUIC_DLOG(INFO) << "0-RTT discard alarm fired";
+    connection_->RemoveDecrypter(ENCRYPTION_ZERO_RTT);
+  }
+
+ private:
+  QuicConnection* connection_;
+};
+
 // When the clearer goes out of scope, the coalesced packet gets cleared.
 class ScopedCoalescedPacketClearer {
  public:
@@ -305,6 +324,9 @@
       discard_previous_one_rtt_keys_alarm_(alarm_factory_->CreateAlarm(
           arena_.New<DiscardPreviousOneRttKeysAlarmDelegate>(this),
           &arena_)),
+      discard_zero_rtt_decryption_keys_alarm_(alarm_factory_->CreateAlarm(
+          arena_.New<DiscardZeroRttDecryptionKeysAlarmDelegate>(this),
+          &arena_)),
       visitor_(nullptr),
       debug_visitor_(nullptr),
       packet_creator_(server_connection_id_, &framer_, random_generator_, this),
@@ -427,6 +449,17 @@
     delete writer_;
   }
   ClearQueuedPackets();
+  if (stats_
+          .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter >
+      0) {
+    QUIC_CODE_COUNT_N(
+        quic_server_received_tls_zero_rtt_packet_after_discarding_decrypter, 2,
+        3);
+  } else {
+    QUIC_CODE_COUNT_N(
+        quic_server_received_tls_zero_rtt_packet_after_discarding_decrypter, 3,
+        3);
+  }
 }
 
 void QuicConnection::ClearQueuedPackets() {
@@ -1150,6 +1183,22 @@
                                        EncryptionLevel level) {
   last_decrypted_packet_level_ = level;
   last_packet_decrypted_ = true;
+  if (level == ENCRYPTION_FORWARD_SECURE &&
+      !have_decrypted_first_one_rtt_packet_) {
+    have_decrypted_first_one_rtt_packet_ = true;
+    if (GetQuicRestartFlag(quic_server_temporarily_retain_tls_zero_rtt_keys) &&
+        version().UsesTls() && perspective_ == Perspective::IS_SERVER) {
+      // Servers MAY temporarily retain 0-RTT keys to allow decrypting reordered
+      // packets without requiring their contents to be retransmitted with 1-RTT
+      // keys. After receiving a 1-RTT packet, servers MUST discard 0-RTT keys
+      // within a short time; the RECOMMENDED time period is three times the
+      // Probe Timeout.
+      // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-discarding-0-rtt-keys
+      QUIC_RESTART_FLAG_COUNT(quic_server_temporarily_retain_tls_zero_rtt_keys);
+      discard_zero_rtt_decryption_keys_alarm_->Set(
+          clock_->ApproximateNow() + sent_packet_manager_.GetPtoDelay() * 3);
+    }
+  }
   if (EnforceAntiAmplificationLimit() &&
       (last_decrypted_packet_level_ == ENCRYPTION_HANDSHAKE ||
        last_decrypted_packet_level_ == ENCRYPTION_FORWARD_SECURE)) {
@@ -2246,6 +2295,16 @@
       }
     }
   }
+
+  if (version().UsesTls() && perspective_ == Perspective::IS_SERVER &&
+      decryption_level == ENCRYPTION_ZERO_RTT && !has_decryption_key &&
+      had_zero_rtt_decrypter_) {
+    QUIC_CODE_COUNT_N(
+        quic_server_received_tls_zero_rtt_packet_after_discarding_decrypter, 1,
+        3);
+    stats_
+        .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter++;
+  }
 }
 
 bool QuicConnection::ShouldEnqueueUnDecryptablePacket(
@@ -3689,6 +3748,9 @@
 void QuicConnection::InstallDecrypter(
     EncryptionLevel level,
     std::unique_ptr<QuicDecrypter> decrypter) {
+  if (level == ENCRYPTION_ZERO_RTT) {
+    had_zero_rtt_decrypter_ = true;
+  }
   framer_.InstallDecrypter(level, std::move(decrypter));
   if (!undecryptable_packets_.empty() &&
       !process_undecryptable_packets_alarm_->IsSet()) {
@@ -4048,6 +4110,7 @@
   mtu_discovery_alarm_->Cancel();
   process_undecryptable_packets_alarm_->Cancel();
   discard_previous_one_rtt_keys_alarm_->Cancel();
+  discard_zero_rtt_decryption_keys_alarm_->Cancel();
   blackhole_detector_.StopDetection();
   idle_network_detector_.StopDetection();
 }
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 327a696..8cc3a4b 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1719,6 +1719,10 @@
   // An alarm that fires to discard keys for the previous key phase some time
   // after a key update has completed.
   QuicArenaScopedPtr<QuicAlarm> discard_previous_one_rtt_keys_alarm_;
+  // An alarm that fires to discard 0-RTT decryption keys some time after the
+  // first 1-RTT packet has been decrypted. Only used on server connections with
+  // TLS handshaker.
+  QuicArenaScopedPtr<QuicAlarm> discard_zero_rtt_decryption_keys_alarm_;
   // Neither visitor is owned by this class.
   QuicConnectionVisitorInterface* visitor_;
   QuicConnectionDebugVisitor* debug_visitor_;
@@ -1956,6 +1960,13 @@
   // Indicate whether AckFrequency frame has been sent.
   bool ack_frequency_sent_ = false;
 
+  // True if a 0-RTT decrypter was or is installed at some point in the
+  // connection's lifetime.
+  bool had_zero_rtt_decrypter_ = false;
+
+  // True after the first 1-RTT packet has successfully decrypted.
+  bool have_decrypted_first_one_rtt_packet_ = false;
+
   const bool fix_missing_initial_keys_ =
       GetQuicReloadableFlag(quic_fix_missing_initial_keys2);
 
diff --git a/quic/core/quic_connection_stats.cc b/quic/core/quic_connection_stats.cc
index d7440b4..a693677 100644
--- a/quic/core/quic_connection_stats.cc
+++ b/quic/core/quic_connection_stats.cc
@@ -58,6 +58,8 @@
   os << " key_update_count: " << s.key_update_count;
   os << " num_failed_authentication_packets_received: "
      << s.num_failed_authentication_packets_received;
+  os << " num_tls_server_zero_rtt_packets_received_after_discarding_decrypter: "
+     << s.num_tls_server_zero_rtt_packets_received_after_discarding_decrypter;
   os << " }";
 
   return os;
diff --git a/quic/core/quic_connection_stats.h b/quic/core/quic_connection_stats.h
index 95cf61e..409e8ec 100644
--- a/quic/core/quic_connection_stats.h
+++ b/quic/core/quic_connection_stats.h
@@ -178,6 +178,11 @@
   // Counts the number of undecryptable packets received across all keys. Does
   // not include packets where a decryption key for that level was absent.
   QuicPacketCount num_failed_authentication_packets_received = 0;
+
+  // Counts the number of QUIC+TLS 0-RTT packets received after 0-RTT decrypter
+  // was discarded, only on server connections.
+  QuicPacketCount
+      num_tls_server_zero_rtt_packets_received_after_discarding_decrypter = 0;
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index b3dc2d8..e06eb86 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -462,6 +462,11 @@
         QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm(this));
   }
 
+  TestAlarmFactory::TestAlarm* GetDiscardZeroRttDecryptionKeysAlarm() {
+    return reinterpret_cast<TestAlarmFactory::TestAlarm*>(
+        QuicConnectionPeer::GetDiscardZeroRttDecryptionKeysAlarm(this));
+  }
+
   TestAlarmFactory::TestAlarm* GetBlackholeDetectorAlarm() {
     return reinterpret_cast<TestAlarmFactory::TestAlarm*>(
         QuicConnectionPeer::GetBlackholeDetectorAlarm(this));
@@ -12879,6 +12884,102 @@
   }
 }
 
+TEST_P(QuicConnectionTest,
+       ServerReceivedZeroRttPacketAfterOneRttPacketWithoutRetainedKey) {
+  SetQuicRestartFlag(quic_server_temporarily_retain_tls_zero_rtt_keys, false);
+  if (!connection_.version().UsesTls()) {
+    return;
+  }
+
+  set_perspective(Perspective::IS_SERVER);
+  SetDecrypter(ENCRYPTION_ZERO_RTT,
+               std::make_unique<NullDecrypter>(Perspective::IS_SERVER));
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+
+  // Finish handshake.
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+  notifier_.NeuterUnencryptedData();
+  connection_.NeuterUnencryptedPackets();
+  connection_.OnHandshakeComplete();
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
+  // When quic_server_temporarily_retain_tls_zero_rtt_keys=false,
+  // TlsServerHandshaker::FinishHandshake will remove the ENCRYPTION_ZERO_RTT
+  // decrypter, simulate that here:
+  connection_.RemoveDecrypter(ENCRYPTION_ZERO_RTT);
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE);
+  EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+  EXPECT_EQ(
+      0u,
+      connection_.GetStats()
+          .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter);
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(0);
+  ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+  EXPECT_EQ(
+      1u,
+      connection_.GetStats()
+          .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter);
+}
+
+TEST_P(QuicConnectionTest,
+       ServerReceivedZeroRttPacketAfterOneRttPacketWithRetainedKey) {
+  SetQuicRestartFlag(quic_server_temporarily_retain_tls_zero_rtt_keys, true);
+  if (!connection_.version().UsesTls()) {
+    return;
+  }
+
+  set_perspective(Perspective::IS_SERVER);
+  SetDecrypter(ENCRYPTION_ZERO_RTT,
+               std::make_unique<NullDecrypter>(Perspective::IS_SERVER));
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+
+  // Finish handshake.
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+  notifier_.NeuterUnencryptedData();
+  connection_.NeuterUnencryptedPackets();
+  connection_.OnHandshakeComplete();
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE);
+  EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+
+  // 0-RTT packet received out of order should be decoded since the decrypter
+  // is temporarily retained.
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+  EXPECT_EQ(
+      0u,
+      connection_.GetStats()
+          .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter);
+
+  // Simulate the timeout for discarding 0-RTT keys passing.
+  connection_.GetDiscardZeroRttDecryptionKeysAlarm()->Fire();
+
+  // Another 0-RTT packet received now should not be decoded.
+  EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(0);
+  ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+  EXPECT_EQ(
+      1u,
+      connection_.GetStats()
+          .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter);
+
+  // The |discard_zero_rtt_decryption_keys_alarm_| should only be set on the
+  // first 1-RTT packet received.
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(5, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE);
+  EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index f9530bb..7a8bde0 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -68,6 +68,7 @@
 QUIC_FLAG(FLAGS_quic_restart_flag_dont_fetch_quic_private_keys_from_leto, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_enable_zero_rtt_for_tls_v2, true)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_offload_pacing_to_usps2, false)
+QUIC_FLAG(FLAGS_quic_restart_flag_quic_server_temporarily_retain_tls_zero_rtt_keys, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_session_tickets_always_enabled, true)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_startup_faster_interval_set, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false)
diff --git a/quic/core/quic_one_block_arena.h b/quic/core/quic_one_block_arena.h
index d821803..15c0a99 100644
--- a/quic/core/quic_one_block_arena.h
+++ b/quic/core/quic_one_block_arena.h
@@ -75,7 +75,7 @@
 
 // QuicConnections currently use around 1KB of polymorphic types which would
 // ordinarily be on the heap. Instead, store them inline in an arena.
-using QuicConnectionArena = QuicOneBlockArena<1024>;
+using QuicConnectionArena = QuicOneBlockArena<1056>;
 
 }  // namespace quic
 
diff --git a/quic/core/quic_path_validator.cc b/quic/core/quic_path_validator.cc
index 04d1be2..e24e053 100644
--- a/quic/core/quic_path_validator.cc
+++ b/quic/core/quic_path_validator.cc
@@ -30,7 +30,7 @@
 }
 
 QuicPathValidator::QuicPathValidator(QuicAlarmFactory* alarm_factory,
-                                     QuicOneBlockArena<1024>* arena,
+                                     QuicConnectionArena* arena,
                                      SendDelegate* send_delegate,
                                      QuicRandom* random)
     : send_delegate_(send_delegate),
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 3d0d8dc..93f6c58 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -404,7 +404,15 @@
   handshaker_delegate()->OnTlsHandshakeComplete();
   handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_HANDSHAKE);
   handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_HANDSHAKE);
-  handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_ZERO_RTT);
+  if (!GetQuicRestartFlag(quic_server_temporarily_retain_tls_zero_rtt_keys)) {
+    handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_ZERO_RTT);
+  } else {
+    // ENCRYPTION_ZERO_RTT decryption key is not discarded here as "Servers MAY
+    // temporarily retain 0-RTT keys to allow decrypting reordered packets
+    // without requiring their contents to be retransmitted with 1-RTT keys."
+    // It is expected that QuicConnection will discard the key at an
+    // appropriate time.
+  }
 }
 
 QuicAsyncStatus TlsServerHandshaker::VerifyCertChain(
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc
index 8303487..2ddbd3e 100644
--- a/quic/test_tools/crypto_test_utils.cc
+++ b/quic/test_tools/crypto_test_utils.cc
@@ -606,7 +606,7 @@
     if (level == ENCRYPTION_FORWARD_SECURE ||
         !((level == ENCRYPTION_HANDSHAKE || level == ENCRYPTION_ZERO_RTT ||
            client_encrypter == nullptr) &&
-          server_decrypter == nullptr)) {
+          (level == ENCRYPTION_ZERO_RTT || server_decrypter == nullptr))) {
       CompareCrypters(client_encrypter, server_decrypter,
                       "client " + EncryptionLevelString(level) + " write");
     }
diff --git a/quic/test_tools/quic_connection_peer.cc b/quic/test_tools/quic_connection_peer.cc
index 10bd838..dfdc4f7 100644
--- a/quic/test_tools/quic_connection_peer.cc
+++ b/quic/test_tools/quic_connection_peer.cc
@@ -157,6 +157,12 @@
 }
 
 // static
+QuicAlarm* QuicConnectionPeer::GetDiscardZeroRttDecryptionKeysAlarm(
+    QuicConnection* connection) {
+  return connection->discard_zero_rtt_decryption_keys_alarm_.get();
+}
+
+// static
 QuicPacketWriter* QuicConnectionPeer::GetWriter(QuicConnection* connection) {
   return connection->writer_;
 }
diff --git a/quic/test_tools/quic_connection_peer.h b/quic/test_tools/quic_connection_peer.h
index 1769033..889c06f 100644
--- a/quic/test_tools/quic_connection_peer.h
+++ b/quic/test_tools/quic_connection_peer.h
@@ -84,6 +84,8 @@
       QuicConnection* connection);
   static QuicAlarm* GetDiscardPreviousOneRttKeysAlarm(
       QuicConnection* connection);
+  static QuicAlarm* GetDiscardZeroRttDecryptionKeysAlarm(
+      QuicConnection* connection);
 
   static QuicPacketWriter* GetWriter(QuicConnection* connection);
   // If |owns_writer| is true, takes ownership of |writer|.