Copy the ECN codepoint in QuicReceivedPacket::Clone().

Cloned incoming packets will not carry a ECN codepoint even when the original packet has one, so the resulting ECN feedback will be incorrect -- this will cause the peer to abandon ECN, potentially impairing performance.

The relevant instance of cloned packets is Buffered Packets at the beginning of the connection. ECN peers will abandon ECN if early packets are buffered.

The only other instance is the QuicProxyConnectionMap. Fixing Clone() is necessary but not sufficient for getting QuicProxyDispatcher to forward ECN properly.

There are tests for both downstream effects of this bug.

Protected by FLAGS_quic_reloadable_flag_quic_clone_ecn.

PiperOrigin-RevId: 595547494
diff --git a/quiche/quic/core/quic_buffered_packet_store_test.cc b/quiche/quic/core/quic_buffered_packet_store_test.cc
index 13f75f8..b241fe1 100644
--- a/quiche/quic/core/quic_buffered_packet_store_test.cc
+++ b/quiche/quic/core/quic_buffered_packet_store_test.cc
@@ -668,6 +668,23 @@
     previous_packet_type = long_packet_type;
   }
 }
+
+// Test for b/316633326.
+TEST_F(QuicBufferedPacketStoreTest, BufferedPacketRetainsEcn) {
+  SetQuicReloadableFlag(quic_clone_ecn, true);
+  QuicConnectionId connection_id = TestConnectionId(1);
+  QuicReceivedPacket ect1_packet(packet_content_.data(), packet_content_.size(),
+                                 packet_time_, false, 0, true, nullptr, 0,
+                                 false, ECN_ECT1);
+  store_.EnqueuePacket(connection_id, false, ect1_packet, self_address_,
+                       peer_address_, valid_version_, kNoParsedChlo, nullptr);
+  BufferedPacketList delivered_packets = store_.DeliverPackets(connection_id);
+  EXPECT_THAT(delivered_packets.buffered_packets, SizeIs(1));
+  for (const auto& packet : delivered_packets.buffered_packets) {
+    EXPECT_EQ(packet.packet->ecn_codepoint(), ECN_ECT1);
+  }
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quiche/quic/core/quic_dispatcher_test.cc b/quiche/quic/core/quic_dispatcher_test.cc
index f921147..0e6b848 100644
--- a/quiche/quic/core/quic_dispatcher_test.cc
+++ b/quiche/quic/core/quic_dispatcher_test.cc
@@ -3122,6 +3122,66 @@
   dispatcher_->ProcessBufferedChlos(kMaxNumSessionsToCreate);
 }
 
+TEST_P(BufferedPacketStoreTest, BufferedChloWithEcn) {
+  if (!version_.HasIetfQuicFrames()) {
+    return;
+  }
+  SetQuicReloadableFlag(quic_clone_ecn, true);
+  SetQuicRestartFlag(quic_support_ect1, true);
+  InSequence s;
+  QuicConnectionId conn_id = TestConnectionId(1);
+  // Process non-CHLO packet. This ProcessUndecryptableEarlyPacket() but with
+  // an injected step to set the ECN bits.
+  std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
+      GetUndecryptableEarlyPacket(version_, conn_id);
+  std::unique_ptr<QuicReceivedPacket> received_packet(ConstructReceivedPacket(
+      *encrypted_packet, mock_helper_.GetClock()->Now(), ECN_ECT1));
+  ProcessReceivedPacket(std::move(received_packet), client_addr_, version_,
+                        conn_id);
+  EXPECT_EQ(0u, dispatcher_->NumSessions())
+      << "No session should be created before CHLO arrives.";
+
+  // When CHLO arrives, a new session should be created, and all packets
+  // buffered should be delivered to the session.
+  EXPECT_CALL(connection_id_generator_,
+              MaybeReplaceConnectionId(conn_id, version_))
+      .WillOnce(Return(std::nullopt));
+  EXPECT_CALL(*dispatcher_,
+              CreateQuicSession(conn_id, _, client_addr_, Eq(ExpectedAlpn()), _,
+                                MatchParsedClientHello(), _))
+      .WillOnce(Return(ByMove(CreateSession(
+          dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_,
+          &mock_alarm_factory_, &crypto_config_,
+          QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_))));
+  bool got_ect1 = false;
+  bool got_ce = false;
+  EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
+              ProcessUdpPacket(_, _, _))
+      .Times(2)  // non-CHLO + CHLO.
+      .WillRepeatedly(WithArg<2>(Invoke([&](const QuicReceivedPacket& packet) {
+        switch (packet.ecn_codepoint()) {
+          case ECN_ECT1:
+            got_ect1 = true;
+            break;
+          case ECN_CE:
+            got_ce = true;
+            break;
+          default:
+            break;
+        }
+      })));
+  QuicConnectionId client_connection_id = TestConnectionId(2);
+  std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
+      GetFirstFlightOfPackets(version_, DefaultQuicConfig(), conn_id,
+                              client_connection_id, TestClientCryptoConfig(),
+                              ECN_CE);
+  for (auto&& packet : packets) {
+    ProcessReceivedPacket(std::move(packet), client_addr_, version_, conn_id);
+  }
+  EXPECT_TRUE(got_ect1);
+  EXPECT_TRUE(got_ce);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quiche/quic/core/quic_flags_list.h b/quiche/quic/core/quic_flags_list.h
index ea2876d..23508a0 100644
--- a/quiche/quic/core/quic_flags_list.h
+++ b/quiche/quic/core/quic_flags_list.h
@@ -103,6 +103,8 @@
 QUIC_FLAG(quic_reloadable_flag_quic_send_placeholder_ticket_when_encrypt_ticket_fails, true)
 // When true, allows sending of QUIC packets marked ECT(1). A different flag (TBD) will actually utilize this capability to send ECT(1).
 QUIC_FLAG(quic_restart_flag_quic_support_ect1, false)
+// When true, correctly stores the ECN mark on incoming packets when buffered while waiting for a crypto context.
+QUIC_FLAG(quic_reloadable_flag_quic_clone_ecn, false)
 // When true, defaults to BBR congestion control instead of Cubic.
 QUIC_FLAG(quic_reloadable_flag_quic_default_to_bbr, false)
 // When true, report received ECN markings to the peer. Replaces quic_receive_ecn2 to use correct codepoints.
diff --git a/quiche/quic/core/quic_packets.cc b/quiche/quic/core/quic_packets.cc
index 907c7c8..e864251 100644
--- a/quiche/quic/core/quic_packets.cc
+++ b/quiche/quic/core/quic_packets.cc
@@ -5,6 +5,7 @@
 #include "quiche/quic/core/quic_packets.h"
 
 #include <algorithm>
+#include <memory>
 #include <utility>
 
 #include "absl/strings/escaping.h"
@@ -14,7 +15,6 @@
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/core/quic_utils.h"
 #include "quiche/quic/core/quic_versions.h"
-#include "quiche/quic/platform/api/quic_flag_utils.h"
 #include "quiche/quic/platform/api/quic_flags.h"
 
 namespace quic {
@@ -373,13 +373,30 @@
   if (this->packet_headers()) {
     char* headers_buffer = new char[this->headers_length()];
     memcpy(headers_buffer, this->packet_headers(), this->headers_length());
+    if (GetQuicReloadableFlag(quic_clone_ecn)) {
+      QUIC_RELOADABLE_FLAG_COUNT_N(quic_clone_ecn, 1, 2);
+      return std::make_unique<QuicReceivedPacket>(
+          buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0,
+          headers_buffer, this->headers_length(), true, this->ecn_codepoint());
+    } else {
+      return std::make_unique<QuicReceivedPacket>(
+          buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0,
+          headers_buffer, this->headers_length(), true);
+    }
     return std::make_unique<QuicReceivedPacket>(
         buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0,
-        headers_buffer, this->headers_length(), true);
+        headers_buffer, this->headers_length(), true, this->ecn_codepoint());
   }
 
-  return std::make_unique<QuicReceivedPacket>(
-      buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0);
+  if (GetQuicReloadableFlag(quic_clone_ecn)) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_clone_ecn, 2, 2);
+    return std::make_unique<QuicReceivedPacket>(
+        buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0,
+        nullptr, 0, false, this->ecn_codepoint());
+  } else {
+    return std::make_unique<QuicReceivedPacket>(
+        buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0);
+  }
 }
 
 std::ostream& operator<<(std::ostream& os, const QuicReceivedPacket& s) {
diff --git a/quiche/quic/core/quic_packets.h b/quiche/quic/core/quic_packets.h
index 3d70fec..f317840 100644
--- a/quiche/quic/core/quic_packets.h
+++ b/quiche/quic/core/quic_packets.h
@@ -280,6 +280,10 @@
                                                 const QuicEncryptedPacket& s);
 };
 
+namespace test {
+class QuicReceivedPacketPeer;
+}  // namespace test
+
 // A received encrypted QUIC packet, with a recorded time of receipt.
 class QUICHE_EXPORT QuicReceivedPacket : public QuicEncryptedPacket {
  public:
@@ -325,6 +329,8 @@
   QuicEcnCodepoint ecn_codepoint() const { return ecn_codepoint_; }
 
  private:
+  friend class test::QuicReceivedPacketPeer;
+
   const QuicTime receipt_time_;
   int ttl_;
   // Points to the start of packet headers.
diff --git a/quiche/quic/core/quic_packets_test.cc b/quiche/quic/core/quic_packets_test.cc
index 4e6598d..82feda6 100644
--- a/quiche/quic/core/quic_packets_test.cc
+++ b/quiche/quic/core/quic_packets_test.cc
@@ -4,7 +4,12 @@
 
 #include "quiche/quic/core/quic_packets.h"
 
+#include <memory>
+
 #include "absl/memory/memory.h"
+#include "quiche/quic/core/quic_time.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/platform/api/quic_flags.h"
 #include "quiche/quic/platform/api/quic_test.h"
 #include "quiche/quic/test_tools/quic_test_utils.h"
 #include "quiche/common/test_tools/quiche_test_utils.h"
@@ -115,6 +120,16 @@
   EXPECT_EQ(1000u, copy2->encrypted_length);
 }
 
+TEST_F(QuicPacketsTest, CloneReceivedPacket) {
+  SetQuicReloadableFlag(quic_clone_ecn, true);
+  char header[4] = "bar";
+  QuicReceivedPacket packet("foo", 3, QuicTime::Zero(), false, 0, true, header,
+                            sizeof(header) - 1, false,
+                            QuicEcnCodepoint::ECN_ECT1);
+  std::unique_ptr<QuicReceivedPacket> copy = packet.Clone();
+  EXPECT_EQ(packet.ecn_codepoint(), copy->ecn_codepoint());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quiche/quic/test_tools/first_flight.cc b/quiche/quic/test_tools/first_flight.cc
index 3214c56..820dbbb 100644
--- a/quiche/quic/test_tools/first_flight.cc
+++ b/quiche/quic/test_tools/first_flight.cc
@@ -20,6 +20,7 @@
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/quic/test_tools/crypto_test_utils.h"
 #include "quiche/quic/test_tools/mock_connection_id_generator.h"
+#include "quiche/quic/test_tools/quic_connection_peer.h"
 #include "quiche/quic/test_tools/quic_test_utils.h"
 
 namespace quic {
@@ -53,7 +54,7 @@
             std::make_unique<QuicCryptoClientConfig>(
                 crypto_test_utils::ProofVerifierForTesting())) {}
 
-  void GenerateFirstFlight() {
+  void GenerateFirstFlight(QuicEcnCodepoint ecn = ECN_NOT_ECT) {
     crypto_config_->set_alpn(AlpnForVersion(version_));
     connection_ = new QuicConnection(
         server_connection_id_,
@@ -62,6 +63,10 @@
         &alarm_factory_, &writer_,
         /*owns_writer=*/false, Perspective::IS_CLIENT,
         ParsedQuicVersionVector{version_}, connection_id_generator_);
+    if (ecn != ECN_NOT_ECT) {
+      QuicConnectionPeer::DisableEcnCodepointValidation(connection_);
+      connection_->set_ecn_codepoint(ecn);
+    }
     connection_->set_client_connection_id(client_connection_id_);
     session_ = std::make_unique<QuicSpdyClientSession>(
         config_, ParsedQuicVersionVector{version_},
@@ -115,17 +120,28 @@
     const ParsedQuicVersion& version, const QuicConfig& config,
     const QuicConnectionId& server_connection_id,
     const QuicConnectionId& client_connection_id,
-    std::unique_ptr<QuicCryptoClientConfig> crypto_config) {
+    std::unique_ptr<QuicCryptoClientConfig> crypto_config,
+    QuicEcnCodepoint ecn) {
   FirstFlightExtractor first_flight_extractor(
       version, config, server_connection_id, client_connection_id,
       std::move(crypto_config));
-  first_flight_extractor.GenerateFirstFlight();
+  first_flight_extractor.GenerateFirstFlight(ecn);
   return first_flight_extractor.ConsumePackets();
 }
 
 std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
     const ParsedQuicVersion& version, const QuicConfig& config,
     const QuicConnectionId& server_connection_id,
+    const QuicConnectionId& client_connection_id,
+    std::unique_ptr<QuicCryptoClientConfig> crypto_config) {
+  return GetFirstFlightOfPackets(version, config, server_connection_id,
+                                 client_connection_id, std::move(crypto_config),
+                                 ECN_NOT_ECT);
+}
+
+std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
+    const ParsedQuicVersion& version, const QuicConfig& config,
+    const QuicConnectionId& server_connection_id,
     const QuicConnectionId& client_connection_id) {
   FirstFlightExtractor first_flight_extractor(
       version, config, server_connection_id, client_connection_id);
diff --git a/quiche/quic/test_tools/first_flight.h b/quiche/quic/test_tools/first_flight.h
index 7360cf3..2fe86c5 100644
--- a/quiche/quic/test_tools/first_flight.h
+++ b/quiche/quic/test_tools/first_flight.h
@@ -53,7 +53,7 @@
   }
   bool SupportsReleaseTime() const override { return false; }
   bool IsBatchMode() const override { return false; }
-  bool SupportsEcn() const override { return false; }
+  bool SupportsEcn() const override { return true; }
   QuicPacketBuffer GetNextWriteLocation(
       const QuicIpAddress& /*self_address*/,
       const QuicSocketAddress& /*peer_address*/) override {
@@ -82,7 +82,8 @@
     const ParsedQuicVersion& version, const QuicConfig& config,
     const QuicConnectionId& server_connection_id,
     const QuicConnectionId& client_connection_id,
-    std::unique_ptr<QuicCryptoClientConfig> crypto_config);
+    std::unique_ptr<QuicCryptoClientConfig> crypto_config,
+    QuicEcnCodepoint ecn);
 
 // Below are various convenience overloads that use default values for the
 // omitted parameters:
@@ -91,6 +92,12 @@
 // |client_connection_id| = EmptyQuicConnectionId().
 // |crypto_config| =
 //     QuicCryptoClientConfig(crypto_test_utils::ProofVerifierForTesting())
+// |ecn| = ECN_NOT_ECT
+std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
+    const ParsedQuicVersion& version, const QuicConfig& config,
+    const QuicConnectionId& server_connection_id,
+    const QuicConnectionId& client_connection_id,
+    std::unique_ptr<QuicCryptoClientConfig> crypto_config);
 std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
     const ParsedQuicVersion& version, const QuicConfig& config,
     const QuicConnectionId& server_connection_id,
diff --git a/quiche/quic/test_tools/quic_test_utils.cc b/quiche/quic/test_tools/quic_test_utils.cc
index 6298cda..8ffb8e8 100644
--- a/quiche/quic/test_tools/quic_test_utils.cc
+++ b/quiche/quic/test_tools/quic_test_utils.cc
@@ -27,6 +27,8 @@
 #include "quiche/quic/core/quic_data_writer.h"
 #include "quiche/quic/core/quic_framer.h"
 #include "quiche/quic/core/quic_packet_creator.h"
+#include "quiche/quic/core/quic_packets.h"
+#include "quiche/quic/core/quic_time.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/core/quic_utils.h"
 #include "quiche/quic/core/quic_versions.h"
@@ -1009,10 +1011,16 @@
 
 QuicReceivedPacket* ConstructReceivedPacket(
     const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
+  return ConstructReceivedPacket(encrypted_packet, receipt_time, ECN_NOT_ECT);
+}
+
+QuicReceivedPacket* ConstructReceivedPacket(
+    const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time,
+    QuicEcnCodepoint ecn) {
   char* buffer = new char[encrypted_packet.length()];
   memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
   return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
-                                true);
+                                true, 0, true, nullptr, 0, false, ecn);
 }
 
 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
diff --git a/quiche/quic/test_tools/quic_test_utils.h b/quiche/quic/test_tools/quic_test_utils.h
index 52b4c12..8fdaf98 100644
--- a/quiche/quic/test_tools/quic_test_utils.h
+++ b/quiche/quic/test_tools/quic_test_utils.h
@@ -33,6 +33,7 @@
 #include "quiche/quic/core/quic_path_validator.h"
 #include "quiche/quic/core/quic_sent_packet_manager.h"
 #include "quiche/quic/core/quic_server_id.h"
+#include "quiche/quic/core/quic_time.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/core/quic_utils.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
@@ -163,6 +164,9 @@
 // of the returned pointer.
 QuicReceivedPacket* ConstructReceivedPacket(
     const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time);
+QuicReceivedPacket* ConstructReceivedPacket(
+    const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time,
+    QuicEcnCodepoint ecn);
 
 // Create an encrypted packet for testing whose data portion erroneous.
 // The specific way the data portion is erroneous is not specified, but