Improve ECN coverage in QuicConnectionTest

PiperOrigin-RevId: 501664181
diff --git a/quiche/quic/core/quic_connection_test.cc b/quiche/quic/core/quic_connection_test.cc
index a3a4144..98027d4 100644
--- a/quiche/quic/core/quic_connection_test.cc
+++ b/quiche/quic/core/quic_connection_test.cc
@@ -6,6 +6,7 @@
 
 #include <errno.h>
 
+#include <cstdint>
 #include <memory>
 #include <ostream>
 #include <string>
@@ -876,13 +877,28 @@
 
   size_t ProcessFramePacketAtLevel(uint64_t number, QuicFrame frame,
                                    EncryptionLevel level) {
-    QuicFrames frames;
-    frames.push_back(frame);
-    return ProcessFramesPacketAtLevel(number, frames, level);
+    return ProcessFramePacketAtLevelWithEcn(number, frame, level, ECN_NOT_ECT);
   }
 
-  size_t ProcessFramesPacketAtLevel(uint64_t number, const QuicFrames& frames,
+  size_t ProcessFramePacketAtLevelWithEcn(uint64_t number, QuicFrame frame,
+                                          EncryptionLevel level,
+                                          QuicEcnCodepoint ecn_codepoint) {
+    QuicFrames frames;
+    frames.push_back(frame);
+    return ProcessFramesPacketAtLevelWithEcn(number, frames, level,
+                                             ecn_codepoint);
+  }
+
+  size_t ProcessFramesPacketAtLevel(uint64_t number, QuicFrames frames,
                                     EncryptionLevel level) {
+    return ProcessFramesPacketAtLevelWithEcn(number, frames, level,
+                                             ECN_NOT_ECT);
+  }
+
+  size_t ProcessFramesPacketAtLevelWithEcn(uint64_t number,
+                                           const QuicFrames& frames,
+                                           EncryptionLevel level,
+                                           QuicEcnCodepoint ecn_codepoint) {
     QuicPacketHeader header = ConstructPacketHeader(number, level);
     // Set the correct encryption level and encrypter on peer_creator and
     // peer_framer, respectively.
@@ -907,7 +923,8 @@
                                     buffer, kMaxOutgoingPacketSize);
     connection_.ProcessUdpPacket(
         kSelfAddress, kPeerAddress,
-        QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false));
+        QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false, 0,
+                           true, nullptr, 0, false, ecn_codepoint));
     if (connection_.GetSendAlarm()->IsSet()) {
       connection_.GetSendAlarm()->Fire();
     }
@@ -924,6 +941,11 @@
   };
 
   size_t ProcessCoalescedPacket(std::vector<PacketInfo> packets) {
+    return ProcessCoalescedPacket(packets, ECN_NOT_ECT);
+  }
+
+  size_t ProcessCoalescedPacket(std::vector<PacketInfo> packets,
+                                QuicEcnCodepoint ecn_codepoint) {
     char coalesced_buffer[kMaxOutgoingPacketSize];
     size_t coalesced_size = 0;
     bool contains_initial = false;
@@ -970,7 +992,7 @@
     connection_.ProcessUdpPacket(
         kSelfAddress, kPeerAddress,
         QuicReceivedPacket(coalesced_buffer, coalesced_size, clock_.Now(),
-                           false));
+                           false, 0, true, nullptr, 0, false, ecn_codepoint));
     if (connection_.GetSendAlarm()->IsSet()) {
       connection_.GetSendAlarm()->Fire();
     }
@@ -16499,59 +16521,193 @@
 
 TEST_P(QuicConnectionTest, EcnMarksCorrectlyRecorded) {
   set_perspective(Perspective::IS_SERVER);
-  QuicPacketHeader header = ConstructPacketHeader(1, ENCRYPTION_FORWARD_SECURE);
   QuicFrames frames;
-  QuicPingFrame ping_frame;
-  QuicPaddingFrame padding_frame;
-  frames.push_back(QuicFrame(ping_frame));
-  frames.push_back(QuicFrame(padding_frame));
-  std::unique_ptr<QuicPacket> packet =
-      BuildUnsizedDataPacket(&peer_framer_, header, frames);
-  char buffer[kMaxOutgoingPacketSize];
-  size_t encrypted_length = peer_framer_.EncryptPayload(
-      ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(1), *packet, buffer,
-      kMaxOutgoingPacketSize);
-  QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(),
-                                     false, 0, true, nullptr, 0, false,
-                                     ECN_ECT0);
-  if (connection_.SupportsMultiplePacketNumberSpaces()) {
-    EXPECT_FALSE(connection_.received_packet_manager()
-                     .GetAckFrame(APPLICATION_DATA)
-                     .ecn_counters.has_value());
-    connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, received_packet);
-    if (GetQuicRestartFlag(quic_receive_ecn)) {
-      EXPECT_TRUE(connection_.received_packet_manager()
-                      .GetAckFrame(APPLICATION_DATA)
-                      .ecn_counters.has_value());
-      EXPECT_EQ(connection_.received_packet_manager()
-                    .GetAckFrame(APPLICATION_DATA)
-                    .ecn_counters->ect0,
-                1);
-    } else {
-      EXPECT_FALSE(connection_.received_packet_manager()
-                       .GetAckFrame(APPLICATION_DATA)
-                       .ecn_counters.has_value());
-    }
+  frames.push_back(QuicFrame(QuicPingFrame()));
+  frames.push_back(QuicFrame(QuicPaddingFrame(7)));
+  QuicAckFrame ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+
+  ProcessFramesPacketAtLevelWithEcn(1, frames, ENCRYPTION_FORWARD_SECURE,
+                                    ECN_ECT0);
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  if (GetQuicRestartFlag(quic_receive_ecn)) {
+    ASSERT_TRUE(ack_frame.ecn_counters.has_value());
+    EXPECT_EQ(ack_frame.ecn_counters->ect0, 1);
   } else {
-    EXPECT_FALSE(connection_.received_packet_manager()
-                     .ack_frame()
-                     .ecn_counters.has_value());
-    connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, received_packet);
-    if (GetQuicRestartFlag(quic_receive_ecn)) {
-      EXPECT_TRUE(connection_.received_packet_manager()
-                      .ack_frame()
-                      .ecn_counters.has_value());
-      EXPECT_EQ(
-          connection_.received_packet_manager().ack_frame().ecn_counters->ect0,
-          1);
-    } else {
-      EXPECT_FALSE(connection_.received_packet_manager()
-                       .ack_frame()
-                       .ecn_counters.has_value());
-    }
+    EXPECT_FALSE(ack_frame.ecn_counters.has_value());
   }
 }
 
+TEST_P(QuicConnectionTest, EcnMarksCoalescedPacket) {
+  if (!connection_.version().CanSendCoalescedPackets() ||
+      !GetQuicRestartFlag(quic_receive_ecn)) {
+    return;
+  }
+  QuicCryptoFrame crypto_frame1{ENCRYPTION_HANDSHAKE, 0, "foo"};
+  QuicFrames frames1;
+  frames1.push_back(QuicFrame(&crypto_frame1));
+  QuicFrames frames2;
+  QuicCryptoFrame crypto_frame2{ENCRYPTION_FORWARD_SECURE, 0, "bar"};
+  frames2.push_back(QuicFrame(&crypto_frame2));
+  std::vector<PacketInfo> packets = {{2, frames1, ENCRYPTION_HANDSHAKE},
+                                     {3, frames2, ENCRYPTION_FORWARD_SECURE}};
+  QuicAckFrame ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+  // Deliver packets.
+  connection_.SetEncrypter(
+      ENCRYPTION_HANDSHAKE,
+      std::make_unique<TaggingEncrypter>(ENCRYPTION_HANDSHAKE));
+  EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(2);
+  ProcessCoalescedPacket(packets, ECN_ECT0);
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  ASSERT_TRUE(ack_frame.ecn_counters.has_value());
+  EXPECT_EQ(ack_frame.ecn_counters->ect0,
+            connection_.SupportsMultiplePacketNumberSpaces() ? 1 : 2);
+  if (connection_.SupportsMultiplePacketNumberSpaces()) {
+    ack_frame = connection_.SupportsMultiplePacketNumberSpaces()
+                    ? connection_.received_packet_manager().GetAckFrame(
+                          APPLICATION_DATA)
+                    : connection_.received_packet_manager().ack_frame();
+    EXPECT_TRUE(ack_frame.ecn_counters.has_value());
+    EXPECT_EQ(ack_frame.ecn_counters->ect0, 1);
+  }
+}
+
+TEST_P(QuicConnectionTest, EcnMarksUndecryptableCoalescedPacket) {
+  if (!connection_.version().CanSendCoalescedPackets() ||
+      !GetQuicRestartFlag(quic_receive_ecn)) {
+    return;
+  }
+  // SetFromConfig is always called after construction from InitializeSession.
+  EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+  QuicConfig config;
+  config.set_max_undecryptable_packets(100);
+  connection_.SetFromConfig(config);
+  QuicCryptoFrame crypto_frame1{ENCRYPTION_HANDSHAKE, 0, "foo"};
+  QuicFrames frames1;
+  frames1.push_back(QuicFrame(&crypto_frame1));
+  QuicFrames frames2;
+  QuicCryptoFrame crypto_frame2{ENCRYPTION_FORWARD_SECURE, 0, "bar"};
+  frames2.push_back(QuicFrame(&crypto_frame2));
+  std::vector<PacketInfo> packets = {{2, frames1, ENCRYPTION_HANDSHAKE},
+                                     {3, frames2, ENCRYPTION_FORWARD_SECURE}};
+  char coalesced_buffer[kMaxOutgoingPacketSize];
+  size_t coalesced_size = 0;
+  for (const auto& packet : packets) {
+    QuicPacketHeader header =
+        ConstructPacketHeader(packet.packet_number, packet.level);
+    // Set the correct encryption level and encrypter on peer_creator and
+    // peer_framer, respectively.
+    peer_creator_.set_encryption_level(packet.level);
+    peer_framer_.SetEncrypter(packet.level,
+                              std::make_unique<TaggingEncrypter>(packet.level));
+    // Set the corresponding decrypter.
+    if (packet.level == ENCRYPTION_HANDSHAKE) {
+      connection_.SetEncrypter(
+          packet.level, std::make_unique<TaggingEncrypter>(packet.level));
+      connection_.SetDefaultEncryptionLevel(packet.level);
+      SetDecrypter(packet.level,
+                   std::make_unique<StrictTaggingDecrypter>(packet.level));
+    }
+    // Forward Secure packet is undecryptable.
+    std::unique_ptr<QuicPacket> constructed_packet(
+        ConstructPacket(header, packet.frames));
+
+    char buffer[kMaxOutgoingPacketSize];
+    size_t encrypted_length = peer_framer_.EncryptPayload(
+        packet.level, QuicPacketNumber(packet.packet_number),
+        *constructed_packet, buffer, kMaxOutgoingPacketSize);
+    QUICHE_DCHECK_LE(coalesced_size + encrypted_length, kMaxOutgoingPacketSize);
+    memcpy(coalesced_buffer + coalesced_size, buffer, encrypted_length);
+    coalesced_size += encrypted_length;
+  }
+  QuicAckFrame ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+  // Deliver packets, but first remove the Forward Secure decrypter so that
+  // packet has to be buffered.
+  connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE);
+  EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1);
+  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  connection_.ProcessUdpPacket(
+      kSelfAddress, kPeerAddress,
+      QuicReceivedPacket(coalesced_buffer, coalesced_size, clock_.Now(), false,
+                         0, true, nullptr, 0, true, ECN_ECT0));
+  if (connection_.GetSendAlarm()->IsSet()) {
+    connection_.GetSendAlarm()->Fire();
+  }
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  ASSERT_TRUE(ack_frame.ecn_counters.has_value());
+  EXPECT_EQ(ack_frame.ecn_counters->ect0, 1);
+  if (connection_.SupportsMultiplePacketNumberSpaces()) {
+    ack_frame = connection_.SupportsMultiplePacketNumberSpaces()
+                    ? connection_.received_packet_manager().GetAckFrame(
+                          APPLICATION_DATA)
+                    : connection_.received_packet_manager().ack_frame();
+    EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+  }
+  // Send PING packet with ECN_CE, which will change the ECN codepoint in
+  // last_received_packet_info_.
+  ProcessFramePacketAtLevelWithEcn(4, QuicFrame(QuicPingFrame()),
+                                   ENCRYPTION_HANDSHAKE, ECN_CE);
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  ASSERT_TRUE(ack_frame.ecn_counters.has_value());
+  EXPECT_EQ(ack_frame.ecn_counters->ect0, 1);
+  EXPECT_EQ(ack_frame.ecn_counters->ce, 1);
+  if (connection_.SupportsMultiplePacketNumberSpaces()) {
+    ack_frame = connection_.SupportsMultiplePacketNumberSpaces()
+                    ? connection_.received_packet_manager().GetAckFrame(
+                          APPLICATION_DATA)
+                    : connection_.received_packet_manager().ack_frame();
+    EXPECT_FALSE(ack_frame.ecn_counters.has_value());
+  }
+  // Install decrypter for ENCRYPTION_FORWARD_SECURE. Make sure the original
+  // ECN codepoint is incremented.
+  EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1);
+  SetDecrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE));
+  connection_.GetProcessUndecryptablePacketsAlarm()->Fire();
+  ack_frame =
+      connection_.SupportsMultiplePacketNumberSpaces()
+          ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA)
+          : connection_.received_packet_manager().ack_frame();
+  ASSERT_TRUE(ack_frame.ecn_counters.has_value());
+  // Should be recorded as ECT(0), not CE.
+  EXPECT_EQ(ack_frame.ecn_counters->ect0,
+            connection_.SupportsMultiplePacketNumberSpaces() ? 1 : 2);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic