Move Incoming Subscribe tests out of MoqtSessionTest. Here is the mapping of deleted tests: IDST = IncomingDataStreamTest SubgroupStreamObjectAfterGroupEnd -> IDST::ObjectAfterGroupEnd SubgroupStreamObjectAfterTrackEnd -> IDST::ObjectAfterTrackEnd IncomingObject -> IDST::OnObjectMessage IncomingPartialObject -> IDST::OnObjectMessageBufferPartialObject IncomingPartialObjectNoBuffer -> IDST::OnObjectMessageDontBufferPartialObject ObjectBeforeSubscribeOk -> IDST::OnObjectMessageInvalidTrack StreamObjectOutOfWindow -> IDST::OnObjectMessageNotInWindow PiperOrigin-RevId: 922989304
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 662a6bc..c916401 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -922,14 +922,19 @@ TEST_F(MoqtSessionTest, Unsubscribe) { std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); + FullTrackName ftn = FullTrackName("foo", "bar"); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); + EXPECT_TRUE( + session_.Subscribe(ftn, &remote_track_visitor_, MessageParameters())); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribe), _)); - EXPECT_NE(MoqtSessionPeer::remote_track(&session_, 2), nullptr); - session_.Unsubscribe(FullTrackName("foo", "bar")); - // State is destroyed. - EXPECT_EQ(MoqtSessionPeer::remote_track(&session_, 2), nullptr); + EXPECT_CALL(remote_track_visitor_, OnPublishDone); + session_.Unsubscribe(ftn); + // Verify it was destroyed. + EXPECT_CALL(mock_stream_, Writev).Times(0); + EXPECT_CALL(remote_track_visitor_, OnPublishDone).Times(0); + session_.Unsubscribe(ftn); } TEST_F(MoqtSessionTest, ReplyToPublishNamespaceWithOkThenPublishNamespaceDone) { @@ -1125,144 +1130,6 @@ nullptr); } -TEST_F(MoqtSessionTest, IncomingObject) { - FullTrackName ftn("foo", "bar"); - std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*extension_headers=*/"foo", - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/8, - }; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType, 2, - &remote_track_visitor_); - - EXPECT_CALL(remote_track_visitor_, OnObjectFragment) - .WillOnce([&](const FullTrackName& track_name, - const PublishedObjectMetadata& metadata, - const absl::string_view received_payload, uint64_t offset) { - EXPECT_EQ(track_name, ftn); - EXPECT_EQ(metadata.location, Location(0, 0)); - EXPECT_EQ(metadata.subgroup, 0); - EXPECT_EQ(metadata.extensions, "foo"); - EXPECT_EQ(metadata.status, MoqtObjectStatus::kNormal); - EXPECT_EQ(metadata.publisher_priority, 0); - EXPECT_EQ(metadata.payload_length, payload.length()); - EXPECT_EQ(payload, received_payload); - EXPECT_EQ(offset, 0); - }); - EXPECT_CALL(mock_stream_, GetStreamId()) - .WillRepeatedly(Return(kIncomingUniStreamId)); - object_stream->OnObjectMessage(object, payload, true); -} - -TEST_F(MoqtSessionTest, IncomingPartialObject) { - FullTrackName ftn("foo", "bar"); - std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*extension_headers=*/"", - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/16, - }; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType, 2, - &remote_track_visitor_); - - EXPECT_CALL(remote_track_visitor_, OnObjectFragment).Times(1); - EXPECT_CALL(mock_stream_, GetStreamId()) - .WillRepeatedly(Return(kIncomingUniStreamId)); - object_stream->OnObjectMessage(object, payload, false); - object_stream->OnObjectMessage(object, payload, true); // complete the object -} - -TEST_F(MoqtSessionTest, IncomingPartialObjectNoBuffer) { - MoqtSessionParameters parameters(quic::Perspective::IS_CLIENT); - parameters.deliver_partial_objects = true; - MoqtSession session(&mock_session_, parameters, - std::make_unique<quic::test::TestAlarmFactory>(), - session_callbacks_.AsSessionCallbacks()); - FullTrackName ftn("foo", "bar"); - std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*extension_headers=*/"", - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/16, - }; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session, &mock_stream_, - kDefaultSubgroupStreamType, 2, - &remote_track_visitor_); - EXPECT_CALL(mock_stream_, GetStreamId()) - .WillRepeatedly(Return(kIncomingUniStreamId)); - EXPECT_CALL(remote_track_visitor_, OnObjectFragment(ftn, _, payload, 0)); - object_stream->OnObjectMessage(object, payload, false); - EXPECT_CALL(remote_track_visitor_, - OnObjectFragment(ftn, _, payload, payload.length())); - object_stream->OnObjectMessage(object, payload, true); // complete the object - // New object, check the offset was reset. - ++object.object_id; - EXPECT_CALL(remote_track_visitor_, OnObjectFragment(ftn, _, payload, 0)); - object_stream->OnObjectMessage(object, payload, true); // complete the object -} - -TEST_F(MoqtSessionTest, ObjectBeforeSubscribeOk) { - FullTrackName ftn("foo", "bar"); - std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultLocalSubscribe(), - std::nullopt, &remote_track_visitor_); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*extension_headers=*/"", - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/8, - }; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType, 2); - EXPECT_CALL(mock_stream_, SendStopSending); - object_stream->OnObjectMessage(object, payload, true); - - // SUBSCRIBE_OK arrives - MoqtSubscribeOk ok = { - kDefaultLocalRequestId, - /*track_alias=*/2, - MessageParameters(), - TrackExtensions(), - }; - webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = - MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); - EXPECT_CALL(remote_track_visitor_, OnReply).Times(1); - control_stream->ReceiveMessage(ok); -} - TEST_F(MoqtSessionTest, SubscribeOkWithBadTrackAlias) { // Create open subscription. We cannot use CreateRemoteTrack because that // skips the code that sets the track alias callbacks. @@ -1307,9 +1174,22 @@ TEST_F(MoqtSessionTest, ReceiveDatagram) { FullTrackName ftn("foo", "bar"); + const MoqtPriority kPeerDefaultPriority = 0x20; + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); + session_.Subscribe(ftn, &remote_track_visitor_, MessageParameters()); + MoqtSubscribeOk ok; + ok.request_id = 0; + ok.track_alias = 2; + ok.extensions = + TrackExtensions(std::nullopt, std::nullopt, kPeerDefaultPriority, + std::nullopt, std::nullopt, std::nullopt); + EXPECT_CALL(remote_track_visitor_, OnReply); + stream_input->ReceiveMessage(ok); + MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -1468,36 +1348,24 @@ kLocalDefaultPriority + 1); } -TEST_F(MoqtSessionTest, StreamObjectOutOfWindow) { - std::string payload = "deadbeef"; - MoqtSubscribe subscribe = DefaultSubscribe(); - subscribe.parameters.subscription_filter.emplace(Location(1, 0)); - MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe, /*track_alias=*/2, - &remote_track_visitor_); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*extension_headers=*/"", - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/8, - }; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType, 2, - &remote_track_visitor_); - EXPECT_CALL(remote_track_visitor_, OnObjectFragment).Times(0); - object_stream->OnObjectMessage(object, payload, true); -} - TEST_F(MoqtSessionTest, DatagramOutOfWindow) { - std::string payload = "deadbeef"; - MoqtSubscribe subscribe = DefaultSubscribe(); - subscribe.parameters.subscription_filter.emplace(Location(1, 0)); - MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe, /*track_alias=*/2, - &remote_track_visitor_); + FullTrackName ftn("foo", "bar"); + const MoqtPriority kPeerDefaultPriority = 0x20; + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); + MessageParameters params; + params.subscription_filter.emplace(Location(1, 0)); + session_.Subscribe(ftn, &remote_track_visitor_, params); + MoqtSubscribeOk ok; + ok.request_id = 0; + ok.track_alias = 2; + ok.extensions = + TrackExtensions(std::nullopt, std::nullopt, kPeerDefaultPriority, + std::nullopt, std::nullopt, std::nullopt); + EXPECT_CALL(remote_track_visitor_, OnReply); + stream_input->ReceiveMessage(ok); char datagram[] = {0x01, 0x02, 0x00, 0x00, 0x80, 0x00, 0x08, 0x64, 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66}; EXPECT_CALL(remote_track_visitor_, OnObjectFragment).Times(0); @@ -2744,70 +2612,6 @@ EXPECT_EQ(MoqtSessionPeer::remote_track(&session_, 0), nullptr); } -TEST_F(MoqtSessionTest, SubgroupStreamObjectAfterGroupEnd) { - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream( - &session_, &mock_stream_, - MoqtDataStreamType::Subgroup(/*subgroup_id=*/0, /*first_object_id=*/0, - /*no_extension_headers=*/true, - /*has_default_priority=*/false), - 2); - object_stream->OnObjectMessage( - MoqtObject(/*track_alias=*/2, /*group_id=*/0, /*object_id=*/0, - /*publisher_priority=*/0x80, /*extension_headers=*/"", - MoqtObjectStatus::kEndOfGroup, /*subgroup_id=*/0, - /*payload_length=*/0), - "", true); - EXPECT_CALL(mock_session_, GetStreamById(_)) - .WillRepeatedly(Return(&control_stream)); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribe), _)); - EXPECT_CALL(remote_track_visitor_, OnMalformedTrack); - object_stream->OnObjectMessage( - MoqtObject(/*track_alias=*/2, /*group_id=*/0, /*object_id=*/1, - /*publisher_priority=*/0x80, /*extension_headers=*/"", - MoqtObjectStatus::kNormal, /*subgroup_id=*/0, - /*payload_length=*/3), - "bar", true); -} - -TEST_F(MoqtSessionTest, SubgroupStreamObjectAfterTrackEnd) { - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), - /*track_alias=*/2, &remote_track_visitor_); - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream( - &session_, &mock_stream_, - MoqtDataStreamType::Subgroup(/*subgroup_id=*/0, /*first_object_id=*/0, - /*no_extension_headers=*/true, - /*has_default_priority=*/false), - /*track_alias=*/2); - object_stream->OnObjectMessage( - MoqtObject(/*track_alias=*/2, /*group_id=*/0, /*object_id=*/0, - /*publisher_priority=*/0x80, /*extension_headers=*/"", - MoqtObjectStatus::kEndOfTrack, /*subgroup_id=*/0, - /*payload_length=*/0), - "", true); - EXPECT_CALL(mock_session_, GetStreamById(_)) - .WillRepeatedly(Return(&control_stream)); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribe), _)); - EXPECT_CALL(remote_track_visitor_, OnMalformedTrack); - object_stream->OnObjectMessage( - MoqtObject(/*track_alias=*/2, /*group_id=*/0, /*object_id=*/1, - /*publisher_priority=*/0x80, /*extension_headers=*/"", - MoqtObjectStatus::kNormal, /*subgroup_id=*/0, - /*payload_length=*/3), - "bar", true); -} - TEST_F(MoqtSessionTest, FetchStreamMalformedTrack) { webtransport::test::InMemoryStream stream(kIncomingUniStreamId); std::unique_ptr<MoqtFetchTask> task =
diff --git a/quiche/quic/moqt/moqt_uni_stream_test.cc b/quiche/quic/moqt/moqt_uni_stream_test.cc index 8459c98..4bcf4a1 100644 --- a/quiche/quic/moqt/moqt_uni_stream_test.cc +++ b/quiche/quic/moqt/moqt_uni_stream_test.cc
@@ -569,17 +569,84 @@ "Object delivered without preliminaries"); } +TEST_F(IncomingDataStreamTest, OnObjectMessage) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + MoqtObject object = kDefaultObject; + object.payload_length = 8; + EXPECT_CALL(visitor_, OnObjectFragment) + .WillOnce([&](const FullTrackName& track_name, + const PublishedObjectMetadata& metadata, + const absl::string_view received_payload, uint64_t offset) { + EXPECT_EQ(track_name, ftn_); + EXPECT_EQ(metadata.location, Location(0, 0)); + EXPECT_EQ(metadata.subgroup, 0); + EXPECT_EQ(metadata.extensions, ""); + EXPECT_EQ(metadata.status, MoqtObjectStatus::kNormal); + EXPECT_EQ(metadata.publisher_priority, 0x80); + EXPECT_EQ(metadata.payload_length, 8); + EXPECT_EQ(received_payload, "deadbeef"); + EXPECT_EQ(offset, 0); + }); + stream_->OnObjectMessage(object, "deadbeef", true); +} + TEST_F(IncomingDataStreamTest, OnObjectMessageBufferPartialObject) { ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); ProcessAlias(2); MoqtObject object = kDefaultObject; - object.payload_length = 10; + object.payload_length = 6; EXPECT_CALL(visitor_, OnObjectFragment).Times(0); stream_->OnObjectMessage(object, "foo", false); - EXPECT_CALL(visitor_, OnObjectFragment); + EXPECT_CALL(visitor_, OnObjectFragment) + .WillOnce([&](const FullTrackName& track_name, + const PublishedObjectMetadata& metadata, + const absl::string_view received_payload, uint64_t offset) { + EXPECT_EQ(metadata.payload_length, 6); + EXPECT_EQ(received_payload, "foobar"); + EXPECT_EQ(offset, 0); + }); stream_->OnObjectMessage(object, "bar", true); } +TEST_F(IncomingDataStreamTest, OnObjectMessageDontBufferPartialObject) { + EXPECT_CALL(session_, deliver_partial_objects()).WillRepeatedly(Return(true)); + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + MoqtObject object = kDefaultObject; + object.payload_length = 6; + EXPECT_CALL(visitor_, OnObjectFragment).Times(0); + EXPECT_CALL(visitor_, OnObjectFragment) + .WillOnce([&](const FullTrackName& track_name, + const PublishedObjectMetadata& metadata, + const absl::string_view received_payload, uint64_t offset) { + EXPECT_EQ(metadata.payload_length, 6); + EXPECT_EQ(received_payload, "foo"); + EXPECT_EQ(offset, 0); + }); + stream_->OnObjectMessage(object, "foo", false); + EXPECT_CALL(visitor_, OnObjectFragment) + .WillOnce([&](const FullTrackName& track_name, + const PublishedObjectMetadata& metadata, + const absl::string_view received_payload, uint64_t offset) { + EXPECT_EQ(metadata.payload_length, 6); + EXPECT_EQ(received_payload, "bar"); + EXPECT_EQ(offset, 3); + }); + stream_->OnObjectMessage(object, "bar", true); + // New object, make sure offset has been reset. + ++object.object_id; + EXPECT_CALL(visitor_, OnObjectFragment) + .WillOnce([&](const FullTrackName& track_name, + const PublishedObjectMetadata& metadata, + const absl::string_view received_payload, uint64_t offset) { + EXPECT_EQ(metadata.payload_length, 6); + EXPECT_EQ(received_payload, "foobaz"); + EXPECT_EQ(offset, 0); + }); + stream_->OnObjectMessage(object, "foobaz", true); +} + TEST_F(IncomingDataStreamTest, OnObjectMessageInvalidTrack) { ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); uint8_t alias = 2; @@ -608,7 +675,7 @@ "Missing subgroup ID on SUBSCRIBE stream"); } -TEST_F(IncomingDataStreamTest, OnObjectMessageMalformedTrack) { +TEST_F(IncomingDataStreamTest, ObjectAfterTrackEnd) { ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); ProcessAlias(2); MoqtObject object = kDefaultObject; @@ -623,6 +690,21 @@ stream_->OnObjectMessage(object2, "", true); } +TEST_F(IncomingDataStreamTest, ObjectAfterGroupEnd) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + MoqtObject object = kDefaultObject; + object.object_status = MoqtObjectStatus::kEndOfGroup; + EXPECT_CALL(visitor_, OnObjectFragment); + stream_->OnObjectMessage(object, "", true); + + EXPECT_CALL(session_, OnMalformedTrack(track_.get())); + MoqtObject object2 = object; + object2.object_id = 1; + object2.object_status = MoqtObjectStatus::kNormal; + stream_->OnObjectMessage(object2, "", true); +} + TEST_F(IncomingDataStreamTest, MaybeReadOneObjectUnexpectedState) { EXPECT_QUICHE_BUG(stream_->MaybeReadOneObject(), "Requesting object, parser in unexpected state");
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index d70c396..6aebd20 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -143,34 +143,6 @@ absl::down_cast<MoqtSession::ControlStream*>(visitor.release()))); } - static void CreateRemoteTrack(MoqtSession* session, - const MoqtSubscribe& subscribe, - const std::optional<uint64_t> track_alias, - SubscribeVisitor* visitor) { - auto track = std::make_unique<SubscribeRemoteTrack>( - subscribe, visitor, - [session = session, ftn = subscribe.full_track_name, - id = subscribe.request_id]() { - session->subscribe_by_name_.erase(ftn); - session->upstream_by_id_.erase(id); - }, - [session = session](uint64_t alias, SubscribeRemoteTrack* track) { - if (track == nullptr) { - session->subscribe_by_alias_.erase(alias); - return true; - } - session->subscribe_by_alias_[alias] = track; - return true; - }); - if (track_alias.has_value()) { - ASSERT_TRUE(track->set_track_alias(*track_alias)); - } - session->subscribe_by_name_.try_emplace(subscribe.full_track_name, - track.get()); - session->upstream_by_id_.try_emplace(subscribe.request_id, - std::move(track)); - } - static SubscribeRemoteTrack* remote_track(MoqtSession* session, uint64_t track_alias) { return session->RemoteTrackByAlias(track_alias);