Add ECN counters to ACKs sent by the QuicBufferedPacketStore. Also add a test that checks for the existence of the ACK, with the ECN count. Protected by quic_reloadable_flag_quic_ecn_in_first_ack. PiperOrigin-RevId: 694503577
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h index 48bd038..b4f0c7b 100755 --- a/quiche/common/quiche_feature_flags_list.h +++ b/quiche/common/quiche_feature_flags_list.h
@@ -31,6 +31,7 @@ QUICHE_FLAG(bool, quiche_reloadable_flag_quic_disable_version_q046, false, true, "If true, disable QUIC version Q046.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_disable_version_rfcv1, false, false, "If true, disable QUIC version h3 (RFCv1).") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_discard_initial_packet_with_key_dropped, false, true, "If true, discard INITIAL packet if the key has been dropped.") +QUICHE_FLAG(bool, quiche_reloadable_flag_quic_ecn_in_first_ack, false, false, "When true, reports ECN in counts in the ACK of the a client initial that goes in the buffered packet store.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_enable_disable_resumption, true, true, "If true, disable resumption when receiving NRES connection option.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_enable_mtu_discovery_at_server, false, false, "If true, QUIC will default enable MTU discovery at server, with a target of 1450 bytes.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_enable_server_on_wire_ping, true, true, "If true, enable server retransmittable on wire PING.")
diff --git a/quiche/quic/core/quic_buffered_packet_store.cc b/quiche/quic/core/quic_buffered_packet_store.cc index 0392ecf..4421b66 100644 --- a/quiche/quic/core/quic_buffered_packet_store.cc +++ b/quiche/quic/core/quic_buffered_packet_store.cc
@@ -69,6 +69,19 @@ QuicBufferedPacketStore* connection_store_; }; +std::optional<QuicEcnCounts> SinglePacketEcnCount( + QuicEcnCodepoint ecn_codepoint) { + switch (ecn_codepoint) { + case ECN_CE: + return QuicEcnCounts(0, 0, 1); + case ECN_ECT0: + return QuicEcnCounts(1, 0, 0); + case ECN_ECT1: + return QuicEcnCounts(0, 1, 0); + default: + return std::nullopt; + } +} } // namespace BufferedPacket::BufferedPacket(std::unique_ptr<QuicReceivedPacket> packet, @@ -313,7 +326,11 @@ initial_ack_frame.packets.Add(sent_packet.received_packet_number); } initial_ack_frame.largest_acked = initial_ack_frame.packets.Max(); - + if (GetQuicReloadableFlag(quic_ecn_in_first_ack)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_ecn_in_first_ack); + initial_ack_frame.ecn_counters = + SinglePacketEcnCount(packet_info.packet.ecn_codepoint()); + } if (!creator.AddFrame(QuicFrame(&initial_ack_frame), NOT_RETRANSMISSION)) { QUIC_BUG(quic_dispatcher_add_ack_frame_failed) << "Unable to add ack frame to an empty packet while acking packet "
diff --git a/quiche/quic/core/quic_buffered_packet_store_test.cc b/quiche/quic/core/quic_buffered_packet_store_test.cc index 107acaa..60495dc 100644 --- a/quiche/quic/core/quic_buffered_packet_store_test.cc +++ b/quiche/quic/core/quic_buffered_packet_store_test.cc
@@ -6,6 +6,7 @@ #include <cstddef> #include <cstdint> +#include <cstring> #include <list> #include <memory> #include <optional> @@ -16,12 +17,16 @@ #include "absl/strings/string_view.h" #include "quiche/quic/core/connection_id_generator.h" #include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/frames/quic_padding_frame.h" #include "quiche/quic/core/quic_connection_id.h" #include "quiche/quic/core/quic_constants.h" #include "quiche/quic/core/quic_dispatcher.h" #include "quiche/quic/core/quic_dispatcher_stats.h" #include "quiche/quic/core/quic_error_codes.h" #include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_packet_writer.h" #include "quiche/quic/core/quic_packets.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/core/quic_types.h" @@ -34,6 +39,7 @@ #include "quiche/quic/test_tools/mock_connection_id_generator.h" #include "quiche/quic/test_tools/quic_buffered_packet_store_peer.h" #include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/quiche_endian.h" namespace quic { static const size_t kDefaultMaxConnectionsInStore = 100; @@ -50,11 +56,14 @@ using BufferedPacket = QuicBufferedPacketStore::BufferedPacket; using BufferedPacketList = QuicBufferedPacketStore::BufferedPacketList; using EnqueuePacketResult = QuicBufferedPacketStore::EnqueuePacketResult; +using ::testing::_; using ::testing::A; using ::testing::Conditional; using ::testing::Each; using ::testing::ElementsAre; +using ::testing::Invoke; using ::testing::Ne; +using ::testing::Return; using ::testing::SizeIs; using ::testing::Truly; @@ -149,6 +158,83 @@ EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); } +TEST_F(QuicBufferedPacketStoreTest, SimpleEnqueueAckSent) { + SetQuicReloadableFlag(quic_ecn_in_first_ack, true); + QuicConnectionId connection_id = TestConnectionId(1); + MockPacketWriter writer; + store_.set_writer(&writer); + // Build a decryptable Initial packet with PADDING. + QuicFramer client_framer(ParsedQuicVersionVector{ParsedQuicVersion::RFCv1()}, + QuicTime::Zero(), Perspective::IS_CLIENT, 8); + client_framer.SetInitialObfuscators(connection_id); + QuicPacketHeader header; + header.destination_connection_id = connection_id; + header.version_flag = true; + header.packet_number = QuicPacketNumber(1); + header.packet_number_length = PACKET_1BYTE_PACKET_NUMBER; + header.long_packet_type = INITIAL; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + QuicFrames frames = {QuicFrame(QuicPaddingFrame(1200))}; + + char* buffer = new char[1500]; + EncryptionLevel level = HeaderToEncryptionLevel(header); + size_t length = + client_framer.BuildDataPacket(header, frames, buffer, 1500, level); + + ASSERT_GT(length, 0); + + // Re-construct the data packet with data ownership. + auto data = std::make_unique<QuicPacket>( + buffer, length, /* owns_buffer */ true, + GetIncludedDestinationConnectionIdLength(header), + GetIncludedSourceConnectionIdLength(header), header.version_flag, + header.nonce != nullptr, header.packet_number_length, + header.retry_token_length_length, header.retry_token.length(), + header.length_length); + unsigned char raw[1500] = {}; + size_t final_size = client_framer.EncryptPayload( + ENCRYPTION_INITIAL, header.packet_number, *data, (char*)raw, 1500); + QuicReceivedPacket packet((char*)raw, final_size, QuicTime::Zero(), false, 0, + true, nullptr, 0, false, ECN_ECT1); + + EXPECT_CALL(writer, IsWriteBlocked()).WillOnce(Return(false)); + std::unique_ptr<QuicEncryptedPacket> ack_packet; + EXPECT_CALL(writer, WritePacket(_, _, _, _, _, _)) + .WillOnce(Invoke([&](const char* buffer, size_t buf_len, + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, + PerPacketOptions* /*options*/, + const QuicPacketWriterParams& /*params*/) { + auto tmp_packet = + std::make_unique<QuicEncryptedPacket>(buffer, buf_len); + ack_packet = tmp_packet->Clone(); + return WriteResult(WRITE_STATUS_OK, 1); + })); + EXPECT_CALL(writer, Flush()); + EnqueuePacketToStore(store_, connection_id, IETF_QUIC_LONG_HEADER_PACKET, + INITIAL, packet, self_address_, peer_address_, + ParsedQuicVersion::RFCv1(), kNoParsedChlo, + connection_id_generator_); + const BufferedPacketList* buffered_list = store_.GetPacketList(connection_id); + ASSERT_NE(buffered_list, nullptr); + ASSERT_EQ(buffered_list->dispatcher_sent_packets.size(), 1); + EXPECT_EQ(buffered_list->dispatcher_sent_packets[0].largest_acked, + QuicPacketNumber(1)); + + // Decrypt the packet, and verify it reports ECN. + MockFramerVisitor mock_framer_visitor; + client_framer.set_visitor(&mock_framer_visitor); + EXPECT_CALL(mock_framer_visitor, OnPacket()).Times(1); + EXPECT_CALL(mock_framer_visitor, OnAckFrameStart(_, _)) + .WillOnce(Return(true)); + EXPECT_CALL(mock_framer_visitor, OnAckRange(_, _)).WillOnce(Return(true)); + std::optional<QuicEcnCounts> counts = QuicEcnCounts(0, 1, 0); + EXPECT_CALL(mock_framer_visitor, OnAckFrameEnd(_, counts)) + .WillOnce(Return(true)); + client_framer.ProcessPacket(*ack_packet); +} + TEST_F(QuicBufferedPacketStoreTest, DifferentPacketAddressOnOneConnection) { QuicSocketAddress addr_with_new_port(QuicIpAddress::Any4(), 256); QuicConnectionId connection_id = TestConnectionId(1);