Parameterize ChloExtractorTest by QUIC version This CL also adds ALPN extraction to this test. gfe-relnote: n/a, test-only PiperOrigin-RevId: 307801505 Change-Id: I9c8b7be8edba6dc531746d6bb35fc20559afe28f
diff --git a/quic/core/chlo_extractor_test.cc b/quic/core/chlo_extractor_test.cc index e293902..1b6a616 100644 --- a/quic/core/chlo_extractor_test.cc +++ b/quic/core/chlo_extractor_test.cc
@@ -12,6 +12,7 @@ #include "net/third_party/quiche/src/quic/core/quic_utils.h" #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h" +#include "net/third_party/quiche/src/quic/test_tools/first_flight.h" #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h" #include "net/third_party/quiche/src/common/platform/api/quiche_arraysize.h" #include "net/third_party/quiche/src/common/platform/api/quiche_string_piece.h" @@ -32,50 +33,54 @@ version_ = version; connection_id_ = connection_id; chlo_ = chlo.DebugString(); + quiche::QuicheStringPiece alpn_value; + if (chlo.GetStringPiece(kALPN, &alpn_value)) { + alpn_ = std::string(alpn_value); + } } QuicConnectionId connection_id() const { return connection_id_; } QuicTransportVersion transport_version() const { return version_; } const std::string& chlo() const { return chlo_; } + const std::string& alpn() const { return alpn_; } private: QuicConnectionId connection_id_; QuicTransportVersion version_; std::string chlo_; + std::string alpn_; }; -class ChloExtractorTest : public QuicTest { +class ChloExtractorTest : public QuicTestWithParam<ParsedQuicVersion> { public: - ChloExtractorTest() { - header_.destination_connection_id = TestConnectionId(); - header_.destination_connection_id_included = CONNECTION_ID_PRESENT; - header_.version_flag = true; - header_.version = AllSupportedVersions().front(); - header_.reset_flag = false; - header_.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; - header_.packet_number = QuicPacketNumber(1); - if (QuicVersionHasLongHeaderLengths(header_.version.transport_version)) { - header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1; - header_.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2; - } - } + ChloExtractorTest() : version_(GetParam()) {} - void MakePacket(ParsedQuicVersion version, - quiche::QuicheStringPiece data, + void MakePacket(quiche::QuicheStringPiece data, bool munge_offset, bool munge_stream_id) { + QuicPacketHeader header; + header.destination_connection_id = TestConnectionId(); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + header.version_flag = true; + header.version = version_; + header.reset_flag = false; + header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header.packet_number = QuicPacketNumber(1); + if (version_.HasLongHeaderLengths()) { + header.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2; + } QuicFrames frames; size_t offset = 0; if (munge_offset) { offset++; } - QuicFramer framer(SupportedVersions(header_.version), QuicTime::Zero(), + QuicFramer framer(SupportedVersions(version_), QuicTime::Zero(), Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength); framer.SetInitialObfuscators(TestConnectionId()); - if (!QuicVersionUsesCryptoFrames(version.transport_version) || - munge_stream_id) { + if (!version_.UsesCryptoFrames() || munge_stream_id) { QuicStreamId stream_id = - QuicUtils::GetCryptoStreamId(version.transport_version); + QuicUtils::GetCryptoStreamId(version_.transport_version); if (munge_stream_id) { stream_id++; } @@ -86,11 +91,11 @@ QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, offset, data))); } std::unique_ptr<QuicPacket> packet( - BuildUnsizedDataPacket(&framer, header_, frames)); + BuildUnsizedDataPacket(&framer, header, frames)); EXPECT_TRUE(packet != nullptr); size_t encrypted_length = - framer.EncryptPayload(ENCRYPTION_INITIAL, header_.packet_number, - *packet, buffer_, QUICHE_ARRAYSIZE(buffer_)); + framer.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, *packet, + buffer_, QUICHE_ARRAYSIZE(buffer_)); ASSERT_NE(0u, encrypted_length); packet_ = std::make_unique<QuicEncryptedPacket>(buffer_, encrypted_length); EXPECT_TRUE(packet_ != nullptr); @@ -98,79 +103,77 @@ } protected: + ParsedQuicVersion version_; TestDelegate delegate_; - QuicPacketHeader header_; std::unique_ptr<QuicEncryptedPacket> packet_; char buffer_[kMaxOutgoingPacketSize]; }; -TEST_F(ChloExtractorTest, FindsValidChlo) { +INSTANTIATE_TEST_SUITE_P( + ChloExtractorTests, + ChloExtractorTest, + ::testing::ValuesIn(AllSupportedVersionsWithQuicCrypto()), + ::testing::PrintToStringParamName()); + +TEST_P(ChloExtractorTest, FindsValidChlo) { CryptoHandshakeMessage client_hello; client_hello.set_tag(kCHLO); std::string client_hello_str(client_hello.GetSerialized().AsStringPiece()); - // Construct a CHLO with each supported version - for (ParsedQuicVersion version : AllSupportedVersions()) { - SCOPED_TRACE(version); - header_.version = version; - if (QuicVersionHasLongHeaderLengths(version.transport_version) && - header_.version_flag) { - header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1; - header_.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2; - } else { - header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_0; - header_.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_0; - } - MakePacket(version, client_hello_str, /*munge_offset*/ false, - /*munge_stream_id*/ false); - EXPECT_TRUE(ChloExtractor::Extract(*packet_, version, {}, &delegate_, - kQuicDefaultConnectionIdLength)) - << ParsedQuicVersionToString(version); - EXPECT_EQ(version.transport_version, delegate_.transport_version()); - EXPECT_EQ(header_.destination_connection_id, delegate_.connection_id()); - EXPECT_EQ(client_hello.DebugString(), delegate_.chlo()) - << ParsedQuicVersionToString(version); - } + + MakePacket(client_hello_str, /*munge_offset=*/false, + /*munge_stream_id=*/false); + EXPECT_TRUE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); + EXPECT_EQ(version_.transport_version, delegate_.transport_version()); + EXPECT_EQ(TestConnectionId(), delegate_.connection_id()); + EXPECT_EQ(client_hello.DebugString(), delegate_.chlo()); } -TEST_F(ChloExtractorTest, DoesNotFindValidChloOnWrongStream) { - ParsedQuicVersion version = AllSupportedVersions()[0]; - if (QuicVersionUsesCryptoFrames(version.transport_version)) { +TEST_P(ChloExtractorTest, DoesNotFindValidChloOnWrongStream) { + if (version_.UsesCryptoFrames()) { + // When crypto frames are in use we do not use stream frames. return; } CryptoHandshakeMessage client_hello; client_hello.set_tag(kCHLO); std::string client_hello_str(client_hello.GetSerialized().AsStringPiece()); - MakePacket(version, client_hello_str, - /*munge_offset*/ false, /*munge_stream_id*/ true); - EXPECT_FALSE(ChloExtractor::Extract(*packet_, version, {}, &delegate_, + MakePacket(client_hello_str, + /*munge_offset=*/false, /*munge_stream_id=*/true); + EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, kQuicDefaultConnectionIdLength)); } -TEST_F(ChloExtractorTest, DoesNotFindValidChloOnWrongOffset) { - ParsedQuicVersion version = AllSupportedVersions()[0]; +TEST_P(ChloExtractorTest, DoesNotFindValidChloOnWrongOffset) { CryptoHandshakeMessage client_hello; client_hello.set_tag(kCHLO); std::string client_hello_str(client_hello.GetSerialized().AsStringPiece()); - MakePacket(version, client_hello_str, /*munge_offset*/ true, - /*munge_stream_id*/ false); - EXPECT_FALSE(ChloExtractor::Extract(*packet_, version, {}, &delegate_, + MakePacket(client_hello_str, /*munge_offset=*/true, + /*munge_stream_id=*/false); + EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, kQuicDefaultConnectionIdLength)); } -TEST_F(ChloExtractorTest, DoesNotFindInvalidChlo) { - ParsedQuicVersion version = AllSupportedVersions()[0]; - if (QuicVersionUsesCryptoFrames(version.transport_version)) { - return; - } - MakePacket(version, "foo", /*munge_offset*/ false, - /*munge_stream_id*/ true); - EXPECT_FALSE(ChloExtractor::Extract(*packet_, version, {}, &delegate_, +TEST_P(ChloExtractorTest, DoesNotFindInvalidChlo) { + MakePacket("foo", /*munge_offset=*/false, + /*munge_stream_id=*/false); + EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, kQuicDefaultConnectionIdLength)); } +TEST_P(ChloExtractorTest, FirstFlight) { + std::vector<std::unique_ptr<QuicReceivedPacket>> packets = + GetFirstFlightOfPackets(version_); + ASSERT_EQ(packets.size(), 1u); + EXPECT_TRUE(ChloExtractor::Extract(*packets[0], version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); + EXPECT_EQ(version_.transport_version, delegate_.transport_version()); + EXPECT_EQ(TestConnectionId(), delegate_.connection_id()); + EXPECT_EQ(AlpnForVersion(version_), delegate_.alpn()); +} + } // namespace } // namespace test } // namespace quic