Fix NetSLO QUIC prober tests

These broke when we landed cl/335846456 because the tests were hand-crafting QUIC packets. This CL adds a new test-only API to QUIC to perform this parsing, and then uses that API from the NetSLO tests. This CL also slightly tweaks the NetSLO code to populate the connection ID length to improve memory safety.

PiperOrigin-RevId: 335974800
Change-Id: Ide0a6e30b6510f606fd40da91f9b61598883d926
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc
index e90489a..ae0d594 100644
--- a/quic/core/quic_dispatcher_test.cc
+++ b/quic/core/quic_dispatcher_test.cc
@@ -1317,7 +1317,7 @@
   ASSERT_EQ(1u, saving_writer->packets()->size());
 
   char source_connection_id_bytes[255] = {};
-  uint8_t source_connection_id_length = 0;
+  uint8_t source_connection_id_length = sizeof(source_connection_id_bytes);
   std::string detailed_error = "foobar";
   EXPECT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse(
       (*(saving_writer->packets()))[0]->data(),
@@ -1361,7 +1361,7 @@
   ASSERT_EQ(1u, saving_writer->packets()->size());
 
   char source_connection_id_bytes[255] = {};
-  uint8_t source_connection_id_length = 0;
+  uint8_t source_connection_id_length = sizeof(source_connection_id_bytes);
   std::string detailed_error = "foobar";
   EXPECT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse(
       (*(saving_writer->packets()))[0]->data(),
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index e57277c..cbf1959 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -6716,6 +6716,13 @@
     *detailed_error = "Received unexpected destination connection ID length";
     return false;
   }
+  if (*source_connection_id_length_out < source_connection_id.length()) {
+    *detailed_error = quiche::QuicheStrCat(
+        "*source_connection_id_length_out too small ",
+        static_cast<int>(*source_connection_id_length_out), " < ",
+        static_cast<int>(source_connection_id.length()));
+    return false;
+  }
 
   memcpy(source_connection_id_bytes, source_connection_id.data(),
          source_connection_id.length());
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h
index 4fcd671..c9820d4 100644
--- a/quic/core/quic_framer.h
+++ b/quic/core/quic_framer.h
@@ -397,7 +397,7 @@
       uint64_t retry_token_length,
       QuicVariableLengthIntegerLength length_length);
 
-  // Parses the unencryoted fields in a QUIC header using |reader| as input,
+  // Parses the unencrypted fields in a QUIC header using |reader| as input,
   // stores the result in the other parameters.
   // |expected_destination_connection_id_length| is only used for short headers.
   static QuicErrorCode ParsePublicHeader(
@@ -417,7 +417,7 @@
       quiche::QuicheStringPiece* retry_token,
       std::string* detailed_error);
 
-  // Parses the unencryoted fields in |packet| and stores them in the other
+  // Parses the unencrypted fields in |packet| and stores them in the other
   // parameters. This can only be called on the server.
   // |expected_destination_connection_id_length| is only used for short headers.
   static QuicErrorCode ParsePublicHeaderDispatcher(
@@ -637,10 +637,12 @@
   // WriteClientVersionNegotiationProbePacket. |packet_bytes| must point to
   // |packet_length| bytes in memory which represent the response.
   // |packet_length| must be greater or equal to 6. This method will fill in
-  // |source_connection_id_bytes| which must point to at least 18 bytes in
-  // memory. |source_connection_id_length_out| will contain the length of the
-  // received source connection ID, which on success will match the contents of
-  // the destination connection ID passed in to
+  // |source_connection_id_bytes| which must point to at least
+  // |*source_connection_id_length_out| bytes in memory.
+  // |*source_connection_id_length_out| must be at least 18.
+  // |*source_connection_id_length_out| will contain the length of the received
+  // source connection ID, which on success will match the contents of the
+  // destination connection ID passed in to
   // WriteClientVersionNegotiationProbePacket. In the case of a failure,
   // |detailed_error| will be filled in with an explanation of what failed.
   static bool ParseServerVersionNegotiationProbeResponse(
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc
index 4da3e03..688aa02 100644
--- a/quic/core/quic_framer_test.cc
+++ b/quic/core/quic_framer_test.cc
@@ -14296,7 +14296,7 @@
   // clang-format on
   char probe_payload_bytes[] = {0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21};
   char parsed_probe_payload_bytes[255] = {};
-  uint8_t parsed_probe_payload_length = 0;
+  uint8_t parsed_probe_payload_length = sizeof(parsed_probe_payload_bytes);
   std::string parse_detailed_error = "";
   EXPECT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse(
       reinterpret_cast<const char*>(packet), sizeof(packet),
@@ -14327,7 +14327,7 @@
   // clang-format on
   char probe_payload_bytes[] = {0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21};
   char parsed_probe_payload_bytes[255] = {};
-  uint8_t parsed_probe_payload_length = 0;
+  uint8_t parsed_probe_payload_length = sizeof(parsed_probe_payload_bytes);
   std::string parse_detailed_error = "";
   EXPECT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse(
       reinterpret_cast<const char*>(packet), sizeof(packet),
@@ -14339,6 +14339,96 @@
       probe_payload_bytes, sizeof(probe_payload_bytes));
 }
 
+TEST_P(QuicFramerTest, ParseClientVersionNegotiationProbePacket) {
+  SetQuicFlag(FLAGS_quic_prober_uses_length_prefixed_connection_ids, true);
+  char packet[1200];
+  char input_destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70,
+                                                  0x6c, 0x7a, 0x20, 0x21};
+  ASSERT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket(
+      packet, sizeof(packet), input_destination_connection_id_bytes,
+      sizeof(input_destination_connection_id_bytes)));
+  char parsed_destination_connection_id_bytes[255] = {0};
+  uint8_t parsed_destination_connection_id_length =
+      sizeof(parsed_destination_connection_id_bytes);
+  ASSERT_TRUE(ParseClientVersionNegotiationProbePacket(
+      packet, sizeof(packet), parsed_destination_connection_id_bytes,
+      &parsed_destination_connection_id_length));
+  quiche::test::CompareCharArraysWithHexError(
+      "parsed destination connection ID",
+      parsed_destination_connection_id_bytes,
+      parsed_destination_connection_id_length,
+      input_destination_connection_id_bytes,
+      sizeof(input_destination_connection_id_bytes));
+}
+
+TEST_P(QuicFramerTest, WriteServerVersionNegotiationProbeResponse) {
+  SetQuicFlag(FLAGS_quic_prober_uses_length_prefixed_connection_ids, true);
+  char packet[1200];
+  size_t packet_length = sizeof(packet);
+  char input_source_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70,
+                                             0x6c, 0x7a, 0x20, 0x21};
+  ASSERT_TRUE(WriteServerVersionNegotiationProbeResponse(
+      packet, &packet_length, input_source_connection_id_bytes,
+      sizeof(input_source_connection_id_bytes)));
+  char parsed_source_connection_id_bytes[255] = {0};
+  uint8_t parsed_source_connection_id_length =
+      sizeof(parsed_source_connection_id_bytes);
+  std::string detailed_error;
+  ASSERT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse(
+      packet, packet_length, parsed_source_connection_id_bytes,
+      &parsed_source_connection_id_length, &detailed_error))
+      << detailed_error;
+  quiche::test::CompareCharArraysWithHexError(
+      "parsed destination connection ID", parsed_source_connection_id_bytes,
+      parsed_source_connection_id_length, input_source_connection_id_bytes,
+      sizeof(input_source_connection_id_bytes));
+}
+
+TEST_P(QuicFramerTest, ParseClientVersionNegotiationProbePacketOld) {
+  SetQuicFlag(FLAGS_quic_prober_uses_length_prefixed_connection_ids, false);
+  char packet[1200];
+  char input_destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70,
+                                                  0x6c, 0x7a, 0x20, 0x21};
+  ASSERT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket(
+      packet, sizeof(packet), input_destination_connection_id_bytes,
+      sizeof(input_destination_connection_id_bytes)));
+  char parsed_destination_connection_id_bytes[255] = {0};
+  uint8_t parsed_destination_connection_id_length =
+      sizeof(parsed_destination_connection_id_bytes);
+  ASSERT_TRUE(ParseClientVersionNegotiationProbePacket(
+      packet, sizeof(packet), parsed_destination_connection_id_bytes,
+      &parsed_destination_connection_id_length));
+  quiche::test::CompareCharArraysWithHexError(
+      "parsed destination connection ID",
+      parsed_destination_connection_id_bytes,
+      parsed_destination_connection_id_length,
+      input_destination_connection_id_bytes,
+      sizeof(input_destination_connection_id_bytes));
+}
+
+TEST_P(QuicFramerTest, WriteServerVersionNegotiationProbeResponseOld) {
+  SetQuicFlag(FLAGS_quic_prober_uses_length_prefixed_connection_ids, false);
+  char packet[1200];
+  size_t packet_length = sizeof(packet);
+  char input_source_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70,
+                                             0x6c, 0x7a, 0x20, 0x21};
+  ASSERT_TRUE(WriteServerVersionNegotiationProbeResponse(
+      packet, &packet_length, input_source_connection_id_bytes,
+      sizeof(input_source_connection_id_bytes)));
+  char parsed_source_connection_id_bytes[255] = {0};
+  uint8_t parsed_source_connection_id_length =
+      sizeof(parsed_source_connection_id_bytes);
+  std::string detailed_error;
+  ASSERT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse(
+      packet, packet_length, parsed_source_connection_id_bytes,
+      &parsed_source_connection_id_length, &detailed_error))
+      << detailed_error;
+  quiche::test::CompareCharArraysWithHexError(
+      "parsed destination connection ID", parsed_source_connection_id_bytes,
+      parsed_source_connection_id_length, input_source_connection_id_bytes,
+      sizeof(input_source_connection_id_bytes));
+}
+
 TEST_P(QuicFramerTest, ClientConnectionIdFromLongHeaderToClient) {
   if (!framer_.version().HasIetfInvariantHeader()) {
     // This test requires an IETF long header.
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc
index 6fa77c0..e904c03 100644
--- a/quic/test_tools/quic_test_utils.cc
+++ b/quic/test_tools/quic_test_utils.cc
@@ -1519,5 +1519,108 @@
   packet_buffer_free_list_.push_back(p);
 }
 
+bool WriteServerVersionNegotiationProbeResponse(
+    char* packet_bytes,
+    QuicByteCount* packet_length_out,
+    const char* source_connection_id_bytes,
+    uint8_t source_connection_id_length) {
+  if (packet_bytes == nullptr) {
+    QUIC_BUG << "Invalid packet_bytes";
+    return false;
+  }
+  if (packet_length_out == nullptr) {
+    QUIC_BUG << "Invalid packet_length_out";
+    return false;
+  }
+  QuicConnectionId source_connection_id(source_connection_id_bytes,
+                                        source_connection_id_length);
+  const bool use_length_prefix =
+      GetQuicFlag(FLAGS_quic_prober_uses_length_prefixed_connection_ids);
+  std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
+      QuicFramer::BuildVersionNegotiationPacket(
+          source_connection_id, EmptyQuicConnectionId(),
+          /*ietf_quic=*/true, use_length_prefix, ParsedQuicVersionVector{});
+  if (!encrypted_packet) {
+    QUIC_BUG << "Failed to create version negotiation packet";
+    return false;
+  }
+  if (*packet_length_out < encrypted_packet->length()) {
+    QUIC_BUG << "Invalid *packet_length_out " << *packet_length_out << " < "
+             << encrypted_packet->length();
+    return false;
+  }
+  *packet_length_out = encrypted_packet->length();
+  memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
+  return true;
+}
+
+bool ParseClientVersionNegotiationProbePacket(
+    const char* packet_bytes,
+    QuicByteCount packet_length,
+    char* destination_connection_id_bytes,
+    uint8_t* destination_connection_id_length_out) {
+  if (packet_bytes == nullptr) {
+    QUIC_BUG << "Invalid packet_bytes";
+    return false;
+  }
+  if (packet_length < kMinPacketSizeForVersionNegotiation ||
+      packet_length > 65535) {
+    QUIC_BUG << "Invalid packet_length";
+    return false;
+  }
+  if (destination_connection_id_bytes == nullptr) {
+    QUIC_BUG << "Invalid destination_connection_id_bytes";
+    return false;
+  }
+  if (destination_connection_id_length_out == nullptr) {
+    QUIC_BUG << "Invalid destination_connection_id_length_out";
+    return false;
+  }
+  if (*destination_connection_id_length_out <
+      kQuicMinimumInitialConnectionIdLength) {
+    QUIC_BUG << "Invalid *destination_connection_id_length_out";
+    return false;
+  }
+
+  QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
+  PacketHeaderFormat format;
+  QuicLongHeaderType long_packet_type;
+  bool version_present, has_length_prefix, retry_token_present;
+  QuicVersionLabel version_label;
+  ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
+  QuicConnectionId destination_connection_id, source_connection_id;
+  quiche::QuicheStringPiece retry_token;
+  std::string detailed_error;
+  QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
+      encrypted_packet,
+      /*expected_destination_connection_id_length=*/0, &format,
+      &long_packet_type, &version_present, &has_length_prefix, &version_label,
+      &parsed_version, &destination_connection_id, &source_connection_id,
+      &retry_token_present, &retry_token, &detailed_error);
+  if (error != QUIC_NO_ERROR) {
+    QUIC_BUG << "Failed to parse packet: " << detailed_error;
+    return false;
+  }
+  if (!version_present) {
+    QUIC_BUG << "Packet is not a long header";
+    return false;
+  }
+  if (destination_connection_id.length() <
+      kQuicMinimumInitialConnectionIdLength) {
+    QUIC_BUG << "Invalid destination_connection_id length "
+             << static_cast<int>(destination_connection_id.length());
+    return false;
+  }
+  if (*destination_connection_id_length_out <
+      destination_connection_id.length()) {
+    QUIC_BUG << "destination_connection_id_length_out too small";
+    return false;
+  }
+  *destination_connection_id_length_out = destination_connection_id.length();
+  memcpy(destination_connection_id_bytes, destination_connection_id.data(),
+         *destination_connection_id_length_out);
+  return true;
+}
+
 }  // namespace test
 }  // namespace quic
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 5842ac3..85d10cb 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -2098,6 +2098,35 @@
   QuicSocketAddress last_write_peer_address_;
 };
 
+// Parses a packet generated by
+// QuicFramer::WriteClientVersionNegotiationProbePacket.
+// |packet_bytes| must point to |packet_length| bytes in memory which represent
+// the packet. This method will fill in |destination_connection_id_bytes|
+// which must point to at least |*destination_connection_id_length_out| bytes in
+// memory. |*destination_connection_id_length_out| will contain the length of
+// the received destination connection ID, which on success will match the
+// contents of the destination connection ID passed in to
+// WriteClientVersionNegotiationProbePacket.
+bool ParseClientVersionNegotiationProbePacket(
+    const char* packet_bytes,
+    QuicByteCount packet_length,
+    char* destination_connection_id_bytes,
+    uint8_t* destination_connection_id_length_out);
+
+// Writes an array of bytes that correspond to a QUIC version negotiation packet
+// that a QUIC server would send in response to a probe created by
+// QuicFramer::WriteClientVersionNegotiationProbePacket.
+// The bytes will be written to |packet_bytes|, which must point to
+// |*packet_length_out| bytes of memory. |*packet_length_out| will contain the
+// length of the created packet. |source_connection_id_bytes| will be sent as
+// the source connection ID, and must point to |source_connection_id_length|
+// bytes of memory.
+bool WriteServerVersionNegotiationProbeResponse(
+    char* packet_bytes,
+    QuicByteCount* packet_length_out,
+    const char* source_connection_id_bytes,
+    uint8_t source_connection_id_length);
+
 }  // namespace test
 }  // namespace quic