Parameterize ChloExtractorTest by QUIC version

This CL also adds ALPN extraction to this test.

gfe-relnote: n/a, test-only
PiperOrigin-RevId: 307801505
Change-Id: I9c8b7be8edba6dc531746d6bb35fc20559afe28f
diff --git a/quic/core/chlo_extractor_test.cc b/quic/core/chlo_extractor_test.cc
index e293902..1b6a616 100644
--- a/quic/core/chlo_extractor_test.cc
+++ b/quic/core/chlo_extractor_test.cc
@@ -12,6 +12,7 @@
 #include "net/third_party/quiche/src/quic/core/quic_utils.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
 #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h"
+#include "net/third_party/quiche/src/quic/test_tools/first_flight.h"
 #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
 #include "net/third_party/quiche/src/common/platform/api/quiche_arraysize.h"
 #include "net/third_party/quiche/src/common/platform/api/quiche_string_piece.h"
@@ -32,50 +33,54 @@
     version_ = version;
     connection_id_ = connection_id;
     chlo_ = chlo.DebugString();
+    quiche::QuicheStringPiece alpn_value;
+    if (chlo.GetStringPiece(kALPN, &alpn_value)) {
+      alpn_ = std::string(alpn_value);
+    }
   }
 
   QuicConnectionId connection_id() const { return connection_id_; }
   QuicTransportVersion transport_version() const { return version_; }
   const std::string& chlo() const { return chlo_; }
+  const std::string& alpn() const { return alpn_; }
 
  private:
   QuicConnectionId connection_id_;
   QuicTransportVersion version_;
   std::string chlo_;
+  std::string alpn_;
 };
 
-class ChloExtractorTest : public QuicTest {
+class ChloExtractorTest : public QuicTestWithParam<ParsedQuicVersion> {
  public:
-  ChloExtractorTest() {
-    header_.destination_connection_id = TestConnectionId();
-    header_.destination_connection_id_included = CONNECTION_ID_PRESENT;
-    header_.version_flag = true;
-    header_.version = AllSupportedVersions().front();
-    header_.reset_flag = false;
-    header_.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
-    header_.packet_number = QuicPacketNumber(1);
-    if (QuicVersionHasLongHeaderLengths(header_.version.transport_version)) {
-      header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1;
-      header_.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2;
-    }
-  }
+  ChloExtractorTest() : version_(GetParam()) {}
 
-  void MakePacket(ParsedQuicVersion version,
-                  quiche::QuicheStringPiece data,
+  void MakePacket(quiche::QuicheStringPiece data,
                   bool munge_offset,
                   bool munge_stream_id) {
+    QuicPacketHeader header;
+    header.destination_connection_id = TestConnectionId();
+    header.destination_connection_id_included = CONNECTION_ID_PRESENT;
+    header.version_flag = true;
+    header.version = version_;
+    header.reset_flag = false;
+    header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
+    header.packet_number = QuicPacketNumber(1);
+    if (version_.HasLongHeaderLengths()) {
+      header.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1;
+      header.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2;
+    }
     QuicFrames frames;
     size_t offset = 0;
     if (munge_offset) {
       offset++;
     }
-    QuicFramer framer(SupportedVersions(header_.version), QuicTime::Zero(),
+    QuicFramer framer(SupportedVersions(version_), QuicTime::Zero(),
                       Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength);
     framer.SetInitialObfuscators(TestConnectionId());
-    if (!QuicVersionUsesCryptoFrames(version.transport_version) ||
-        munge_stream_id) {
+    if (!version_.UsesCryptoFrames() || munge_stream_id) {
       QuicStreamId stream_id =
-          QuicUtils::GetCryptoStreamId(version.transport_version);
+          QuicUtils::GetCryptoStreamId(version_.transport_version);
       if (munge_stream_id) {
         stream_id++;
       }
@@ -86,11 +91,11 @@
           QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, offset, data)));
     }
     std::unique_ptr<QuicPacket> packet(
-        BuildUnsizedDataPacket(&framer, header_, frames));
+        BuildUnsizedDataPacket(&framer, header, frames));
     EXPECT_TRUE(packet != nullptr);
     size_t encrypted_length =
-        framer.EncryptPayload(ENCRYPTION_INITIAL, header_.packet_number,
-                              *packet, buffer_, QUICHE_ARRAYSIZE(buffer_));
+        framer.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, *packet,
+                              buffer_, QUICHE_ARRAYSIZE(buffer_));
     ASSERT_NE(0u, encrypted_length);
     packet_ = std::make_unique<QuicEncryptedPacket>(buffer_, encrypted_length);
     EXPECT_TRUE(packet_ != nullptr);
@@ -98,79 +103,77 @@
   }
 
  protected:
+  ParsedQuicVersion version_;
   TestDelegate delegate_;
-  QuicPacketHeader header_;
   std::unique_ptr<QuicEncryptedPacket> packet_;
   char buffer_[kMaxOutgoingPacketSize];
 };
 
-TEST_F(ChloExtractorTest, FindsValidChlo) {
+INSTANTIATE_TEST_SUITE_P(
+    ChloExtractorTests,
+    ChloExtractorTest,
+    ::testing::ValuesIn(AllSupportedVersionsWithQuicCrypto()),
+    ::testing::PrintToStringParamName());
+
+TEST_P(ChloExtractorTest, FindsValidChlo) {
   CryptoHandshakeMessage client_hello;
   client_hello.set_tag(kCHLO);
 
   std::string client_hello_str(client_hello.GetSerialized().AsStringPiece());
-  // Construct a CHLO with each supported version
-  for (ParsedQuicVersion version : AllSupportedVersions()) {
-    SCOPED_TRACE(version);
-    header_.version = version;
-    if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
-        header_.version_flag) {
-      header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1;
-      header_.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2;
-    } else {
-      header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_0;
-      header_.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_0;
-    }
-    MakePacket(version, client_hello_str, /*munge_offset*/ false,
-               /*munge_stream_id*/ false);
-    EXPECT_TRUE(ChloExtractor::Extract(*packet_, version, {}, &delegate_,
-                                       kQuicDefaultConnectionIdLength))
-        << ParsedQuicVersionToString(version);
-    EXPECT_EQ(version.transport_version, delegate_.transport_version());
-    EXPECT_EQ(header_.destination_connection_id, delegate_.connection_id());
-    EXPECT_EQ(client_hello.DebugString(), delegate_.chlo())
-        << ParsedQuicVersionToString(version);
-  }
+
+  MakePacket(client_hello_str, /*munge_offset=*/false,
+             /*munge_stream_id=*/false);
+  EXPECT_TRUE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_,
+                                     kQuicDefaultConnectionIdLength));
+  EXPECT_EQ(version_.transport_version, delegate_.transport_version());
+  EXPECT_EQ(TestConnectionId(), delegate_.connection_id());
+  EXPECT_EQ(client_hello.DebugString(), delegate_.chlo());
 }
 
-TEST_F(ChloExtractorTest, DoesNotFindValidChloOnWrongStream) {
-  ParsedQuicVersion version = AllSupportedVersions()[0];
-  if (QuicVersionUsesCryptoFrames(version.transport_version)) {
+TEST_P(ChloExtractorTest, DoesNotFindValidChloOnWrongStream) {
+  if (version_.UsesCryptoFrames()) {
+    // When crypto frames are in use we do not use stream frames.
     return;
   }
   CryptoHandshakeMessage client_hello;
   client_hello.set_tag(kCHLO);
 
   std::string client_hello_str(client_hello.GetSerialized().AsStringPiece());
-  MakePacket(version, client_hello_str,
-             /*munge_offset*/ false, /*munge_stream_id*/ true);
-  EXPECT_FALSE(ChloExtractor::Extract(*packet_, version, {}, &delegate_,
+  MakePacket(client_hello_str,
+             /*munge_offset=*/false, /*munge_stream_id=*/true);
+  EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_,
                                       kQuicDefaultConnectionIdLength));
 }
 
-TEST_F(ChloExtractorTest, DoesNotFindValidChloOnWrongOffset) {
-  ParsedQuicVersion version = AllSupportedVersions()[0];
+TEST_P(ChloExtractorTest, DoesNotFindValidChloOnWrongOffset) {
   CryptoHandshakeMessage client_hello;
   client_hello.set_tag(kCHLO);
 
   std::string client_hello_str(client_hello.GetSerialized().AsStringPiece());
-  MakePacket(version, client_hello_str, /*munge_offset*/ true,
-             /*munge_stream_id*/ false);
-  EXPECT_FALSE(ChloExtractor::Extract(*packet_, version, {}, &delegate_,
+  MakePacket(client_hello_str, /*munge_offset=*/true,
+             /*munge_stream_id=*/false);
+  EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_,
                                       kQuicDefaultConnectionIdLength));
 }
 
-TEST_F(ChloExtractorTest, DoesNotFindInvalidChlo) {
-  ParsedQuicVersion version = AllSupportedVersions()[0];
-  if (QuicVersionUsesCryptoFrames(version.transport_version)) {
-    return;
-  }
-  MakePacket(version, "foo", /*munge_offset*/ false,
-             /*munge_stream_id*/ true);
-  EXPECT_FALSE(ChloExtractor::Extract(*packet_, version, {}, &delegate_,
+TEST_P(ChloExtractorTest, DoesNotFindInvalidChlo) {
+  MakePacket("foo", /*munge_offset=*/false,
+             /*munge_stream_id=*/false);
+  EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_,
                                       kQuicDefaultConnectionIdLength));
 }
 
+TEST_P(ChloExtractorTest, FirstFlight) {
+  std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
+      GetFirstFlightOfPackets(version_);
+  ASSERT_EQ(packets.size(), 1u);
+  EXPECT_TRUE(ChloExtractor::Extract(*packets[0], version_, {}, &delegate_,
+                                     kQuicDefaultConnectionIdLength));
+  EXPECT_EQ(version_.transport_version, delegate_.transport_version());
+  EXPECT_EQ(TestConnectionId(), delegate_.connection_id());
+  EXPECT_EQ(AlpnForVersion(version_), delegate_.alpn());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic