Move MoqtSessionPeer into its own file.
PiperOrigin-RevId: 699983051
diff --git a/build/source_list.bzl b/build/source_list.bzl
index 8c72689..2492384 100644
--- a/build/source_list.bzl
+++ b/build/source_list.bzl
@@ -1525,6 +1525,7 @@
"quic/moqt/moqt_session.h",
"quic/moqt/moqt_subscribe_windows.h",
"quic/moqt/moqt_track.h",
+ "quic/moqt/test_tools/moqt_session_peer.h",
"quic/moqt/test_tools/moqt_simulator_harness.h",
"quic/moqt/test_tools/moqt_test_message.h",
"quic/moqt/tools/chat_client.h",
diff --git a/build/source_list.gni b/build/source_list.gni
index 9c50ad0..b9d9c23 100644
--- a/build/source_list.gni
+++ b/build/source_list.gni
@@ -1529,6 +1529,7 @@
"src/quiche/quic/moqt/moqt_session.h",
"src/quiche/quic/moqt/moqt_subscribe_windows.h",
"src/quiche/quic/moqt/moqt_track.h",
+ "src/quiche/quic/moqt/test_tools/moqt_session_peer.h",
"src/quiche/quic/moqt/test_tools/moqt_simulator_harness.h",
"src/quiche/quic/moqt/test_tools/moqt_test_message.h",
"src/quiche/quic/moqt/tools/chat_client.h",
diff --git a/build/source_list.json b/build/source_list.json
index e4b767e..2e77d42 100644
--- a/build/source_list.json
+++ b/build/source_list.json
@@ -1528,6 +1528,7 @@
"quiche/quic/moqt/moqt_session.h",
"quiche/quic/moqt/moqt_subscribe_windows.h",
"quiche/quic/moqt/moqt_track.h",
+ "quiche/quic/moqt/test_tools/moqt_session_peer.h",
"quiche/quic/moqt/test_tools/moqt_simulator_harness.h",
"quiche/quic/moqt/test_tools/moqt_test_message.h",
"quiche/quic/moqt/tools/chat_client.h",
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 844b4d5..c9f78ca 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -12,7 +12,6 @@
#include <utility>
#include <vector>
-
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
@@ -26,6 +25,7 @@
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_track.h"
+#include "quiche/quic/moqt/test_tools/moqt_session_peer.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
#include "quiche/quic/platform/api/quic_test.h"
#include "quiche/quic/test_tools/quic_test_utils.h"
@@ -46,7 +46,6 @@
using ::testing::Return;
using ::testing::StrictMock;
-constexpr webtransport::StreamId kControlStreamId = 4;
constexpr webtransport::StreamId kIncomingUniStreamId = 15;
constexpr webtransport::StreamId kOutgoingUniStreamId = 14;
@@ -76,145 +75,6 @@
} // namespace
-class MockFetchTask : public MoqtFetchTask {
- public:
- MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject,
- (PublishedObject & output), (override));
- MOCK_METHOD(absl::Status, GetStatus, (), (override));
- MOCK_METHOD(FullSequence, GetLargestId, (), (const, override));
-
- void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override {
- callback_ = std::move(callback);
- }
-
- ObjectsAvailableCallback callback_;
-};
-
-class MoqtSessionPeer {
- public:
- static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream(
- MoqtSession* session, webtransport::test::MockStream* stream) {
- auto new_stream =
- std::make_unique<MoqtSession::ControlStream>(session, stream);
- session->control_stream_ = kControlStreamId;
- ON_CALL(*stream, visitor()).WillByDefault(Return(new_stream.get()));
- webtransport::test::MockSession* mock_session =
- static_cast<webtransport::test::MockSession*>(session->session());
- EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId))
- .Times(AnyNumber())
- .WillRepeatedly(Return(stream));
- return new_stream;
- }
-
- static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream(
- MoqtSession* session, webtransport::Stream* stream) {
- auto new_stream =
- std::make_unique<MoqtSession::IncomingDataStream>(session, stream);
- return new_stream;
- }
-
- // In the test OnSessionReady, the session creates a stream and then passes
- // its unique_ptr to the mock webtransport stream. This function casts
- // that unique_ptr into a MoqtSession::Stream*, which is a private class of
- // MoqtSession, and then casts again into MoqtParserVisitor so that the test
- // can inject packets into that stream.
- // This function is useful for any test that wants to inject packets on a
- // stream created by the MoqtSession.
- static MoqtControlParserVisitor*
- FetchParserVisitorFromWebtransportStreamVisitor(
- MoqtSession* session, webtransport::StreamVisitor* visitor) {
- return static_cast<MoqtSession::ControlStream*>(visitor);
- }
-
- static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
- RemoteTrack::Visitor* visitor,
- uint64_t track_alias) {
- session->remote_tracks_.try_emplace(track_alias, name, track_alias,
- visitor);
- session->remote_track_aliases_.try_emplace(name, track_alias);
- }
-
- static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
- MoqtSubscribe& subscribe,
- RemoteTrack::Visitor* visitor) {
- session->active_subscribes_[subscribe_id] = {subscribe, visitor};
- }
-
- static MoqtObjectListener* AddSubscription(
- MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> publisher,
- uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group,
- uint64_t start_object) {
- MoqtSubscribe subscribe;
- subscribe.full_track_name = publisher->GetTrackName();
- subscribe.track_alias = track_alias;
- subscribe.subscribe_id = subscribe_id;
- subscribe.start_group = start_group;
- subscribe.start_object = start_object;
- subscribe.subscriber_priority = 0x80;
- session->published_subscriptions_.emplace(
- subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>(
- session, std::move(publisher), subscribe,
- /*monitoring_interface=*/nullptr));
- return session->published_subscriptions_[subscribe_id].get();
- }
-
- static void DeleteSubscription(MoqtSession* session, uint64_t subscribe_id) {
- session->published_subscriptions_.erase(subscribe_id);
- }
-
- static void UpdateSubscriberPriority(MoqtSession* session,
- uint64_t subscribe_id,
- MoqtPriority priority) {
- session->published_subscriptions_[subscribe_id]->set_subscriber_priority(
- priority);
- }
-
- static void set_peer_role(MoqtSession* session, MoqtRole role) {
- session->peer_role_ = role;
- }
-
- static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
- return session->remote_tracks_.find(track_alias)->second;
- }
-
- static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
- session->next_subscribe_id_ = id;
- }
-
- static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) {
- session->peer_max_subscribe_id_ = id;
- }
-
- static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) {
- auto fetch_task = std::make_unique<MockFetchTask>();
- MockFetchTask* return_ptr = fetch_task.get();
- auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>(
- fetch_id, session, std::move(fetch_task));
- session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch));
- // Add the fetch to the pending stream queue.
- session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0);
- return return_ptr;
- }
-
- static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session,
- uint64_t fetch_id) {
- auto it = session->incoming_fetches_.find(fetch_id);
- if (it == session->incoming_fetches_.end()) {
- return nullptr;
- }
- return it->second.get();
- }
-
- static void ValidateSubscribeId(MoqtSession* session, uint64_t id) {
- session->ValidateSubscribeId(id);
- }
-
- static FullSequence LargestSentForSubscription(MoqtSession* session,
- uint64_t subscribe_id) {
- return *session->published_subscriptions_[subscribe_id]->largest_sent();
- }
-};
-
class MoqtSessionTest : public quic::test::QuicTest {
public:
MoqtSessionTest()
@@ -2191,7 +2051,7 @@
MoqtDataStreamType::kStreamHeaderFetch));
return absl::OkStatus();
});
- fetch_task->callback_();
+ fetch_task->objects_available_callback()();
EXPECT_TRUE(correct_message);
}
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h
new file mode 100644
index 0000000..d41edd6
--- /dev/null
+++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -0,0 +1,156 @@
+// Copyright 2023 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_
+#define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_parser.h"
+#include "quiche/quic/moqt/moqt_priority.h"
+#include "quiche/quic/moqt/moqt_publisher.h"
+#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_track.h"
+#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
+#include "quiche/quic/platform/api/quic_test.h"
+#include "quiche/web_transport/test_tools/mock_web_transport.h"
+#include "quiche/web_transport/web_transport.h"
+
+namespace moqt::test {
+
+class MoqtSessionPeer {
+ public:
+ static constexpr webtransport::StreamId kControlStreamId = 4;
+
+ static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream(
+ MoqtSession* session, webtransport::test::MockStream* stream) {
+ auto new_stream =
+ std::make_unique<MoqtSession::ControlStream>(session, stream);
+ session->control_stream_ = kControlStreamId;
+ ON_CALL(*stream, visitor())
+ .WillByDefault(::testing::Return(new_stream.get()));
+ webtransport::test::MockSession* mock_session =
+ static_cast<webtransport::test::MockSession*>(session->session());
+ EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId))
+ .Times(::testing::AnyNumber())
+ .WillRepeatedly(::testing::Return(stream));
+ return new_stream;
+ }
+
+ static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream(
+ MoqtSession* session, webtransport::Stream* stream) {
+ auto new_stream =
+ std::make_unique<MoqtSession::IncomingDataStream>(session, stream);
+ return new_stream;
+ }
+
+ // In the test OnSessionReady, the session creates a stream and then passes
+ // its unique_ptr to the mock webtransport stream. This function casts
+ // that unique_ptr into a MoqtSession::Stream*, which is a private class of
+ // MoqtSession, and then casts again into MoqtParserVisitor so that the test
+ // can inject packets into that stream.
+ // This function is useful for any test that wants to inject packets on a
+ // stream created by the MoqtSession.
+ static MoqtControlParserVisitor*
+ FetchParserVisitorFromWebtransportStreamVisitor(
+ MoqtSession* session, webtransport::StreamVisitor* visitor) {
+ return static_cast<MoqtSession::ControlStream*>(visitor);
+ }
+
+ static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
+ RemoteTrack::Visitor* visitor,
+ uint64_t track_alias) {
+ session->remote_tracks_.try_emplace(track_alias, name, track_alias,
+ visitor);
+ session->remote_track_aliases_.try_emplace(name, track_alias);
+ }
+
+ static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
+ MoqtSubscribe& subscribe,
+ RemoteTrack::Visitor* visitor) {
+ session->active_subscribes_[subscribe_id] = {subscribe, visitor};
+ }
+
+ static MoqtObjectListener* AddSubscription(
+ MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> publisher,
+ uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group,
+ uint64_t start_object) {
+ MoqtSubscribe subscribe;
+ subscribe.full_track_name = publisher->GetTrackName();
+ subscribe.track_alias = track_alias;
+ subscribe.subscribe_id = subscribe_id;
+ subscribe.start_group = start_group;
+ subscribe.start_object = start_object;
+ subscribe.subscriber_priority = 0x80;
+ session->published_subscriptions_.emplace(
+ subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>(
+ session, std::move(publisher), subscribe,
+ /*monitoring_interface=*/nullptr));
+ return session->published_subscriptions_[subscribe_id].get();
+ }
+
+ static void DeleteSubscription(MoqtSession* session, uint64_t subscribe_id) {
+ session->published_subscriptions_.erase(subscribe_id);
+ }
+
+ static void UpdateSubscriberPriority(MoqtSession* session,
+ uint64_t subscribe_id,
+ MoqtPriority priority) {
+ session->published_subscriptions_[subscribe_id]->set_subscriber_priority(
+ priority);
+ }
+
+ static void set_peer_role(MoqtSession* session, MoqtRole role) {
+ session->peer_role_ = role;
+ }
+
+ static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
+ return session->remote_tracks_.find(track_alias)->second;
+ }
+
+ static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
+ session->next_subscribe_id_ = id;
+ }
+
+ static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) {
+ session->peer_max_subscribe_id_ = id;
+ }
+
+ static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) {
+ auto fetch_task = std::make_unique<MockFetchTask>();
+ MockFetchTask* return_ptr = fetch_task.get();
+ auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>(
+ fetch_id, session, std::move(fetch_task));
+ session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch));
+ // Add the fetch to the pending stream queue.
+ session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0);
+ return return_ptr;
+ }
+
+ static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session,
+ uint64_t fetch_id) {
+ auto it = session->incoming_fetches_.find(fetch_id);
+ if (it == session->incoming_fetches_.end()) {
+ return nullptr;
+ }
+ return it->second.get();
+ }
+
+ static void ValidateSubscribeId(MoqtSession* session, uint64_t id) {
+ session->ValidateSubscribeId(id);
+ }
+
+ static FullSequence LargestSentForSubscription(MoqtSession* session,
+ uint64_t subscribe_id) {
+ return *session->published_subscriptions_[subscribe_id]->largest_sent();
+ }
+};
+
+} // namespace moqt::test
+
+#endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h
index 87e3541..972fc0c 100644
--- a/quiche/quic/moqt/tools/moqt_mock_visitor.h
+++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -11,6 +11,7 @@
#include <utility>
#include <vector>
+#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/core/quic_time.h"
@@ -100,6 +101,24 @@
(override));
};
+class MockFetchTask : public MoqtFetchTask {
+ public:
+ MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject,
+ (PublishedObject & output), (override));
+ MOCK_METHOD(absl::Status, GetStatus, (), (override));
+ MOCK_METHOD(FullSequence, GetLargestId, (), (const, override));
+
+ void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override {
+ objects_available_callback_ = std::move(callback);
+ }
+ ObjectsAvailableCallback& objects_available_callback() {
+ return objects_available_callback_;
+ };
+
+ private:
+ ObjectsAvailableCallback objects_available_callback_;
+};
+
} // namespace moqt::test
#endif // QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_