Add test matchers for writing MoQT control messages into a mock stream. This should simplify writing unit tests for application code built on top of MoqtSession. Initially, this was a more ambitious refactor involving TestMessageBase and framer/parser tests, but it got messy because of the way we mix control and object framing. PiperOrigin-RevId: 700332136
diff --git a/build/source_list.bzl b/build/source_list.bzl index 933ba80..6b07808 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1524,6 +1524,7 @@ "quic/moqt/moqt_session.h", "quic/moqt/moqt_subscribe_windows.h", "quic/moqt/moqt_track.h", + "quic/moqt/test_tools/moqt_framer_utils.h", "quic/moqt/test_tools/moqt_session_peer.h", "quic/moqt/test_tools/moqt_simulator_harness.h", "quic/moqt/test_tools/moqt_test_message.h", @@ -1560,6 +1561,7 @@ "quic/moqt/moqt_subscribe_windows_test.cc", "quic/moqt/moqt_track.cc", "quic/moqt/moqt_track_test.cc", + "quic/moqt/test_tools/moqt_framer_utils.cc", "quic/moqt/test_tools/moqt_simulator_harness.cc", "quic/moqt/tools/chat_client.cc", "quic/moqt/tools/chat_client_bin.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index db4f4ad..749e132 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1528,6 +1528,7 @@ "src/quiche/quic/moqt/moqt_session.h", "src/quiche/quic/moqt/moqt_subscribe_windows.h", "src/quiche/quic/moqt/moqt_track.h", + "src/quiche/quic/moqt/test_tools/moqt_framer_utils.h", "src/quiche/quic/moqt/test_tools/moqt_session_peer.h", "src/quiche/quic/moqt/test_tools/moqt_simulator_harness.h", "src/quiche/quic/moqt/test_tools/moqt_test_message.h", @@ -1564,6 +1565,7 @@ "src/quiche/quic/moqt/moqt_subscribe_windows_test.cc", "src/quiche/quic/moqt/moqt_track.cc", "src/quiche/quic/moqt/moqt_track_test.cc", + "src/quiche/quic/moqt/test_tools/moqt_framer_utils.cc", "src/quiche/quic/moqt/test_tools/moqt_simulator_harness.cc", "src/quiche/quic/moqt/tools/chat_client.cc", "src/quiche/quic/moqt/tools/chat_client_bin.cc",
diff --git a/build/source_list.json b/build/source_list.json index 2f041c7..c09bc72 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1527,6 +1527,7 @@ "quiche/quic/moqt/moqt_session.h", "quiche/quic/moqt/moqt_subscribe_windows.h", "quiche/quic/moqt/moqt_track.h", + "quiche/quic/moqt/test_tools/moqt_framer_utils.h", "quiche/quic/moqt/test_tools/moqt_session_peer.h", "quiche/quic/moqt/test_tools/moqt_simulator_harness.h", "quiche/quic/moqt/test_tools/moqt_test_message.h", @@ -1563,6 +1564,7 @@ "quiche/quic/moqt/moqt_subscribe_windows_test.cc", "quiche/quic/moqt/moqt_track.cc", "quiche/quic/moqt/moqt_track_test.cc", + "quiche/quic/moqt/test_tools/moqt_framer_utils.cc", "quiche/quic/moqt/test_tools/moqt_simulator_harness.cc", "quiche/quic/moqt/tools/chat_client.cc", "quiche/quic/moqt/tools/chat_client_bin.cc",
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index c9f78ca..2e7aed0 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -25,6 +25,7 @@ #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_track.h" +#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" #include "quiche/quic/moqt/test_tools/moqt_session_peer.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" #include "quiche/quic/platform/api/quic_test.h" @@ -49,17 +50,6 @@ constexpr webtransport::StreamId kIncomingUniStreamId = 15; constexpr webtransport::StreamId kOutgoingUniStreamId = 14; -// Returns nullopt if there is not enough in |message| to extract a type -static std::optional<MoqtMessageType> ExtractMessageType( - const absl::string_view message) { - quic::QuicDataReader reader(message); - uint64_t value; - if (!reader.ReadVarInt62(&value)) { - return std::nullopt; - } - return static_cast<MoqtMessageType>(value); -} - static std::shared_ptr<MockTrackPublisher> SetupPublisher( FullTrackName track_name, MoqtForwardingPreference forwarding_preference, FullSequence largest_sequence) { @@ -113,17 +103,10 @@ EXPECT_CALL(mock_stream, GetStreamId()) .WillOnce(Return(webtransport::StreamId(4))); EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); - bool correct_message = false; EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); }); - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kClientSetup), _)); session_.OnSessionReady(); - EXPECT_TRUE(correct_message); // Receive SERVER_SETUP MoqtControlParserVisitor* stream_input = @@ -150,14 +133,8 @@ /*role=*/MoqtRole::kPubSub, /*path=*/std::nullopt, }; - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kServerSetup), _)); EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0)); EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); stream_input->OnClientSetupMessage(setup); @@ -234,17 +211,10 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); // Request for track returns SUBSCRIBE_ERROR. - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), - MoqtMessageType::kSubscribeError); - return absl::OkStatus(); - }); + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeError), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); // Add the track. Now Subscribe should succeed. auto track_publisher = @@ -252,17 +222,10 @@ EXPECT_CALL(*track_publisher, GetTrackStatus()) .WillRepeatedly(Return(MoqtTrackStatusCode::kStatusNotAvailable)); publisher_.Add(track_publisher); - correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)); request.subscribe_id = 2; stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, AnnounceWithOk) { @@ -274,31 +237,21 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kAnnounce), _)); session_.Announce(FullTrackName{"foo"}, announce_resolved_callback.AsStdFunction()); - EXPECT_TRUE(correct_message); MoqtAnnounceOk ok = { /*track_namespace=*/FullTrackName{"foo"}, }; - correct_message = false; EXPECT_CALL(announce_resolved_callback, Call(_, _)) .WillOnce([&](FullTrackName track_namespace, std::optional<MoqtAnnounceErrorReason> error) { - correct_message = true; EXPECT_EQ(track_namespace, FullTrackName{"foo"}); EXPECT_FALSE(error.has_value()); }); stream_input->OnAnnounceOkMessage(ok); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, AnnounceWithError) { @@ -310,35 +263,25 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kAnnounce), _)); session_.Announce(FullTrackName{"foo"}, announce_resolved_callback.AsStdFunction()); - EXPECT_TRUE(correct_message); MoqtAnnounceError error = { /*track_namespace=*/FullTrackName{"foo"}, /*error_code=*/MoqtAnnounceErrorCode::kInternalError, /*reason_phrase=*/"Test error", }; - correct_message = false; EXPECT_CALL(announce_resolved_callback, Call(_, _)) .WillOnce([&](FullTrackName track_namespace, std::optional<MoqtAnnounceErrorReason> error) { - correct_message = true; EXPECT_EQ(track_namespace, FullTrackName{"foo"}); ASSERT_TRUE(error.has_value()); EXPECT_EQ(error->error_code, MoqtAnnounceErrorCode::kInternalError); EXPECT_EQ(error->reason_phrase, "Test error"); }); stream_input->OnAnnounceErrorMessage(error); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, SubscribeForPast) { @@ -371,17 +314,10 @@ webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), - MoqtMessageType::kSubscribeError); - return absl::OkStatus(); - }); + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeError), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, TwoSubscribesForTrack) { @@ -414,16 +350,9 @@ webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); request.subscribe_id = 2; request.start_group = 12; @@ -465,46 +394,24 @@ webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); // Peer unsubscribes. MoqtUnsubscribe unsubscribe = { /*subscribe_id=*/1, }; - correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), - MoqtMessageType::kSubscribeDone); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeDone), _)); stream_input->OnUnsubscribeMessage(unsubscribe); - EXPECT_TRUE(correct_message); // Subscribe again, succeeds. request.subscribe_id = 2; request.start_group = 12; - correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, SubscribeIdTooHigh) { @@ -549,17 +456,10 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); // Request for track returns SUBSCRIBE_ERROR. - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), - MoqtMessageType::kSubscribeError); - return absl::OkStatus(); - }); + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeError), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); // Second request is a protocol violation. request.subscribe_id = 0; @@ -580,19 +480,12 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); EXPECT_TRUE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), &remote_track_visitor)); EXPECT_FALSE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), &remote_track_visitor)); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, SubscribeWithOk) { @@ -601,14 +494,8 @@ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MockRemoteTrackVisitor remote_track_visitor; EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), &remote_track_visitor); @@ -616,16 +503,13 @@ /*subscribe_id=*/0, /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0), }; - correct_message = false; EXPECT_CALL(remote_track_visitor, OnReply(_, _)) .WillOnce([&](const FullTrackName& ftn, std::optional<absl::string_view> error_message) { - correct_message = true; EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_FALSE(error_message.has_value()); }); stream_input->OnSubscribeOkMessage(ok); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, MaxSubscribeIdChangesResponse) { @@ -642,17 +526,10 @@ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); stream_input->OnMaxSubscribeIdMessage(max_subscribe_id); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); EXPECT_TRUE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), &remote_track_visitor)); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, LowerMaxSubscribeIdIsAnError) { @@ -675,17 +552,10 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), - MoqtMessageType::kMaxSubscribeId); - return absl::OkStatus(); - }); + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kMaxSubscribeId), _)); session_.GrantMoreSubscribes(1); - EXPECT_TRUE(correct_message); // Peer subscribes to (0, 0) MoqtSubscribe request = { /*subscribe_id=*/kDefaultInitialMaxSubscribeId + 1, @@ -699,7 +569,6 @@ /*end_object=*/std::nullopt, /*parameters=*/MoqtSubscribeParameters(), }; - correct_message = false; FullTrackName ftn("foo", "bar"); auto track = std::make_shared<MockTrackPublisher>(ftn); EXPECT_CALL(*track, GetTrackStatus()) @@ -712,15 +581,9 @@ EXPECT_CALL(*track, GetLargestSequence()) .WillRepeatedly(Return(FullSequence(10, 20))); publisher_.Add(track); - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)); stream_input->OnSubscribeMessage(request); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, SubscribeWithError) { @@ -729,14 +592,8 @@ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MockRemoteTrackVisitor remote_track_visitor; EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); - bool correct_message = true; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), &remote_track_visitor); @@ -746,16 +603,13 @@ /*reason_phrase=*/"deadbeef", /*track_alias=*/2, }; - correct_message = false; EXPECT_CALL(remote_track_visitor, OnReply(_, _)) .WillOnce([&](const FullTrackName& ftn, std::optional<absl::string_view> error_message) { - correct_message = true; EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_EQ(*error_message, "deadbeef"); }); stream_input->OnSubscribeErrorMessage(error); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, ReplyToAnnounce) { @@ -765,19 +619,14 @@ MoqtAnnounce announce = { /*track_namespace=*/FullTrackName{"foo"}, }; - bool correct_message = false; EXPECT_CALL(session_callbacks_.incoming_announce_callback, Call(FullTrackName{"foo"})) .WillOnce(Return(std::nullopt)); - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounceOk); - return absl::OkStatus(); - }); + EXPECT_CALL( + mock_stream, + Writev(SerializedControlMessage(MoqtAnnounceOk{FullTrackName{"foo"}}), + _)); stream_input->OnAnnounceMessage(announce); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, IncomingObject) { @@ -1475,17 +1324,10 @@ EXPECT_CALL(mock_stream, GetStreamId()) .WillOnce(Return(webtransport::StreamId(4))); EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); - bool correct_message = false; EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); }); - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kClientSetup), _)); session_.OnSessionReady(); - EXPECT_TRUE(correct_message); // Peer tries to open a bidi stream. bool reported_error = false; @@ -1515,14 +1357,8 @@ /*role=*/MoqtRole::kPubSub, /*path=*/std::nullopt, }; - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kServerSetup), _)); EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0)); EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); stream_input->OnClientSetupMessage(setup); @@ -1555,17 +1391,9 @@ /*subscribe_id=*/0, }; EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); - bool correct_message = false; - EXPECT_CALL(mock_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), - MoqtMessageType::kSubscribeDone); - return absl::OkStatus(); - }); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeDone), _)); stream_input->OnUnsubscribeMessage(unsubscribe); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, SendDatagram) { @@ -1944,7 +1772,6 @@ /*end_object=*/std::nullopt, /*parameters=*/MoqtSubscribeParameters(), }; - bool correct_message = false; auto track = std::make_shared<MockTrackPublisher>(ftn); publisher_.Add(track); @@ -1956,18 +1783,12 @@ EXPECT_CALL(*track, GetDeliveryOrder()) .WillRepeatedly(Return(MoqtDeliveryOrder::kAscending)); EXPECT_CALL(*fetch_task, GetLargestId()).WillOnce(Return(FullSequence(0, 0))); - EXPECT_CALL(control_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kFetchOk); - return absl::OkStatus(); - }); + EXPECT_CALL(control_stream, + Writev(ControlMessageOfType(MoqtMessageType::kFetchOk), _)); // Stream can't open yet. EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream) .WillOnce(Return(false)); stream_input->OnFetchMessage(request); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, FetchReturnsOkImmediateOpen) { @@ -1985,7 +1806,6 @@ /*end_object=*/std::nullopt, /*parameters=*/MoqtSubscribeParameters(), }; - bool correct_message = false; auto track = std::make_shared<MockTrackPublisher>(ftn); publisher_.Add(track); @@ -1997,13 +1817,8 @@ EXPECT_CALL(*track, GetDeliveryOrder()) .WillRepeatedly(Return(MoqtDeliveryOrder::kAscending)); EXPECT_CALL(*fetch_task, GetLargestId()).WillOnce(Return(FullSequence(0, 0))); - EXPECT_CALL(control_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kFetchOk); - return absl::OkStatus(); - }); + EXPECT_CALL(control_stream, + Writev(ControlMessageOfType(MoqtMessageType::kFetchOk), _)); // Open stream immediately. EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream) .WillOnce(Return(true)); @@ -2023,10 +1838,8 @@ EXPECT_CALL(*fetch_task, GetNextObject(_)) .WillOnce(Return(MoqtFetchTask::GetNextObjectResult::kPending)); stream_input->OnFetchMessage(request); - EXPECT_TRUE(correct_message); // Signal the stream that pending object is now available. - correct_message = false; EXPECT_CALL(data_stream, CanWrite()).WillRepeatedly(Return(true)); EXPECT_CALL(*fetch_task, GetNextObject(_)) .WillOnce(Invoke([](PublishedObject& output) { @@ -2043,7 +1856,6 @@ EXPECT_CALL(data_stream, Writev(_, _)) .WillOnce([&](absl::Span<const absl::string_view> data, const quiche::StreamWriteOptions& options) { - correct_message = true; quic::QuicDataReader reader(data[0]); uint64_t type; EXPECT_TRUE(reader.ReadVarInt62(&type)); @@ -2052,7 +1864,6 @@ return absl::OkStatus(); }); fetch_task->objects_available_callback()(); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, InvalidFetch) { @@ -2094,7 +1905,6 @@ /*end_object=*/std::nullopt, /*parameters=*/MoqtSubscribeParameters(), }; - bool correct_message = false; auto track = std::make_shared<MockTrackPublisher>(ftn); publisher_.Add(track); @@ -2104,15 +1914,9 @@ .WillOnce(Return(std::move(fetch_task_ptr))); EXPECT_CALL(*fetch_task, GetStatus()) .WillRepeatedly(Return(absl::Status(absl::StatusCode::kInternal, "foo"))); - EXPECT_CALL(control_stream, Writev(_, _)) - .WillOnce([&](absl::Span<const absl::string_view> data, - const quiche::StreamWriteOptions& options) { - correct_message = true; - EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kFetchError); - return absl::OkStatus(); - }); + EXPECT_CALL(control_stream, + Writev(ControlMessageOfType(MoqtMessageType::kFetchError), _)); stream_input->OnFetchMessage(request); - EXPECT_TRUE(correct_message); } TEST_F(MoqtSessionTest, FetchDelivery) {
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc new file mode 100644 index 0000000..5d0116d --- /dev/null +++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc
@@ -0,0 +1,191 @@ +// Copyright 2024 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" + +#include <string> + +#include "absl/types/variant.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace moqt::test { + +struct TypeVisitor { + MoqtMessageType operator()(const MoqtClientSetup&) { + return MoqtMessageType::kClientSetup; + } + MoqtMessageType operator()(const MoqtServerSetup&) { + return MoqtMessageType::kServerSetup; + } + MoqtMessageType operator()(const MoqtSubscribe&) { + return MoqtMessageType::kSubscribe; + } + MoqtMessageType operator()(const MoqtSubscribeOk&) { + return MoqtMessageType::kSubscribeOk; + } + MoqtMessageType operator()(const MoqtSubscribeError&) { + return MoqtMessageType::kSubscribeError; + } + MoqtMessageType operator()(const MoqtUnsubscribe&) { + return MoqtMessageType::kUnsubscribe; + } + MoqtMessageType operator()(const MoqtSubscribeDone&) { + return MoqtMessageType::kSubscribeDone; + } + MoqtMessageType operator()(const MoqtSubscribeUpdate&) { + return MoqtMessageType::kSubscribeUpdate; + } + MoqtMessageType operator()(const MoqtAnnounce&) { + return MoqtMessageType::kAnnounce; + } + MoqtMessageType operator()(const MoqtAnnounceOk&) { + return MoqtMessageType::kAnnounceOk; + } + MoqtMessageType operator()(const MoqtAnnounceError&) { + return MoqtMessageType::kAnnounceError; + } + MoqtMessageType operator()(const MoqtAnnounceCancel&) { + return MoqtMessageType::kAnnounceCancel; + } + MoqtMessageType operator()(const MoqtTrackStatusRequest&) { + return MoqtMessageType::kTrackStatusRequest; + } + MoqtMessageType operator()(const MoqtUnannounce&) { + return MoqtMessageType::kUnannounce; + } + MoqtMessageType operator()(const MoqtTrackStatus&) { + return MoqtMessageType::kTrackStatus; + } + MoqtMessageType operator()(const MoqtGoAway&) { + return MoqtMessageType::kGoAway; + } + MoqtMessageType operator()(const MoqtSubscribeAnnounces&) { + return MoqtMessageType::kSubscribeAnnounces; + } + MoqtMessageType operator()(const MoqtSubscribeAnnouncesOk&) { + return MoqtMessageType::kSubscribeAnnouncesOk; + } + MoqtMessageType operator()(const MoqtSubscribeAnnouncesError&) { + return MoqtMessageType::kSubscribeAnnouncesError; + } + MoqtMessageType operator()(const MoqtUnsubscribeAnnounces&) { + return MoqtMessageType::kUnsubscribeAnnounces; + } + MoqtMessageType operator()(const MoqtMaxSubscribeId&) { + return MoqtMessageType::kMaxSubscribeId; + } + MoqtMessageType operator()(const MoqtFetch&) { + return MoqtMessageType::kFetch; + } + MoqtMessageType operator()(const MoqtFetchCancel&) { + return MoqtMessageType::kFetchCancel; + } + MoqtMessageType operator()(const MoqtFetchOk&) { + return MoqtMessageType::kFetchOk; + } + MoqtMessageType operator()(const MoqtFetchError&) { + return MoqtMessageType::kFetchError; + } + MoqtMessageType operator()(const MoqtObjectAck&) { + return MoqtMessageType::kObjectAck; + } +}; + +MoqtMessageType MessageTypeForGenericMessage(const MoqtGenericFrame& frame) { + return absl::visit(TypeVisitor(), frame); +} + +struct FramingVisitor { + quiche::QuicheBuffer operator()(const MoqtClientSetup& message) { + return framer.SerializeClientSetup(message); + } + quiche::QuicheBuffer operator()(const MoqtServerSetup& message) { + return framer.SerializeServerSetup(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribe& message) { + return framer.SerializeSubscribe(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeOk& message) { + return framer.SerializeSubscribeOk(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeError& message) { + return framer.SerializeSubscribeError(message); + } + quiche::QuicheBuffer operator()(const MoqtUnsubscribe& message) { + return framer.SerializeUnsubscribe(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeDone& message) { + return framer.SerializeSubscribeDone(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeUpdate& message) { + return framer.SerializeSubscribeUpdate(message); + } + quiche::QuicheBuffer operator()(const MoqtAnnounce& message) { + return framer.SerializeAnnounce(message); + } + quiche::QuicheBuffer operator()(const MoqtAnnounceOk& message) { + return framer.SerializeAnnounceOk(message); + } + quiche::QuicheBuffer operator()(const MoqtAnnounceError& message) { + return framer.SerializeAnnounceError(message); + } + quiche::QuicheBuffer operator()(const MoqtAnnounceCancel& message) { + return framer.SerializeAnnounceCancel(message); + } + quiche::QuicheBuffer operator()(const MoqtTrackStatusRequest& message) { + return framer.SerializeTrackStatusRequest(message); + } + quiche::QuicheBuffer operator()(const MoqtUnannounce& message) { + return framer.SerializeUnannounce(message); + } + quiche::QuicheBuffer operator()(const MoqtTrackStatus& message) { + return framer.SerializeTrackStatus(message); + } + quiche::QuicheBuffer operator()(const MoqtGoAway& message) { + return framer.SerializeGoAway(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeAnnounces& message) { + return framer.SerializeSubscribeAnnounces(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeAnnouncesOk& message) { + return framer.SerializeSubscribeAnnouncesOk(message); + } + quiche::QuicheBuffer operator()(const MoqtSubscribeAnnouncesError& message) { + return framer.SerializeSubscribeAnnouncesError(message); + } + quiche::QuicheBuffer operator()(const MoqtUnsubscribeAnnounces& message) { + return framer.SerializeUnsubscribeAnnounces(message); + } + quiche::QuicheBuffer operator()(const MoqtMaxSubscribeId& message) { + return framer.SerializeMaxSubscribeId(message); + } + quiche::QuicheBuffer operator()(const MoqtFetch& message) { + return framer.SerializeFetch(message); + } + quiche::QuicheBuffer operator()(const MoqtFetchCancel& message) { + return framer.SerializeFetchCancel(message); + } + quiche::QuicheBuffer operator()(const MoqtFetchOk& message) { + return framer.SerializeFetchOk(message); + } + quiche::QuicheBuffer operator()(const MoqtFetchError& message) { + return framer.SerializeFetchError(message); + } + quiche::QuicheBuffer operator()(const MoqtObjectAck& message) { + return framer.SerializeObjectAck(message); + } + + MoqtFramer& framer; +}; + +std::string SerializeGenericMessage(const MoqtGenericFrame& frame, + bool use_webtrans) { + MoqtFramer framer(quiche::SimpleBufferAllocator::Get(), use_webtrans); + return std::string(absl::visit(FramingVisitor{framer}, frame).AsStringView()); +} + +} // namespace moqt::test
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.h b/quiche/quic/moqt/test_tools/moqt_framer_utils.h new file mode 100644 index 0000000..a1ad760 --- /dev/null +++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.h
@@ -0,0 +1,62 @@ +// Copyright 2024 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_ +#define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_ + +#include <cstdint> +#include <string> + +#include "absl/strings/str_join.h" +#include "absl/types/variant.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/quiche_data_reader.h" + +namespace moqt::test { + +// TODO: remove MoqtObject from TestMessageBase::MessageStructuredData and merge +// those two types. +using MoqtGenericFrame = absl::variant< + MoqtClientSetup, MoqtServerSetup, MoqtSubscribe, MoqtSubscribeOk, + MoqtSubscribeError, MoqtUnsubscribe, MoqtSubscribeDone, MoqtSubscribeUpdate, + MoqtAnnounce, MoqtAnnounceOk, MoqtAnnounceError, MoqtAnnounceCancel, + MoqtTrackStatusRequest, MoqtUnannounce, MoqtTrackStatus, MoqtGoAway, + MoqtSubscribeAnnounces, MoqtSubscribeAnnouncesOk, + MoqtSubscribeAnnouncesError, MoqtUnsubscribeAnnounces, MoqtMaxSubscribeId, + MoqtFetch, MoqtFetchCancel, MoqtFetchOk, MoqtFetchError, MoqtObjectAck>; + +MoqtMessageType MessageTypeForGenericMessage(const MoqtGenericFrame& frame); + +std::string SerializeGenericMessage(const MoqtGenericFrame& frame, + bool use_webtrans = false); + +MATCHER_P(SerializedControlMessage, message, + "Matches against a specific expected MoQT message") { + std::string merged_message = absl::StrJoin(arg, ""); + return merged_message == SerializeGenericMessage(message); +} + +MATCHER_P(ControlMessageOfType, expected_type, + "Matches against an MoQT message of a specific type") { + std::string merged_message = absl::StrJoin(arg, ""); + quiche::QuicheDataReader reader(merged_message); + uint64_t type_raw; + if (!reader.ReadVarInt62(&type_raw)) { + *result_listener << "Failed to extract type from the message"; + return false; + } + MoqtMessageType type = static_cast<MoqtMessageType>(type_raw); + if (type != expected_type) { + *result_listener << "Expected message of type " + << MoqtMessageTypeToString(expected_type) << ", got " + << MoqtMessageTypeToString(type); + return false; + } + return true; +} + +} // namespace moqt::test + +#endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_