Refactor QuicDispatcher connection ID replacement

Historically, QuicDispatcher needed the ability to generate a new distinct connection ID for the purpose of QUIC_CRYPTO stateless rejects. Due to how stateless rejects were designed, they needed a new connection ID distinct from the current one, but that still routes to the same server instance. For that reason, QuicDispatcher::GenerateNewServerConnectionId would just hash the connection ID and GfeQuicDispatcher::GenerateNewServerConnectionId would just add 1 to the lowest order bit. However, since then, stateless rejects have been deprecated and entirely removed from the codebase. Separately, we started using GenerateNewServerConnectionId to truncate client-provided connection IDs that were too long. Additionally, with cl/311115671 we no longer increment the connection ID by one because we no longer have the requirement of it being different.

All of this to say that the code no longer matches what it is used for, and this CL refactors it to make it clearer, without changing any behavior.

Refactor, no behavior change, not flag protected

PiperOrigin-RevId: 311788800
Change-Id: I44c16f2fc91babb0607159a4d1f2e580ec17317c
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index c0d3970..d3f01b4 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -307,30 +307,54 @@
 }
 
 QuicConnectionId QuicDispatcher::MaybeReplaceServerConnectionId(
-    QuicConnectionId server_connection_id,
-    ParsedQuicVersion version) const {
-  if (server_connection_id.length() == expected_server_connection_id_length_) {
+    const QuicConnectionId& server_connection_id,
+    const ParsedQuicVersion& version) const {
+  const uint8_t server_connection_id_length = server_connection_id.length();
+  if (server_connection_id_length == expected_server_connection_id_length_) {
     return server_connection_id;
   }
   DCHECK(version.AllowsVariableLengthConnectionIds());
-
-  QuicConnectionId new_connection_id =
-      GenerateNewServerConnectionId(version, server_connection_id);
+  QuicConnectionId new_connection_id;
+  if (server_connection_id_length < expected_server_connection_id_length_) {
+    new_connection_id = ReplaceShortServerConnectionId(
+        version, server_connection_id, expected_server_connection_id_length_);
+    // Verify that ReplaceShortServerConnectionId is deterministic.
+    DCHECK_EQ(new_connection_id, ReplaceShortServerConnectionId(
+                                     version, server_connection_id,
+                                     expected_server_connection_id_length_));
+  } else {
+    new_connection_id = ReplaceLongServerConnectionId(
+        version, server_connection_id, expected_server_connection_id_length_);
+    // Verify that ReplaceLongServerConnectionId is deterministic.
+    DCHECK_EQ(new_connection_id, ReplaceLongServerConnectionId(
+                                     version, server_connection_id,
+                                     expected_server_connection_id_length_));
+  }
   DCHECK_EQ(expected_server_connection_id_length_, new_connection_id.length());
 
-  // Verify that GenerateNewServerConnectionId is deterministic.
-  DCHECK_EQ(new_connection_id,
-            GenerateNewServerConnectionId(version, server_connection_id));
-
   QUIC_DLOG(INFO) << "Replacing incoming connection ID " << server_connection_id
                   << " with " << new_connection_id;
   return new_connection_id;
 }
 
-QuicConnectionId QuicDispatcher::GenerateNewServerConnectionId(
-    ParsedQuicVersion /*version*/,
-    QuicConnectionId connection_id) const {
-  return QuicUtils::CreateReplacementConnectionId(connection_id);
+QuicConnectionId QuicDispatcher::ReplaceShortServerConnectionId(
+    const ParsedQuicVersion& /*version*/,
+    const QuicConnectionId& server_connection_id,
+    uint8_t expected_server_connection_id_length) const {
+  DCHECK_LT(server_connection_id.length(),
+            expected_server_connection_id_length);
+  return QuicUtils::CreateReplacementConnectionId(
+      server_connection_id, expected_server_connection_id_length);
+}
+
+QuicConnectionId QuicDispatcher::ReplaceLongServerConnectionId(
+    const ParsedQuicVersion& /*version*/,
+    const QuicConnectionId& server_connection_id,
+    uint8_t expected_server_connection_id_length) const {
+  DCHECK_GT(server_connection_id.length(),
+            expected_server_connection_id_length);
+  return QuicUtils::CreateReplacementConnectionId(
+      server_connection_id, expected_server_connection_id_length);
 }
 
 bool QuicDispatcher::MaybeDispatchPacket(
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h
index 73ab3ec..b553e3c 100644
--- a/quic/core/quic_dispatcher.h
+++ b/quic/core/quic_dispatcher.h
@@ -121,7 +121,7 @@
   // send a handshake and then up to 50 or so data packets, and then it may
   // resend the handshake packet up to 10 times.  (Retransmitted packets are
   // sent with unique packet numbers.)
-  static const uint64_t kMaxReasonableInitialPacketNumber = 100;
+  static constexpr uint64_t kMaxReasonableInitialPacketNumber = 100;
   static_assert(kMaxReasonableInitialPacketNumber >=
                     kInitialCongestionWindow + 10,
                 "kMaxReasonableInitialPacketNumber is unreasonably small "
@@ -162,11 +162,29 @@
   virtual bool MaybeDispatchPacket(const ReceivedPacketInfo& packet_info);
 
   // Generate a connection ID with a length that is expected by the dispatcher.
+  // Called only when |server_connection_id| is shorter than
+  // |expected_connection_id_length|.
   // Note that this MUST produce a deterministic result (calling this method
   // with two connection IDs that are equal must produce the same result).
-  virtual QuicConnectionId GenerateNewServerConnectionId(
-      ParsedQuicVersion version,
-      QuicConnectionId connection_id) const;
+  // Note that this is not used in general operation because our default
+  // |expected_server_connection_id_length| is 8, and the IETF specification
+  // requires clients to use an initial length of at least 8. However, we
+  // allow disabling that requirement via
+  // |allow_short_initial_server_connection_ids_|.
+  virtual QuicConnectionId ReplaceShortServerConnectionId(
+      const ParsedQuicVersion& version,
+      const QuicConnectionId& server_connection_id,
+      uint8_t expected_server_connection_id_length) const;
+
+  // Generate a connection ID with a length that is expected by the dispatcher.
+  // Called only when |server_connection_id| is longer than
+  // |expected_connection_id_length|.
+  // Note that this MUST produce a deterministic result (calling this method
+  // with two connection IDs that are equal must produce the same result).
+  virtual QuicConnectionId ReplaceLongServerConnectionId(
+      const ParsedQuicVersion& version,
+      const QuicConnectionId& server_connection_id,
+      uint8_t expected_server_connection_id_length) const;
 
   // Values to be returned by ValidityChecks() to indicate what should be done
   // with a packet. Fates with greater values are considered to be higher
@@ -308,11 +326,12 @@
   std::string SelectAlpn(const std::vector<std::string>& alpns);
 
   // If the connection ID length is different from what the dispatcher expects,
-  // replace the connection ID with a random one of the right length,
-  // and save it to make sure the mapping is persistent.
+  // replace the connection ID with one of the right length.
+  // Note that this MUST produce a deterministic result (calling this method
+  // with two connection IDs that are equal must produce the same result).
   QuicConnectionId MaybeReplaceServerConnectionId(
-      QuicConnectionId server_connection_id,
-      ParsedQuicVersion version) const;
+      const QuicConnectionId& server_connection_id,
+      const ParsedQuicVersion& version) const;
 
  private:
   friend class test::QuicDispatcherPeer;
diff --git a/quic/core/quic_utils.cc b/quic/core/quic_utils.cc
index 60354f3..2d851d8 100644
--- a/quic/core/quic_utils.cc
+++ b/quic/core/quic_utils.cc
@@ -6,6 +6,7 @@
 
 #include <algorithm>
 #include <cstdint>
+#include <cstring>
 #include <string>
 
 #include "net/third_party/quiche/src/quic/core/quic_connection_id.h"
@@ -491,11 +492,37 @@
 
 // static
 QuicConnectionId QuicUtils::CreateReplacementConnectionId(
-    QuicConnectionId connection_id) {
-  const uint64_t connection_id_hash = FNV1a_64_Hash(
+    const QuicConnectionId& connection_id) {
+  return CreateReplacementConnectionId(connection_id,
+                                       kQuicDefaultConnectionIdLength);
+}
+
+// static
+QuicConnectionId QuicUtils::CreateReplacementConnectionId(
+    const QuicConnectionId& connection_id,
+    uint8_t expected_connection_id_length) {
+  if (expected_connection_id_length == 0) {
+    return EmptyQuicConnectionId();
+  }
+  const uint64_t connection_id_hash64 = FNV1a_64_Hash(
       quiche::QuicheStringPiece(connection_id.data(), connection_id.length()));
-  return QuicConnectionId(reinterpret_cast<const char*>(&connection_id_hash),
-                          sizeof(connection_id_hash));
+  if (expected_connection_id_length <= sizeof(uint64_t)) {
+    return QuicConnectionId(
+        reinterpret_cast<const char*>(&connection_id_hash64),
+        expected_connection_id_length);
+  }
+  char new_connection_id_data[255] = {};
+  const QuicUint128 connection_id_hash128 = FNV1a_128_Hash(
+      quiche::QuicheStringPiece(connection_id.data(), connection_id.length()));
+  static_assert(sizeof(connection_id_hash64) + sizeof(connection_id_hash128) <=
+                    sizeof(new_connection_id_data),
+                "bad size");
+  memcpy(new_connection_id_data, &connection_id_hash64,
+         sizeof(connection_id_hash64));
+  memcpy(new_connection_id_data + sizeof(connection_id_hash64),
+         &connection_id_hash128, sizeof(connection_id_hash128));
+  return QuicConnectionId(new_connection_id_data,
+                          expected_connection_id_length);
 }
 
 // static
diff --git a/quic/core/quic_utils.h b/quic/core/quic_utils.h
index eec9a5f..e470e6f 100644
--- a/quic/core/quic_utils.h
+++ b/quic/core/quic_utils.h
@@ -168,11 +168,19 @@
       QuicTransportVersion version,
       Perspective perspective);
 
-  // Generates a 64bit connection ID derived from the input connection ID.
+  // Generates a connection ID of length |expected_connection_id_length|
+  // derived from |connection_id|.
   // This is guaranteed to be deterministic (calling this method with two
   // connection IDs that are equal is guaranteed to produce the same result).
   static QuicConnectionId CreateReplacementConnectionId(
-      QuicConnectionId connection_id);
+      const QuicConnectionId& connection_id,
+      uint8_t expected_connection_id_length);
+
+  // Generates a 64bit connection ID derived from |connection_id|.
+  // This is guaranteed to be deterministic (calling this method with two
+  // connection IDs that are equal is guaranteed to produce the same result).
+  static QuicConnectionId CreateReplacementConnectionId(
+      const QuicConnectionId& connection_id);
 
   // Generates a random 64bit connection ID.
   static QuicConnectionId CreateRandomConnectionId();
diff --git a/quic/core/quic_utils_test.cc b/quic/core/quic_utils_test.cc
index 5b2186a..082a9bd 100644
--- a/quic/core/quic_utils_test.cc
+++ b/quic/core/quic_utils_test.cc
@@ -180,6 +180,23 @@
   EXPECT_EQ(connection_id72a, connection_id72b);
   EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a),
             QuicUtils::CreateReplacementConnectionId(connection_id72b));
+  // Test variant with custom length.
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id64a, 7),
+            QuicUtils::CreateReplacementConnectionId(connection_id64b, 7));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id64a, 9),
+            QuicUtils::CreateReplacementConnectionId(connection_id64b, 9));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id64a, 16),
+            QuicUtils::CreateReplacementConnectionId(connection_id64b, 16));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a, 7),
+            QuicUtils::CreateReplacementConnectionId(connection_id72b, 7));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a, 9),
+            QuicUtils::CreateReplacementConnectionId(connection_id72b, 9));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a, 16),
+            QuicUtils::CreateReplacementConnectionId(connection_id72b, 16));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a, 32),
+            QuicUtils::CreateReplacementConnectionId(connection_id72b, 32));
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a, 255),
+            QuicUtils::CreateReplacementConnectionId(connection_id72b, 255));
 }
 
 TEST_F(QuicUtilsTest, ReplacementConnectionIdLengthIsCorrect) {
@@ -191,6 +208,22 @@
         QuicUtils::CreateReplacementConnectionId(connection_id);
     EXPECT_EQ(kQuicDefaultConnectionIdLength,
               replacement_connection_id.length());
+    // Test variant with custom length.
+    QuicConnectionId replacement_connection_id7 =
+        QuicUtils::CreateReplacementConnectionId(connection_id, 7);
+    EXPECT_EQ(7, replacement_connection_id7.length());
+    QuicConnectionId replacement_connection_id9 =
+        QuicUtils::CreateReplacementConnectionId(connection_id, 9);
+    EXPECT_EQ(9, replacement_connection_id9.length());
+    QuicConnectionId replacement_connection_id16 =
+        QuicUtils::CreateReplacementConnectionId(connection_id, 16);
+    EXPECT_EQ(16, replacement_connection_id16.length());
+    QuicConnectionId replacement_connection_id32 =
+        QuicUtils::CreateReplacementConnectionId(connection_id, 32);
+    EXPECT_EQ(32, replacement_connection_id32.length());
+    QuicConnectionId replacement_connection_id255 =
+        QuicUtils::CreateReplacementConnectionId(connection_id, 255);
+    EXPECT_EQ(255, replacement_connection_id255.length());
   }
 }
 
@@ -205,6 +238,17 @@
       EXPECT_NE(connection_id_i, connection_id_j);
       EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i),
                 QuicUtils::CreateReplacementConnectionId(connection_id_j));
+      // Test variant with custom length.
+      EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i, 7),
+                QuicUtils::CreateReplacementConnectionId(connection_id_j, 7));
+      EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i, 9),
+                QuicUtils::CreateReplacementConnectionId(connection_id_j, 9));
+      EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i, 16),
+                QuicUtils::CreateReplacementConnectionId(connection_id_j, 16));
+      EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i, 32),
+                QuicUtils::CreateReplacementConnectionId(connection_id_j, 32));
+      EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i, 255),
+                QuicUtils::CreateReplacementConnectionId(connection_id_j, 255));
     }
   }
 }
diff --git a/quic/qbone/qbone_client_test.cc b/quic/qbone/qbone_client_test.cc
index cdb611c..0c7c949 100644
--- a/quic/qbone/qbone_client_test.cc
+++ b/quic/qbone/qbone_client_test.cc
@@ -139,13 +139,6 @@
     return session;
   }
 
-  QuicConnectionId GenerateNewServerConnectionId(
-      ParsedQuicVersion version,
-      QuicConnectionId connection_id) const override {
-    char connection_id_bytes[kQuicDefaultConnectionIdLength] = {};
-    return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
-  }
-
  private:
   QbonePacketWriter* writer_;
 };
diff --git a/quic/quartc/quartc_dispatcher.h b/quic/quartc/quartc_dispatcher.h
index ca4fe2d..b7605f9 100644
--- a/quic/quartc/quartc_dispatcher.h
+++ b/quic/quartc/quartc_dispatcher.h
@@ -46,9 +46,6 @@
       quiche::QuicheStringPiece alpn,
       const ParsedQuicVersion& version) override;
 
-  // TODO(b/124399417): Override GenerateNewServerConnectionId and request a
-  // zero-length connection id when the QUIC server perspective supports it.
-
   // QuartcPacketTransport::Delegate overrides.
   void OnTransportCanWrite() override;
   void OnTransportReceived(const char* data, size_t data_len) override;