Handle MoQT ANNOUNCE_CANCEL messages. PiperOrigin-RevId: 647675746
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 9843ab6..6832f25 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -9,13 +9,14 @@ #include <cstdint> #include <memory> #include <optional> -#include <set> #include <string> #include <utility> #include <vector> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -192,6 +193,18 @@ return (it != local_tracks_.end() && it->second.HasSubscriber()); } +void MoqtSession::CancelAnnounce(absl::string_view track_namespace) { + for (auto it = local_tracks_.begin(); it != local_tracks_.end(); ++it) { + if (it->first.track_namespace == track_namespace) { + it->second.set_announce_cancel(); + } + } + absl::erase_if(local_tracks_, [&](const auto& it) { + return it.first.track_namespace == track_namespace && + !it.second.HasSubscriber(); + }); +} + bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace, absl::string_view name, uint64_t start_group, uint64_t start_object, @@ -322,6 +335,9 @@ // Clean up the subscription track.DeleteWindow(subscribe_id); local_track_by_subscribe_id_.erase(name_it); + if (track.canceled() && !track.HasSubscriber()) { + local_tracks_.erase(track_it); + } return true; } @@ -724,6 +740,14 @@ return; } LocalTrack& track = it->second; + if (it->second.canceled()) { + // Note that if the track has already been deleted, there will not be a + // protocol violation, which the spec says there SHOULD be. It's not worth + // keeping state on deleted tracks. + session_->Error(MoqtError::kProtocolViolation, + "Received SUBSCRIBE for canceled track"); + return; + } if ((track.track_alias().has_value() && message.track_alias != *track.track_alias()) || session_->used_track_aliases_.contains(message.track_alias)) { @@ -967,6 +991,11 @@ session_->pending_outgoing_announces_.erase(it); } +void MoqtSession::Stream::OnAnnounceCancelMessage( + const MoqtAnnounceCancel& message) { + session_->CancelAnnounce(message.track_namespace); +} + void MoqtSession::Stream::OnParsingError(MoqtError error_code, absl::string_view reason) { session_->Error(error_code, absl::StrCat("Parse error: ", reason));
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index c41bae5..9d88da9 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -98,6 +98,9 @@ void Announce(absl::string_view track_namespace, MoqtOutgoingAnnounceCallback announce_callback); bool HasSubscribers(const FullTrackName& full_track_name) const; + // Send an ANNOUNCE_CANCEL and delete local tracks in that namespace when all + // subscriptions are closed for that track. + void CancelAnnounce(absl::string_view track_namespace); // Returns true if SUBSCRIBE was sent. If there is already a subscription to // the track, the message will still be sent. However, the visitor will be @@ -177,13 +180,14 @@ void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override; void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override; void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override; + // There is no state to update for SUBSCRIBE_DONE. void OnSubscribeDoneMessage(const MoqtSubscribeDone& /*message*/) override { } void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) override; void OnAnnounceMessage(const MoqtAnnounce& message) override; void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override; void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override; - void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override {}; + void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override; void OnTrackStatusRequestMessage( const MoqtTrackStatusRequest& message) override {}; void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override {}
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index e943999..a002462 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -108,19 +108,23 @@ session->active_subscribes_[subscribe_id] = {subscribe, visitor}; } - static LocalTrack& local_track(MoqtSession* session, FullTrackName& name) { - return session->local_tracks_.find(name)->second; + static LocalTrack* local_track(MoqtSession* session, FullTrackName& name) { + auto it = session->local_tracks_.find(name); + if (it == session->local_tracks_.end()) { + return nullptr; + } + return &it->second; } static void AddSubscription(MoqtSession* session, FullTrackName& name, uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group, uint64_t start_object) { - LocalTrack& track = local_track(session, name); - track.set_track_alias(track_alias); - track.AddWindow(subscribe_id, start_group, start_object); + LocalTrack* track = local_track(session, name); + track->set_track_alias(track_alias); + track->AddWindow(subscribe_id, start_group, start_object); session->used_track_aliases_.emplace(track_alias); session->local_track_by_subscribe_id_.emplace(subscribe_id, - track.full_track_name()); + track->full_track_name()); } static FullSequence next_sequence(MoqtSession* session, FullTrackName& name) { @@ -1243,9 +1247,9 @@ session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor); MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); // Get the window, set the maximum delivered. - LocalTrack& track = MoqtSessionPeer::local_track(&session_, ftn); - track.GetWindow(0)->OnObjectSent(FullSequence(7, 3), - MoqtObjectStatus::kNormal); + LocalTrack* track = MoqtSessionPeer::local_track(&session_, ftn); + track->GetWindow(0)->OnObjectSent(FullSequence(7, 3), + MoqtObjectStatus::kNormal); // Update the end to fall at the last delivered object. MoqtSubscribeUpdate update = { /*subscribe_id=*/0, @@ -1272,6 +1276,108 @@ EXPECT_FALSE(session_.HasSubscribers(ftn)); } +TEST_F(MoqtSessionTest, ProcessAnnounceCancelNoSubscribes) { + MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber); + FullTrackName ftn("foo", "bar"); + MockLocalTrackVisitor track_visitor; + session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor); + EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr); + StrictMock<webtransport::test::MockStream> mock_stream; + MoqtAnnounceCancel cancel = { + /*track_namespace=*/"foo", + }; + std::unique_ptr<MoqtParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + stream_input->OnAnnounceCancelMessage(cancel); + EXPECT_EQ(MoqtSessionPeer::local_track(&session_, ftn), nullptr); +} + +TEST_F(MoqtSessionTest, ProcessAnnounceCancelActiveSubscribes) { + MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber); + FullTrackName ftn("foo", "bar"); + MockLocalTrackVisitor track_visitor; + session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor); + EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr); + MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); + MoqtSessionPeer::AddSubscription(&session_, ftn, 1, 2, 7, 0); + StrictMock<webtransport::test::MockStream> mock_stream; + MoqtAnnounceCancel cancel = { + /*track_namespace=*/"foo", + }; + std::unique_ptr<MoqtParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + stream_input->OnAnnounceCancelMessage(cancel); + // The track is still there because there is a subscribe. + EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr); + // Unsubscribe from 0. + MoqtUnsubscribe unsubscribe = { + /*subscribe_id=*/0, + }; + EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); + bool correct_message = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), + MoqtMessageType::kSubscribeDone); + return absl::OkStatus(); + }); + stream_input->OnUnsubscribeMessage(unsubscribe); + EXPECT_TRUE(correct_message); + EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr); + // Unsubscribe from 1. + unsubscribe.subscribe_id = 1; + EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); + correct_message = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), + MoqtMessageType::kSubscribeDone); + return absl::OkStatus(); + }); + stream_input->OnUnsubscribeMessage(unsubscribe); + EXPECT_TRUE(correct_message); + + EXPECT_EQ(MoqtSessionPeer::local_track(&session_, ftn), nullptr); +} + +TEST_F(MoqtSessionTest, AnnounceCancelThenSubscribe) { + MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber); + FullTrackName ftn("foo", "bar"); + MockLocalTrackVisitor track_visitor; + session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor); + EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr); + MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); + StrictMock<webtransport::test::MockStream> mock_stream; + MoqtAnnounceCancel cancel = { + /*track_namespace=*/"foo", + }; + std::unique_ptr<MoqtParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + stream_input->OnAnnounceCancelMessage(cancel); + // The track is still there because there is a subscribe. + EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr); + MoqtSubscribe subscribe = { + /*subscribe_id=*/1, + /*track_alias=*/2, + /*track_namespace=*/"foo", + /*track_name=*/"bar", + /*start_group=*/4, + /*start_object=*/0, + /*end_group=*/std::nullopt, + /*end_object=*/std::nullopt, + /*authorization_info=*/std::nullopt, + }; + EXPECT_CALL(mock_session_, + CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), + "Received SUBSCRIBE for canceled track")) + .Times(1); + stream_input->OnSubscribeMessage(subscribe); +} + // TODO: Cover more error cases in the above } // namespace test
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index b813eb3..e61dc8b 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -14,6 +14,7 @@ #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_subscribe_windows.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" #include "quiche/common/quiche_callbacks.h" namespace moqt { @@ -69,11 +70,15 @@ void AddWindow(uint64_t subscribe_id, uint64_t start_group, uint64_t start_object) { + QUIC_BUG_IF(quic_bug_subscribe_to_canceled_track, announce_canceled_) + << "Canceled track got subscription"; windows_.AddWindow(subscribe_id, next_sequence_, start_group, start_object); } void AddWindow(uint64_t subscribe_id, uint64_t start_group, uint64_t start_object, uint64_t end_group) { + QUIC_BUG_IF(quic_bug_subscribe_to_canceled_track, announce_canceled_) + << "Canceled track got subscription"; // The end object might be unknown. auto it = max_object_ids_.find(end_group); if (end_group >= next_sequence_.group) { @@ -89,6 +94,8 @@ void AddWindow(uint64_t subscribe_id, uint64_t start_group, uint64_t start_object, uint64_t end_group, uint64_t end_object) { + QUIC_BUG_IF(quic_bug_subscribe_to_canceled_track, announce_canceled_) + << "Canceled track got subscription"; windows_.AddWindow(subscribe_id, next_sequence_, start_group, start_object, end_group, end_object); } @@ -142,6 +149,9 @@ return forwarding_preference_; } + void set_announce_cancel() { announce_canceled_ = true; } + bool canceled() const { return announce_canceled_; } + private: // This only needs to track subscriptions to current and future objects; // requests for objects in the past are forwarded to the application. @@ -160,6 +170,11 @@ // EndOfTrack has been received for that group. absl::flat_hash_map<uint64_t, uint64_t> max_object_ids_; Visitor* visitor_; + + // If true, the session has received ANNOUNCE_CANCELED for this namespace. + // Additional subscribes will be a protocol error, and the track can be + // destroyed once all active subscribes end. + bool announce_canceled_ = false; }; // A track on the peer to which the session has subscribed.