Add multiple CID per connection support to time_wait_list. Protected by FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid. PiperOrigin-RevId: 350245863 Change-Id: Ida67aeabe3554d6229859fa0179d66c560da53d3
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index 3ea815b..0a093f1 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -5562,6 +5562,11 @@ OnSuccessfulMigration(); } +std::vector<QuicConnectionId> QuicConnection::GetActiveServerConnectionIds() + const { + return {server_connection_id_}; +} + void QuicConnection::SetUnackedMapInitialCapacity() { sent_packet_manager_.ReserveUnackedPacketsInitialCapacity( GetUnackedMapInitialCapacity());
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h index 48b560f..64e83b0 100644 --- a/quic/core/quic_connection.h +++ b/quic/core/quic_connection.h
@@ -1178,6 +1178,8 @@ : encryption_level_); } + virtual std::vector<QuicConnectionId> GetActiveServerConnectionIds() const; + protected: // Calls cancel() on all the alarms owned by this connection. void CancelAllAlarms();
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc index 9e17539..8003dbb 100644 --- a/quic/core/quic_dispatcher.cc +++ b/quic/core/quic_dispatcher.cc
@@ -160,7 +160,8 @@ time_wait_list_manager_->AddConnectionIdToTimeWait( server_connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, - TimeWaitConnectionInfo(ietf_quic, collector_.packets())); + TimeWaitConnectionInfo(ietf_quic, collector_.packets(), + {server_connection_id_})); } private: @@ -831,6 +832,7 @@ TimeWaitConnectionInfo( connection->version().HasIetfInvariantHeader(), connection->termination_packets(), + connection->GetActiveServerConnectionIds(), connection->sent_packet_manager().GetRttStats()->smoothed_rtt())); } @@ -1072,7 +1074,8 @@ << ", error_details:" << error_details; time_wait_list_manager_->AddConnectionIdToTimeWait( server_connection_id, action, - TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr)); + TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr, + {server_connection_id})); return; } @@ -1108,7 +1111,7 @@ time_wait_list_manager()->AddConnectionIdToTimeWait( server_connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, TimeWaitConnectionInfo(/*ietf_quic=*/format != GOOGLE_QUIC_PACKET, - &termination_packets)); + &termination_packets, {server_connection_id})); } bool QuicDispatcher::ShouldCreateSessionForUnknownVersion(
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h index f64344c..fd5da05 100644 --- a/quic/core/quic_flags_list.h +++ b/quic/core/quic_flags_list.h
@@ -75,4 +75,5 @@ QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false) QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_false, false) QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_true, true) +QUIC_FLAG(FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid, false) QUIC_FLAG(FLAGS_quic_restart_flag_quic_use_reference_counted_sesssion_map, false)
diff --git a/quic/core/quic_time_wait_list_manager.cc b/quic/core/quic_time_wait_list_manager.cc index c82fb4f..11fca70 100644 --- a/quic/core/quic_time_wait_list_manager.cc +++ b/quic/core/quic_time_wait_list_manager.cc
@@ -49,16 +49,21 @@ TimeWaitConnectionInfo::TimeWaitConnectionInfo( bool ietf_quic, - std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets) + std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets, + std::vector<QuicConnectionId> active_connection_ids) : TimeWaitConnectionInfo(ietf_quic, termination_packets, + std::move(active_connection_ids), QuicTime::Delta::Zero()) {} TimeWaitConnectionInfo::TimeWaitConnectionInfo( bool ietf_quic, std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets, + std::vector<QuicConnectionId> active_connection_ids, QuicTime::Delta srtt) - : ietf_quic(ietf_quic), srtt(srtt) { + : ietf_quic(ietf_quic), + active_connection_ids(std::move(active_connection_ids)), + srtt(srtt) { if (termination_packets != nullptr) { this->termination_packets.swap(*termination_packets); } @@ -76,6 +81,9 @@ clock_(clock), writer_(writer), visitor_(visitor) { + if (use_indirect_connection_id_map_) { + QUIC_RESTART_FLAG_COUNT(quic_time_wait_list_support_multiple_cid); + } SetConnectionIdCleanUpAlarm(); } @@ -83,35 +91,80 @@ connection_id_clean_up_alarm_->Cancel(); } +QuicTimeWaitListManager::ConnectionIdMap::iterator +QuicTimeWaitListManager::FindConnectionIdDataInMap( + const QuicConnectionId& connection_id) { + if (!use_indirect_connection_id_map_) { + return connection_id_map_.find(connection_id); + } + auto it = indirect_connection_id_map_.find(connection_id); + if (it == indirect_connection_id_map_.end()) { + return connection_id_map_.end(); + } + return connection_id_map_.find(it->second); +} + +void QuicTimeWaitListManager::AddConnectionIdDataToMap( + const QuicConnectionId& canonical_connection_id, + int num_packets, + TimeWaitAction action, + TimeWaitConnectionInfo info) { + if (use_indirect_connection_id_map_) { + for (const auto& cid : info.active_connection_ids) { + indirect_connection_id_map_[cid] = canonical_connection_id; + } + } + ConnectionIdData data(num_packets, clock_->ApproximateNow(), action, + std::move(info)); + connection_id_map_.emplace( + std::make_pair(canonical_connection_id, std::move(data))); +} + +void QuicTimeWaitListManager::RemoveConnectionDataFromMap( + ConnectionIdMap::iterator it) { + if (use_indirect_connection_id_map_) { + for (const auto& cid : it->second.info.active_connection_ids) { + indirect_connection_id_map_.erase(cid); + } + } + connection_id_map_.erase(it); +} + void QuicTimeWaitListManager::AddConnectionIdToTimeWait( QuicConnectionId connection_id, TimeWaitAction action, TimeWaitConnectionInfo info) { + DCHECK(!info.active_connection_ids.empty()); + const QuicConnectionId& canonical_connection_id = + use_indirect_connection_id_map_ ? info.active_connection_ids.front() + : connection_id; DCHECK(action != SEND_TERMINATION_PACKETS || !info.termination_packets.empty()); DCHECK(action != DO_NOTHING || info.ietf_quic); int num_packets = 0; - auto it = connection_id_map_.find(connection_id); + auto it = FindConnectionIdDataInMap(canonical_connection_id); const bool new_connection_id = it == connection_id_map_.end(); if (!new_connection_id) { // Replace record if it is reinserted. num_packets = it->second.num_packets; - connection_id_map_.erase(it); + RemoveConnectionDataFromMap(it); } TrimTimeWaitListIfNeeded(); int64_t max_connections = GetQuicFlag(FLAGS_quic_time_wait_list_max_connections); DCHECK(connection_id_map_.empty() || num_connections() < static_cast<size_t>(max_connections)); - ConnectionIdData data(num_packets, clock_->ApproximateNow(), action, - std::move(info)); - connection_id_map_.emplace(std::make_pair(connection_id, std::move(data))); + AddConnectionIdDataToMap(canonical_connection_id, num_packets, action, + std::move(info)); if (new_connection_id) { - visitor_->OnConnectionAddedToTimeWaitList(connection_id); + visitor_->OnConnectionAddedToTimeWaitList(canonical_connection_id); } } bool QuicTimeWaitListManager::IsConnectionIdInTimeWait( QuicConnectionId connection_id) const { + if (use_indirect_connection_id_map_) { + return indirect_connection_id_map_.contains(connection_id); + } return QuicContainsKey(connection_id_map_, connection_id); } @@ -135,7 +188,7 @@ DCHECK(IsConnectionIdInTimeWait(connection_id)); // TODO(satyamshekhar): Think about handling packets from different peer // addresses. - auto it = connection_id_map_.find(connection_id); + auto it = FindConnectionIdDataInMap(connection_id); DCHECK(it != connection_id_map_.end()); // Increment the received packet count. ConnectionIdData* connection_data = &it->second; @@ -388,7 +441,7 @@ // This connection_id has lived its age, retire it now. QUIC_DLOG(INFO) << "Connection " << it->first << " expired from time wait list"; - connection_id_map_.erase(it); + RemoveConnectionDataFromMap(it); if (expiration_time == QuicTime::Infinite()) { QUIC_CODE_COUNT(quic_time_wait_list_trim_full); } else {
diff --git a/quic/core/quic_time_wait_list_manager.h b/quic/core/quic_time_wait_list_manager.h index 1bb3a68..ad6d45c 100644 --- a/quic/core/quic_time_wait_list_manager.h +++ b/quic/core/quic_time_wait_list_manager.h
@@ -12,6 +12,7 @@ #include <memory> #include "quic/core/quic_blocked_writer_interface.h" +#include "quic/core/quic_connection_id.h" #include "quic/core/quic_framer.h" #include "quic/core/quic_packet_writer.h" #include "quic/core/quic_packets.h" @@ -31,10 +32,12 @@ struct QUIC_NO_EXPORT TimeWaitConnectionInfo { TimeWaitConnectionInfo( bool ietf_quic, - std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets); + std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets, + std::vector<QuicConnectionId> active_connection_ids); TimeWaitConnectionInfo( bool ietf_quic, std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets, + std::vector<QuicConnectionId> active_connection_ids, QuicTime::Delta srtt); TimeWaitConnectionInfo(const TimeWaitConnectionInfo& other) = delete; @@ -44,6 +47,7 @@ bool ietf_quic; std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets; + std::vector<QuicConnectionId> active_connection_ids; QuicTime::Delta srtt; }; @@ -278,8 +282,33 @@ using ConnectionIdMap = QuicLinkedHashMap<QuicConnectionId, ConnectionIdData, QuicConnectionIdHash>; + // Do not use find/emplace/erase on this map directly. Use + // FindConnectionIdDataInMap, AddConnectionIdDateToMap, + // RemoveConnectionDataFromMap instead. ConnectionIdMap connection_id_map_; + // TODO(haoyuewang) Consider making connection_id_map_ a map of shared pointer + // and remove the indirect map. + // A connection can have multiple unretired ConnectionIds when it is closed. + // These Ids have the same ConnectionIdData entry in connection_id_map_. To + // find the entry, look up the cannoical ConnectionId in + // indirect_connection_id_map_ first, and look up connection_id_map_ with the + // cannoical ConnectionId. + QuicHashMap<QuicConnectionId, QuicConnectionId, QuicConnectionIdHash> + indirect_connection_id_map_; + + // Find an iterator for the given connection_id. Returns + // connection_id_map_.end() if none found. + ConnectionIdMap::iterator FindConnectionIdDataInMap( + const QuicConnectionId& connection_id); + // Inserts a ConnectionIdData entry to connection_id_map_. + void AddConnectionIdDataToMap(const QuicConnectionId& canonical_connection_id, + int num_packets, + TimeWaitAction action, + TimeWaitConnectionInfo info); + // Removes a ConnectionIdData entry in connection_id_map_. + void RemoveConnectionDataFromMap(ConnectionIdMap::iterator it); + // Pending termination packets that need to be sent out to the peer when we // are given a chance to write by the dispatcher. QuicCircularDeque<std::unique_ptr<QueuedPacket>> pending_packets_queue_; @@ -299,6 +328,11 @@ // Interface that manages blocked writers. Visitor* visitor_; + + // When this is default true, remove the connection_id argument of + // AddConnectionIdToTimeWait. + bool use_indirect_connection_id_map_ = + GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid); }; } // namespace quic
diff --git a/quic/core/quic_time_wait_list_manager_test.cc b/quic/core/quic_time_wait_list_manager_test.cc index cf40777..5ece763 100644 --- a/quic/core/quic_time_wait_list_manager_test.cc +++ b/quic/core/quic_time_wait_list_manager_test.cc
@@ -13,6 +13,7 @@ #include "quic/core/crypto/null_encrypter.h" #include "quic/core/crypto/quic_decrypter.h" #include "quic/core/crypto/quic_encrypter.h" +#include "quic/core/quic_connection_id.h" #include "quic/core/quic_data_reader.h" #include "quic/core/quic_framer.h" #include "quic/core/quic_packet_writer.h" @@ -154,7 +155,7 @@ new QuicEncryptedPacket(nullptr, 0, false))); time_wait_list_manager_.AddConnectionIdToTimeWait( connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, - TimeWaitConnectionInfo(false, &termination_packets)); + TimeWaitConnectionInfo(false, &termination_packets, {connection_id})); } void AddConnectionId( @@ -164,7 +165,8 @@ std::vector<std::unique_ptr<QuicEncryptedPacket>>* packets) { time_wait_list_manager_.AddConnectionIdToTimeWait( connection_id, action, - TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets)); + TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets, + {connection_id})); } bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) { @@ -443,6 +445,45 @@ time_wait_list_manager_.num_connections()); } +TEST_F(QuicTimeWaitListManagerTest, + CleanUpOldConnectionIdsForMultipleConnectionIdsPerConnection) { + if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid)) { + return; + } + + connection_id_ = TestConnectionId(7); + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets; + termination_packets.push_back( + std::unique_ptr<QuicEncryptedPacket>(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + + // Add a CONNECTION_CLOSE termination packet. + std::vector<QuicConnectionId> active_connection_ids{connection_id_, + TestConnectionId(8)}; + time_wait_list_manager_.AddConnectionIdToTimeWait( + connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + active_connection_ids, QuicTime::Delta::Zero())); + + EXPECT_TRUE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7))); + EXPECT_TRUE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8))); + + // Remove these IDs. + const QuicTime::Delta time_wait_period = + QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); + clock_.AdvanceTime(time_wait_period); + time_wait_list_manager_.CleanUpOldConnectionIds(); + + EXPECT_FALSE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7))); + EXPECT_FALSE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8))); +} + TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) { QuicConnectionId connection_id = TestConnectionId(1); EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id)); @@ -631,7 +672,8 @@ new char[kConnectionCloseLength], kConnectionCloseLength, true))); time_wait_list_manager_.AddConnectionIdToTimeWait( connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, - TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets)); + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + {connection_id_})); // Termination packet is not encrypted, instead, send stateless reset. EXPECT_CALL(writer_, @@ -655,7 +697,8 @@ // Add a CONNECTION_CLOSE termination packet. time_wait_list_manager_.AddConnectionIdToTimeWait( connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, - TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets)); + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + {connection_id_})); EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, self_address_.host(), peer_address_, _)) .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); @@ -666,6 +709,40 @@ IETF_QUIC_SHORT_HEADER_PACKET, std::make_unique<QuicPerPacketContext>()); } +TEST_F(QuicTimeWaitListManagerTest, + SendConnectionClosePacketsForMultipleConnectionIds) { + if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid)) { + return; + } + + connection_id_ = TestConnectionId(7); + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets; + termination_packets.push_back( + std::unique_ptr<QuicEncryptedPacket>(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + + // Add a CONNECTION_CLOSE termination packet. + std::vector<QuicConnectionId> active_connection_ids{connection_id_, + TestConnectionId(8)}; + time_wait_list_manager_.AddConnectionIdToTimeWait( + connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + active_connection_ids, QuicTime::Delta::Zero())); + + EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, + self_address_.host(), peer_address_, _)) + .Times(2) + .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 1))); + // Processes IETF short header packet. + for (auto const& cid : active_connection_ids) { + time_wait_list_manager_.ProcessPacket( + self_address_, peer_address_, cid, IETF_QUIC_SHORT_HEADER_PACKET, + std::make_unique<QuicPerPacketContext>()); + } +} + } // namespace } // namespace test } // namespace quic