Replaces session_map_ with reference_counted_session_map_ in QuicDispatcher. This is to support multiple connection IDs mapping to the same session in follow-up CLs (CID and session is still 1-to-1 in this CL).

Protected by FLAGS_quic_restart_flag_quic_use_reference_counted_sesssion_map.

PiperOrigin-RevId: 348723093
Change-Id: I7760576cd310ec5217787b3ed8f9707b1141e624
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 0b0707a..631d8d9 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -8,6 +8,7 @@
 #include <string>
 #include <utility>
 
+#include "absl/container/flat_hash_set.h"
 #include "absl/strings/string_view.h"
 #include "net/third_party/quiche/src/quic/core/chlo_extractor.h"
 #include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h"
@@ -318,6 +319,9 @@
       expected_server_connection_id_length_(
           expected_server_connection_id_length),
       should_update_expected_server_connection_id_length_(false) {
+  if (use_reference_counted_session_map_) {
+    QUIC_RESTART_FLAG_COUNT(quic_use_reference_counted_sesssion_map);
+  }
   QUIC_BUG_IF(GetSupportedVersions().empty())
       << "Trying to create dispatcher without any supported versions";
   QUIC_DLOG(INFO) << "Created QuicDispatcher with versions: "
@@ -325,8 +329,13 @@
 }
 
 QuicDispatcher::~QuicDispatcher() {
-  session_map_.clear();
-  closed_session_list_.clear();
+  if (use_reference_counted_session_map_) {
+    reference_counted_session_map_.clear();
+    closed_ref_counted_session_list_.clear();
+  } else {
+    session_map_.clear();
+    closed_session_list_.clear();
+  }
 }
 
 void QuicDispatcher::InitializeWithWriter(QuicPacketWriter* writer) {
@@ -491,29 +500,58 @@
 
   // Packets with connection IDs for active connections are processed
   // immediately.
-  auto it = session_map_.find(server_connection_id);
-  if (it != session_map_.end()) {
-    DCHECK(!buffered_packets_.HasBufferedPackets(server_connection_id));
-    if (packet_info.version_flag &&
-        packet_info.version != it->second->version() &&
-        packet_info.version == LegacyVersionForEncapsulation()) {
-      // This packet is using the Legacy Version Encapsulation version but the
-      // corresponding session isn't, attempt extraction of inner packet.
-      ChloAlpnExtractor alpn_extractor;
-      if (ChloExtractor::Extract(packet_info.packet, packet_info.version,
-                                 config_->create_session_tag_indicators(),
-                                 &alpn_extractor,
-                                 server_connection_id.length())) {
-        if (MaybeHandleLegacyVersionEncapsulation(this, &alpn_extractor,
-                                                  packet_info)) {
-          return true;
+  if (use_reference_counted_session_map_) {
+    auto it = reference_counted_session_map_.find(server_connection_id);
+    if (it != reference_counted_session_map_.end()) {
+      DCHECK(!buffered_packets_.HasBufferedPackets(server_connection_id));
+      if (packet_info.version_flag &&
+          packet_info.version != it->second->version() &&
+          packet_info.version == LegacyVersionForEncapsulation()) {
+        // This packet is using the Legacy Version Encapsulation version but the
+        // corresponding session isn't, attempt extraction of inner packet.
+        ChloAlpnExtractor alpn_extractor;
+        if (ChloExtractor::Extract(packet_info.packet, packet_info.version,
+                                   config_->create_session_tag_indicators(),
+                                   &alpn_extractor,
+                                   server_connection_id.length())) {
+          if (MaybeHandleLegacyVersionEncapsulation(this, &alpn_extractor,
+                                                    packet_info)) {
+            return true;
+          }
         }
       }
+      it->second->ProcessUdpPacket(packet_info.self_address,
+                                   packet_info.peer_address,
+                                   packet_info.packet);
+      return true;
     }
-    it->second->ProcessUdpPacket(packet_info.self_address,
-                                 packet_info.peer_address, packet_info.packet);
-    return true;
-  } else if (packet_info.version.IsKnown()) {
+  } else {
+    auto it = session_map_.find(server_connection_id);
+    if (it != session_map_.end()) {
+      DCHECK(!buffered_packets_.HasBufferedPackets(server_connection_id));
+      if (packet_info.version_flag &&
+          packet_info.version != it->second->version() &&
+          packet_info.version == LegacyVersionForEncapsulation()) {
+        // This packet is using the Legacy Version Encapsulation version but the
+        // corresponding session isn't, attempt extraction of inner packet.
+        ChloAlpnExtractor alpn_extractor;
+        if (ChloExtractor::Extract(packet_info.packet, packet_info.version,
+                                   config_->create_session_tag_indicators(),
+                                   &alpn_extractor,
+                                   server_connection_id.length())) {
+          if (MaybeHandleLegacyVersionEncapsulation(this, &alpn_extractor,
+                                                    packet_info)) {
+            return true;
+          }
+        }
+      }
+      it->second->ProcessUdpPacket(packet_info.self_address,
+                                   packet_info.peer_address,
+                                   packet_info.packet);
+      return true;
+    }
+  }
+  if (packet_info.version.IsKnown()) {
     // We did not find the connection ID, check if we've replaced it.
     // This is only performed for supported versions because packets with
     // unsupported versions can flow through this function in order to send
@@ -523,14 +561,26 @@
     QuicConnectionId replaced_connection_id = MaybeReplaceServerConnectionId(
         server_connection_id, packet_info.version);
     if (replaced_connection_id != server_connection_id) {
-      // Search for the replacement.
-      auto it2 = session_map_.find(replaced_connection_id);
-      if (it2 != session_map_.end()) {
-        DCHECK(!buffered_packets_.HasBufferedPackets(replaced_connection_id));
-        it2->second->ProcessUdpPacket(packet_info.self_address,
-                                      packet_info.peer_address,
-                                      packet_info.packet);
-        return true;
+      if (use_reference_counted_session_map_) {
+        // Search for the replacement.
+        auto it2 = reference_counted_session_map_.find(replaced_connection_id);
+        if (it2 != reference_counted_session_map_.end()) {
+          DCHECK(!buffered_packets_.HasBufferedPackets(replaced_connection_id));
+          it2->second->ProcessUdpPacket(packet_info.self_address,
+                                        packet_info.peer_address,
+                                        packet_info.packet);
+          return true;
+        }
+      } else {
+        // Search for the replacement.
+        auto it2 = session_map_.find(replaced_connection_id);
+        if (it2 != session_map_.end()) {
+          DCHECK(!buffered_packets_.HasBufferedPackets(replaced_connection_id));
+          it2->second->ProcessUdpPacket(packet_info.self_address,
+                                        packet_info.peer_address,
+                                        packet_info.packet);
+          return true;
+        }
       }
     }
   }
@@ -797,26 +847,68 @@
 
 void QuicDispatcher::PerformActionOnActiveSessions(
     std::function<void(QuicSession*)> operation) const {
-  for (auto const& kv : session_map_) {
-    operation(kv.second.get());
+  if (use_reference_counted_session_map_) {
+    absl::flat_hash_set<QuicSession*> visited_session;
+    visited_session.reserve(reference_counted_session_map_.size());
+    for (auto const& kv : reference_counted_session_map_) {
+      QuicSession* session = kv.second.get();
+      if (visited_session.insert(session).second) {
+        operation(session);
+      }
+    }
+  } else {
+    for (auto const& kv : session_map_) {
+      operation(kv.second.get());
+    }
   }
 }
 
+// Get a snapshot of all sessions.
+std::vector<std::shared_ptr<QuicSession>> QuicDispatcher::GetSessionsSnapshot()
+    const {
+  DCHECK(use_reference_counted_session_map_);
+  std::vector<std::shared_ptr<QuicSession>> snapshot;
+  snapshot.reserve(reference_counted_session_map_.size());
+  absl::flat_hash_set<QuicSession*> visited_session;
+  visited_session.reserve(reference_counted_session_map_.size());
+  for (auto const& kv : reference_counted_session_map_) {
+    QuicSession* session = kv.second.get();
+    if (visited_session.insert(session).second) {
+      snapshot.push_back(kv.second);
+    }
+  }
+  return snapshot;
+}
+
 std::unique_ptr<QuicPerPacketContext> QuicDispatcher::GetPerPacketContext()
     const {
   return nullptr;
 }
 
 void QuicDispatcher::DeleteSessions() {
-  if (!write_blocked_list_.empty()) {
-    for (const std::unique_ptr<QuicSession>& session : closed_session_list_) {
-      if (write_blocked_list_.erase(session->connection()) != 0) {
-        QUIC_BUG << "QuicConnection was in WriteBlockedList before destruction "
-                 << session->connection()->connection_id();
+  if (use_reference_counted_session_map_) {
+    if (!write_blocked_list_.empty()) {
+      for (const auto& session : closed_ref_counted_session_list_) {
+        if (write_blocked_list_.erase(session->connection()) != 0) {
+          QUIC_BUG
+              << "QuicConnection was in WriteBlockedList before destruction "
+              << session->connection()->connection_id();
+        }
       }
     }
+    closed_ref_counted_session_list_.clear();
+  } else {
+    if (!write_blocked_list_.empty()) {
+      for (const std::unique_ptr<QuicSession>& session : closed_session_list_) {
+        if (write_blocked_list_.erase(session->connection()) != 0) {
+          QUIC_BUG
+              << "QuicConnection was in WriteBlockedList before destruction "
+              << session->connection()->connection_id();
+        }
+      }
+    }
+    closed_session_list_.clear();
   }
-  closed_session_list_.clear();
 }
 
 void QuicDispatcher::OnCanWrite() {
@@ -852,14 +944,27 @@
 }
 
 void QuicDispatcher::Shutdown() {
-  while (!session_map_.empty()) {
-    QuicSession* session = session_map_.begin()->second.get();
-    session->connection()->CloseConnection(
-        QUIC_PEER_GOING_AWAY, "Server shutdown imminent",
-        ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
-    // Validate that the session removes itself from the session map on close.
-    DCHECK(session_map_.empty() ||
-           session_map_.begin()->second.get() != session);
+  if (use_reference_counted_session_map_) {
+    while (!reference_counted_session_map_.empty()) {
+      QuicSession* session =
+          reference_counted_session_map_.begin()->second.get();
+      session->connection()->CloseConnection(
+          QUIC_PEER_GOING_AWAY, "Server shutdown imminent",
+          ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
+      // Validate that the session removes itself from the session map on close.
+      DCHECK(reference_counted_session_map_.empty() ||
+             reference_counted_session_map_.begin()->second.get() != session);
+    }
+  } else {
+    while (!session_map_.empty()) {
+      QuicSession* session = session_map_.begin()->second.get();
+      session->connection()->CloseConnection(
+          QUIC_PEER_GOING_AWAY, "Server shutdown imminent",
+          ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
+      // Validate that the session removes itself from the session map on close.
+      DCHECK(session_map_.empty() ||
+             session_map_.begin()->second.get() != session);
+    }
   }
   DeleteSessions();
 }
@@ -868,32 +973,61 @@
                                         QuicErrorCode error,
                                         const std::string& error_details,
                                         ConnectionCloseSource source) {
-  auto it = session_map_.find(server_connection_id);
-  if (it == session_map_.end()) {
-    QUIC_BUG << "ConnectionId " << server_connection_id
-             << " does not exist in the session map.  Error: "
-             << QuicErrorCodeToString(error);
-    QUIC_BUG << QuicStackTrace();
-    return;
-  }
-
-  QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR)
-      << "Closing connection (" << server_connection_id
-      << ") due to error: " << QuicErrorCodeToString(error)
-      << ", with details: " << error_details;
-
-  QuicConnection* connection = it->second->connection();
-  if (ShouldDestroySessionAsynchronously()) {
-    // Set up alarm to fire immediately to bring destruction of this session
-    // out of current call stack.
-    if (closed_session_list_.empty()) {
-      delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(),
-                                     QuicTime::Delta::Zero());
+  if (use_reference_counted_session_map_) {
+    auto it = reference_counted_session_map_.find(server_connection_id);
+    if (it == reference_counted_session_map_.end()) {
+      QUIC_BUG << "ConnectionId " << server_connection_id
+               << " does not exist in the session map.  Error: "
+               << QuicErrorCodeToString(error);
+      QUIC_BUG << QuicStackTrace();
+      return;
     }
-    closed_session_list_.push_back(std::move(it->second));
+
+    QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR)
+        << "Closing connection (" << server_connection_id
+        << ") due to error: " << QuicErrorCodeToString(error)
+        << ", with details: " << error_details;
+
+    QuicConnection* connection = it->second->connection();
+    if (ShouldDestroySessionAsynchronously()) {
+      // Set up alarm to fire immediately to bring destruction of this session
+      // out of current call stack.
+      if (closed_ref_counted_session_list_.empty()) {
+        delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(),
+                                       QuicTime::Delta::Zero());
+      }
+      closed_ref_counted_session_list_.push_back(std::move(it->second));
+    }
+    CleanUpSession(it->first, connection, source);
+    reference_counted_session_map_.erase(it);
+  } else {
+    auto it = session_map_.find(server_connection_id);
+    if (it == session_map_.end()) {
+      QUIC_BUG << "ConnectionId " << server_connection_id
+               << " does not exist in the session map.  Error: "
+               << QuicErrorCodeToString(error);
+      QUIC_BUG << QuicStackTrace();
+      return;
+    }
+
+    QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR)
+        << "Closing connection (" << server_connection_id
+        << ") due to error: " << QuicErrorCodeToString(error)
+        << ", with details: " << error_details;
+
+    QuicConnection* connection = it->second->connection();
+    if (ShouldDestroySessionAsynchronously()) {
+      // Set up alarm to fire immediately to bring destruction of this session
+      // out of current call stack.
+      if (closed_session_list_.empty()) {
+        delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(),
+                                       QuicTime::Delta::Zero());
+      }
+      closed_session_list_.push_back(std::move(it->second));
+    }
+    CleanUpSession(it->first, connection, source);
+    session_map_.erase(it);
   }
-  CleanUpSession(it->first, connection, source);
-  session_map_.erase(it);
 }
 
 void QuicDispatcher::OnWriteBlocked(
@@ -1023,12 +1157,24 @@
     }
     QUIC_DLOG(INFO) << "Created new session for " << server_connection_id;
 
-    auto insertion_result = session_map_.insert(
-        std::make_pair(server_connection_id, std::move(session)));
-    QUIC_BUG_IF(!insertion_result.second)
-        << "Tried to add a session to session_map with existing connection id: "
-        << server_connection_id;
-    DeliverPacketsToSession(packets, insertion_result.first->second.get());
+    if (use_reference_counted_session_map_) {
+      auto insertion_result = reference_counted_session_map_.insert(
+          std::make_pair(server_connection_id,
+                         std::shared_ptr<QuicSession>(std::move(session))));
+      QUIC_BUG_IF(!insertion_result.second)
+          << "Tried to add a session to session_map with existing connection "
+             "id: "
+          << server_connection_id;
+      DeliverPacketsToSession(packets, insertion_result.first->second.get());
+    } else {
+      auto insertion_result = session_map_.insert(
+          std::make_pair(server_connection_id, std::move(session)));
+      QUIC_BUG_IF(!insertion_result.second)
+          << "Tried to add a session to session_map with existing connection "
+             "id: "
+          << server_connection_id;
+      DeliverPacketsToSession(packets, insertion_result.first->second.get());
+    }
   }
 }
 
@@ -1124,12 +1270,24 @@
   QUIC_DLOG(INFO) << "Created new session for "
                   << packet_info->destination_connection_id;
 
-  auto insertion_result = session_map_.insert(std::make_pair(
-      packet_info->destination_connection_id, std::move(session)));
-  QUIC_BUG_IF(!insertion_result.second)
-      << "Tried to add a session to session_map with existing connection id: "
-      << packet_info->destination_connection_id;
-  QuicSession* session_ptr = insertion_result.first->second.get();
+  QuicSession* session_ptr;
+  if (use_reference_counted_session_map_) {
+    auto insertion_result =
+        reference_counted_session_map_.insert(std::make_pair(
+            packet_info->destination_connection_id,
+            std::shared_ptr<QuicSession>(std::move(session.release()))));
+    QUIC_BUG_IF(!insertion_result.second)
+        << "Tried to add a session to session_map with existing connection id: "
+        << packet_info->destination_connection_id;
+    session_ptr = insertion_result.first->second.get();
+  } else {
+    auto insertion_result = session_map_.insert(std::make_pair(
+        packet_info->destination_connection_id, std::move(session)));
+    QUIC_BUG_IF(!insertion_result.second)
+        << "Tried to add a session to session_map with existing connection id: "
+        << packet_info->destination_connection_id;
+    session_ptr = insertion_result.first->second.get();
+  }
   std::list<BufferedPacket> packets =
       buffered_packets_.DeliverPackets(packet_info->destination_connection_id)
           .buffered_packets;
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h
index 58d7932..0a920ff 100644
--- a/quic/core/quic_dispatcher.h
+++ b/quic/core/quic_dispatcher.h
@@ -27,6 +27,7 @@
 #include "net/third_party/quiche/src/quic/core/quic_time_wait_list_manager.h"
 #include "net/third_party/quiche/src/quic/core/quic_version_manager.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_containers.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_reference_counted.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_socket_address.h"
 
 namespace quic {
@@ -109,7 +110,17 @@
                                  std::unique_ptr<QuicSession>,
                                  QuicConnectionIdHash>;
 
-  size_t NumSessions() const { return session_map_.size(); }
+  using ReferenceCountedSessionMap = QuicHashMap<QuicConnectionId,
+                                                 std::shared_ptr<QuicSession>,
+                                                 QuicConnectionIdHash>;
+
+  // TODO(haoyuewang) Update this function when multiple CIDs per connection are
+  // supported.
+  size_t NumSessions() const {
+    return use_reference_counted_session_map_
+               ? reference_counted_session_map_.size()
+               : session_map_.size();
+  }
 
   const SessionMap& session_map() const { return session_map_; }
 
@@ -141,8 +152,15 @@
   void PerformActionOnActiveSessions(
       std::function<void(QuicSession*)> operation) const;
 
+  // Get a snapshot of all sessions.
+  std::vector<std::shared_ptr<QuicSession>> GetSessionsSnapshot() const;
+
   bool accept_new_connections() const { return accept_new_connections_; }
 
+  bool use_reference_counted_session_map() const {
+    return use_reference_counted_session_map_;
+  }
+
  protected:
   virtual std::unique_ptr<QuicSession> CreateQuicSession(
       QuicConnectionId server_connection_id,
@@ -359,12 +377,14 @@
   WriteBlockedList write_blocked_list_;
 
   SessionMap session_map_;
+  ReferenceCountedSessionMap reference_counted_session_map_;
 
   // Entity that manages connection_ids in time wait state.
   std::unique_ptr<QuicTimeWaitListManager> time_wait_list_manager_;
 
   // The list of closed but not-yet-deleted sessions.
   std::vector<std::unique_ptr<QuicSession>> closed_session_list_;
+  std::vector<std::shared_ptr<QuicSession>> closed_ref_counted_session_list_;
 
   // The helper used for all connections.
   std::unique_ptr<QuicConnectionHelperInterface> helper_;
@@ -416,6 +436,9 @@
   // If true, change expected_server_connection_id_length_ to be the received
   // destination connection ID length of all IETF long headers.
   bool should_update_expected_server_connection_id_length_;
+
+  const bool use_reference_counted_session_map_ =
+      GetQuicRestartFlag(quic_use_reference_counted_sesssion_map);
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 912455f..8e99f37 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -73,3 +73,4 @@
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_false, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_true, true)
+QUIC_FLAG(FLAGS_quic_restart_flag_quic_use_reference_counted_sesssion_map, false)
diff --git a/quic/test_tools/quic_dispatcher_peer.cc b/quic/test_tools/quic_dispatcher_peer.cc
index 072b801..be428ed 100644
--- a/quic/test_tools/quic_dispatcher_peer.cc
+++ b/quic/test_tools/quic_dispatcher_peer.cc
@@ -116,10 +116,17 @@
 // static
 QuicSession* QuicDispatcherPeer::GetFirstSessionIfAny(
     QuicDispatcher* dispatcher) {
-  if (dispatcher->session_map_.empty()) {
-    return nullptr;
+  if (dispatcher->use_reference_counted_session_map()) {
+    if (dispatcher->reference_counted_session_map_.empty()) {
+      return nullptr;
+    }
+    return dispatcher->reference_counted_session_map_.begin()->second.get();
+  } else {
+    if (dispatcher->session_map_.empty()) {
+      return nullptr;
+    }
+    return dispatcher->session_map_.begin()->second.get();
   }
-  return dispatcher->session_map_.begin()->second.get();
 }
 
 }  // namespace test