Add multiple CID per connection support to time_wait_list.
Protected by FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid.
PiperOrigin-RevId: 350245863
Change-Id: Ida67aeabe3554d6229859fa0179d66c560da53d3
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 3ea815b..0a093f1 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -5562,6 +5562,11 @@
OnSuccessfulMigration();
}
+std::vector<QuicConnectionId> QuicConnection::GetActiveServerConnectionIds()
+ const {
+ return {server_connection_id_};
+}
+
void QuicConnection::SetUnackedMapInitialCapacity() {
sent_packet_manager_.ReserveUnackedPacketsInitialCapacity(
GetUnackedMapInitialCapacity());
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 48b560f..64e83b0 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1178,6 +1178,8 @@
: encryption_level_);
}
+ virtual std::vector<QuicConnectionId> GetActiveServerConnectionIds() const;
+
protected:
// Calls cancel() on all the alarms owned by this connection.
void CancelAllAlarms();
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 9e17539..8003dbb 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -160,7 +160,8 @@
time_wait_list_manager_->AddConnectionIdToTimeWait(
server_connection_id_,
QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
- TimeWaitConnectionInfo(ietf_quic, collector_.packets()));
+ TimeWaitConnectionInfo(ietf_quic, collector_.packets(),
+ {server_connection_id_}));
}
private:
@@ -831,6 +832,7 @@
TimeWaitConnectionInfo(
connection->version().HasIetfInvariantHeader(),
connection->termination_packets(),
+ connection->GetActiveServerConnectionIds(),
connection->sent_packet_manager().GetRttStats()->smoothed_rtt()));
}
@@ -1072,7 +1074,8 @@
<< ", error_details:" << error_details;
time_wait_list_manager_->AddConnectionIdToTimeWait(
server_connection_id, action,
- TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr));
+ TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr,
+ {server_connection_id}));
return;
}
@@ -1108,7 +1111,7 @@
time_wait_list_manager()->AddConnectionIdToTimeWait(
server_connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
TimeWaitConnectionInfo(/*ietf_quic=*/format != GOOGLE_QUIC_PACKET,
- &termination_packets));
+ &termination_packets, {server_connection_id}));
}
bool QuicDispatcher::ShouldCreateSessionForUnknownVersion(
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index f64344c..fd5da05 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -75,4 +75,5 @@
QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false)
QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_false, false)
QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_true, true)
+QUIC_FLAG(FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid, false)
QUIC_FLAG(FLAGS_quic_restart_flag_quic_use_reference_counted_sesssion_map, false)
diff --git a/quic/core/quic_time_wait_list_manager.cc b/quic/core/quic_time_wait_list_manager.cc
index c82fb4f..11fca70 100644
--- a/quic/core/quic_time_wait_list_manager.cc
+++ b/quic/core/quic_time_wait_list_manager.cc
@@ -49,16 +49,21 @@
TimeWaitConnectionInfo::TimeWaitConnectionInfo(
bool ietf_quic,
- std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets)
+ std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+ std::vector<QuicConnectionId> active_connection_ids)
: TimeWaitConnectionInfo(ietf_quic,
termination_packets,
+ std::move(active_connection_ids),
QuicTime::Delta::Zero()) {}
TimeWaitConnectionInfo::TimeWaitConnectionInfo(
bool ietf_quic,
std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+ std::vector<QuicConnectionId> active_connection_ids,
QuicTime::Delta srtt)
- : ietf_quic(ietf_quic), srtt(srtt) {
+ : ietf_quic(ietf_quic),
+ active_connection_ids(std::move(active_connection_ids)),
+ srtt(srtt) {
if (termination_packets != nullptr) {
this->termination_packets.swap(*termination_packets);
}
@@ -76,6 +81,9 @@
clock_(clock),
writer_(writer),
visitor_(visitor) {
+ if (use_indirect_connection_id_map_) {
+ QUIC_RESTART_FLAG_COUNT(quic_time_wait_list_support_multiple_cid);
+ }
SetConnectionIdCleanUpAlarm();
}
@@ -83,35 +91,80 @@
connection_id_clean_up_alarm_->Cancel();
}
+QuicTimeWaitListManager::ConnectionIdMap::iterator
+QuicTimeWaitListManager::FindConnectionIdDataInMap(
+ const QuicConnectionId& connection_id) {
+ if (!use_indirect_connection_id_map_) {
+ return connection_id_map_.find(connection_id);
+ }
+ auto it = indirect_connection_id_map_.find(connection_id);
+ if (it == indirect_connection_id_map_.end()) {
+ return connection_id_map_.end();
+ }
+ return connection_id_map_.find(it->second);
+}
+
+void QuicTimeWaitListManager::AddConnectionIdDataToMap(
+ const QuicConnectionId& canonical_connection_id,
+ int num_packets,
+ TimeWaitAction action,
+ TimeWaitConnectionInfo info) {
+ if (use_indirect_connection_id_map_) {
+ for (const auto& cid : info.active_connection_ids) {
+ indirect_connection_id_map_[cid] = canonical_connection_id;
+ }
+ }
+ ConnectionIdData data(num_packets, clock_->ApproximateNow(), action,
+ std::move(info));
+ connection_id_map_.emplace(
+ std::make_pair(canonical_connection_id, std::move(data)));
+}
+
+void QuicTimeWaitListManager::RemoveConnectionDataFromMap(
+ ConnectionIdMap::iterator it) {
+ if (use_indirect_connection_id_map_) {
+ for (const auto& cid : it->second.info.active_connection_ids) {
+ indirect_connection_id_map_.erase(cid);
+ }
+ }
+ connection_id_map_.erase(it);
+}
+
void QuicTimeWaitListManager::AddConnectionIdToTimeWait(
QuicConnectionId connection_id,
TimeWaitAction action,
TimeWaitConnectionInfo info) {
+ DCHECK(!info.active_connection_ids.empty());
+ const QuicConnectionId& canonical_connection_id =
+ use_indirect_connection_id_map_ ? info.active_connection_ids.front()
+ : connection_id;
DCHECK(action != SEND_TERMINATION_PACKETS ||
!info.termination_packets.empty());
DCHECK(action != DO_NOTHING || info.ietf_quic);
int num_packets = 0;
- auto it = connection_id_map_.find(connection_id);
+ auto it = FindConnectionIdDataInMap(canonical_connection_id);
const bool new_connection_id = it == connection_id_map_.end();
if (!new_connection_id) { // Replace record if it is reinserted.
num_packets = it->second.num_packets;
- connection_id_map_.erase(it);
+ RemoveConnectionDataFromMap(it);
}
TrimTimeWaitListIfNeeded();
int64_t max_connections =
GetQuicFlag(FLAGS_quic_time_wait_list_max_connections);
DCHECK(connection_id_map_.empty() ||
num_connections() < static_cast<size_t>(max_connections));
- ConnectionIdData data(num_packets, clock_->ApproximateNow(), action,
- std::move(info));
- connection_id_map_.emplace(std::make_pair(connection_id, std::move(data)));
+ AddConnectionIdDataToMap(canonical_connection_id, num_packets, action,
+ std::move(info));
if (new_connection_id) {
- visitor_->OnConnectionAddedToTimeWaitList(connection_id);
+ visitor_->OnConnectionAddedToTimeWaitList(canonical_connection_id);
}
}
bool QuicTimeWaitListManager::IsConnectionIdInTimeWait(
QuicConnectionId connection_id) const {
+ if (use_indirect_connection_id_map_) {
+ return indirect_connection_id_map_.contains(connection_id);
+ }
return QuicContainsKey(connection_id_map_, connection_id);
}
@@ -135,7 +188,7 @@
DCHECK(IsConnectionIdInTimeWait(connection_id));
// TODO(satyamshekhar): Think about handling packets from different peer
// addresses.
- auto it = connection_id_map_.find(connection_id);
+ auto it = FindConnectionIdDataInMap(connection_id);
DCHECK(it != connection_id_map_.end());
// Increment the received packet count.
ConnectionIdData* connection_data = &it->second;
@@ -388,7 +441,7 @@
// This connection_id has lived its age, retire it now.
QUIC_DLOG(INFO) << "Connection " << it->first
<< " expired from time wait list";
- connection_id_map_.erase(it);
+ RemoveConnectionDataFromMap(it);
if (expiration_time == QuicTime::Infinite()) {
QUIC_CODE_COUNT(quic_time_wait_list_trim_full);
} else {
diff --git a/quic/core/quic_time_wait_list_manager.h b/quic/core/quic_time_wait_list_manager.h
index 1bb3a68..ad6d45c 100644
--- a/quic/core/quic_time_wait_list_manager.h
+++ b/quic/core/quic_time_wait_list_manager.h
@@ -12,6 +12,7 @@
#include <memory>
#include "quic/core/quic_blocked_writer_interface.h"
+#include "quic/core/quic_connection_id.h"
#include "quic/core/quic_framer.h"
#include "quic/core/quic_packet_writer.h"
#include "quic/core/quic_packets.h"
@@ -31,10 +32,12 @@
struct QUIC_NO_EXPORT TimeWaitConnectionInfo {
TimeWaitConnectionInfo(
bool ietf_quic,
- std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets);
+ std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+ std::vector<QuicConnectionId> active_connection_ids);
TimeWaitConnectionInfo(
bool ietf_quic,
std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets,
+ std::vector<QuicConnectionId> active_connection_ids,
QuicTime::Delta srtt);
TimeWaitConnectionInfo(const TimeWaitConnectionInfo& other) = delete;
@@ -44,6 +47,7 @@
bool ietf_quic;
std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets;
+ std::vector<QuicConnectionId> active_connection_ids;
QuicTime::Delta srtt;
};
@@ -278,8 +282,33 @@
using ConnectionIdMap = QuicLinkedHashMap<QuicConnectionId,
ConnectionIdData,
QuicConnectionIdHash>;
+ // Do not use find/emplace/erase on this map directly. Use
+ // FindConnectionIdDataInMap, AddConnectionIdDateToMap,
+ // RemoveConnectionDataFromMap instead.
ConnectionIdMap connection_id_map_;
+ // TODO(haoyuewang) Consider making connection_id_map_ a map of shared pointer
+ // and remove the indirect map.
+ // A connection can have multiple unretired ConnectionIds when it is closed.
+ // These Ids have the same ConnectionIdData entry in connection_id_map_. To
+ // find the entry, look up the cannoical ConnectionId in
+ // indirect_connection_id_map_ first, and look up connection_id_map_ with the
+ // cannoical ConnectionId.
+ QuicHashMap<QuicConnectionId, QuicConnectionId, QuicConnectionIdHash>
+ indirect_connection_id_map_;
+
+ // Find an iterator for the given connection_id. Returns
+ // connection_id_map_.end() if none found.
+ ConnectionIdMap::iterator FindConnectionIdDataInMap(
+ const QuicConnectionId& connection_id);
+ // Inserts a ConnectionIdData entry to connection_id_map_.
+ void AddConnectionIdDataToMap(const QuicConnectionId& canonical_connection_id,
+ int num_packets,
+ TimeWaitAction action,
+ TimeWaitConnectionInfo info);
+ // Removes a ConnectionIdData entry in connection_id_map_.
+ void RemoveConnectionDataFromMap(ConnectionIdMap::iterator it);
+
// Pending termination packets that need to be sent out to the peer when we
// are given a chance to write by the dispatcher.
QuicCircularDeque<std::unique_ptr<QueuedPacket>> pending_packets_queue_;
@@ -299,6 +328,11 @@
// Interface that manages blocked writers.
Visitor* visitor_;
+
+ // When this is default true, remove the connection_id argument of
+ // AddConnectionIdToTimeWait.
+ bool use_indirect_connection_id_map_ =
+ GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid);
};
} // namespace quic
diff --git a/quic/core/quic_time_wait_list_manager_test.cc b/quic/core/quic_time_wait_list_manager_test.cc
index cf40777..5ece763 100644
--- a/quic/core/quic_time_wait_list_manager_test.cc
+++ b/quic/core/quic_time_wait_list_manager_test.cc
@@ -13,6 +13,7 @@
#include "quic/core/crypto/null_encrypter.h"
#include "quic/core/crypto/quic_decrypter.h"
#include "quic/core/crypto/quic_encrypter.h"
+#include "quic/core/quic_connection_id.h"
#include "quic/core/quic_data_reader.h"
#include "quic/core/quic_framer.h"
#include "quic/core/quic_packet_writer.h"
@@ -154,7 +155,7 @@
new QuicEncryptedPacket(nullptr, 0, false)));
time_wait_list_manager_.AddConnectionIdToTimeWait(
connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
- TimeWaitConnectionInfo(false, &termination_packets));
+ TimeWaitConnectionInfo(false, &termination_packets, {connection_id}));
}
void AddConnectionId(
@@ -164,7 +165,8 @@
std::vector<std::unique_ptr<QuicEncryptedPacket>>* packets) {
time_wait_list_manager_.AddConnectionIdToTimeWait(
connection_id, action,
- TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets));
+ TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets,
+ {connection_id}));
}
bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) {
@@ -443,6 +445,45 @@
time_wait_list_manager_.num_connections());
}
+TEST_F(QuicTimeWaitListManagerTest,
+ CleanUpOldConnectionIdsForMultipleConnectionIdsPerConnection) {
+ if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid)) {
+ return;
+ }
+
+ connection_id_ = TestConnectionId(7);
+ const size_t kConnectionCloseLength = 100;
+ EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_));
+ std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets;
+ termination_packets.push_back(
+ std::unique_ptr<QuicEncryptedPacket>(new QuicEncryptedPacket(
+ new char[kConnectionCloseLength], kConnectionCloseLength, true)));
+
+ // Add a CONNECTION_CLOSE termination packet.
+ std::vector<QuicConnectionId> active_connection_ids{connection_id_,
+ TestConnectionId(8)};
+ time_wait_list_manager_.AddConnectionIdToTimeWait(
+ connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS,
+ TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+ active_connection_ids, QuicTime::Delta::Zero()));
+
+ EXPECT_TRUE(
+ time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7)));
+ EXPECT_TRUE(
+ time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8)));
+
+ // Remove these IDs.
+ const QuicTime::Delta time_wait_period =
+ QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_);
+ clock_.AdvanceTime(time_wait_period);
+ time_wait_list_manager_.CleanUpOldConnectionIds();
+
+ EXPECT_FALSE(
+ time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7)));
+ EXPECT_FALSE(
+ time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8)));
+}
+
TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) {
QuicConnectionId connection_id = TestConnectionId(1);
EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id));
@@ -631,7 +672,8 @@
new char[kConnectionCloseLength], kConnectionCloseLength, true)));
time_wait_list_manager_.AddConnectionIdToTimeWait(
connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS,
- TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets));
+ TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+ {connection_id_}));
// Termination packet is not encrypted, instead, send stateless reset.
EXPECT_CALL(writer_,
@@ -655,7 +697,8 @@
// Add a CONNECTION_CLOSE termination packet.
time_wait_list_manager_.AddConnectionIdToTimeWait(
connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS,
- TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets));
+ TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+ {connection_id_}));
EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength,
self_address_.host(), peer_address_, _))
.WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1)));
@@ -666,6 +709,40 @@
IETF_QUIC_SHORT_HEADER_PACKET, std::make_unique<QuicPerPacketContext>());
}
+TEST_F(QuicTimeWaitListManagerTest,
+ SendConnectionClosePacketsForMultipleConnectionIds) {
+ if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid)) {
+ return;
+ }
+
+ connection_id_ = TestConnectionId(7);
+ const size_t kConnectionCloseLength = 100;
+ EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_));
+ std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets;
+ termination_packets.push_back(
+ std::unique_ptr<QuicEncryptedPacket>(new QuicEncryptedPacket(
+ new char[kConnectionCloseLength], kConnectionCloseLength, true)));
+
+ // Add a CONNECTION_CLOSE termination packet.
+ std::vector<QuicConnectionId> active_connection_ids{connection_id_,
+ TestConnectionId(8)};
+ time_wait_list_manager_.AddConnectionIdToTimeWait(
+ connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS,
+ TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets,
+ active_connection_ids, QuicTime::Delta::Zero()));
+
+ EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength,
+ self_address_.host(), peer_address_, _))
+ .Times(2)
+ .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 1)));
+ // Processes IETF short header packet.
+ for (auto const& cid : active_connection_ids) {
+ time_wait_list_manager_.ProcessPacket(
+ self_address_, peer_address_, cid, IETF_QUIC_SHORT_HEADER_PACKET,
+ std::make_unique<QuicPerPacketContext>());
+ }
+}
+
} // namespace
} // namespace test
} // namespace quic