Use deterministic replacement connection IDs

This CL removes a DoS attack vector where an attacker could grow QuicDispatcher::connection_id_map_ unboundedly. It does so by no longer using random connection IDs that are saved in connection_id_map_; instead we now generate deterministic replacement connection IDs, removing the need for a map. It should not impact the GFE because the GFE overrides QuicDispatcher::GenerateNewServerConnectionId with an already deterministic method, but is still flag protected just in case.

gfe-relnote: use deterministic replacement connection IDs, protected by new disabled flag gfe2_restart_flag_quic_deterministic_replacement_connection_ids
PiperOrigin-RevId: 264192278
Change-Id: I843bf0d846830d4b13e0bb1b470a71b2428ad7c8
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index b659bda..553fcd4 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -305,18 +305,29 @@
   }
   DCHECK(QuicUtils::VariableLengthConnectionIdAllowedForVersion(
       version.transport_version));
-  auto it = connection_id_map_.find(server_connection_id);
-  if (it != connection_id_map_.end()) {
-    return it->second;
+
+  if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+    auto it = connection_id_map_.find(server_connection_id);
+    if (it != connection_id_map_.end()) {
+      return it->second;
+    }
+  } else {
+    // TODO(dschinazi) Remove QuicDispatcher::connection_id_map_ entirely
+    // when quic_deterministic_replacement_connection_ids is deprecated.
+    QUIC_RESTART_FLAG_COUNT_N(quic_deterministic_replacement_connection_ids, 1,
+                              2);
   }
   QuicConnectionId new_connection_id =
       GenerateNewServerConnectionId(version, server_connection_id);
   DCHECK_EQ(expected_server_connection_id_length_, new_connection_id.length());
-  // TODO(dschinazi) Prevent connection_id_map_ from growing indefinitely
-  // before we ship a version that supports variable length connection IDs
-  // to production.
-  connection_id_map_.insert(
-      std::make_pair(server_connection_id, new_connection_id));
+  if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+    connection_id_map_.insert(
+        std::make_pair(server_connection_id, new_connection_id));
+  } else {
+    // Verify that GenerateNewServerConnectionId is deterministic.
+    DCHECK_EQ(new_connection_id,
+              GenerateNewServerConnectionId(version, server_connection_id));
+  }
   QUIC_DLOG(INFO) << "Replacing incoming connection ID " << server_connection_id
                   << " with " << new_connection_id;
   return new_connection_id;
@@ -324,8 +335,15 @@
 
 QuicConnectionId QuicDispatcher::GenerateNewServerConnectionId(
     ParsedQuicVersion /*version*/,
-    QuicConnectionId /*connection_id*/) const {
-  return QuicUtils::CreateRandomConnectionId();
+    QuicConnectionId connection_id) const {
+  if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+    return QuicUtils::CreateRandomConnectionId();
+  }
+
+  QUIC_RESTART_FLAG_COUNT_N(quic_deterministic_replacement_connection_ids, 2,
+                            2);
+
+  return QuicUtils::CreateReplacementConnectionId(connection_id);
 }
 
 bool QuicDispatcher::MaybeDispatchPacket(
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h
index f7b0450..566d3bc 100644
--- a/quic/core/quic_dispatcher.h
+++ b/quic/core/quic_dispatcher.h
@@ -153,6 +153,8 @@
   virtual bool MaybeDispatchPacket(const ReceivedPacketInfo& packet_info);
 
   // Generate a connection ID with a length that is expected by the dispatcher.
+  // Note that this MUST produce a deterministic result (calling this method
+  // with two connection IDs that are equal must produce the same result).
   virtual QuicConnectionId GenerateNewServerConnectionId(
       ParsedQuicVersion version,
       QuicConnectionId connection_id) const;
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc
index b41c57f..b8cc470 100644
--- a/quic/core/quic_dispatcher_test.cc
+++ b/quic/core/quic_dispatcher_test.cc
@@ -134,8 +134,13 @@
 
   QuicConnectionId GenerateNewServerConnectionId(
       ParsedQuicVersion /*version*/,
-      QuicConnectionId /*connection_id*/) const override {
-    return QuicUtils::CreateRandomConnectionId(random_);
+      QuicConnectionId connection_id) const override {
+    if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+      return QuicUtils::CreateRandomConnectionId(random_);
+    }
+    // TODO(dschinazi) Remove this override entirely when
+    // quic_deterministic_replacement_connection_ids is deprecated.
+    return QuicUtils::CreateReplacementConnectionId(connection_id);
   }
 
   struct TestQuicPerPacketContext : public QuicPerPacketContext {
@@ -755,8 +760,14 @@
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
 
   QuicConnectionId bad_connection_id = TestConnectionIdNineBytesLong(2);
-  QuicConnectionId fixed_connection_id =
-      QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+  QuicConnectionId fixed_connection_id;
+  if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+    fixed_connection_id =
+        QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+  } else {
+    fixed_connection_id =
+        QuicUtils::CreateReplacementConnectionId(bad_connection_id);
+  }
 
   EXPECT_CALL(*dispatcher_,
               CreateQuicSession(fixed_connection_id, client_address,
@@ -788,8 +799,14 @@
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
 
   QuicConnectionId bad_connection_id = EmptyQuicConnectionId();
-  QuicConnectionId fixed_connection_id =
-      QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+  QuicConnectionId fixed_connection_id;
+  if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+    fixed_connection_id =
+        QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+  } else {
+    fixed_connection_id =
+        QuicUtils::CreateReplacementConnectionId(bad_connection_id);
+  }
 
   // Disable validation of invalid short connection IDs.
   dispatcher_->SetAllowShortInitialServerConnectionIds(true);
@@ -825,8 +842,14 @@
 
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
   QuicConnectionId bad_connection_id = TestConnectionIdNineBytesLong(2);
-  QuicConnectionId fixed_connection_id =
-      QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+  QuicConnectionId fixed_connection_id;
+  if (!GetQuicRestartFlag(quic_deterministic_replacement_connection_ids)) {
+    fixed_connection_id =
+        QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+  } else {
+    fixed_connection_id =
+        QuicUtils::CreateReplacementConnectionId(bad_connection_id);
+  }
 
   EXPECT_CALL(*dispatcher_,
               CreateQuicSession(TestConnectionId(1), client_address,
diff --git a/quic/core/quic_utils.cc b/quic/core/quic_utils.cc
index 1e2c5fc..a3366da 100644
--- a/quic/core/quic_utils.cc
+++ b/quic/core/quic_utils.cc
@@ -485,6 +485,15 @@
 }
 
 // static
+QuicConnectionId QuicUtils::CreateReplacementConnectionId(
+    QuicConnectionId connection_id) {
+  const uint64_t connection_id_hash = FNV1a_64_Hash(
+      QuicStringPiece(connection_id.data(), connection_id.length()));
+  return QuicConnectionId(reinterpret_cast<const char*>(&connection_id_hash),
+                          sizeof(connection_id_hash));
+}
+
+// static
 QuicConnectionId QuicUtils::CreateRandomConnectionId() {
   return CreateRandomConnectionId(kQuicDefaultConnectionIdLength,
                                   QuicRandom::GetInstance());
diff --git a/quic/core/quic_utils.h b/quic/core/quic_utils.h
index 11154a5..de855a0 100644
--- a/quic/core/quic_utils.h
+++ b/quic/core/quic_utils.h
@@ -11,6 +11,7 @@
 
 #include "net/third_party/quiche/src/quic/core/crypto/quic_random.h"
 #include "net/third_party/quiche/src/quic/core/frames/quic_frame.h"
+#include "net/third_party/quiche/src/quic/core/quic_connection_id.h"
 #include "net/third_party/quiche/src/quic/core/quic_error_codes.h"
 #include "net/third_party/quiche/src/quic/core/quic_types.h"
 #include "net/third_party/quiche/src/quic/core/quic_versions.h"
@@ -162,6 +163,12 @@
       QuicTransportVersion version,
       Perspective perspective);
 
+  // Generates a 64bit connection ID derived from the input connection ID.
+  // This is guaranteed to be deterministic (calling this method with two
+  // connection IDs that are equal is guaranteed to produce the same result).
+  static QuicConnectionId CreateReplacementConnectionId(
+      QuicConnectionId connection_id);
+
   // Generates a random 64bit connection ID.
   static QuicConnectionId CreateRandomConnectionId();
 
diff --git a/quic/core/quic_utils_test.cc b/quic/core/quic_utils_test.cc
index d0ce7e8..dc46133 100644
--- a/quic/core/quic_utils_test.cc
+++ b/quic/core/quic_utils_test.cc
@@ -165,6 +165,47 @@
   EXPECT_FALSE(QuicUtils::IsIetfPacketShortHeader(first_byte));
 }
 
+TEST_F(QuicUtilsTest, ReplacementConnectionIdIsDeterministic) {
+  // Verify that two equal connection IDs get the same replacement.
+  QuicConnectionId connection_id64a = TestConnectionId(33);
+  QuicConnectionId connection_id64b = TestConnectionId(33);
+  EXPECT_EQ(connection_id64a, connection_id64b);
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id64a),
+            QuicUtils::CreateReplacementConnectionId(connection_id64b));
+  QuicConnectionId connection_id72a = TestConnectionIdNineBytesLong(42);
+  QuicConnectionId connection_id72b = TestConnectionIdNineBytesLong(42);
+  EXPECT_EQ(connection_id72a, connection_id72b);
+  EXPECT_EQ(QuicUtils::CreateReplacementConnectionId(connection_id72a),
+            QuicUtils::CreateReplacementConnectionId(connection_id72b));
+}
+
+TEST_F(QuicUtilsTest, ReplacementConnectionIdLengthIsCorrect) {
+  // Verify that all lengths get replaced by kQuicDefaultConnectionIdLength.
+  const char connection_id_bytes[kQuicMaxConnectionIdAllVersionsLength] = {};
+  for (uint8_t i = 0; i < sizeof(connection_id_bytes) - 1; ++i) {
+    QuicConnectionId connection_id(connection_id_bytes, i);
+    QuicConnectionId replacement_connection_id =
+        QuicUtils::CreateReplacementConnectionId(connection_id);
+    EXPECT_EQ(kQuicDefaultConnectionIdLength,
+              replacement_connection_id.length());
+  }
+}
+
+TEST_F(QuicUtilsTest, ReplacementConnectionIdHasEntropy) {
+  // Make sure all these test connection IDs have different replacements.
+  for (uint64_t i = 0; i < 256; ++i) {
+    QuicConnectionId connection_id_i = TestConnectionId(i);
+    EXPECT_NE(connection_id_i,
+              QuicUtils::CreateReplacementConnectionId(connection_id_i));
+    for (uint64_t j = i + 1; j <= 256; ++j) {
+      QuicConnectionId connection_id_j = TestConnectionId(j);
+      EXPECT_NE(connection_id_i, connection_id_j);
+      EXPECT_NE(QuicUtils::CreateReplacementConnectionId(connection_id_i),
+                QuicUtils::CreateReplacementConnectionId(connection_id_j));
+    }
+  }
+}
+
 TEST_F(QuicUtilsTest, RandomConnectionId) {
   MockRandom random(33);
   QuicConnectionId connection_id = QuicUtils::CreateRandomConnectionId(&random);