Detect and close connections on invalid acks received by the QUIC dispatcher.

With this change, the QUIC dispatcher will track the largest packet number among the packets it has sent. When processing incoming packets, the `TlsChloExtractor` will check if any ACK frames in the received packet, acknowledge packet numbers that were not sent by the dispatcher. If such an "invalid ack" is found, the connection is considered to be in violation of the protocol and is statelessly closed by adding the connection ID to the time-wait list.

Protected by quic_restart_flag_quic_dispatcher_close_connection_on_invalid_ack.

PiperOrigin-RevId: 826482880
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h
index 172ea24..bb6f9e5 100755
--- a/quiche/common/quiche_feature_flags_list.h
+++ b/quiche/common/quiche_feature_flags_list.h
@@ -61,6 +61,7 @@
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_inlining_send_buffer2, true, true, "Uses an inlining version of QuicSendStreamBuffer.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_proof_source_get_cert_chains, false, false, "When true, quic::TlsServerHandshaker will use ProofSource::GetCertChains() instead of ProofSource::GetCertChain()")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_received_client_addresses_cache, true, true, "If true, use a LRU cache to record client addresses of packets received on server's original address.")
+QUICHE_FLAG(bool, quiche_restart_flag_quic_dispatcher_close_connection_on_invalid_ack, false, false, "An invalid ack is an ack that the peer sent for a packet that was not sent by the dispatcher. If true, the dispatcher will close the connection if it receives an invalid ack.")
 QUICHE_FLAG(bool, quiche_restart_flag_quic_support_release_time_for_gso, false, false, "If true, QuicGsoBatchWriter will support release time if it is available and the process has the permission to do so.")
 QUICHE_FLAG(bool, quiche_restart_flag_quic_testonly_default_false, false, false, "A testonly restart flag that will always default to false.")
 QUICHE_FLAG(bool, quiche_restart_flag_quic_testonly_default_true, true, true, "A testonly restart flag that will always default to true.")
diff --git a/quiche/quic/core/http/end_to_end_test.cc b/quiche/quic/core/http/end_to_end_test.cc
index 3b0be31..451632a 100644
--- a/quiche/quic/core/http/end_to_end_test.cc
+++ b/quiche/quic/core/http/end_to_end_test.cc
@@ -1570,6 +1570,63 @@
   });
 }
 
+TEST_P(EndToEndTest, TestInvalidAckBeforeHandshakeClosesConnection) {
+  if (!version_.UsesTls()) {
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+  if (!version_.HasIetfQuicFrames()) {
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+  SetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack, true);
+  connect_to_server_on_initialize_ = false;
+  ASSERT_TRUE(Initialize());
+
+  // Create client without connecting.
+  client_writer_->set_fake_packet_loss_percentage(100);
+  client_.reset(CreateQuicClient(client_writer_));
+  client_->client()->Initialize();
+
+  QuicConnection* client_connection = GetClientConnection();
+  ASSERT_TRUE(client_connection);
+  client_writer_->Initialize(
+      QuicConnectionPeer::GetHelper(client_connection),
+      QuicConnectionPeer::GetAlarmFactory(client_connection),
+      std::make_unique<ClientDelegate>(client_->client()));
+
+  // Generate connection IDs for the crafted packet. Since the client hasn't
+  // connected yet, these are effectively new, random IDs.
+  QuicConnectionId server_connection_id = TestConnectionId(1);
+  QuicConnectionId client_connection_id = TestConnectionId(2);
+
+  // Manually craft and send an INITIAL packet with an invalid ACK frame.
+  QuicFrames frames;
+  frames.push_back(QuicFrame(QuicPingFrame()));
+  // This packet contains an invalid ack frame acking packet number 1 which has
+  // not been sent by the dispatcher yet.
+  frames.push_back(QuicFrame(new QuicAckFrame(InitAckFrame(1))));
+  frames.push_back(QuicFrame(QuicPaddingFrame(1200)));
+  std::unique_ptr<QuicEncryptedPacket> packet = MakeLongHeaderPacket(
+      version_, server_connection_id, frames, INITIAL, ENCRYPTION_INITIAL);
+  DeleteFrames(&frames);
+  ASSERT_TRUE(packet);
+
+  client_writer_->writer()->WritePacket(
+      packet->data(), packet->length(),
+      client_->client()->network_helper()->GetLatestClientAddress().host(),
+      server_address_, nullptr, packet_writer_params_);
+
+  // The server should see this as an invalid ACK and add the connection ID to a
+  // time-wait list. Subsequent connection attempts with the same connection ID
+  // should fail.
+  client_->UseConnectionId(server_connection_id);
+  client_->Connect();
+  EXPECT_FALSE(client_->connected());
+  EXPECT_THAT(client_->connection_error(),
+              IsError(IETF_QUIC_PROTOCOL_VIOLATION));
+}
+
 TEST_P(EndToEndTest, SendAndReceiveCoalescedPackets) {
   ASSERT_TRUE(Initialize());
   if (!version_.CanSendCoalescedPackets()) {
diff --git a/quiche/quic/core/quic_buffered_packet_store.cc b/quiche/quic/core/quic_buffered_packet_store.cc
index f70acb0..c9b946c 100644
--- a/quiche/quic/core/quic_buffered_packet_store.cc
+++ b/quiche/quic/core/quic_buffered_packet_store.cc
@@ -585,7 +585,7 @@
     std::vector<uint16_t>* out_cert_compression_algos,
     std::vector<std::string>* out_alpns, std::string* out_sni,
     bool* out_resumption_attempted, bool* out_early_data_attempted,
-    std::optional<uint8_t>* tls_alert) {
+    std::optional<uint8_t>* tls_alert, bool* out_invalid_ack) {
   QUICHE_DCHECK_NE(out_alpns, nullptr);
   QUICHE_DCHECK_NE(out_sni, nullptr);
   QUICHE_DCHECK_NE(tls_alert, nullptr);
@@ -598,7 +598,16 @@
     return false;
   }
   BufferedPacketListNode& node = *it->second;
-  node.tls_chlo_extractor.IngestPacket(version, packet);
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+    QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_close_connection_on_invalid_ack,
+                              7, 7);
+    node.tls_chlo_extractor.IngestPacket(version, packet,
+                                         node.GetLastSentPacketNumber());
+    *out_invalid_ack = node.tls_chlo_extractor.has_invalid_ack();
+  } else {
+    node.tls_chlo_extractor.IngestPacket(version, packet);
+  }
+
   if (!node.tls_chlo_extractor.HasParsedFullChlo()) {
     *tls_alert = node.tls_chlo_extractor.tls_alert();
     return false;
diff --git a/quiche/quic/core/quic_buffered_packet_store.h b/quiche/quic/core/quic_buffered_packet_store.h
index 37bd432..28b9355 100644
--- a/quiche/quic/core/quic_buffered_packet_store.h
+++ b/quiche/quic/core/quic_buffered_packet_store.h
@@ -216,8 +216,11 @@
   // populated with the SNI tag in CHLO. |out_resumption_attempted| is populated
   // if the CHLO has the 'pre_shared_key' TLS extension.
   // |out_early_data_attempted| is populated if the CHLO has the 'early_data'
-  // TLS extension. When this returns false, and an unrecoverable error happened
-  // due to a TLS alert, |*tls_alert| will be set to the alert value.
+  // TLS extension. When this returns false, either an unrecoverable error
+  // happened due to a TLS alert, |*tls_alert| will be set to the alert value or
+  // an invalid ack is received that will cause a connection close,
+  // |*out_invalid_ack| will be set to true. An invalid ack is an ack that the
+  // peer sent for a packet that was not sent by the dispatcher.
   bool IngestPacketForTlsChloExtraction(
       const QuicConnectionId& connection_id, const ParsedQuicVersion& version,
       const QuicReceivedPacket& packet,
@@ -225,7 +228,7 @@
       std::vector<uint16_t>* out_cert_compression_algos,
       std::vector<std::string>* out_alpns, std::string* out_sni,
       bool* out_resumption_attempted, bool* out_early_data_attempted,
-      std::optional<uint8_t>* tls_alert);
+      std::optional<uint8_t>* tls_alert, bool* out_invalid_ack);
 
   // Returns the list of buffered packets for |connection_id| and removes them
   // from the store. Returns an empty list if no early arrived packets for this
diff --git a/quiche/quic/core/quic_buffered_packet_store_test.cc b/quiche/quic/core/quic_buffered_packet_store_test.cc
index f18c5b5..7552694 100644
--- a/quiche/quic/core/quic_buffered_packet_store_test.cc
+++ b/quiche/quic/core/quic_buffered_packet_store_test.cc
@@ -125,16 +125,15 @@
     EXPECT_CALL(mock_packet_writer_, IsWriteBlocked())
         .WillRepeatedly(Return(false));
     EXPECT_CALL(mock_packet_writer_, WritePacket(_, _, _, _, _, _))
-        .WillRepeatedly(
-            [&](const char* buffer, size_t buf_len, const QuicIpAddress&,
-                const QuicSocketAddress&, PerPacketOptions*,
-                const QuicPacketWriterParams&) {
-              // This packet is sent by the store and "received" by the client.
-              client_received_packets_.push_back(
-                  std::make_unique<ClientReceivedPacket>(
-                      buffer, buf_len, peer_address_, self_address_));
-              return WriteResult(WRITE_STATUS_OK, buf_len);
-            });
+        .WillRepeatedly([&](const char* buffer, size_t buf_len,
+                            const QuicIpAddress&, const QuicSocketAddress&,
+                            PerPacketOptions*, const QuicPacketWriterParams&) {
+          // This packet is sent by the store and "received" by the client.
+          client_received_packets_.push_back(
+              std::make_unique<ClientReceivedPacket>(
+                  buffer, buf_len, peer_address_, self_address_));
+          return WriteResult(WRITE_STATUS_OK, buf_len);
+        });
   }
 
  protected:
@@ -707,6 +706,7 @@
   bool early_data_attempted = false;
   QuicConfig config;
   std::optional<uint8_t> tls_alert;
+  bool has_invalid_ack = false;
 
   EXPECT_FALSE(store_.HasBufferedPackets(connection_id));
   EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_Q043_PACKET,
@@ -719,8 +719,8 @@
   EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction(
       connection_id, valid_version_, packet_, &supported_groups,
       &cert_compression_algos, &alpns, &sni, &resumption_attempted,
-      &early_data_attempted, &tls_alert));
-
+      &early_data_attempted, &tls_alert, &has_invalid_ack));
+  EXPECT_FALSE(has_invalid_ack);
   store_.DiscardPackets(connection_id);
 
   // Force the TLS CHLO to span multiple packets.
@@ -745,11 +745,13 @@
   EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction(
       connection_id, valid_version_, *packets[0], &supported_groups,
       &cert_compression_algos, &alpns, &sni, &resumption_attempted,
-      &early_data_attempted, &tls_alert));
+      &early_data_attempted, &tls_alert, &has_invalid_ack));
+  EXPECT_FALSE(has_invalid_ack);
   EXPECT_TRUE(store_.IngestPacketForTlsChloExtraction(
       connection_id, valid_version_, *packets[1], &supported_groups,
       &cert_compression_algos, &alpns, &sni, &resumption_attempted,
-      &early_data_attempted, &tls_alert));
+      &early_data_attempted, &tls_alert, &has_invalid_ack));
+  EXPECT_FALSE(has_invalid_ack);
 
   EXPECT_THAT(alpns, ElementsAre(AlpnForVersion(valid_version_)));
   EXPECT_FALSE(supported_groups.empty());
@@ -759,6 +761,73 @@
   EXPECT_FALSE(early_data_attempted);
 }
 
+TEST_F(QuicBufferedPacketStoreTest, MultiPacketChloInvalidAckDetected) {
+  QuicConnectionId connection_id = TestConnectionId();
+  std::vector<std::string> alpns;
+  std::vector<uint16_t> supported_groups;
+  std::vector<uint16_t> cert_compression_algos;
+  std::string sni;
+  bool resumption_attempted = false;
+  bool early_data_attempted = false;
+  QuicConfig config;
+  std::optional<uint8_t> tls_alert;
+  bool has_invalid_ack = false;
+
+  // Force the TLS CHLO to span multiple packets.
+  constexpr auto kCustomParameterId =
+      static_cast<TransportParameters::TransportParameterId>(0xff33);
+  std::string kCustomParameterValue(2000, '-');
+  config.custom_transport_parameters_to_send()[kCustomParameterId] =
+      kCustomParameterValue;
+  auto packets = GetFirstFlightOfPackets(valid_version_, config);
+  ASSERT_EQ(packets.size(), 2u);
+
+  // Create a packet with an invalid ack frame acking packet number 2, but the
+  // largest packet number sent is 1.
+  QuicFrames frames;
+  frames.push_back(QuicFrame(QuicPingFrame()));
+  frames.push_back(QuicFrame(QuicPaddingFrame(1200)));
+  // This ack frame is invalid because the largest packet number sent is 1 and
+  // the ack frame is acking packet number 1 and 2.
+  frames.push_back(
+      QuicFrame(new QuicAckFrame(InitAckFrame(QuicPacketNumber(2)))));
+
+  std::unique_ptr<QuicEncryptedPacket> encrypted_packet = MakeLongHeaderPacket(
+      valid_version_, connection_id, frames, INITIAL, ENCRYPTION_INITIAL);
+  DeleteFrames(&frames);
+
+  std::unique_ptr<QuicReceivedPacket> received_packet_with_invalid_ack =
+      std::unique_ptr<QuicReceivedPacket>(
+          ConstructReceivedPacket(*encrypted_packet, packet_time_));
+
+  // Enqueue the first packet in the CHLO.
+  EnqueuePacketToStore(store_, connection_id, IETF_QUIC_LONG_HEADER_PACKET,
+                       INITIAL, *packets[0], self_address_, peer_address_,
+                       valid_version_, kNoParsedChlo, connection_id_generator_);
+
+  EXPECT_TRUE(store_.HasBufferedPackets(connection_id));
+
+  // Ingest the first packet in the CHLO.
+  EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction(
+      connection_id, valid_version_, *packets[0], &supported_groups,
+      &cert_compression_algos, &alpns, &sni, &resumption_attempted,
+      &early_data_attempted, &tls_alert, &has_invalid_ack));
+  EXPECT_FALSE(has_invalid_ack);
+
+  // Ingest the second packet with invalid ack frame in the CHLO.
+  EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction(
+      connection_id, valid_version_, *received_packet_with_invalid_ack,
+      &supported_groups, &cert_compression_algos, &alpns, &sni,
+      &resumption_attempted, &early_data_attempted, &tls_alert,
+      &has_invalid_ack));
+
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+    EXPECT_TRUE(has_invalid_ack);
+  } else {
+    EXPECT_FALSE(has_invalid_ack);
+  }
+}
+
 TEST_F(QuicBufferedPacketStoreTest, DeliverInitialPacketsFirst) {
   QuicConfig config;
   QuicConnectionId connection_id = TestConnectionId(1);
diff --git a/quiche/quic/core/quic_dispatcher.cc b/quiche/quic/core/quic_dispatcher.cc
index 3ba39c0..f9c31cb 100644
--- a/quiche/quic/core/quic_dispatcher.cc
+++ b/quiche/quic/core/quic_dispatcher.cc
@@ -565,17 +565,25 @@
   QuicErrorCode connection_close_error_code =
       QUIC_HANDSHAKE_FAILED_INVALID_CONNECTION;
 
-  // If a fatal TLS alert was received when extracting Client Hello,
-  // |tls_alert_error_detail| will be set and will be used as the error_details
-  // of the connection close.
-  std::string tls_alert_error_detail;
+  // If a fatal TLS alert was received when extracting Client Hello or a packet
+  // with an invalid ack was received, |error_detail| will be set and will
+  // be used as the error_details of the connection close.
+  std::string error_detail;
 
   if (fate == kFateProcess) {
     ExtractChloResult extract_chlo_result =
         TryExtractChloOrBufferEarlyPacket(*packet_info);
     auto& parsed_chlo = extract_chlo_result.parsed_chlo;
 
-    if (extract_chlo_result.tls_alert.has_value()) {
+    if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack) &&
+        extract_chlo_result.has_invalid_ack) {
+      QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_close_connection_on_invalid_ack,
+                                6, 7);
+      // Invalid ack when parsing received packet.
+      fate = kFateTimeWait;
+      connection_close_error_code = IETF_QUIC_PROTOCOL_VIOLATION;
+      error_detail = "Received packet with invalid ack.";
+    } else if (extract_chlo_result.tls_alert.has_value()) {
       QUIC_BUG_IF(quic_dispatcher_parsed_chlo_and_tls_alert_coexist_1,
                   parsed_chlo.has_value())
           << "parsed_chlo and tls_alert should not be set at the same time.";
@@ -584,11 +592,10 @@
       uint8_t tls_alert = *extract_chlo_result.tls_alert;
       connection_close_error_code = TlsAlertToQuicErrorCode(tls_alert).value_or(
           connection_close_error_code);
-      tls_alert_error_detail =
-          absl::StrCat("TLS handshake failure from dispatcher (",
-                       EncryptionLevelToString(ENCRYPTION_INITIAL), ") ",
-                       static_cast<int>(tls_alert), ": ",
-                       SSL_alert_desc_string_long(tls_alert));
+      error_detail = absl::StrCat("TLS handshake failure from dispatcher (",
+                                  EncryptionLevelToString(ENCRYPTION_INITIAL),
+                                  ") ", static_cast<int>(tls_alert), ": ",
+                                  SSL_alert_desc_string_long(tls_alert));
     } else if (!parsed_chlo.has_value()) {
       // Client Hello incomplete. Packet has been buffered or (rarely) dropped.
       return;
@@ -615,8 +622,7 @@
                       << " to time-wait list.";
       QUIC_CODE_COUNT(quic_reject_fate_time_wait);
       const std::string& connection_close_error_detail =
-          tls_alert_error_detail.empty() ? "Reject connection"
-                                         : tls_alert_error_detail;
+          error_detail.empty() ? "Reject connection" : error_detail;
       StatelesslyTerminateConnection(
           packet_info->self_address, packet_info->peer_address,
           server_connection_id, packet_info->form, packet_info->version_flag,
@@ -657,13 +663,23 @@
           packet_info.destination_connection_id, packet_info.version,
           packet_info.packet, &supported_groups, &cert_compression_algos,
           &alpns, &sni, &resumption_attempted, &early_data_attempted,
-          &result.tls_alert);
+          &result.tls_alert, &result.has_invalid_ack);
     } else {
       // If we do not have a BufferedPacketList for this connection ID,
       // create a single-use one to check whether this packet contains a
       // full single-packet CHLO.
       TlsChloExtractor tls_chlo_extractor;
-      tls_chlo_extractor.IngestPacket(packet_info.version, packet_info.packet);
+      if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+        QUIC_RESTART_FLAG_COUNT_N(
+            quic_dispatcher_close_connection_on_invalid_ack, 1, 7);
+        tls_chlo_extractor.IngestPacket(packet_info.version, packet_info.packet,
+                                        QuicPacketNumber());
+        result.has_invalid_ack = tls_chlo_extractor.has_invalid_ack();
+      } else {
+        tls_chlo_extractor.IngestPacket(packet_info.version,
+                                        packet_info.packet);
+      }
+
       if (tls_chlo_extractor.HasParsedFullChlo()) {
         // This packet contains a full single-packet CHLO.
         has_full_tls_chlo = true;
@@ -678,6 +694,13 @@
       }
     }
 
+    if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack) &&
+        result.has_invalid_ack) {
+      QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_close_connection_on_invalid_ack,
+                                2, 7);
+      return result;
+    }
+
     if (result.tls_alert.has_value()) {
       QUIC_BUG_IF(quic_dispatcher_parsed_chlo_and_tls_alert_coexist_2,
                   has_full_tls_chlo)
diff --git a/quiche/quic/core/quic_dispatcher.h b/quiche/quic/core/quic_dispatcher.h
index dce6f4f..b4ce3d6 100644
--- a/quiche/quic/core/quic_dispatcher.h
+++ b/quiche/quic/core/quic_dispatcher.h
@@ -370,6 +370,8 @@
     // If set, the TLS alert that will cause a connection close.
     // Always empty for Google QUIC.
     std::optional<uint8_t> tls_alert;
+    // If set, an invalid ack will cause a connection close.
+    bool has_invalid_ack = false;
   };
 
   // Try to extract information(sni, alpns, ...) if the full Client Hello has
diff --git a/quiche/quic/core/quic_dispatcher_test.cc b/quiche/quic/core/quic_dispatcher_test.cc
index 05d7453..2b3cf40 100644
--- a/quiche/quic/core/quic_dispatcher_test.cc
+++ b/quiche/quic/core/quic_dispatcher_test.cc
@@ -303,10 +303,17 @@
     return reinterpret_cast<MockQuicConnection*>(session2_->connection());
   }
 
-  QuicFrames CreatePaddedPingPacketFrames() {
+  QuicFrames CreatePaddedPingPacketFrames(int padding_length) {
     QuicFrames frames;
     frames.push_back(QuicFrame(QuicPingFrame()));
-    frames.push_back(QuicFrame(QuicPaddingFrame(100)));
+    frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
+    return frames;
+  }
+
+  QuicFrames CreatePaddedPingAckPacketFrames(
+      int padding_length, std::vector<QuicAckBlock> ack_blocks) {
+    QuicFrames frames = CreatePaddedPingPacketFrames(padding_length);
+    frames.push_back(QuicFrame(new QuicAckFrame(InitAckFrame(ack_blocks))));
     return frames;
   }
 
@@ -458,7 +465,7 @@
       const QuicConnectionId& server_connection_id) {
     ProcessUndecryptableEarlyPacket(version_, peer_address,
                                     server_connection_id,
-                                    CreatePaddedPingPacketFrames());
+                                    CreatePaddedPingPacketFrames(100));
   }
 
   void ProcessUndecryptableEarlyPacket(
@@ -2550,7 +2557,7 @@
       const QuicConnectionId& server_connection_id) {
     QuicDispatcherTestBase::ProcessUndecryptableEarlyPacket(
         version, peer_address, server_connection_id,
-        CreatePaddedPingPacketFrames());
+        CreatePaddedPingPacketFrames(100));
   }
 
   void ProcessUndecryptableEarlyPacket(
@@ -2566,6 +2573,34 @@
                                     server_connection_id);
   }
 
+  std::unique_ptr<QuicReceivedPacket> ConstructEncryptedPacketWithAckFrames(
+      const QuicConnectionId& server_connection_id,
+      QuicLongHeaderType long_header_type, EncryptionLevel encryption_level,
+      const std::vector<QuicAckBlock>& ack_blocks) {
+    QuicFrames frames = CreatePaddedPingAckPacketFrames(
+        /*padding_length=*/1200, ack_blocks);
+    std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
+        MakeLongHeaderPacket(version_, server_connection_id, frames,
+                             long_header_type, encryption_level);
+    DeleteFrames(&frames);
+    return std::unique_ptr<QuicReceivedPacket>(ConstructReceivedPacket(
+        *encrypted_packet, mock_helper_.GetClock()->Now()));
+  }
+
+  std::vector<std::unique_ptr<QuicReceivedPacket>> CreateMultiPacketChlo(
+      const QuicConnectionId conn_id) {
+    QuicConfig client_config = DefaultQuicConfig();
+    // Add a 2000-byte custom parameter to increase the length of the CHLO.
+    constexpr auto kCustomParameterId =
+        static_cast<TransportParameters::TransportParameterId>(0xff33);
+    std::string kCustomParameterValue(2000, '-');
+    client_config.custom_transport_parameters_to_send()[kCustomParameterId] =
+        kCustomParameterValue;
+    return GetFirstFlightOfPackets(version_, client_config, conn_id,
+                                   EmptyQuicConnectionId(),
+                                   TestClientCryptoConfig());
+  }
+
  protected:
   QuicSocketAddress client_addr_;
 };
@@ -2607,6 +2642,182 @@
   ProcessFirstFlight(conn_id);
 }
 
+TEST_P(BufferedPacketStoreTest,
+       ProcessNonChloPacketWithInvalidAckBeforeChloClosesConnection) {
+  if (!version_.UsesTls()) {
+    return;
+  }
+  CreateTimeWaitListManager();
+  QuicConnectionId conn_id = TestConnectionId(1);
+
+  // This packet contains an invalid ack frame acking packet number 1 which has
+  // not been sent by the dispatcher yet.
+  std::vector<QuicAckBlock> ack_blocks = {
+      {QuicPacketNumber(1), QuicPacketNumber(2)}};
+
+  std::unique_ptr<QuicReceivedPacket> received_packet_with_invalid_ack =
+      ConstructEncryptedPacketWithAckFrames(conn_id, INITIAL,
+                                            ENCRYPTION_INITIAL, ack_blocks);
+
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+    //  As this packet contains an invalid ack, the dispatcher should add
+    //  the connection to time-wait list.
+    EXPECT_CALL(*time_wait_list_manager_,
+                ProcessPacket(_, _, conn_id, _, _, _));
+    ProcessReceivedPacket(std::move(received_packet_with_invalid_ack),
+                          client_addr_, version_, conn_id);
+    EXPECT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+  } else {
+    ProcessReceivedPacket(std::move(received_packet_with_invalid_ack),
+                          client_addr_, version_, conn_id);
+    EXPECT_FALSE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+  }
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u);
+}
+
+TEST_P(BufferedPacketStoreTest, ProcessMultiPacketChloWithValidAckFrame) {
+  if (!version_.UsesTls()) {
+    return;
+  }
+  QuicConnectionId conn_id = TestConnectionId(1);
+  CreateTimeWaitListManager();
+
+  std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
+      CreateMultiPacketChlo(conn_id);
+  ASSERT_EQ(packets.size(), 2u);
+
+  // This packet contains a valid ack frame acking packet number 1 which has
+  // been sent by the dispatcher in response to the first part of packets[0].
+  std::vector<QuicAckBlock> ack_blocks = {
+      {QuicPacketNumber(1), QuicPacketNumber(2)}};
+
+  // This packet contains a valid ack frame.
+  std::unique_ptr<QuicReceivedPacket> received_packet_with_valid_ack_frame =
+      ConstructEncryptedPacketWithAckFrames(conn_id, INITIAL,
+                                            ENCRYPTION_INITIAL, ack_blocks);
+
+  // Processing the first packet should not create a new session.
+  ProcessReceivedPacket(std::move(packets[0]), client_addr_, version_, conn_id);
+
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u)
+      << "No session should be created before the rest of the CHLO arrives.";
+
+  // This is a valid ack frame as it acks the first packet sent by the
+  // dispatcher in response to the first part of packets[0].
+  ProcessReceivedPacket(std::move(received_packet_with_valid_ack_frame),
+                        client_addr_, version_, conn_id);
+  EXPECT_FALSE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u)
+      << "No session should be created before the rest of the CHLO arrives.";
+
+  // Processing the second packet should create the new session.
+  EXPECT_CALL(*dispatcher_,
+              CreateQuicSession(conn_id, _, client_addr_, Eq(ExpectedAlpn()), _,
+                                MatchParsedClientHello(), _))
+      .WillOnce(Return(ByMove(CreateSession(
+          dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_,
+          &mock_alarm_factory_, &crypto_config_,
+          QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_))));
+  EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
+              ProcessUdpPacket(_, _, _))
+      .Times(3);
+
+  ProcessReceivedPacket(std::move(packets[1]), client_addr_, version_, conn_id);
+  EXPECT_EQ(dispatcher_->NumSessions(), 1u);
+}
+
+TEST_P(BufferedPacketStoreTest,
+       ProcessNonChloPacketAckingPacketNumberZeroClosesConnection) {
+  if (!version_.UsesTls()) {
+    return;
+  }
+  QuicConnectionId conn_id = TestConnectionId(1);
+  CreateTimeWaitListManager();
+
+  std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
+      CreateMultiPacketChlo(conn_id);
+  ASSERT_EQ(packets.size(), 2u);
+
+  // This packet contains an invalid ack frame acking packet number 0 and 1 but
+  // packet number 0 is not a valid packet number.
+  std::vector<QuicAckBlock> ack_blocks = {
+      {QuicPacketNumber(0), QuicPacketNumber(2)}};
+
+  // This packet contains an invalid ack frame with invalid ack range acking
+  // packet number 0.
+  std::unique_ptr<QuicReceivedPacket> received_packet_with_invalid_ack =
+      ConstructEncryptedPacketWithAckFrames(conn_id, INITIAL,
+                                            ENCRYPTION_INITIAL, ack_blocks);
+
+  // Processing the first packet should not create a new session.
+  ProcessReceivedPacket(std::move(packets[0]), client_addr_, version_, conn_id);
+
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u)
+      << "No session should be created before the rest of the CHLO arrives.";
+
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+    // Processing the packet with invalid ack should not create a new
+    // session and the connection should be added to time-wait list. The
+    // dispatcher should close the connection.
+    EXPECT_CALL(*time_wait_list_manager_,
+                ProcessPacket(_, _, conn_id, _, _, _));
+    ProcessReceivedPacket(std::move(received_packet_with_invalid_ack),
+                          client_addr_, version_, conn_id);
+    EXPECT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+  } else {
+    ProcessReceivedPacket(std::move(received_packet_with_invalid_ack),
+                          client_addr_, version_, conn_id);
+    EXPECT_FALSE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+  }
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u)
+      << "No session should be created before the rest of the CHLO arrives.";
+}
+
+TEST_P(BufferedPacketStoreTest,
+       ProcessMultiPacketChloWithInvalidAckClosesConnection) {
+  if (!version_.UsesTls()) {
+    return;
+  }
+  CreateTimeWaitListManager();
+  QuicConnectionId conn_id = TestConnectionId(1);
+
+  std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
+      CreateMultiPacketChlo(conn_id);
+  ASSERT_EQ(packets.size(), 2u);
+
+  // This packet contains an invalid ack frame acking packet number 1 and 2 but
+  // packet number 2 has not been sent by the dispatcher yet.
+  std::vector<QuicAckBlock> ack_blocks = {
+      {QuicPacketNumber(1), QuicPacketNumber(3)}};
+
+  std::unique_ptr<QuicReceivedPacket> received_packet_with_invalid_ack =
+      ConstructEncryptedPacketWithAckFrames(conn_id, INITIAL,
+                                            ENCRYPTION_INITIAL, ack_blocks);
+
+  // Processing the first packet should not create a new session.
+  ProcessReceivedPacket(std::move(packets[0]), client_addr_, version_, conn_id);
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u)
+      << "No session should be created before the rest of the CHLO arrives.";
+
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+    // Processing the packet with invalid ack should not create a new
+    // session and the connection should be added to time-wait list. The
+    // dispatcher should close the connection.
+    EXPECT_CALL(*time_wait_list_manager_,
+                ProcessPacket(_, _, conn_id, _, _, _));
+    ProcessReceivedPacket(std::move(received_packet_with_invalid_ack),
+                          client_addr_, version_, conn_id);
+    EXPECT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+  } else {
+    ProcessReceivedPacket(std::move(received_packet_with_invalid_ack),
+                          client_addr_, version_, conn_id);
+    EXPECT_FALSE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id));
+  }
+  EXPECT_EQ(dispatcher_->NumSessions(), 0u)
+      << "No session should be created before the rest of the CHLO arrives.";
+}
+
 TEST_P(BufferedPacketStoreTest, ProcessNonChloPacketsUptoLimitAndProcessChlo) {
   InSequence s;
   QuicConnectionId conn_id = TestConnectionId(1);
@@ -3177,7 +3388,7 @@
   // Process non-CHLO packet. This ProcessUndecryptableEarlyPacket() but with
   // an injected step to set the ECN bits.
   std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
-      MakeLongHeaderPacket(version_, conn_id, CreatePaddedPingPacketFrames(),
+      MakeLongHeaderPacket(version_, conn_id, CreatePaddedPingPacketFrames(100),
                            ZERO_RTT_PROTECTED, ENCRYPTION_ZERO_RTT);
   std::unique_ptr<QuicReceivedPacket> received_packet(ConstructReceivedPacket(
       *encrypted_packet, mock_helper_.GetClock()->Now(), ECN_ECT1));
diff --git a/quiche/quic/core/tls_chlo_extractor.cc b/quiche/quic/core/tls_chlo_extractor.cc
index bf946c7..59ff100 100644
--- a/quiche/quic/core/tls_chlo_extractor.cc
+++ b/quiche/quic/core/tls_chlo_extractor.cc
@@ -160,8 +160,9 @@
   return *this;
 }
 
-void TlsChloExtractor::IngestPacket(const ParsedQuicVersion& version,
-                                    const QuicReceivedPacket& packet) {
+void TlsChloExtractor::IngestPacket(
+    const ParsedQuicVersion& version, const QuicReceivedPacket& packet,
+    QuicPacketNumber dispatcher_largest_packet_number_sent) {
   if (state_ == State::kUnrecoverableFailure) {
     QUIC_DLOG(ERROR) << "Not ingesting packet after unrecoverable error";
     return;
@@ -192,6 +193,14 @@
     framer_->set_visitor(this);
   }
 
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack) &&
+      dispatcher_largest_packet_number_sent.IsInitialized()) {
+    QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_close_connection_on_invalid_ack,
+                              3, 7);
+    dispatcher_largest_packet_number_sent_ =
+        dispatcher_largest_packet_number_sent;
+  }
+
   // When the framer parses |packet|, if it sees a CRYPTO frame it will call
   // OnCryptoFrame below and that will set parsed_crypto_frame_in_this_packet_
   // to true.
@@ -236,6 +245,18 @@
   return true;
 }
 
+void TlsChloExtractor::OnError(QuicFramer* framer) {
+  // This is called when the framer encounters an error while parsing an ACK
+  // frame, including the case where the packet number 0 is acknowledged. Since
+  // QUICHE never sends packet number 0, this is invalid..
+  if (GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack) &&
+      framer->error() == QUIC_INVALID_ACK_DATA) {
+    QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_close_connection_on_invalid_ack,
+                              4, 7);
+    has_invalid_ack_ = true;
+  }
+}
+
 // This is called by the framer if it detects a change in version during
 // parsing.
 bool TlsChloExtractor::OnProtocolVersionMismatch(ParsedQuicVersion version) {
@@ -278,6 +299,24 @@
   return true;
 }
 
+bool TlsChloExtractor::OnAckFrameStart(QuicPacketNumber largest_acked,
+                                       QuicTime::Delta /*ack_delay_time*/) {
+  if (!GetQuicRestartFlag(quic_dispatcher_close_connection_on_invalid_ack)) {
+    return true;
+  }
+  if (!dispatcher_largest_packet_number_sent_.IsInitialized() ||
+      largest_acked > dispatcher_largest_packet_number_sent_) {
+    QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_close_connection_on_invalid_ack,
+                              5, 7);
+    // The dispatcher sends packets sequentially from 1 to
+    // dispatcher_largest_packet_number_sent_ (inclusive), without skipping any
+    // packet numbers in between.
+    has_invalid_ack_ = true;
+    return false;
+  }
+  return true;
+}
+
 // Called by the QuicStreamSequencer when it receives a CRYPTO frame that
 // advances the amount of contiguous data we now have starting from offset 0.
 void TlsChloExtractor::OnDataAvailable() {
diff --git a/quiche/quic/core/tls_chlo_extractor.h b/quiche/quic/core/tls_chlo_extractor.h
index aeef76e..5de68de 100644
--- a/quiche/quic/core/tls_chlo_extractor.h
+++ b/quiche/quic/core/tls_chlo_extractor.h
@@ -7,6 +7,7 @@
 
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <string>
 #include <vector>
 
@@ -69,7 +70,18 @@
 
   // Ingests |packet| and attempts to parse out the CHLO.
   void IngestPacket(const ParsedQuicVersion& version,
-                    const QuicReceivedPacket& packet);
+                    const QuicReceivedPacket& packet) {
+    IngestPacket(version, packet, QuicPacketNumber());
+  }
+
+  // Ingests |packet| and attempts to parse out the CHLO.
+  // |dispatcher_largest_packet_number_sent| is the largest packet number the
+  // dispatcher has sent. It is used to validate ACKs in the client's initial
+  // packet. If the client acks a packet that has an invalid ack, then
+  // has_invalid_ack_ will be set to true.
+  void IngestPacket(const ParsedQuicVersion& version,
+                    const QuicReceivedPacket& packet,
+                    QuicPacketNumber dispatcher_largest_packet_number_sent);
 
   // Returns whether the ingested packets have allowed parsing a complete CHLO.
   bool HasParsedFullChlo() const {
@@ -85,7 +97,7 @@
   }
 
   // Methods from QuicFramerVisitorInterface.
-  void OnError(QuicFramer* /*framer*/) override {}
+  void OnError(QuicFramer* framer) override;
   bool OnProtocolVersionMismatch(ParsedQuicVersion version) override;
   void OnPacket() override {}
   void OnVersionNegotiationPacket(
@@ -110,10 +122,8 @@
                              bool /*has_decryption_key*/) override {}
   bool OnStreamFrame(const QuicStreamFrame& /*frame*/) override { return true; }
   bool OnCryptoFrame(const QuicCryptoFrame& frame) override;
-  bool OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
-                       QuicTime::Delta /*ack_delay_time*/) override {
-    return true;
-  }
+  bool OnAckFrameStart(QuicPacketNumber largest_acked,
+                       QuicTime::Delta /*ack_delay_time*/) override;
   bool OnAckRange(QuicPacketNumber /*start*/,
                   QuicPacketNumber /*end*/) override {
     return true;
@@ -219,6 +229,7 @@
                             const std::string& details) override;
   QuicStreamId id() const override { return 0; }
   ParsedQuicVersion version() const override { return framer_->version(); }
+  bool has_invalid_ack() const { return has_invalid_ack_; }
 
  private:
   // Parses the length of the CHLO message by looking at the first four bytes.
@@ -293,6 +304,9 @@
   std::vector<uint8_t> transport_params_;
   // Exact TLS message bytes.
   std::vector<uint8_t> client_hello_bytes_;
+  QuicPacketNumber dispatcher_largest_packet_number_sent_;
+  // Whether the packet has an invalid ack.
+  bool has_invalid_ack_ = false;
 };
 
 // Convenience method to facilitate logging TlsChloExtractor::State.
diff --git a/quiche/quic/test_tools/quic_test_utils.cc b/quiche/quic/test_tools/quic_test_utils.cc
index d47c05e..246cdab 100644
--- a/quiche/quic/test_tools/quic_test_utils.cc
+++ b/quiche/quic/test_tools/quic_test_utils.cc
@@ -96,7 +96,7 @@
   QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
 
   QuicAckFrame ack;
-  QuicPacketNumber end_of_previous_block(1);
+  QuicPacketNumber end_of_previous_block(0);
   for (const QuicAckBlock& block : ack_blocks) {
     QUICHE_DCHECK_GE(block.start, end_of_previous_block);
     QUICHE_DCHECK_GT(block.limit, block.start);
@@ -989,6 +989,7 @@
   header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
   header.packet_number = QuicPacketNumber(33);
   header.long_packet_type = long_header_type;
+  header.form = IETF_QUIC_LONG_HEADER_PACKET;
   if (version.HasLongHeaderLengths()) {
     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
@@ -998,7 +999,7 @@
                     kQuicDefaultConnectionIdLength);
   framer.SetInitialObfuscators(server_connection_id);
 
-  if (long_header_type != INITIAL) {
+  if (encryption_level != ENCRYPTION_INITIAL) {
     framer.SetEncrypter(encryption_level,
                         std::make_unique<TaggingEncrypter>(encryption_level));
   }