Close QUIC+TLS connection if decrypted a 0-RTT packet with higher packet number than 1-RTT packet number

Protected by FLAGS_quic_reloadable_flag_quic_close_connection_on_0rtt_packet_number_higher_than_1rtt.

PiperOrigin-RevId: 346385500
Change-Id: Id1baf34e93dc5f1ee68e0dce7523bd1bb265111f
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 941bd21..5d26ab0 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -12861,6 +12861,119 @@
             retry_token);
 }
 
+TEST_P(QuicConnectionTest,
+       ServerReceivedZeroRttWithHigherPacketNumberThanOneRttAndFlagDisabled) {
+  SetQuicRestartFlag(quic_server_temporarily_retain_tls_zero_rtt_keys, true);
+  SetQuicReloadableFlag(
+      quic_close_connection_on_0rtt_packet_number_higher_than_1rtt, false);
+  if (!connection_.version().UsesTls()) {
+    return;
+  }
+
+  // The code that checks for this error piggybacks on some book-keeping state
+  // kept for key update, so enable key update for the test.
+  std::string error_details;
+  TransportParameters params;
+  params.key_update_not_yet_supported = false;
+  QuicConfig config;
+  EXPECT_THAT(config.ProcessTransportParameters(
+                  params, /* is_resumption = */ false, &error_details),
+              IsQuicNoError());
+  config.SetKeyUpdateSupportedLocally();
+  QuicConfigPeer::SetNegotiated(&config, true);
+  QuicConfigPeer::SetReceivedOriginalConnectionId(&config,
+                                                  connection_.connection_id());
+  QuicConfigPeer::SetReceivedInitialSourceConnectionId(
+      &config, connection_.connection_id());
+  EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+  connection_.SetFromConfig(config);
+
+  set_perspective(Perspective::IS_SERVER);
+  SetDecrypter(ENCRYPTION_ZERO_RTT,
+               std::make_unique<NullDecrypter>(Perspective::IS_SERVER));
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+
+  // Finish handshake.
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+  notifier_.NeuterUnencryptedData();
+  connection_.NeuterUnencryptedPackets();
+  connection_.OnHandshakeComplete();
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
+
+  // Decrypt a 1-RTT packet.
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE);
+  EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+
+  // 0-RTT packet with higher packet number than a 1-RTT packet is invalid, but
+  // accepted as the
+  // quic_close_connection_on_0rtt_packet_number_higher_than_1rtt
+  // flag is disabled.
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+  EXPECT_TRUE(connection_.connected());
+}
+
+TEST_P(QuicConnectionTest,
+       ServerReceivedZeroRttWithHigherPacketNumberThanOneRtt) {
+  SetQuicRestartFlag(quic_server_temporarily_retain_tls_zero_rtt_keys, true);
+  SetQuicReloadableFlag(
+      quic_close_connection_on_0rtt_packet_number_higher_than_1rtt, true);
+  if (!connection_.version().UsesTls()) {
+    return;
+  }
+
+  // The code that checks for this error piggybacks on some book-keeping state
+  // kept for key update, so enable key update for the test.
+  std::string error_details;
+  TransportParameters params;
+  params.key_update_not_yet_supported = false;
+  QuicConfig config;
+  EXPECT_THAT(config.ProcessTransportParameters(
+                  params, /* is_resumption = */ false, &error_details),
+              IsQuicNoError());
+  config.SetKeyUpdateSupportedLocally();
+  QuicConfigPeer::SetNegotiated(&config, true);
+  QuicConfigPeer::SetReceivedOriginalConnectionId(&config,
+                                                  connection_.connection_id());
+  QuicConfigPeer::SetReceivedInitialSourceConnectionId(
+      &config, connection_.connection_id());
+  EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+  connection_.SetFromConfig(config);
+
+  set_perspective(Perspective::IS_SERVER);
+  SetDecrypter(ENCRYPTION_ZERO_RTT,
+               std::make_unique<NullDecrypter>(Perspective::IS_SERVER));
+
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+
+  // Finish handshake.
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+  notifier_.NeuterUnencryptedData();
+  connection_.NeuterUnencryptedPackets();
+  connection_.OnHandshakeComplete();
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
+
+  // Decrypt a 1-RTT packet.
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE);
+  EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+
+  // 0-RTT packet with higher packet number than a 1-RTT packet is invalid and
+  // should cause the connection to be closed.
+  EXPECT_CALL(visitor_, BeforeConnectionCloseSent());
+  EXPECT_CALL(visitor_, OnConnectionClosed(_, _));
+  ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
+  EXPECT_FALSE(connection_.connected());
+  TestConnectionCloseQuicErrorCode(
+      QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/quic_error_codes.cc b/quic/core/quic_error_codes.cc
index 65db041..1aa9760 100644
--- a/quic/core/quic_error_codes.cc
+++ b/quic/core/quic_error_codes.cc
@@ -108,6 +108,7 @@
     RETURN_STRING_LITERAL(QUIC_TOO_MANY_OPEN_STREAMS);
     RETURN_STRING_LITERAL(QUIC_PUBLIC_RESET);
     RETURN_STRING_LITERAL(QUIC_INVALID_VERSION);
+    RETURN_STRING_LITERAL(QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER);
     RETURN_STRING_LITERAL(QUIC_INVALID_HEADER_ID);
     RETURN_STRING_LITERAL(QUIC_INVALID_NEGOTIATED_VALUE);
     RETURN_STRING_LITERAL(QUIC_DECOMPRESSION_FAILURE);
@@ -384,6 +385,8 @@
       return {true, static_cast<uint64_t>(INTERNAL_ERROR)};
     case QUIC_INVALID_VERSION:
       return {true, static_cast<uint64_t>(PROTOCOL_VIOLATION)};
+    case QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER:
+      return {true, static_cast<uint64_t>(PROTOCOL_VIOLATION)};
     case QUIC_INVALID_HEADER_ID:
       return {true, static_cast<uint64_t>(INTERNAL_ERROR)};
     case QUIC_INVALID_NEGOTIATED_VALUE:
diff --git a/quic/core/quic_error_codes.h b/quic/core/quic_error_codes.h
index a1cd7a6..20ab087 100644
--- a/quic/core/quic_error_codes.h
+++ b/quic/core/quic_error_codes.h
@@ -563,8 +563,11 @@
   // timeout.
   QUIC_MAX_AGE_TIMEOUT = 191,
 
+  // Decrypted a 0-RTT packet with a higher packet number than a 1-RTT packet.
+  QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER = 192,
+
   // No error. Used as bound while iterating.
-  QUIC_LAST_ERROR = 192,
+  QUIC_LAST_ERROR = 193,
 };
 // QuicErrorCodes is encoded as four octets on-the-wire when doing Google QUIC,
 // or a varint62 when doing IETF QUIC. Ensure that its value does not exceed
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 4f435a6..d5f27e6 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -15,6 +15,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_fewer_startup_round_trips, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_use_bytes_delivered, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_can_send_ack_frequency, true)
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_close_connection_on_0rtt_packet_number_higher_than_1rtt, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_bursts, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_cwnd_and_pacing_gains, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_enable_5rto_blackhole_detection2, true)
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index 6b29cae..7803b59 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -4781,6 +4781,22 @@
       decrypted_buffer, decrypted_length, buffer_length);
   if (success) {
     visitor_->OnDecryptedPacket(udp_packet_length, level);
+    if (GetQuicReloadableFlag(
+            quic_close_connection_on_0rtt_packet_number_higher_than_1rtt)) {
+      QUIC_RELOADABLE_FLAG_COUNT(
+          quic_close_connection_on_0rtt_packet_number_higher_than_1rtt);
+      if (level == ENCRYPTION_ZERO_RTT &&
+          current_key_phase_first_received_packet_number_.IsInitialized() &&
+          header.packet_number >
+              current_key_phase_first_received_packet_number_) {
+        set_detailed_error(absl::StrCat(
+            "Decrypted a 0-RTT packet with a packet number ",
+            header.packet_number.ToString(),
+            " which is higher than a 1-RTT packet number ",
+            current_key_phase_first_received_packet_number_.ToString()));
+        return RaiseError(QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER);
+      }
+    }
     *decrypted_level = level;
     potential_peer_key_update_attempt_count_ = 0;
     if (attempt_key_update) {