Add new & retire ConnectionId interface to QuicSession::visitor and implement them in (Gfe)QuicDispatcher. Protected by FLAGS_quic_restart_flag_quic_dispatcher_support_multiple_cid_per_connection. PiperOrigin-RevId: 350347306 Change-Id: I4f0292f0d56cbcfe19854f09285d9389f729ce08
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc index 8003dbb..28f3923 100644 --- a/quic/core/quic_dispatcher.cc +++ b/quic/core/quic_dispatcher.cc
@@ -13,6 +13,7 @@ #include "quic/core/chlo_extractor.h" #include "quic/core/crypto/crypto_protocol.h" #include "quic/core/crypto/quic_random.h" +#include "quic/core/quic_connection_id.h" #include "quic/core/quic_error_codes.h" #include "quic/core/quic_session.h" #include "quic/core/quic_time_wait_list_manager.h" @@ -154,14 +155,16 @@ // |error_code| and |error_details| and add the connection to time wait. void CloseConnection(QuicErrorCode error_code, const std::string& error_details, - bool ietf_quic) { + bool ietf_quic, + std::vector<QuicConnectionId> active_connection_ids) { SerializeConnectionClosePacket(error_code, error_details); time_wait_list_manager_->AddConnectionIdToTimeWait( server_connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, TimeWaitConnectionInfo(ietf_quic, collector_.packets(), - {server_connection_id_})); + std::move(active_connection_ids), + /*srtt=*/QuicTime::Delta::Zero())); } private: @@ -323,6 +326,10 @@ if (use_reference_counted_session_map_) { QUIC_RESTART_FLAG_COUNT(quic_use_reference_counted_sesssion_map); } + if (support_multiple_cid_per_connection_) { + QUIC_RESTART_FLAG_COUNT( + quic_dispatcher_support_multiple_cid_per_connection); + } QUIC_BUG_IF(GetSupportedVersions().empty()) << "Trying to create dispatcher without any supported versions"; QUIC_DLOG(INFO) << "Created QuicDispatcher with versions: " @@ -333,6 +340,9 @@ if (use_reference_counted_session_map_) { reference_counted_session_map_.clear(); closed_ref_counted_session_list_.clear(); + if (support_multiple_cid_per_connection_) { + num_sessions_in_session_map_ = 0; + } } else { session_map_.clear(); closed_session_list_.clear(); @@ -807,22 +817,35 @@ } else { QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_handshake_failed); } - action = QuicTimeWaitListManager::SEND_TERMINATION_PACKETS; - // This serializes a connection close termination packet with error code - // QUIC_HANDSHAKE_FAILED and adds the connection to the time wait list. - StatelesslyTerminateConnection( - connection->connection_id(), - connection->version().HasIetfInvariantHeader() - ? IETF_QUIC_LONG_HEADER_PACKET - : GOOGLE_QUIC_PACKET, - /*version_flag=*/true, - connection->version().HasLengthPrefixedConnectionIds(), - connection->version(), QUIC_HANDSHAKE_FAILED, - "Connection is closed by server before handshake confirmed", - // Although it is our intention to send termination packets, the - // |action| argument is not used by this call to - // StatelesslyTerminateConnection(). - action); + if (support_multiple_cid_per_connection_) { + // This serializes a connection close termination packet with error code + // QUIC_HANDSHAKE_FAILED and adds the connection to the time wait list. + StatelessConnectionTerminator terminator( + server_connection_id, connection->version(), helper_.get(), + time_wait_list_manager_.get()); + terminator.CloseConnection( + QUIC_HANDSHAKE_FAILED, + "Connection is closed by server before handshake confirmed", + connection->version().HasIetfInvariantHeader(), + connection->GetActiveServerConnectionIds()); + } else { + action = QuicTimeWaitListManager::SEND_TERMINATION_PACKETS; + // This serializes a connection close termination packet with error code + // QUIC_HANDSHAKE_FAILED and adds the connection to the time wait list. + StatelesslyTerminateConnection( + connection->connection_id(), + connection->version().HasIetfInvariantHeader() + ? IETF_QUIC_LONG_HEADER_PACKET + : GOOGLE_QUIC_PACKET, + /*version_flag=*/true, + connection->version().HasLengthPrefixedConnectionIds(), + connection->version(), QUIC_HANDSHAKE_FAILED, + "Connection is closed by server before handshake confirmed", + // Although it is our intention to send termination packets, the + // |action| argument is not used by this call to + // StatelesslyTerminateConnection(). + action); + } return; } QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_stateless_reset); @@ -1001,7 +1024,15 @@ closed_ref_counted_session_list_.push_back(std::move(it->second)); } CleanUpSession(it->first, connection, source); - reference_counted_session_map_.erase(it); + if (support_multiple_cid_per_connection_) { + for (const QuicConnectionId& cid : + connection->GetActiveServerConnectionIds()) { + reference_counted_session_map_.erase(cid); + } + --num_sessions_in_session_map_; + } else { + reference_counted_session_map_.erase(it); + } } else { auto it = session_map_.find(server_connection_id); if (it == session_map_.end()) { @@ -1052,6 +1083,29 @@ void QuicDispatcher::OnStopSendingReceived( const QuicStopSendingFrame& /*frame*/) {} +void QuicDispatcher::OnNewConnectionIdSent( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id) { + DCHECK(support_multiple_cid_per_connection_); + auto it = reference_counted_session_map_.find(server_connection_id); + if (it == reference_counted_session_map_.end()) { + QUIC_BUG << "Couldn't locate the session that issues the connection ID in " + "reference_counted_session_map_. server_connection_id:" + << server_connection_id + << " new_connection_id: " << new_connection_id; + return; + } + auto insertion_result = reference_counted_session_map_.insert( + std::make_pair(new_connection_id, it->second)); + DCHECK(insertion_result.second); +} + +void QuicDispatcher::OnConnectionIdRetired( + const QuicConnectionId& server_connection_id) { + DCHECK(support_multiple_cid_per_connection_); + reference_counted_session_map_.erase(server_connection_id); +} + void QuicDispatcher::OnConnectionAddedToTimeWaitList( QuicConnectionId server_connection_id) { QUIC_DLOG(INFO) << "Connection " << server_connection_id @@ -1091,8 +1145,9 @@ helper_.get(), time_wait_list_manager_.get()); // This also adds the connection to time wait list. - terminator.CloseConnection(error_code, error_details, - format != GOOGLE_QUIC_PACKET); + terminator.CloseConnection( + error_code, error_details, format != GOOGLE_QUIC_PACKET, + /*active_connection_ids=*/{server_connection_id}); return; } @@ -1164,10 +1219,14 @@ auto insertion_result = reference_counted_session_map_.insert( std::make_pair(server_connection_id, std::shared_ptr<QuicSession>(std::move(session)))); - QUIC_BUG_IF(!insertion_result.second) - << "Tried to add a session to session_map with existing connection " - "id: " - << server_connection_id; + if (!insertion_result.second) { + QUIC_BUG + << "Tried to add a session to session_map with existing connection " + "id: " + << server_connection_id; + } else if (support_multiple_cid_per_connection_) { + ++num_sessions_in_session_map_; + } DeliverPacketsToSession(packets, insertion_result.first->second.get()); } else { auto insertion_result = session_map_.insert( @@ -1279,9 +1338,13 @@ reference_counted_session_map_.insert(std::make_pair( packet_info->destination_connection_id, std::shared_ptr<QuicSession>(std::move(session.release())))); - QUIC_BUG_IF(!insertion_result.second) - << "Tried to add a session to session_map with existing connection id: " - << packet_info->destination_connection_id; + if (!insertion_result.second) { + QUIC_BUG << "Tried to add a session to session_map with existing " + "connection id: " + << packet_info->destination_connection_id; + } else if (support_multiple_cid_per_connection_) { + ++num_sessions_in_session_map_; + } session_ptr = insertion_result.first->second.get(); } else { auto insertion_result = session_map_.insert(std::make_pair( @@ -1365,4 +1428,13 @@ packet_info.form != GOOGLE_QUIC_PACKET, GetPerPacketContext()); } +size_t QuicDispatcher::NumSessions() const { + if (support_multiple_cid_per_connection_) { + return num_sessions_in_session_map_; + } + return use_reference_counted_session_map_ + ? reference_counted_session_map_.size() + : session_map_.size(); +} + } // namespace quic
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h index 756f7f6..75cdacf 100644 --- a/quic/core/quic_dispatcher.h +++ b/quic/core/quic_dispatcher.h
@@ -100,6 +100,19 @@ // Collects reset error code received on streams. void OnStopSendingReceived(const QuicStopSendingFrame& frame) override; + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Add the newly issued connection ID to the session map. + void OnNewConnectionIdSent( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id) override; + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Remove the retired connection ID from the session map. + void OnConnectionIdRetired( + const QuicConnectionId& server_connection_id) override; + // QuicTimeWaitListManager::Visitor interface implementation // Called whenever the time wait list manager adds a new connection to the // time-wait list. @@ -114,13 +127,7 @@ std::shared_ptr<QuicSession>, QuicConnectionIdHash>; - // TODO(haoyuewang) Update this function when multiple CIDs per connection are - // supported. - size_t NumSessions() const { - return use_reference_counted_session_map_ - ? reference_counted_session_map_.size() - : session_map_.size(); - } + size_t NumSessions() const; const SessionMap& session_map() const { return session_map_; } @@ -161,6 +168,10 @@ return use_reference_counted_session_map_; } + bool support_multiple_cid_per_connection() const { + return support_multiple_cid_per_connection_; + } + protected: virtual std::unique_ptr<QuicSession> CreateQuicSession( QuicConnectionId server_connection_id, @@ -412,6 +423,9 @@ // TODO(fayang): consider removing last_error_. QuicErrorCode last_error_; + // Number of unique session in session map. + size_t num_sessions_in_session_map_ = 0; + // A backward counter of how many new sessions can be create within current // event loop. When reaches 0, it means can't create sessions for now. int16_t new_sessions_allowed_per_event_loop_; @@ -439,6 +453,10 @@ const bool use_reference_counted_session_map_ = GetQuicRestartFlag(quic_use_reference_counted_sesssion_map); + const bool support_multiple_cid_per_connection_ = + use_reference_counted_session_map_ && + GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid) && + GetQuicRestartFlag(quic_dispatcher_support_multiple_cid_per_connection); }; } // namespace quic
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc index a28dca4..a614199 100644 --- a/quic/core/quic_dispatcher_test.cc +++ b/quic/core/quic_dispatcher_test.cc
@@ -16,7 +16,9 @@ #include "quic/core/crypto/crypto_protocol.h" #include "quic/core/crypto/quic_crypto_server_config.h" #include "quic/core/crypto/quic_random.h" +#include "quic/core/frames/quic_new_connection_id_frame.h" #include "quic/core/quic_config.h" +#include "quic/core/quic_connection.h" #include "quic/core/quic_connection_id.h" #include "quic/core/quic_crypto_stream.h" #include "quic/core/quic_packet_writer_wrapper.h" @@ -33,6 +35,7 @@ #include "quic/test_tools/first_flight.h" #include "quic/test_tools/mock_quic_time_wait_list_manager.h" #include "quic/test_tools/quic_buffered_packet_store_peer.h" +#include "quic/test_tools/quic_connection_peer.h" #include "quic/test_tools/quic_crypto_server_config_peer.h" #include "quic/test_tools/quic_dispatcher_peer.h" #include "quic/test_tools/quic_test_utils.h" @@ -183,7 +186,26 @@ helper, alarm_factory, Perspective::IS_SERVER), - dispatcher_(dispatcher) {} + dispatcher_(dispatcher), + active_connection_ids_({connection_id}) {} + + void AddNewConnectionId(QuicConnectionId id) { + dispatcher_->OnNewConnectionIdSent(active_connection_ids_.back(), id); + QuicConnectionPeer::SetServerConnectionId(this, id); + active_connection_ids_.push_back(id); + } + + void RetireConnectionId(QuicConnectionId id) { + auto it = std::find(active_connection_ids_.begin(), + active_connection_ids_.end(), id); + DCHECK(it != active_connection_ids_.end()); + dispatcher_->OnConnectionIdRetired(id); + active_connection_ids_.erase(it); + } + + std::vector<QuicConnectionId> GetActiveServerConnectionIds() const override { + return active_connection_ids_; + } void UnregisterOnConnectionClosed() { QUIC_LOG(ERROR) << "Unregistering " << connection_id(); @@ -194,6 +216,7 @@ private: QuicDispatcher* dispatcher_; + std::vector<QuicConnectionId> active_connection_ids_; }; class QuicDispatcherTestBase : public QuicTestWithParam<ParsedQuicVersion> { @@ -1945,6 +1968,190 @@ MarkSession1Deleted(); } +class QuicDispatcherSupportMultipleConnectionIdPerConnectionTest + : public QuicDispatcherTestBase { + public: + QuicDispatcherSupportMultipleConnectionIdPerConnectionTest() + : QuicDispatcherTestBase(crypto_test_utils::ProofSourceForTesting()) { + SetQuicRestartFlag(quic_use_reference_counted_sesssion_map, true); + SetQuicRestartFlag(quic_time_wait_list_support_multiple_cid, true); + SetQuicRestartFlag(quic_dispatcher_support_multiple_cid_per_connection, + true); + dispatcher_ = std::make_unique<NiceMock<TestDispatcher>>( + &config_, &crypto_config_, &version_manager_, + mock_helper_.GetRandomGenerator()); + } + void AddConnection1() { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(_, _, client_address, Eq(ExpectedAlpn()), _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &helper_, &alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + EXPECT_CALL(*dispatcher_, + ShouldCreateOrBufferPacketForConnection( + ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + } + + void AddConnection2() { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 2); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(_, _, client_address, Eq(ExpectedAlpn()), _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(2), client_address, + &helper_, &alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_)))); + EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session2_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(2), packet); + }))); + EXPECT_CALL(*dispatcher_, + ShouldCreateOrBufferPacketForConnection( + ReceivedPacketInfoConnectionIdEquals(TestConnectionId(2)))); + ProcessFirstFlight(client_address, TestConnectionId(2)); + } + + protected: + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; +}; + +INSTANTIATE_TEST_SUITE_P( + QuicDispatcherSupportMultipleConnectionIdPerConnectionTests, + QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + ::testing::Values(CurrentSupportedVersions().front()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + OnNewConnectionIdSent) { + AddConnection1(); + ASSERT_EQ(dispatcher_->NumSessions(), 1u); + ASSERT_THAT(session1_, testing::NotNull()); + MockServerConnection* mock_server_connection1 = + reinterpret_cast<MockServerConnection*>(connection1()); + + { + mock_server_connection1->AddNewConnectionId(TestConnectionId(3)); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(3)); + ASSERT_EQ(session, session1_); + } + + { + mock_server_connection1->AddNewConnectionId(TestConnectionId(4)); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(4)); + ASSERT_EQ(session, session1_); + } + + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + // Would timed out unless all sessions have been removed from the session map. + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + RetireConnectionIdFromSingleConnection) { + AddConnection1(); + ASSERT_EQ(dispatcher_->NumSessions(), 1u); + ASSERT_THAT(session1_, testing::NotNull()); + MockServerConnection* mock_server_connection1 = + reinterpret_cast<MockServerConnection*>(connection1()); + + // Adds 1 new connection id every turn and retires 2 connection ids every + // other turn. + for (int i = 2; i < 10; ++i) { + mock_server_connection1->AddNewConnectionId(TestConnectionId(i)); + ASSERT_EQ( + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(i)), + session1_); + ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(), + TestConnectionId(i - 1)), + session1_); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); + if (i % 2 == 1) { + mock_server_connection1->RetireConnectionId(TestConnectionId(i - 2)); + mock_server_connection1->RetireConnectionId(TestConnectionId(i - 1)); + } + } + + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + // Would timed out unless all sessions have been removed from the session map. + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + RetireConnectionIdFromMultipleConnections) { + AddConnection1(); + AddConnection2(); + ASSERT_EQ(dispatcher_->NumSessions(), 2u); + MockServerConnection* mock_server_connection1 = + reinterpret_cast<MockServerConnection*>(connection1()); + MockServerConnection* mock_server_connection2 = + reinterpret_cast<MockServerConnection*>(connection2()); + + for (int i = 2; i < 10; ++i) { + mock_server_connection1->AddNewConnectionId(TestConnectionId(2 * i - 1)); + mock_server_connection2->AddNewConnectionId(TestConnectionId(2 * i)); + ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(), + TestConnectionId(2 * i - 1)), + session1_); + ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(), + TestConnectionId(2 * i)), + session2_); + EXPECT_EQ(dispatcher_->NumSessions(), 2u); + mock_server_connection1->RetireConnectionId(TestConnectionId(2 * i - 3)); + mock_server_connection2->RetireConnectionId(TestConnectionId(2 * i - 2)); + } + + mock_server_connection1->AddNewConnectionId(TestConnectionId(19)); + mock_server_connection2->AddNewConnectionId(TestConnectionId(20)); + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + EXPECT_CALL(*connection2(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + // Would timed out unless all sessions have been removed from the session map. + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + TimeWaitListPoplulateCorrectly) { + QuicTimeWaitListManager* time_wait_list_manager = + QuicDispatcherPeer::GetTimeWaitListManager(dispatcher_.get()); + AddConnection1(); + MockServerConnection* mock_server_connection1 = + reinterpret_cast<MockServerConnection*>(connection1()); + + mock_server_connection1->AddNewConnectionId(TestConnectionId(2)); + mock_server_connection1->AddNewConnectionId(TestConnectionId(3)); + mock_server_connection1->AddNewConnectionId(TestConnectionId(4)); + mock_server_connection1->RetireConnectionId(TestConnectionId(1)); + mock_server_connection1->RetireConnectionId(TestConnectionId(2)); + + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + connection1()->CloseConnection( + QUIC_PEER_GOING_AWAY, "Close for testing", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + + EXPECT_FALSE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(1))); + EXPECT_FALSE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(2))); + EXPECT_TRUE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(3))); + EXPECT_TRUE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(4))); + + dispatcher_->Shutdown(); +} + class BufferedPacketStoreTest : public QuicDispatcherTestBase { public: BufferedPacketStoreTest()
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h index fd5da05..0c9dd37 100644 --- a/quic/core/quic_flags_list.h +++ b/quic/core/quic_flags_list.h
@@ -67,6 +67,7 @@ QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_write_or_buffer_data_at_level, false) QUIC_FLAG(FLAGS_quic_reloadable_flag_send_quic_fallback_server_config_on_leto_error, false) QUIC_FLAG(FLAGS_quic_restart_flag_dont_fetch_quic_private_keys_from_leto, false) +QUIC_FLAG(FLAGS_quic_restart_flag_quic_dispatcher_support_multiple_cid_per_connection, false) QUIC_FLAG(FLAGS_quic_restart_flag_quic_enable_zero_rtt_for_tls_v2, true) QUIC_FLAG(FLAGS_quic_restart_flag_quic_offload_pacing_to_usps2, false) QUIC_FLAG(FLAGS_quic_restart_flag_quic_server_temporarily_retain_tls_zero_rtt_keys, true)
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h index 5903830..3eb2867 100644 --- a/quic/core/quic_session.h +++ b/quic/core/quic_session.h
@@ -81,6 +81,15 @@ // Called when the session receives a STOP_SENDING for a stream from the // peer. virtual void OnStopSendingReceived(const QuicStopSendingFrame& frame) = 0; + + // Called when a NewConnectionId frame has been sent. + virtual void OnNewConnectionIdSent( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id) = 0; + + // Called when a ConnectionId has been retired. + virtual void OnConnectionIdRetired( + const QuicConnectionId& server_connection_id) = 0; }; // Does not take ownership of |connection| or |visitor|.
diff --git a/quic/test_tools/mock_quic_session_visitor.h b/quic/test_tools/mock_quic_session_visitor.h index e04bc0f..77b1763 100644 --- a/quic/test_tools/mock_quic_session_visitor.h +++ b/quic/test_tools/mock_quic_session_visitor.h
@@ -35,6 +35,15 @@ (const QuicStopSendingFrame& frame), (override)); MOCK_METHOD(void, + OnNewConnectionIdSent, + (const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id), + (override)); + MOCK_METHOD(void, + OnConnectionIdRetired, + (const quic::QuicConnectionId& server_connection_id), + (override)); + MOCK_METHOD(void, OnConnectionAddedToTimeWaitList, (QuicConnectionId connection_id), (override));
diff --git a/quic/test_tools/quic_dispatcher_peer.cc b/quic/test_tools/quic_dispatcher_peer.cc index 846755f..5736773 100644 --- a/quic/test_tools/quic_dispatcher_peer.cc +++ b/quic/test_tools/quic_dispatcher_peer.cc
@@ -129,5 +129,19 @@ } } +// static +const QuicSession* QuicDispatcherPeer::FindSession( + const QuicDispatcher* dispatcher, + QuicConnectionId id) { + if (dispatcher->use_reference_counted_session_map()) { + auto it = dispatcher->reference_counted_session_map_.find(id); + return (it == dispatcher->reference_counted_session_map_.end()) + ? nullptr + : it->second.get(); + } + auto it = dispatcher->session_map_.find(id); + return (it == dispatcher->session_map_.end()) ? nullptr : it->second.get(); +} + } // namespace test } // namespace quic
diff --git a/quic/test_tools/quic_dispatcher_peer.h b/quic/test_tools/quic_dispatcher_peer.h index 64c8a83..f3be472 100644 --- a/quic/test_tools/quic_dispatcher_peer.h +++ b/quic/test_tools/quic_dispatcher_peer.h
@@ -5,6 +5,7 @@ #ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_ #define QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_ +#include "quic/core/quic_connection_id.h" #include "quic/core/quic_dispatcher.h" namespace quic { @@ -70,6 +71,10 @@ // Get the first session in the session map. Returns nullptr if the map is // empty. static QuicSession* GetFirstSessionIfAny(QuicDispatcher* dispatcher); + + // Find the corresponding session if exsits. + static const QuicSession* FindSession(const QuicDispatcher* dispatcher, + QuicConnectionId id); }; } // namespace test