| // Copyright 2023 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "quiche/quic/moqt/moqt_session.h" |
| |
| #include <cstdint> |
| #include <cstring> |
| #include <memory> |
| #include <optional> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/status/status.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/types/span.h" |
| #include "quiche/quic/core/quic_data_reader.h" |
| #include "quiche/quic/core/quic_time.h" |
| #include "quiche/quic/core/quic_types.h" |
| #include "quiche/quic/moqt/moqt_messages.h" |
| #include "quiche/quic/moqt/moqt_parser.h" |
| #include "quiche/quic/moqt/moqt_subscribe_windows.h" |
| #include "quiche/quic/moqt/moqt_track.h" |
| #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" |
| #include "quiche/quic/platform/api/quic_test.h" |
| #include "quiche/common/quiche_stream.h" |
| #include "quiche/web_transport/test_tools/mock_web_transport.h" |
| #include "quiche/web_transport/web_transport.h" |
| |
| namespace moqt { |
| |
| namespace test { |
| |
| namespace { |
| |
| using ::testing::_; |
| using ::testing::Return; |
| using ::testing::StrictMock; |
| |
| constexpr webtransport::StreamId kControlStreamId = 4; |
| constexpr webtransport::StreamId kIncomingUniStreamId = 15; |
| constexpr webtransport::StreamId kOutgoingUniStreamId = 14; |
| |
| constexpr MoqtSessionParameters default_parameters = { |
| /*version=*/MoqtVersion::kDraft02, |
| /*perspective=*/quic::Perspective::IS_CLIENT, |
| /*using_webtrans=*/true, |
| /*path=*/std::string(), |
| /*deliver_partial_objects=*/false, |
| }; |
| |
| // Returns nullopt if there is not enough in |message| to extract a type |
| static std::optional<MoqtMessageType> ExtractMessageType( |
| const absl::string_view message) { |
| quic::QuicDataReader reader(message); |
| uint64_t value; |
| if (!reader.ReadVarInt62(&value)) { |
| return std::nullopt; |
| } |
| return static_cast<MoqtMessageType>(value); |
| } |
| |
| } // namespace |
| |
| class MoqtSessionPeer { |
| public: |
| static std::unique_ptr<MoqtParserVisitor> CreateControlStream( |
| MoqtSession* session, webtransport::Stream* stream) { |
| auto new_stream = std::make_unique<MoqtSession::Stream>( |
| session, stream, /*is_control_stream=*/true); |
| session->control_stream_ = kControlStreamId; |
| return new_stream; |
| } |
| |
| static std::unique_ptr<MoqtParserVisitor> CreateUniStream( |
| MoqtSession* session, webtransport::Stream* stream) { |
| auto new_stream = std::make_unique<MoqtSession::Stream>( |
| session, stream, /*is_control_stream=*/false); |
| return new_stream; |
| } |
| |
| // In the test OnSessionReady, the session creates a stream and then passes |
| // its unique_ptr to the mock webtransport stream. This function casts |
| // that unique_ptr into a MoqtSession::Stream*, which is a private class of |
| // MoqtSession, and then casts again into MoqtParserVisitor so that the test |
| // can inject packets into that stream. |
| // This function is useful for any test that wants to inject packets on a |
| // stream created by the MoqtSession. |
| static MoqtParserVisitor* FetchParserVisitorFromWebtransportStreamVisitor( |
| MoqtSession* session, webtransport::StreamVisitor* visitor) { |
| return (MoqtSession::Stream*)visitor; |
| } |
| |
| static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name, |
| RemoteTrack::Visitor* visitor, |
| uint64_t track_alias) { |
| session->remote_tracks_.try_emplace(track_alias, name, track_alias, |
| visitor); |
| session->remote_track_aliases_.try_emplace(name, track_alias); |
| } |
| |
| static void AddSubscription(MoqtSession* session, FullTrackName& name, |
| uint64_t subscribe_id, uint64_t track_alias, |
| uint64_t start_group, uint64_t start_object) { |
| auto it = session->local_tracks_.find(name); |
| ASSERT_NE(it, session->local_tracks_.end()); |
| LocalTrack& track = it->second; |
| track.set_track_alias(track_alias); |
| track.AddWindow(SubscribeWindow(subscribe_id, start_group, start_object)); |
| session->used_track_aliases_.emplace(track_alias); |
| } |
| |
| static FullSequence next_sequence(MoqtSession* session, FullTrackName& name) { |
| auto it = session->local_tracks_.find(name); |
| EXPECT_NE(it, session->local_tracks_.end()); |
| LocalTrack& track = it->second; |
| return track.next_sequence(); |
| } |
| }; |
| |
| class MoqtSessionTest : public quic::test::QuicTest { |
| public: |
| MoqtSessionTest() |
| : session_(&mock_session_, default_parameters, |
| session_callbacks_.AsSessionCallbacks()) {} |
| ~MoqtSessionTest() { |
| EXPECT_CALL(session_callbacks_.session_deleted_callback, Call()); |
| } |
| |
| MockSessionCallbacks session_callbacks_; |
| StrictMock<webtransport::test::MockSession> mock_session_; |
| MoqtSession session_; |
| }; |
| |
| TEST_F(MoqtSessionTest, Queries) { |
| EXPECT_EQ(session_.perspective(), quic::Perspective::IS_CLIENT); |
| } |
| |
| // Verify the session sends CLIENT_SETUP on the control stream. |
| TEST_F(MoqtSessionTest, OnSessionReady) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| std::unique_ptr<webtransport::StreamVisitor> visitor; |
| // Save a reference to MoqtSession::Stream |
| EXPECT_CALL(mock_stream, SetVisitor(_)) |
| .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) { |
| visitor = std::move(new_visitor); |
| }); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillOnce(Return(webtransport::StreamId(4))); |
| bool correct_message = false; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup); |
| return absl::OkStatus(); |
| }); |
| session_.OnSessionReady(); |
| EXPECT_TRUE(correct_message); |
| |
| // Receive SERVER_SETUP |
| MoqtParserVisitor* stream_input = |
| MoqtSessionPeer::FetchParserVisitorFromWebtransportStreamVisitor( |
| &session_, visitor.get()); |
| // Handle the server setup |
| MoqtServerSetup setup = { |
| MoqtVersion::kDraft02, |
| MoqtRole::kBoth, |
| }; |
| EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); |
| stream_input->OnServerSetupMessage(setup); |
| } |
| |
| TEST_F(MoqtSessionTest, OnClientSetup) { |
| MoqtSessionParameters server_parameters = { |
| /*version=*/MoqtVersion::kDraft02, |
| /*perspective=*/quic::Perspective::IS_SERVER, |
| /*using_webtrans=*/true, |
| /*path=*/"", |
| /*deliver_partial_objects=*/false, |
| }; |
| MoqtSession server_session(&mock_session_, server_parameters, |
| session_callbacks_.AsSessionCallbacks()); |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); |
| MoqtClientSetup setup = { |
| /*supported_versions=*/{MoqtVersion::kDraft02}, |
| /*role=*/MoqtRole::kBoth, |
| /*path=*/std::nullopt, |
| }; |
| bool correct_message = false; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); |
| return absl::OkStatus(); |
| }); |
| EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); |
| stream_input->OnClientSetupMessage(setup); |
| } |
| |
| TEST_F(MoqtSessionTest, OnSessionClosed) { |
| bool reported_error = false; |
| EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) |
| .WillOnce([&](absl::string_view error_message) { |
| reported_error = true; |
| EXPECT_EQ(error_message, "foo"); |
| }); |
| session_.OnSessionClosed(webtransport::SessionErrorCode(1), "foo"); |
| EXPECT_TRUE(reported_error); |
| } |
| |
| TEST_F(MoqtSessionTest, OnIncomingBidirectionalStream) { |
| ::testing::InSequence seq; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| StrictMock<webtransport::test::MockStreamVisitor> mock_stream_visitor; |
| EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1); |
| EXPECT_CALL(mock_stream, visitor()).WillOnce(Return(&mock_stream_visitor)); |
| EXPECT_CALL(mock_stream_visitor, OnCanRead()).Times(1); |
| EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) |
| .WillOnce(Return(nullptr)); |
| session_.OnIncomingBidirectionalStreamAvailable(); |
| } |
| |
| TEST_F(MoqtSessionTest, OnIncomingUnidirectionalStream) { |
| ::testing::InSequence seq; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| StrictMock<webtransport::test::MockStreamVisitor> mock_stream_visitor; |
| EXPECT_CALL(mock_session_, AcceptIncomingUnidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1); |
| EXPECT_CALL(mock_stream, visitor()).WillOnce(Return(&mock_stream_visitor)); |
| EXPECT_CALL(mock_stream_visitor, OnCanRead()).Times(1); |
| EXPECT_CALL(mock_session_, AcceptIncomingUnidirectionalStream()) |
| .WillOnce(Return(nullptr)); |
| session_.OnIncomingUnidirectionalStreamAvailable(); |
| } |
| |
| TEST_F(MoqtSessionTest, Error) { |
| bool reported_error = false; |
| EXPECT_CALL( |
| mock_session_, |
| CloseSession(static_cast<uint64_t>(MoqtError::kParameterLengthMismatch), |
| "foo")) |
| .Times(1); |
| EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) |
| .WillOnce([&](absl::string_view error_message) { |
| reported_error = (error_message == "foo"); |
| }); |
| session_.Error(MoqtError::kParameterLengthMismatch, "foo"); |
| EXPECT_TRUE(reported_error); |
| } |
| |
| TEST_F(MoqtSessionTest, AddLocalTrack) { |
| MoqtSubscribe request = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/2, |
| /*track_namespace=*/"foo", |
| /*track_name=*/"bar", |
| /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*end_group=*/std::nullopt, |
| /*end_object=*/std::nullopt, |
| #ifdef MOQT_AUTH_INFO |
| /*authorization_info=*/std::nullopt, |
| #endif |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| // Request for track returns SUBSCRIBE_ERROR. |
| bool correct_message = false; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), |
| MoqtMessageType::kSubscribeError); |
| return absl::OkStatus(); |
| }); |
| stream_input->OnSubscribeMessage(request); |
| EXPECT_TRUE(correct_message); |
| |
| // Add the track. Now Subscribe should succeed. |
| MockLocalTrackVisitor local_track_visitor; |
| session_.AddLocalTrack(FullTrackName("foo", "bar"), &local_track_visitor); |
| correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); |
| return absl::OkStatus(); |
| }); |
| stream_input->OnSubscribeMessage(request); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, AnnounceWithOk) { |
| testing::MockFunction<void(absl::string_view track_namespace, |
| std::optional<absl::string_view> error_message)> |
| announce_resolved_callback; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); |
| bool correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce); |
| return absl::OkStatus(); |
| }); |
| session_.Announce("foo", announce_resolved_callback.AsStdFunction()); |
| EXPECT_TRUE(correct_message); |
| |
| MoqtAnnounceOk ok = { |
| /*track_namespace=*/"foo", |
| }; |
| correct_message = false; |
| EXPECT_CALL(announce_resolved_callback, Call(_, _)) |
| .WillOnce([&](absl::string_view track_namespace, |
| std::optional<absl::string_view> error_message) { |
| correct_message = true; |
| EXPECT_EQ(track_namespace, "foo"); |
| EXPECT_FALSE(error_message.has_value()); |
| }); |
| stream_input->OnAnnounceOkMessage(ok); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, AnnounceWithError) { |
| testing::MockFunction<void(absl::string_view track_namespace, |
| std::optional<absl::string_view> error_message)> |
| announce_resolved_callback; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); |
| bool correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce); |
| return absl::OkStatus(); |
| }); |
| session_.Announce("foo", announce_resolved_callback.AsStdFunction()); |
| EXPECT_TRUE(correct_message); |
| |
| MoqtAnnounceError error = { |
| /*track_namespace=*/"foo", |
| }; |
| correct_message = false; |
| EXPECT_CALL(announce_resolved_callback, Call(_, _)) |
| .WillOnce([&](absl::string_view track_namespace, |
| std::optional<absl::string_view> error_message) { |
| correct_message = true; |
| EXPECT_EQ(track_namespace, "foo"); |
| EXPECT_TRUE(error_message.has_value()); |
| }); |
| stream_input->OnAnnounceErrorMessage(error); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, HasSubscribers) { |
| MockLocalTrackVisitor local_track_visitor; |
| FullTrackName ftn("foo", "bar"); |
| EXPECT_FALSE(session_.HasSubscribers(ftn)); |
| session_.AddLocalTrack(ftn, &local_track_visitor); |
| EXPECT_FALSE(session_.HasSubscribers(ftn)); |
| |
| // Peer subscribes. |
| MoqtSubscribe request = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/2, |
| /*track_namespace=*/"foo", |
| /*track_name=*/"bar", |
| /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*end_group=*/std::nullopt, |
| /*end_object=*/std::nullopt, |
| #ifdef MOQT_AUTH_INFO |
| /*authorization_info=*/std::nullopt, |
| #endif |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| bool correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); |
| return absl::OkStatus(); |
| }); |
| stream_input->OnSubscribeMessage(request); |
| EXPECT_TRUE(correct_message); |
| EXPECT_TRUE(session_.HasSubscribers(ftn)); |
| } |
| |
| TEST_F(MoqtSessionTest, SubscribeForPast) { |
| MockLocalTrackVisitor local_track_visitor; |
| FullTrackName ftn("foo", "bar"); |
| session_.AddLocalTrack(ftn, &local_track_visitor); |
| |
| // Send Sequence (2, 0) so that next_sequence is set correctly. |
| session_.PublishObject(ftn, 2, 0, 0, MoqtForwardingPreference::kObject, "foo", |
| std::nullopt, true); |
| // Peer subscribes to (0, 0) |
| MoqtSubscribe request = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/2, |
| /*track_namespace=*/"foo", |
| /*track_name=*/"bar", |
| /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*end_group=*/std::nullopt, |
| /*end_object=*/std::nullopt, |
| #ifdef MOQT_AUTH_INFO |
| /*authorization_info=*/std::nullopt, |
| #endif |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| bool correct_message = true; |
| EXPECT_CALL(local_track_visitor, OnSubscribeForPast(_)) |
| .WillOnce(Return(std::nullopt)); |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); |
| return absl::OkStatus(); |
| }); |
| stream_input->OnSubscribeMessage(request); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, SubscribeWithOk) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| MockRemoteTrackVisitor remote_track_visitor; |
| EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); |
| bool correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); |
| return absl::OkStatus(); |
| }); |
| session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor, ""); |
| |
| MoqtSubscribeOk ok = { |
| /*subscribe_id=*/0, |
| /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0), |
| }; |
| correct_message = false; |
| EXPECT_CALL(remote_track_visitor, OnReply(_, _)) |
| .WillOnce([&](const FullTrackName& ftn, |
| std::optional<absl::string_view> error_message) { |
| correct_message = true; |
| EXPECT_EQ(ftn, FullTrackName("foo", "bar")); |
| EXPECT_FALSE(error_message.has_value()); |
| }); |
| stream_input->OnSubscribeOkMessage(ok); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, SubscribeWithError) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| MockRemoteTrackVisitor remote_track_visitor; |
| EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); |
| bool correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); |
| return absl::OkStatus(); |
| }); |
| session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor, ""); |
| |
| MoqtSubscribeError error = { |
| /*subscribe_id=*/0, |
| /*error_code=*/SubscribeErrorCode::kInvalidRange, |
| /*reason_phrase=*/"deadbeef", |
| /*track_alias=*/2, |
| }; |
| correct_message = false; |
| EXPECT_CALL(remote_track_visitor, OnReply(_, _)) |
| .WillOnce([&](const FullTrackName& ftn, |
| std::optional<absl::string_view> error_message) { |
| correct_message = true; |
| EXPECT_EQ(ftn, FullTrackName("foo", "bar")); |
| EXPECT_EQ(*error_message, "deadbeef"); |
| }); |
| stream_input->OnSubscribeErrorMessage(error); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, ReplyToAnnounce) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| MoqtAnnounce announce = { |
| /*track_namespace=*/"foo", |
| }; |
| bool correct_message = false; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounceOk); |
| return absl::OkStatus(); |
| }); |
| stream_input->OnAnnounceMessage(announce); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, IncomingObject) { |
| MockRemoteTrackVisitor visitor_; |
| FullTrackName ftn("foo", "bar"); |
| std::string payload = "deadbeef"; |
| MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2); |
| MoqtObject object = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/2, |
| /*group_sequence=*/0, |
| /*object_sequence=*/0, |
| /*object_send_order=*/0, |
| /*forwarding_preference=*/MoqtForwardingPreference::kGroup, |
| /*payload_length=*/8, |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> object_stream = |
| MoqtSessionPeer::CreateUniStream(&session_, &mock_stream); |
| |
| EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillRepeatedly(Return(kIncomingUniStreamId)); |
| object_stream->OnObjectMessage(object, payload, true); |
| } |
| |
| TEST_F(MoqtSessionTest, IncomingPartialObject) { |
| MockRemoteTrackVisitor visitor_; |
| FullTrackName ftn("foo", "bar"); |
| std::string payload = "deadbeef"; |
| MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2); |
| MoqtObject object = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/2, |
| /*group_sequence=*/0, |
| /*object_sequence=*/0, |
| /*object_send_order=*/0, |
| /*forwarding_preference=*/MoqtForwardingPreference::kGroup, |
| /*payload_length=*/16, |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> object_stream = |
| MoqtSessionPeer::CreateUniStream(&session_, &mock_stream); |
| |
| EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillRepeatedly(Return(kIncomingUniStreamId)); |
| object_stream->OnObjectMessage(object, payload, false); |
| object_stream->OnObjectMessage(object, payload, true); // complete the object |
| } |
| |
| TEST_F(MoqtSessionTest, IncomingPartialObjectNoBuffer) { |
| MoqtSessionParameters parameters = { |
| /*version=*/MoqtVersion::kDraft02, |
| /*perspective=*/quic::Perspective::IS_CLIENT, |
| /*using_webtrans=*/true, |
| /*path=*/"", |
| /*deliver_partial_objects=*/true, |
| }; |
| MoqtSession session(&mock_session_, parameters, |
| session_callbacks_.AsSessionCallbacks()); |
| MockRemoteTrackVisitor visitor_; |
| FullTrackName ftn("foo", "bar"); |
| std::string payload = "deadbeef"; |
| MoqtSessionPeer::CreateRemoteTrack(&session, ftn, &visitor_, 2); |
| MoqtObject object = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/2, |
| /*group_sequence=*/0, |
| /*object_sequence=*/0, |
| /*object_send_order=*/0, |
| /*forwarding_preference=*/MoqtForwardingPreference::kGroup, |
| /*payload_length=*/16, |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> object_stream = |
| MoqtSessionPeer::CreateUniStream(&session, &mock_stream); |
| |
| EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(2); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillRepeatedly(Return(kIncomingUniStreamId)); |
| object_stream->OnObjectMessage(object, payload, false); |
| object_stream->OnObjectMessage(object, payload, true); // complete the object |
| } |
| |
| TEST_F(MoqtSessionTest, CreateUniStreamAndSend) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| FullTrackName ftn("foo", "bar"); |
| MockLocalTrackVisitor track_visitor; |
| session_.AddLocalTrack(ftn, &track_visitor); |
| MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); |
| |
| // No subscription; this is a no-op except to update next_sequence. |
| EXPECT_CALL(mock_stream, Writev(_, _)).Times(0); |
| session_.PublishObject(ftn, 4, 1, 0, MoqtForwardingPreference::kObject, |
| "deadbeef", std::nullopt, true); |
| EXPECT_EQ(MoqtSessionPeer::next_sequence(&session_, ftn), FullSequence(4, 2)); |
| |
| // Publish in window. |
| EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) |
| .WillOnce(Return(true)); |
| EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillRepeatedly(Return(kOutgoingUniStreamId)); |
| // Send on the stream |
| EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) |
| .WillOnce(Return(&mock_stream)); |
| bool correct_message = false; |
| // Verify first six message fields are sent correctly |
| uint8_t kExpectedMessage[] = {0x00, 0x00, 0x02, 0x05, 0x00, 0x00}; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = (0 == memcmp(data.data()->data(), kExpectedMessage, |
| sizeof(kExpectedMessage))); |
| return absl::OkStatus(); |
| }); |
| session_.PublishObject(ftn, 5, 0, 0, MoqtForwardingPreference::kObject, |
| "deadbeef", std::nullopt, true); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| // TODO: Test operation with multiple streams. |
| |
| // Error cases |
| |
| TEST_F(MoqtSessionTest, CannotOpenUniStream) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| FullTrackName ftn("foo", "bar"); |
| MockLocalTrackVisitor track_visitor; |
| session_.AddLocalTrack(ftn, &track_visitor); |
| MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); |
| ; |
| EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) |
| .WillOnce(Return(false)); |
| EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, |
| MoqtForwardingPreference::kObject, |
| "deadbeef", std::nullopt, true)); |
| } |
| |
| TEST_F(MoqtSessionTest, GetStreamByIdFails) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| FullTrackName ftn("foo", "bar"); |
| MockLocalTrackVisitor track_visitor; |
| session_.AddLocalTrack(ftn, &track_visitor); |
| MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); |
| EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) |
| .WillOnce(Return(true)); |
| EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillRepeatedly(Return(kOutgoingUniStreamId)); |
| EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) |
| .WillOnce(Return(nullptr)); |
| EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, |
| MoqtForwardingPreference::kObject, |
| "deadbeef", std::nullopt, true)); |
| } |
| |
| TEST_F(MoqtSessionTest, SubscribeProposesBadTrackAlias) { |
| MockLocalTrackVisitor local_track_visitor; |
| FullTrackName ftn("foo", "bar"); |
| session_.AddLocalTrack(ftn, &local_track_visitor); |
| MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); |
| |
| // Peer subscribes. |
| MoqtSubscribe request = { |
| /*subscribe_id=*/1, |
| /*track_alias=*/3, // Doesn't match 2. |
| /*track_namespace=*/"foo", |
| /*track_name=*/"bar", |
| /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), |
| /*end_group=*/std::nullopt, |
| /*end_object=*/std::nullopt, |
| #ifdef MOQT_AUTH_INFO |
| /*authorization_info=*/std::nullopt, |
| #endif |
| }; |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); |
| bool correct_message = true; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), |
| MoqtMessageType::kSubscribeError); |
| return absl::OkStatus(); |
| }); |
| stream_input->OnSubscribeMessage(request); |
| EXPECT_TRUE(correct_message); |
| } |
| |
| TEST_F(MoqtSessionTest, OneBidirectionalStreamClient) { |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| std::unique_ptr<webtransport::StreamVisitor> visitor; |
| // Save a reference to MoqtSession::Stream |
| EXPECT_CALL(mock_stream, SetVisitor(_)) |
| .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) { |
| visitor = std::move(new_visitor); |
| }); |
| EXPECT_CALL(mock_stream, GetStreamId()) |
| .WillOnce(Return(webtransport::StreamId(4))); |
| bool correct_message = false; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup); |
| return absl::OkStatus(); |
| }); |
| session_.OnSessionReady(); |
| EXPECT_TRUE(correct_message); |
| |
| // Peer tries to open a bidi stream. |
| bool reported_error = false; |
| EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| EXPECT_CALL(mock_session_, |
| CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), |
| "Bidirectional stream already open")) |
| .Times(1); |
| EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) |
| .WillOnce([&](absl::string_view error_message) { |
| reported_error = (error_message == "Bidirectional stream already open"); |
| }); |
| session_.OnIncomingBidirectionalStreamAvailable(); |
| EXPECT_TRUE(reported_error); |
| } |
| |
| TEST_F(MoqtSessionTest, OneBidirectionalStreamServer) { |
| MoqtSessionParameters server_parameters = { |
| /*version=*/MoqtVersion::kDraft02, |
| /*perspective=*/quic::Perspective::IS_SERVER, |
| /*using_webtrans=*/true, |
| /*path=*/"", |
| /*deliver_partial_objects=*/false, |
| }; |
| MoqtSession server_session(&mock_session_, server_parameters, |
| session_callbacks_.AsSessionCallbacks()); |
| StrictMock<webtransport::test::MockStream> mock_stream; |
| std::unique_ptr<MoqtParserVisitor> stream_input = |
| MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); |
| MoqtClientSetup setup = { |
| /*supported_versions*/ {MoqtVersion::kDraft02}, |
| /*role=*/MoqtRole::kBoth, |
| /*path=*/std::nullopt, |
| }; |
| bool correct_message = false; |
| EXPECT_CALL(mock_stream, Writev(_, _)) |
| .WillOnce([&](absl::Span<const absl::string_view> data, |
| const quiche::StreamWriteOptions& options) { |
| correct_message = true; |
| EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); |
| return absl::OkStatus(); |
| }); |
| EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); |
| stream_input->OnClientSetupMessage(setup); |
| |
| // Peer tries to open a bidi stream. |
| bool reported_error = false; |
| EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) |
| .WillOnce(Return(&mock_stream)); |
| EXPECT_CALL(mock_session_, |
| CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), |
| "Bidirectional stream already open")) |
| .Times(1); |
| EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) |
| .WillOnce([&](absl::string_view error_message) { |
| reported_error = (error_message == "Bidirectional stream already open"); |
| }); |
| server_session.OnIncomingBidirectionalStreamAvailable(); |
| EXPECT_TRUE(reported_error); |
| } |
| |
| // TODO: Cover more error cases in the above |
| |
| } // namespace test |
| |
| } // namespace moqt |