Add multiple CID per connection support to time_wait_list.

Protected by FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid.

PiperOrigin-RevId: 350245863
Change-Id: Ida67aeabe3554d6229859fa0179d66c560da53d3
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 3ea815b..0a093f1 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -5562,6 +5562,11 @@
   OnSuccessfulMigration();
 }
 
+std::vector<QuicConnectionId> QuicConnection::GetActiveServerConnectionIds()
+    const {
+  return {server_connection_id_};
+}
+
 void QuicConnection::SetUnackedMapInitialCapacity() {
   sent_packet_manager_.ReserveUnackedPacketsInitialCapacity(
       GetUnackedMapInitialCapacity());
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 48b560f..64e83b0 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1178,6 +1178,8 @@
                         : encryption_level_);
   }
 
+  virtual std::vector<QuicConnectionId> GetActiveServerConnectionIds() const;
+
  protected:
   // Calls cancel() on all the alarms owned by this connection.
   void CancelAllAlarms();
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 9e17539..8003dbb 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -160,7 +160,8 @@
     time_wait_list_manager_->AddConnectionIdToTimeWait(
         server_connection_id_,
         QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
-        TimeWaitConnectionInfo(ietf_quic, collector_.packets()));
+        TimeWaitConnectionInfo(ietf_quic, collector_.packets(),
+                               {server_connection_id_}));
   }
 
  private:
@@ -831,6 +832,7 @@
       TimeWaitConnectionInfo(
           connection->version().HasIetfInvariantHeader(),
           connection->termination_packets(),
+          connection->GetActiveServerConnectionIds(),
           connection->sent_packet_manager().GetRttStats()->smoothed_rtt()));
 }
 
@@ -1072,7 +1074,8 @@
                   << ", error_details:" << error_details;
     time_wait_list_manager_->AddConnectionIdToTimeWait(
         server_connection_id, action,
-        TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr));
+        TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr,
+                               {server_connection_id}));
     return;
   }
 
@@ -1108,7 +1111,7 @@
   time_wait_list_manager()->AddConnectionIdToTimeWait(
       server_connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
       TimeWaitConnectionInfo(/*ietf_quic=*/format != GOOGLE_QUIC_PACKET,
-                             &termination_packets));
+                             &termination_packets, {server_connection_id}));
 }
 
 bool QuicDispatcher::ShouldCreateSessionForUnknownVersion(
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index f64344c..fd5da05 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -75,4 +75,5 @@
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_false, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_true, true)
+QUIC_FLAG(FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid, false)
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_use_reference_counted_sesssion_map, false)
diff --git a/quic/core/quic_time_wait_list_manager.cc b/quic/core/quic_time_wait_list_manager.cc
index c82fb4f..11fca70 100644
--- a/quic/core/quic_time_wait_list_manager.cc
+++ b/quic/core/quic_time_wait_list_manager.cc
@@ -49,16 +49,21 @@
 
 TimeWaitConnectionInfo::TimeWaitConnectionInfo(
     bool ietf_quic,
-    std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets)
+    std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+    std::vector<QuicConnectionId> active_connection_ids)
     : TimeWaitConnectionInfo(ietf_quic,
                              termination_packets,
+                             std::move(active_connection_ids),
                              QuicTime::Delta::Zero()) {}
 
 TimeWaitConnectionInfo::TimeWaitConnectionInfo(
     bool ietf_quic,
     std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+    std::vector<QuicConnectionId> active_connection_ids,
     QuicTime::Delta srtt)
-    : ietf_quic(ietf_quic), srtt(srtt) {
+    : ietf_quic(ietf_quic),
+      active_connection_ids(std::move(active_connection_ids)),
+      srtt(srtt) {
   if (termination_packets != nullptr) {
     this->termination_packets.swap(*termination_packets);
   }
@@ -76,6 +81,9 @@
       clock_(clock),
       writer_(writer),
       visitor_(visitor) {
+  if (use_indirect_connection_id_map_) {
+    QUIC_RESTART_FLAG_COUNT(quic_time_wait_list_support_multiple_cid);
+  }
   SetConnectionIdCleanUpAlarm();
 }
 
@@ -83,35 +91,80 @@
   connection_id_clean_up_alarm_->Cancel();
 }
 
+QuicTimeWaitListManager::ConnectionIdMap::iterator
+QuicTimeWaitListManager::FindConnectionIdDataInMap(
+    const QuicConnectionId& connection_id) {
+  if (!use_indirect_connection_id_map_) {
+    return connection_id_map_.find(connection_id);
+  }
+  auto it = indirect_connection_id_map_.find(connection_id);
+  if (it == indirect_connection_id_map_.end()) {
+    return connection_id_map_.end();
+  }
+  return connection_id_map_.find(it->second);
+}
+
+void QuicTimeWaitListManager::AddConnectionIdDataToMap(
+    const QuicConnectionId& canonical_connection_id,
+    int num_packets,
+    TimeWaitAction action,
+    TimeWaitConnectionInfo info) {
+  if (use_indirect_connection_id_map_) {
+    for (const auto& cid : info.active_connection_ids) {
+      indirect_connection_id_map_[cid] = canonical_connection_id;
+    }
+  }
+  ConnectionIdData data(num_packets, clock_->ApproximateNow(), action,
+                        std::move(info));
+  connection_id_map_.emplace(
+      std::make_pair(canonical_connection_id, std::move(data)));
+}
+
+void QuicTimeWaitListManager::RemoveConnectionDataFromMap(
+    ConnectionIdMap::iterator it) {
+  if (use_indirect_connection_id_map_) {
+    for (const auto& cid : it->second.info.active_connection_ids) {
+      indirect_connection_id_map_.erase(cid);
+    }
+  }
+  connection_id_map_.erase(it);
+}
+
 void QuicTimeWaitListManager::AddConnectionIdToTimeWait(
     QuicConnectionId connection_id,
     TimeWaitAction action,
     TimeWaitConnectionInfo info) {
+  DCHECK(!info.active_connection_ids.empty());
+  const QuicConnectionId& canonical_connection_id =
+      use_indirect_connection_id_map_ ? info.active_connection_ids.front()
+                                      : connection_id;
   DCHECK(action != SEND_TERMINATION_PACKETS ||
          !info.termination_packets.empty());
   DCHECK(action != DO_NOTHING || info.ietf_quic);
   int num_packets = 0;
-  auto it = connection_id_map_.find(connection_id);
+  auto it = FindConnectionIdDataInMap(canonical_connection_id);
   const bool new_connection_id = it == connection_id_map_.end();
   if (!new_connection_id) {  // Replace record if it is reinserted.
     num_packets = it->second.num_packets;
-    connection_id_map_.erase(it);
+    RemoveConnectionDataFromMap(it);
   }
   TrimTimeWaitListIfNeeded();
   int64_t max_connections =
       GetQuicFlag(FLAGS_quic_time_wait_list_max_connections);
   DCHECK(connection_id_map_.empty() ||
          num_connections() < static_cast<size_t>(max_connections));
-  ConnectionIdData data(num_packets, clock_->ApproximateNow(), action,
-                        std::move(info));
-  connection_id_map_.emplace(std::make_pair(connection_id, std::move(data)));
+  AddConnectionIdDataToMap(canonical_connection_id, num_packets, action,
+                           std::move(info));
   if (new_connection_id) {
-    visitor_->OnConnectionAddedToTimeWaitList(connection_id);
+    visitor_->OnConnectionAddedToTimeWaitList(canonical_connection_id);
   }
 }
 
 bool QuicTimeWaitListManager::IsConnectionIdInTimeWait(
     QuicConnectionId connection_id) const {
+  if (use_indirect_connection_id_map_) {
+    return indirect_connection_id_map_.contains(connection_id);
+  }
   return QuicContainsKey(connection_id_map_, connection_id);
 }
 
@@ -135,7 +188,7 @@
   DCHECK(IsConnectionIdInTimeWait(connection_id));
   // TODO(satyamshekhar): Think about handling packets from different peer
   // addresses.
-  auto it = connection_id_map_.find(connection_id);
+  auto it = FindConnectionIdDataInMap(connection_id);
   DCHECK(it != connection_id_map_.end());
   // Increment the received packet count.
   ConnectionIdData* connection_data = &it->second;
@@ -388,7 +441,7 @@
   // This connection_id has lived its age, retire it now.
   QUIC_DLOG(INFO) << "Connection " << it->first
                   << " expired from time wait list";
-  connection_id_map_.erase(it);
+  RemoveConnectionDataFromMap(it);
   if (expiration_time == QuicTime::Infinite()) {
     QUIC_CODE_COUNT(quic_time_wait_list_trim_full);
   } else {
diff --git a/quic/core/quic_time_wait_list_manager.h b/quic/core/quic_time_wait_list_manager.h
index 1bb3a68..ad6d45c 100644
--- a/quic/core/quic_time_wait_list_manager.h
+++ b/quic/core/quic_time_wait_list_manager.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 #include "quic/core/quic_blocked_writer_interface.h"
+#include "quic/core/quic_connection_id.h"
 #include "quic/core/quic_framer.h"
 #include "quic/core/quic_packet_writer.h"
 #include "quic/core/quic_packets.h"
@@ -31,10 +32,12 @@
 struct QUIC_NO_EXPORT TimeWaitConnectionInfo {
   TimeWaitConnectionInfo(
       bool ietf_quic,
-      std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets);
+      std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+      std::vector<QuicConnectionId> active_connection_ids);
   TimeWaitConnectionInfo(
       bool ietf_quic,
       std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+      std::vector<QuicConnectionId> active_connection_ids,
       QuicTime::Delta srtt);
 
   TimeWaitConnectionInfo(const TimeWaitConnectionInfo& other) = delete;
@@ -44,6 +47,7 @@
 
   bool ietf_quic;
   std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets;
+  std::vector<QuicConnectionId> active_connection_ids;
   QuicTime::Delta srtt;
 };
 
@@ -278,8 +282,33 @@
   using ConnectionIdMap = QuicLinkedHashMap<QuicConnectionId,
                                             ConnectionIdData,
                                             QuicConnectionIdHash>;
+  // Do not use find/emplace/erase on this map directly. Use
+  // FindConnectionIdDataInMap, AddConnectionIdDateToMap,
+  // RemoveConnectionDataFromMap instead.
   ConnectionIdMap connection_id_map_;
 
+  // TODO(haoyuewang) Consider making connection_id_map_ a map of shared pointer
+  // and remove the indirect map.
+  // A connection can have multiple unretired ConnectionIds when it is closed.
+  // These Ids have the same ConnectionIdData entry in connection_id_map_. To
+  // find the entry, look up the cannoical ConnectionId in
+  // indirect_connection_id_map_ first, and look up connection_id_map_ with the
+  // cannoical ConnectionId.
+  QuicHashMap<QuicConnectionId, QuicConnectionId, QuicConnectionIdHash>
+      indirect_connection_id_map_;
+
+  // Find an iterator for the given connection_id. Returns
+  // connection_id_map_.end() if none found.
+  ConnectionIdMap::iterator FindConnectionIdDataInMap(
+      const QuicConnectionId& connection_id);
+  // Inserts a ConnectionIdData entry to connection_id_map_.
+  void AddConnectionIdDataToMap(const QuicConnectionId& canonical_connection_id,
+                                int num_packets,
+                                TimeWaitAction action,
+                                TimeWaitConnectionInfo info);
+  // Removes a ConnectionIdData entry in connection_id_map_.
+  void RemoveConnectionDataFromMap(ConnectionIdMap::iterator it);
+
   // Pending termination packets that need to be sent out to the peer when we
   // are given a chance to write by the dispatcher.
   QuicCircularDeque<std::unique_ptr<QueuedPacket>> pending_packets_queue_;
@@ -299,6 +328,11 @@
 
   // Interface that manages blocked writers.
   Visitor* visitor_;
+
+  // When this is default true, remove the connection_id argument of
+  // AddConnectionIdToTimeWait.
+  bool use_indirect_connection_id_map_ =
+      GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid);
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_time_wait_list_manager_test.cc b/quic/core/quic_time_wait_list_manager_test.cc
index cf40777..5ece763 100644
--- a/quic/core/quic_time_wait_list_manager_test.cc
+++ b/quic/core/quic_time_wait_list_manager_test.cc
@@ -13,6 +13,7 @@
 #include "quic/core/crypto/null_encrypter.h"
 #include "quic/core/crypto/quic_decrypter.h"
 #include "quic/core/crypto/quic_encrypter.h"
+#include "quic/core/quic_connection_id.h"
 #include "quic/core/quic_data_reader.h"
 #include "quic/core/quic_framer.h"
 #include "quic/core/quic_packet_writer.h"
@@ -154,7 +155,7 @@
         new QuicEncryptedPacket(nullptr, 0, false)));
     time_wait_list_manager_.AddConnectionIdToTimeWait(
         connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
-        TimeWaitConnectionInfo(false, &termination_packets));
+        TimeWaitConnectionInfo(false, &termination_packets, {connection_id}));
   }
 
   void AddConnectionId(
@@ -164,7 +165,8 @@
       std::vector<std::unique_ptr<QuicEncryptedPacket>>* packets) {
     time_wait_list_manager_.AddConnectionIdToTimeWait(
         connection_id, action,
-        TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets));
+        TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets,
+                               {connection_id}));
   }
 
   bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) {
@@ -443,6 +445,45 @@
             time_wait_list_manager_.num_connections());
 }
 
+TEST_F(QuicTimeWaitListManagerTest,
+       CleanUpOldConnectionIdsForMultipleConnectionIdsPerConnection) {
+  if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid)) {
+    return;
+  }
+
+  connection_id_ = TestConnectionId(7);
+  const size_t kConnectionCloseLength = 100;
+  EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_));
+  std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets;
+  termination_packets.push_back(
+      std::unique_ptr<QuicEncryptedPacket>(new QuicEncryptedPacket(
+          new char[kConnectionCloseLength], kConnectionCloseLength, true)));
+
+  // Add a CONNECTION_CLOSE termination packet.
+  std::vector<QuicConnectionId> active_connection_ids{connection_id_,
+                                                      TestConnectionId(8)};
+  time_wait_list_manager_.AddConnectionIdToTimeWait(
+      connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS,
+      TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+                             active_connection_ids, QuicTime::Delta::Zero()));
+
+  EXPECT_TRUE(
+      time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7)));
+  EXPECT_TRUE(
+      time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8)));
+
+  // Remove these IDs.
+  const QuicTime::Delta time_wait_period =
+      QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_);
+  clock_.AdvanceTime(time_wait_period);
+  time_wait_list_manager_.CleanUpOldConnectionIds();
+
+  EXPECT_FALSE(
+      time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7)));
+  EXPECT_FALSE(
+      time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8)));
+}
+
 TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) {
   QuicConnectionId connection_id = TestConnectionId(1);
   EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id));
@@ -631,7 +672,8 @@
           new char[kConnectionCloseLength], kConnectionCloseLength, true)));
   time_wait_list_manager_.AddConnectionIdToTimeWait(
       connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
-      TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets));
+      TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+                             {connection_id_}));
 
   // Termination packet is not encrypted, instead, send stateless reset.
   EXPECT_CALL(writer_,
@@ -655,7 +697,8 @@
   // Add a CONNECTION_CLOSE termination packet.
   time_wait_list_manager_.AddConnectionIdToTimeWait(
       connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS,
-      TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets));
+      TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+                             {connection_id_}));
   EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength,
                                    self_address_.host(), peer_address_, _))
       .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1)));
@@ -666,6 +709,40 @@
       IETF_QUIC_SHORT_HEADER_PACKET, std::make_unique<QuicPerPacketContext>());
 }
 
+TEST_F(QuicTimeWaitListManagerTest,
+       SendConnectionClosePacketsForMultipleConnectionIds) {
+  if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid)) {
+    return;
+  }
+
+  connection_id_ = TestConnectionId(7);
+  const size_t kConnectionCloseLength = 100;
+  EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_));
+  std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets;
+  termination_packets.push_back(
+      std::unique_ptr<QuicEncryptedPacket>(new QuicEncryptedPacket(
+          new char[kConnectionCloseLength], kConnectionCloseLength, true)));
+
+  // Add a CONNECTION_CLOSE termination packet.
+  std::vector<QuicConnectionId> active_connection_ids{connection_id_,
+                                                      TestConnectionId(8)};
+  time_wait_list_manager_.AddConnectionIdToTimeWait(
+      connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS,
+      TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+                             active_connection_ids, QuicTime::Delta::Zero()));
+
+  EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength,
+                                   self_address_.host(), peer_address_, _))
+      .Times(2)
+      .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 1)));
+  // Processes IETF short header packet.
+  for (auto const& cid : active_connection_ids) {
+    time_wait_list_manager_.ProcessPacket(
+        self_address_, peer_address_, cid, IETF_QUIC_SHORT_HEADER_PACKET,
+        std::make_unique<QuicPerPacketContext>());
+  }
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic