Move MoqtSessionPeer into its own file. PiperOrigin-RevId: 699983051
diff --git a/build/source_list.bzl b/build/source_list.bzl index 8c72689..2492384 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1525,6 +1525,7 @@ "quic/moqt/moqt_session.h", "quic/moqt/moqt_subscribe_windows.h", "quic/moqt/moqt_track.h", + "quic/moqt/test_tools/moqt_session_peer.h", "quic/moqt/test_tools/moqt_simulator_harness.h", "quic/moqt/test_tools/moqt_test_message.h", "quic/moqt/tools/chat_client.h",
diff --git a/build/source_list.gni b/build/source_list.gni index 9c50ad0..b9d9c23 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1529,6 +1529,7 @@ "src/quiche/quic/moqt/moqt_session.h", "src/quiche/quic/moqt/moqt_subscribe_windows.h", "src/quiche/quic/moqt/moqt_track.h", + "src/quiche/quic/moqt/test_tools/moqt_session_peer.h", "src/quiche/quic/moqt/test_tools/moqt_simulator_harness.h", "src/quiche/quic/moqt/test_tools/moqt_test_message.h", "src/quiche/quic/moqt/tools/chat_client.h",
diff --git a/build/source_list.json b/build/source_list.json index e4b767e..2e77d42 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1528,6 +1528,7 @@ "quiche/quic/moqt/moqt_session.h", "quiche/quic/moqt/moqt_subscribe_windows.h", "quiche/quic/moqt/moqt_track.h", + "quiche/quic/moqt/test_tools/moqt_session_peer.h", "quiche/quic/moqt/test_tools/moqt_simulator_harness.h", "quiche/quic/moqt/test_tools/moqt_test_message.h", "quiche/quic/moqt/tools/chat_client.h",
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 844b4d5..c9f78ca 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -12,7 +12,6 @@ #include <utility> #include <vector> - #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" @@ -26,6 +25,7 @@ #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_track.h" +#include "quiche/quic/moqt/test_tools/moqt_session_peer.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" #include "quiche/quic/platform/api/quic_test.h" #include "quiche/quic/test_tools/quic_test_utils.h" @@ -46,7 +46,6 @@ using ::testing::Return; using ::testing::StrictMock; -constexpr webtransport::StreamId kControlStreamId = 4; constexpr webtransport::StreamId kIncomingUniStreamId = 15; constexpr webtransport::StreamId kOutgoingUniStreamId = 14; @@ -76,145 +75,6 @@ } // namespace -class MockFetchTask : public MoqtFetchTask { - public: - MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject, - (PublishedObject & output), (override)); - MOCK_METHOD(absl::Status, GetStatus, (), (override)); - MOCK_METHOD(FullSequence, GetLargestId, (), (const, override)); - - void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override { - callback_ = std::move(callback); - } - - ObjectsAvailableCallback callback_; -}; - -class MoqtSessionPeer { - public: - static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream( - MoqtSession* session, webtransport::test::MockStream* stream) { - auto new_stream = - std::make_unique<MoqtSession::ControlStream>(session, stream); - session->control_stream_ = kControlStreamId; - ON_CALL(*stream, visitor()).WillByDefault(Return(new_stream.get())); - webtransport::test::MockSession* mock_session = - static_cast<webtransport::test::MockSession*>(session->session()); - EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId)) - .Times(AnyNumber()) - .WillRepeatedly(Return(stream)); - return new_stream; - } - - static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream( - MoqtSession* session, webtransport::Stream* stream) { - auto new_stream = - std::make_unique<MoqtSession::IncomingDataStream>(session, stream); - return new_stream; - } - - // In the test OnSessionReady, the session creates a stream and then passes - // its unique_ptr to the mock webtransport stream. This function casts - // that unique_ptr into a MoqtSession::Stream*, which is a private class of - // MoqtSession, and then casts again into MoqtParserVisitor so that the test - // can inject packets into that stream. - // This function is useful for any test that wants to inject packets on a - // stream created by the MoqtSession. - static MoqtControlParserVisitor* - FetchParserVisitorFromWebtransportStreamVisitor( - MoqtSession* session, webtransport::StreamVisitor* visitor) { - return static_cast<MoqtSession::ControlStream*>(visitor); - } - - static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name, - RemoteTrack::Visitor* visitor, - uint64_t track_alias) { - session->remote_tracks_.try_emplace(track_alias, name, track_alias, - visitor); - session->remote_track_aliases_.try_emplace(name, track_alias); - } - - static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id, - MoqtSubscribe& subscribe, - RemoteTrack::Visitor* visitor) { - session->active_subscribes_[subscribe_id] = {subscribe, visitor}; - } - - static MoqtObjectListener* AddSubscription( - MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> publisher, - uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group, - uint64_t start_object) { - MoqtSubscribe subscribe; - subscribe.full_track_name = publisher->GetTrackName(); - subscribe.track_alias = track_alias; - subscribe.subscribe_id = subscribe_id; - subscribe.start_group = start_group; - subscribe.start_object = start_object; - subscribe.subscriber_priority = 0x80; - session->published_subscriptions_.emplace( - subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>( - session, std::move(publisher), subscribe, - /*monitoring_interface=*/nullptr)); - return session->published_subscriptions_[subscribe_id].get(); - } - - static void DeleteSubscription(MoqtSession* session, uint64_t subscribe_id) { - session->published_subscriptions_.erase(subscribe_id); - } - - static void UpdateSubscriberPriority(MoqtSession* session, - uint64_t subscribe_id, - MoqtPriority priority) { - session->published_subscriptions_[subscribe_id]->set_subscriber_priority( - priority); - } - - static void set_peer_role(MoqtSession* session, MoqtRole role) { - session->peer_role_ = role; - } - - static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) { - return session->remote_tracks_.find(track_alias)->second; - } - - static void set_next_subscribe_id(MoqtSession* session, uint64_t id) { - session->next_subscribe_id_ = id; - } - - static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) { - session->peer_max_subscribe_id_ = id; - } - - static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) { - auto fetch_task = std::make_unique<MockFetchTask>(); - MockFetchTask* return_ptr = fetch_task.get(); - auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>( - fetch_id, session, std::move(fetch_task)); - session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch)); - // Add the fetch to the pending stream queue. - session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0); - return return_ptr; - } - - static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session, - uint64_t fetch_id) { - auto it = session->incoming_fetches_.find(fetch_id); - if (it == session->incoming_fetches_.end()) { - return nullptr; - } - return it->second.get(); - } - - static void ValidateSubscribeId(MoqtSession* session, uint64_t id) { - session->ValidateSubscribeId(id); - } - - static FullSequence LargestSentForSubscription(MoqtSession* session, - uint64_t subscribe_id) { - return *session->published_subscriptions_[subscribe_id]->largest_sent(); - } -}; - class MoqtSessionTest : public quic::test::QuicTest { public: MoqtSessionTest() @@ -2191,7 +2051,7 @@ MoqtDataStreamType::kStreamHeaderFetch)); return absl::OkStatus(); }); - fetch_task->callback_(); + fetch_task->objects_available_callback()(); EXPECT_TRUE(correct_message); }
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h new file mode 100644 index 0000000..d41edd6 --- /dev/null +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -0,0 +1,156 @@ +// Copyright 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_ +#define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_ + +#include <cstdint> +#include <memory> +#include <optional> +#include <utility> + +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_parser.h" +#include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_publisher.h" +#include "quiche/quic/moqt/moqt_session.h" +#include "quiche/quic/moqt/moqt_track.h" +#include "quiche/quic/moqt/tools/moqt_mock_visitor.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/web_transport/test_tools/mock_web_transport.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt::test { + +class MoqtSessionPeer { + public: + static constexpr webtransport::StreamId kControlStreamId = 4; + + static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream( + MoqtSession* session, webtransport::test::MockStream* stream) { + auto new_stream = + std::make_unique<MoqtSession::ControlStream>(session, stream); + session->control_stream_ = kControlStreamId; + ON_CALL(*stream, visitor()) + .WillByDefault(::testing::Return(new_stream.get())); + webtransport::test::MockSession* mock_session = + static_cast<webtransport::test::MockSession*>(session->session()); + EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId)) + .Times(::testing::AnyNumber()) + .WillRepeatedly(::testing::Return(stream)); + return new_stream; + } + + static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream( + MoqtSession* session, webtransport::Stream* stream) { + auto new_stream = + std::make_unique<MoqtSession::IncomingDataStream>(session, stream); + return new_stream; + } + + // In the test OnSessionReady, the session creates a stream and then passes + // its unique_ptr to the mock webtransport stream. This function casts + // that unique_ptr into a MoqtSession::Stream*, which is a private class of + // MoqtSession, and then casts again into MoqtParserVisitor so that the test + // can inject packets into that stream. + // This function is useful for any test that wants to inject packets on a + // stream created by the MoqtSession. + static MoqtControlParserVisitor* + FetchParserVisitorFromWebtransportStreamVisitor( + MoqtSession* session, webtransport::StreamVisitor* visitor) { + return static_cast<MoqtSession::ControlStream*>(visitor); + } + + static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name, + RemoteTrack::Visitor* visitor, + uint64_t track_alias) { + session->remote_tracks_.try_emplace(track_alias, name, track_alias, + visitor); + session->remote_track_aliases_.try_emplace(name, track_alias); + } + + static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id, + MoqtSubscribe& subscribe, + RemoteTrack::Visitor* visitor) { + session->active_subscribes_[subscribe_id] = {subscribe, visitor}; + } + + static MoqtObjectListener* AddSubscription( + MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> publisher, + uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group, + uint64_t start_object) { + MoqtSubscribe subscribe; + subscribe.full_track_name = publisher->GetTrackName(); + subscribe.track_alias = track_alias; + subscribe.subscribe_id = subscribe_id; + subscribe.start_group = start_group; + subscribe.start_object = start_object; + subscribe.subscriber_priority = 0x80; + session->published_subscriptions_.emplace( + subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>( + session, std::move(publisher), subscribe, + /*monitoring_interface=*/nullptr)); + return session->published_subscriptions_[subscribe_id].get(); + } + + static void DeleteSubscription(MoqtSession* session, uint64_t subscribe_id) { + session->published_subscriptions_.erase(subscribe_id); + } + + static void UpdateSubscriberPriority(MoqtSession* session, + uint64_t subscribe_id, + MoqtPriority priority) { + session->published_subscriptions_[subscribe_id]->set_subscriber_priority( + priority); + } + + static void set_peer_role(MoqtSession* session, MoqtRole role) { + session->peer_role_ = role; + } + + static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) { + return session->remote_tracks_.find(track_alias)->second; + } + + static void set_next_subscribe_id(MoqtSession* session, uint64_t id) { + session->next_subscribe_id_ = id; + } + + static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) { + session->peer_max_subscribe_id_ = id; + } + + static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) { + auto fetch_task = std::make_unique<MockFetchTask>(); + MockFetchTask* return_ptr = fetch_task.get(); + auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>( + fetch_id, session, std::move(fetch_task)); + session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch)); + // Add the fetch to the pending stream queue. + session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0); + return return_ptr; + } + + static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session, + uint64_t fetch_id) { + auto it = session->incoming_fetches_.find(fetch_id); + if (it == session->incoming_fetches_.end()) { + return nullptr; + } + return it->second.get(); + } + + static void ValidateSubscribeId(MoqtSession* session, uint64_t id) { + session->ValidateSubscribeId(id); + } + + static FullSequence LargestSentForSubscription(MoqtSession* session, + uint64_t subscribe_id) { + return *session->published_subscriptions_[subscribe_id]->largest_sent(); + } +}; + +} // namespace moqt::test + +#endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h index 87e3541..972fc0c 100644 --- a/quiche/quic/moqt/tools/moqt_mock_visitor.h +++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -11,6 +11,7 @@ #include <utility> #include <vector> +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_time.h" @@ -100,6 +101,24 @@ (override)); }; +class MockFetchTask : public MoqtFetchTask { + public: + MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject, + (PublishedObject & output), (override)); + MOCK_METHOD(absl::Status, GetStatus, (), (override)); + MOCK_METHOD(FullSequence, GetLargestId, (), (const, override)); + + void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override { + objects_available_callback_ = std::move(callback); + } + ObjectsAvailableCallback& objects_available_callback() { + return objects_available_callback_; + }; + + private: + ObjectsAvailableCallback objects_available_callback_; +}; + } // namespace moqt::test #endif // QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_