Add new & retire ConnectionId interface to QuicSession::visitor and implement them in (Gfe)QuicDispatcher.

Protected by FLAGS_quic_restart_flag_quic_dispatcher_support_multiple_cid_per_connection.

PiperOrigin-RevId: 350347306
Change-Id: I4f0292f0d56cbcfe19854f09285d9389f729ce08
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 8003dbb..28f3923 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -13,6 +13,7 @@
 #include "quic/core/chlo_extractor.h"
 #include "quic/core/crypto/crypto_protocol.h"
 #include "quic/core/crypto/quic_random.h"
+#include "quic/core/quic_connection_id.h"
 #include "quic/core/quic_error_codes.h"
 #include "quic/core/quic_session.h"
 #include "quic/core/quic_time_wait_list_manager.h"
@@ -154,14 +155,16 @@
   // |error_code| and |error_details| and add the connection to time wait.
   void CloseConnection(QuicErrorCode error_code,
                        const std::string& error_details,
-                       bool ietf_quic) {
+                       bool ietf_quic,
+                       std::vector<QuicConnectionId> active_connection_ids) {
     SerializeConnectionClosePacket(error_code, error_details);
 
     time_wait_list_manager_->AddConnectionIdToTimeWait(
         server_connection_id_,
         QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
         TimeWaitConnectionInfo(ietf_quic, collector_.packets(),
-                               {server_connection_id_}));
+                               std::move(active_connection_ids),
+                               /*srtt=*/QuicTime::Delta::Zero()));
   }
 
  private:
@@ -323,6 +326,10 @@
   if (use_reference_counted_session_map_) {
     QUIC_RESTART_FLAG_COUNT(quic_use_reference_counted_sesssion_map);
   }
+  if (support_multiple_cid_per_connection_) {
+    QUIC_RESTART_FLAG_COUNT(
+        quic_dispatcher_support_multiple_cid_per_connection);
+  }
   QUIC_BUG_IF(GetSupportedVersions().empty())
       << "Trying to create dispatcher without any supported versions";
   QUIC_DLOG(INFO) << "Created QuicDispatcher with versions: "
@@ -333,6 +340,9 @@
   if (use_reference_counted_session_map_) {
     reference_counted_session_map_.clear();
     closed_ref_counted_session_list_.clear();
+    if (support_multiple_cid_per_connection_) {
+      num_sessions_in_session_map_ = 0;
+    }
   } else {
     session_map_.clear();
     closed_session_list_.clear();
@@ -807,22 +817,35 @@
       } else {
         QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_handshake_failed);
       }
-      action = QuicTimeWaitListManager::SEND_TERMINATION_PACKETS;
-      // This serializes a connection close termination packet with error code
-      // QUIC_HANDSHAKE_FAILED and adds the connection to the time wait list.
-      StatelesslyTerminateConnection(
-          connection->connection_id(),
-          connection->version().HasIetfInvariantHeader()
-              ? IETF_QUIC_LONG_HEADER_PACKET
-              : GOOGLE_QUIC_PACKET,
-          /*version_flag=*/true,
-          connection->version().HasLengthPrefixedConnectionIds(),
-          connection->version(), QUIC_HANDSHAKE_FAILED,
-          "Connection is closed by server before handshake confirmed",
-          // Although it is our intention to send termination packets, the
-          // |action| argument is not used by this call to
-          // StatelesslyTerminateConnection().
-          action);
+      if (support_multiple_cid_per_connection_) {
+        // This serializes a connection close termination packet with error code
+        // QUIC_HANDSHAKE_FAILED and adds the connection to the time wait list.
+        StatelessConnectionTerminator terminator(
+            server_connection_id, connection->version(), helper_.get(),
+            time_wait_list_manager_.get());
+        terminator.CloseConnection(
+            QUIC_HANDSHAKE_FAILED,
+            "Connection is closed by server before handshake confirmed",
+            connection->version().HasIetfInvariantHeader(),
+            connection->GetActiveServerConnectionIds());
+      } else {
+        action = QuicTimeWaitListManager::SEND_TERMINATION_PACKETS;
+        // This serializes a connection close termination packet with error code
+        // QUIC_HANDSHAKE_FAILED and adds the connection to the time wait list.
+        StatelesslyTerminateConnection(
+            connection->connection_id(),
+            connection->version().HasIetfInvariantHeader()
+                ? IETF_QUIC_LONG_HEADER_PACKET
+                : GOOGLE_QUIC_PACKET,
+            /*version_flag=*/true,
+            connection->version().HasLengthPrefixedConnectionIds(),
+            connection->version(), QUIC_HANDSHAKE_FAILED,
+            "Connection is closed by server before handshake confirmed",
+            // Although it is our intention to send termination packets, the
+            // |action| argument is not used by this call to
+            // StatelesslyTerminateConnection().
+            action);
+      }
       return;
     }
     QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_stateless_reset);
@@ -1001,7 +1024,15 @@
       closed_ref_counted_session_list_.push_back(std::move(it->second));
     }
     CleanUpSession(it->first, connection, source);
-    reference_counted_session_map_.erase(it);
+    if (support_multiple_cid_per_connection_) {
+      for (const QuicConnectionId& cid :
+           connection->GetActiveServerConnectionIds()) {
+        reference_counted_session_map_.erase(cid);
+      }
+      --num_sessions_in_session_map_;
+    } else {
+      reference_counted_session_map_.erase(it);
+    }
   } else {
     auto it = session_map_.find(server_connection_id);
     if (it == session_map_.end()) {
@@ -1052,6 +1083,29 @@
 void QuicDispatcher::OnStopSendingReceived(
     const QuicStopSendingFrame& /*frame*/) {}
 
+void QuicDispatcher::OnNewConnectionIdSent(
+    const QuicConnectionId& server_connection_id,
+    const QuicConnectionId& new_connection_id) {
+  DCHECK(support_multiple_cid_per_connection_);
+  auto it = reference_counted_session_map_.find(server_connection_id);
+  if (it == reference_counted_session_map_.end()) {
+    QUIC_BUG << "Couldn't locate the session that issues the connection ID in "
+                "reference_counted_session_map_.  server_connection_id:"
+             << server_connection_id
+             << " new_connection_id: " << new_connection_id;
+    return;
+  }
+  auto insertion_result = reference_counted_session_map_.insert(
+      std::make_pair(new_connection_id, it->second));
+  DCHECK(insertion_result.second);
+}
+
+void QuicDispatcher::OnConnectionIdRetired(
+    const QuicConnectionId& server_connection_id) {
+  DCHECK(support_multiple_cid_per_connection_);
+  reference_counted_session_map_.erase(server_connection_id);
+}
+
 void QuicDispatcher::OnConnectionAddedToTimeWaitList(
     QuicConnectionId server_connection_id) {
   QUIC_DLOG(INFO) << "Connection " << server_connection_id
@@ -1091,8 +1145,9 @@
                                              helper_.get(),
                                              time_wait_list_manager_.get());
     // This also adds the connection to time wait list.
-    terminator.CloseConnection(error_code, error_details,
-                               format != GOOGLE_QUIC_PACKET);
+    terminator.CloseConnection(
+        error_code, error_details, format != GOOGLE_QUIC_PACKET,
+        /*active_connection_ids=*/{server_connection_id});
     return;
   }
 
@@ -1164,10 +1219,14 @@
       auto insertion_result = reference_counted_session_map_.insert(
           std::make_pair(server_connection_id,
                          std::shared_ptr<QuicSession>(std::move(session))));
-      QUIC_BUG_IF(!insertion_result.second)
-          << "Tried to add a session to session_map with existing connection "
-             "id: "
-          << server_connection_id;
+      if (!insertion_result.second) {
+        QUIC_BUG
+            << "Tried to add a session to session_map with existing connection "
+               "id: "
+            << server_connection_id;
+      } else if (support_multiple_cid_per_connection_) {
+        ++num_sessions_in_session_map_;
+      }
       DeliverPacketsToSession(packets, insertion_result.first->second.get());
     } else {
       auto insertion_result = session_map_.insert(
@@ -1279,9 +1338,13 @@
         reference_counted_session_map_.insert(std::make_pair(
             packet_info->destination_connection_id,
             std::shared_ptr<QuicSession>(std::move(session.release()))));
-    QUIC_BUG_IF(!insertion_result.second)
-        << "Tried to add a session to session_map with existing connection id: "
-        << packet_info->destination_connection_id;
+    if (!insertion_result.second) {
+      QUIC_BUG << "Tried to add a session to session_map with existing "
+                  "connection id: "
+               << packet_info->destination_connection_id;
+    } else if (support_multiple_cid_per_connection_) {
+      ++num_sessions_in_session_map_;
+    }
     session_ptr = insertion_result.first->second.get();
   } else {
     auto insertion_result = session_map_.insert(std::make_pair(
@@ -1365,4 +1428,13 @@
       packet_info.form != GOOGLE_QUIC_PACKET, GetPerPacketContext());
 }
 
+size_t QuicDispatcher::NumSessions() const {
+  if (support_multiple_cid_per_connection_) {
+    return num_sessions_in_session_map_;
+  }
+  return use_reference_counted_session_map_
+             ? reference_counted_session_map_.size()
+             : session_map_.size();
+}
+
 }  // namespace quic
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h
index 756f7f6..75cdacf 100644
--- a/quic/core/quic_dispatcher.h
+++ b/quic/core/quic_dispatcher.h
@@ -100,6 +100,19 @@
   // Collects reset error code received on streams.
   void OnStopSendingReceived(const QuicStopSendingFrame& frame) override;
 
+  // QuicSession::Visitor interface implementation (via inheritance of
+  // QuicTimeWaitListManager::Visitor):
+  // Add the newly issued connection ID to the session map.
+  void OnNewConnectionIdSent(
+      const QuicConnectionId& server_connection_id,
+      const QuicConnectionId& new_connection_id) override;
+
+  // QuicSession::Visitor interface implementation (via inheritance of
+  // QuicTimeWaitListManager::Visitor):
+  // Remove the retired connection ID from the session map.
+  void OnConnectionIdRetired(
+      const QuicConnectionId& server_connection_id) override;
+
   // QuicTimeWaitListManager::Visitor interface implementation
   // Called whenever the time wait list manager adds a new connection to the
   // time-wait list.
@@ -114,13 +127,7 @@
                                                  std::shared_ptr<QuicSession>,
                                                  QuicConnectionIdHash>;
 
-  // TODO(haoyuewang) Update this function when multiple CIDs per connection are
-  // supported.
-  size_t NumSessions() const {
-    return use_reference_counted_session_map_
-               ? reference_counted_session_map_.size()
-               : session_map_.size();
-  }
+  size_t NumSessions() const;
 
   const SessionMap& session_map() const { return session_map_; }
 
@@ -161,6 +168,10 @@
     return use_reference_counted_session_map_;
   }
 
+  bool support_multiple_cid_per_connection() const {
+    return support_multiple_cid_per_connection_;
+  }
+
  protected:
   virtual std::unique_ptr<QuicSession> CreateQuicSession(
       QuicConnectionId server_connection_id,
@@ -412,6 +423,9 @@
   // TODO(fayang): consider removing last_error_.
   QuicErrorCode last_error_;
 
+  // Number of unique session in session map.
+  size_t num_sessions_in_session_map_ = 0;
+
   // A backward counter of how many new sessions can be create within current
   // event loop. When reaches 0, it means can't create sessions for now.
   int16_t new_sessions_allowed_per_event_loop_;
@@ -439,6 +453,10 @@
 
   const bool use_reference_counted_session_map_ =
       GetQuicRestartFlag(quic_use_reference_counted_sesssion_map);
+  const bool support_multiple_cid_per_connection_ =
+      use_reference_counted_session_map_ &&
+      GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid) &&
+      GetQuicRestartFlag(quic_dispatcher_support_multiple_cid_per_connection);
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc
index a28dca4..a614199 100644
--- a/quic/core/quic_dispatcher_test.cc
+++ b/quic/core/quic_dispatcher_test.cc
@@ -16,7 +16,9 @@
 #include "quic/core/crypto/crypto_protocol.h"
 #include "quic/core/crypto/quic_crypto_server_config.h"
 #include "quic/core/crypto/quic_random.h"
+#include "quic/core/frames/quic_new_connection_id_frame.h"
 #include "quic/core/quic_config.h"
+#include "quic/core/quic_connection.h"
 #include "quic/core/quic_connection_id.h"
 #include "quic/core/quic_crypto_stream.h"
 #include "quic/core/quic_packet_writer_wrapper.h"
@@ -33,6 +35,7 @@
 #include "quic/test_tools/first_flight.h"
 #include "quic/test_tools/mock_quic_time_wait_list_manager.h"
 #include "quic/test_tools/quic_buffered_packet_store_peer.h"
+#include "quic/test_tools/quic_connection_peer.h"
 #include "quic/test_tools/quic_crypto_server_config_peer.h"
 #include "quic/test_tools/quic_dispatcher_peer.h"
 #include "quic/test_tools/quic_test_utils.h"
@@ -183,7 +186,26 @@
                            helper,
                            alarm_factory,
                            Perspective::IS_SERVER),
-        dispatcher_(dispatcher) {}
+        dispatcher_(dispatcher),
+        active_connection_ids_({connection_id}) {}
+
+  void AddNewConnectionId(QuicConnectionId id) {
+    dispatcher_->OnNewConnectionIdSent(active_connection_ids_.back(), id);
+    QuicConnectionPeer::SetServerConnectionId(this, id);
+    active_connection_ids_.push_back(id);
+  }
+
+  void RetireConnectionId(QuicConnectionId id) {
+    auto it = std::find(active_connection_ids_.begin(),
+                        active_connection_ids_.end(), id);
+    DCHECK(it != active_connection_ids_.end());
+    dispatcher_->OnConnectionIdRetired(id);
+    active_connection_ids_.erase(it);
+  }
+
+  std::vector<QuicConnectionId> GetActiveServerConnectionIds() const override {
+    return active_connection_ids_;
+  }
 
   void UnregisterOnConnectionClosed() {
     QUIC_LOG(ERROR) << "Unregistering " << connection_id();
@@ -194,6 +216,7 @@
 
  private:
   QuicDispatcher* dispatcher_;
+  std::vector<QuicConnectionId> active_connection_ids_;
 };
 
 class QuicDispatcherTestBase : public QuicTestWithParam<ParsedQuicVersion> {
@@ -1945,6 +1968,190 @@
   MarkSession1Deleted();
 }
 
+class QuicDispatcherSupportMultipleConnectionIdPerConnectionTest
+    : public QuicDispatcherTestBase {
+ public:
+  QuicDispatcherSupportMultipleConnectionIdPerConnectionTest()
+      : QuicDispatcherTestBase(crypto_test_utils::ProofSourceForTesting()) {
+    SetQuicRestartFlag(quic_use_reference_counted_sesssion_map, true);
+    SetQuicRestartFlag(quic_time_wait_list_support_multiple_cid, true);
+    SetQuicRestartFlag(quic_dispatcher_support_multiple_cid_per_connection,
+                       true);
+    dispatcher_ = std::make_unique<NiceMock<TestDispatcher>>(
+        &config_, &crypto_config_, &version_manager_,
+        mock_helper_.GetRandomGenerator());
+  }
+  void AddConnection1() {
+    QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
+    EXPECT_CALL(*dispatcher_,
+                CreateQuicSession(_, _, client_address, Eq(ExpectedAlpn()), _))
+        .WillOnce(Return(ByMove(CreateSession(
+            dispatcher_.get(), config_, TestConnectionId(1), client_address,
+            &helper_, &alarm_factory_, &crypto_config_,
+            QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_))));
+    EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
+                ProcessUdpPacket(_, _, _))
+        .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) {
+          ValidatePacket(TestConnectionId(1), packet);
+        })));
+    EXPECT_CALL(*dispatcher_,
+                ShouldCreateOrBufferPacketForConnection(
+                    ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1))));
+    ProcessFirstFlight(client_address, TestConnectionId(1));
+  }
+
+  void AddConnection2() {
+    QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 2);
+    EXPECT_CALL(*dispatcher_,
+                CreateQuicSession(_, _, client_address, Eq(ExpectedAlpn()), _))
+        .WillOnce(Return(ByMove(CreateSession(
+            dispatcher_.get(), config_, TestConnectionId(2), client_address,
+            &helper_, &alarm_factory_, &crypto_config_,
+            QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_))));
+    EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session2_->connection()),
+                ProcessUdpPacket(_, _, _))
+        .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) {
+          ValidatePacket(TestConnectionId(2), packet);
+        })));
+    EXPECT_CALL(*dispatcher_,
+                ShouldCreateOrBufferPacketForConnection(
+                    ReceivedPacketInfoConnectionIdEquals(TestConnectionId(2))));
+    ProcessFirstFlight(client_address, TestConnectionId(2));
+  }
+
+ protected:
+  MockQuicConnectionHelper helper_;
+  MockAlarmFactory alarm_factory_;
+};
+
+INSTANTIATE_TEST_SUITE_P(
+    QuicDispatcherSupportMultipleConnectionIdPerConnectionTests,
+    QuicDispatcherSupportMultipleConnectionIdPerConnectionTest,
+    ::testing::Values(CurrentSupportedVersions().front()),
+    ::testing::PrintToStringParamName());
+
+TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest,
+       OnNewConnectionIdSent) {
+  AddConnection1();
+  ASSERT_EQ(dispatcher_->NumSessions(), 1u);
+  ASSERT_THAT(session1_, testing::NotNull());
+  MockServerConnection* mock_server_connection1 =
+      reinterpret_cast<MockServerConnection*>(connection1());
+
+  {
+    mock_server_connection1->AddNewConnectionId(TestConnectionId(3));
+    EXPECT_EQ(dispatcher_->NumSessions(), 1u);
+    auto* session =
+        QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(3));
+    ASSERT_EQ(session, session1_);
+  }
+
+  {
+    mock_server_connection1->AddNewConnectionId(TestConnectionId(4));
+    EXPECT_EQ(dispatcher_->NumSessions(), 1u);
+    auto* session =
+        QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(4));
+    ASSERT_EQ(session, session1_);
+  }
+
+  EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
+  // Would timed out unless all sessions have been removed from the session map.
+  dispatcher_->Shutdown();
+}
+
+TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest,
+       RetireConnectionIdFromSingleConnection) {
+  AddConnection1();
+  ASSERT_EQ(dispatcher_->NumSessions(), 1u);
+  ASSERT_THAT(session1_, testing::NotNull());
+  MockServerConnection* mock_server_connection1 =
+      reinterpret_cast<MockServerConnection*>(connection1());
+
+  // Adds 1 new connection id every turn and retires 2 connection ids every
+  // other turn.
+  for (int i = 2; i < 10; ++i) {
+    mock_server_connection1->AddNewConnectionId(TestConnectionId(i));
+    ASSERT_EQ(
+        QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(i)),
+        session1_);
+    ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(),
+                                              TestConnectionId(i - 1)),
+              session1_);
+    EXPECT_EQ(dispatcher_->NumSessions(), 1u);
+    if (i % 2 == 1) {
+      mock_server_connection1->RetireConnectionId(TestConnectionId(i - 2));
+      mock_server_connection1->RetireConnectionId(TestConnectionId(i - 1));
+    }
+  }
+
+  EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
+  // Would timed out unless all sessions have been removed from the session map.
+  dispatcher_->Shutdown();
+}
+
+TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest,
+       RetireConnectionIdFromMultipleConnections) {
+  AddConnection1();
+  AddConnection2();
+  ASSERT_EQ(dispatcher_->NumSessions(), 2u);
+  MockServerConnection* mock_server_connection1 =
+      reinterpret_cast<MockServerConnection*>(connection1());
+  MockServerConnection* mock_server_connection2 =
+      reinterpret_cast<MockServerConnection*>(connection2());
+
+  for (int i = 2; i < 10; ++i) {
+    mock_server_connection1->AddNewConnectionId(TestConnectionId(2 * i - 1));
+    mock_server_connection2->AddNewConnectionId(TestConnectionId(2 * i));
+    ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(),
+                                              TestConnectionId(2 * i - 1)),
+              session1_);
+    ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(),
+                                              TestConnectionId(2 * i)),
+              session2_);
+    EXPECT_EQ(dispatcher_->NumSessions(), 2u);
+    mock_server_connection1->RetireConnectionId(TestConnectionId(2 * i - 3));
+    mock_server_connection2->RetireConnectionId(TestConnectionId(2 * i - 2));
+  }
+
+  mock_server_connection1->AddNewConnectionId(TestConnectionId(19));
+  mock_server_connection2->AddNewConnectionId(TestConnectionId(20));
+  EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
+  EXPECT_CALL(*connection2(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
+  // Would timed out unless all sessions have been removed from the session map.
+  dispatcher_->Shutdown();
+}
+
+TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest,
+       TimeWaitListPoplulateCorrectly) {
+  QuicTimeWaitListManager* time_wait_list_manager =
+      QuicDispatcherPeer::GetTimeWaitListManager(dispatcher_.get());
+  AddConnection1();
+  MockServerConnection* mock_server_connection1 =
+      reinterpret_cast<MockServerConnection*>(connection1());
+
+  mock_server_connection1->AddNewConnectionId(TestConnectionId(2));
+  mock_server_connection1->AddNewConnectionId(TestConnectionId(3));
+  mock_server_connection1->AddNewConnectionId(TestConnectionId(4));
+  mock_server_connection1->RetireConnectionId(TestConnectionId(1));
+  mock_server_connection1->RetireConnectionId(TestConnectionId(2));
+
+  EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
+  connection1()->CloseConnection(
+      QUIC_PEER_GOING_AWAY, "Close for testing",
+      ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
+
+  EXPECT_FALSE(
+      time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(1)));
+  EXPECT_FALSE(
+      time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(2)));
+  EXPECT_TRUE(
+      time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(3)));
+  EXPECT_TRUE(
+      time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(4)));
+
+  dispatcher_->Shutdown();
+}
+
 class BufferedPacketStoreTest : public QuicDispatcherTestBase {
  public:
   BufferedPacketStoreTest()
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index fd5da05..0c9dd37 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -67,6 +67,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_write_or_buffer_data_at_level, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_send_quic_fallback_server_config_on_leto_error, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_dont_fetch_quic_private_keys_from_leto, false)
+QUIC_FLAG(FLAGS_quic_restart_flag_quic_dispatcher_support_multiple_cid_per_connection, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_enable_zero_rtt_for_tls_v2, true)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_offload_pacing_to_usps2, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_server_temporarily_retain_tls_zero_rtt_keys, true)
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index 5903830..3eb2867 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -81,6 +81,15 @@
     // Called when the session receives a STOP_SENDING for a stream from the
     // peer.
     virtual void OnStopSendingReceived(const QuicStopSendingFrame& frame) = 0;
+
+    // Called when a NewConnectionId frame has been sent.
+    virtual void OnNewConnectionIdSent(
+        const QuicConnectionId& server_connection_id,
+        const QuicConnectionId& new_connection_id) = 0;
+
+    // Called when a ConnectionId has been retired.
+    virtual void OnConnectionIdRetired(
+        const QuicConnectionId& server_connection_id) = 0;
   };
 
   // Does not take ownership of |connection| or |visitor|.
diff --git a/quic/test_tools/mock_quic_session_visitor.h b/quic/test_tools/mock_quic_session_visitor.h
index e04bc0f..77b1763 100644
--- a/quic/test_tools/mock_quic_session_visitor.h
+++ b/quic/test_tools/mock_quic_session_visitor.h
@@ -35,6 +35,15 @@
               (const QuicStopSendingFrame& frame),
               (override));
   MOCK_METHOD(void,
+              OnNewConnectionIdSent,
+              (const QuicConnectionId& server_connection_id,
+               const QuicConnectionId& new_connection_id),
+              (override));
+  MOCK_METHOD(void,
+              OnConnectionIdRetired,
+              (const quic::QuicConnectionId& server_connection_id),
+              (override));
+  MOCK_METHOD(void,
               OnConnectionAddedToTimeWaitList,
               (QuicConnectionId connection_id),
               (override));
diff --git a/quic/test_tools/quic_dispatcher_peer.cc b/quic/test_tools/quic_dispatcher_peer.cc
index 846755f..5736773 100644
--- a/quic/test_tools/quic_dispatcher_peer.cc
+++ b/quic/test_tools/quic_dispatcher_peer.cc
@@ -129,5 +129,19 @@
   }
 }
 
+// static
+const QuicSession* QuicDispatcherPeer::FindSession(
+    const QuicDispatcher* dispatcher,
+    QuicConnectionId id) {
+  if (dispatcher->use_reference_counted_session_map()) {
+    auto it = dispatcher->reference_counted_session_map_.find(id);
+    return (it == dispatcher->reference_counted_session_map_.end())
+               ? nullptr
+               : it->second.get();
+  }
+  auto it = dispatcher->session_map_.find(id);
+  return (it == dispatcher->session_map_.end()) ? nullptr : it->second.get();
+}
+
 }  // namespace test
 }  // namespace quic
diff --git a/quic/test_tools/quic_dispatcher_peer.h b/quic/test_tools/quic_dispatcher_peer.h
index 64c8a83..f3be472 100644
--- a/quic/test_tools/quic_dispatcher_peer.h
+++ b/quic/test_tools/quic_dispatcher_peer.h
@@ -5,6 +5,7 @@
 #ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_
 #define QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_
 
+#include "quic/core/quic_connection_id.h"
 #include "quic/core/quic_dispatcher.h"
 
 namespace quic {
@@ -70,6 +71,10 @@
   // Get the first session in the session map. Returns nullptr if the map is
   // empty.
   static QuicSession* GetFirstSessionIfAny(QuicDispatcher* dispatcher);
+
+  // Find the corresponding session if exsits.
+  static const QuicSession* FindSession(const QuicDispatcher* dispatcher,
+                                        QuicConnectionId id);
 };
 
 }  // namespace test