Factor PublishedSubscription out of MoqtSession, into SubscriptionPublisher. Rationalized the handling of priority and SendOrder: - In the Session, we only care about MoqtTrackPriority (subscriber_priority | max(publisher_priority) for pending streams. - Within the Subscription, we care about everything except subscriber_priority for pending streams. - When actually setting stream priority, all the inputs matter. There are numerous tests in MoqtSessionTest that are duplicated in SubscriptionPublisherTest or the Uni Stream tests. I've left them here to show that this CL is mostly a no-op, but will delete them in a followon CL. PiperOrigin-RevId: 917466914
diff --git a/build/source_list.bzl b/build/source_list.bzl index 8a79fac..b8ca2e7 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1605,6 +1605,7 @@ "quic/moqt/moqt_session_callbacks.h", "quic/moqt/moqt_session_interface.h", "quic/moqt/moqt_stream_map.h", + "quic/moqt/moqt_subscription.h", "quic/moqt/moqt_trace_recorder.h", "quic/moqt/moqt_track.h", "quic/moqt/moqt_types.h", @@ -1638,6 +1639,7 @@ "quic/moqt/moqt_relay_track_publisher.cc", "quic/moqt/moqt_session.cc", "quic/moqt/moqt_stream_map.cc", + "quic/moqt/moqt_subscription.cc", "quic/moqt/moqt_trace_recorder.cc", "quic/moqt/moqt_track.cc", "quic/moqt/moqt_uni_stream.cc", @@ -1670,6 +1672,7 @@ "quic/moqt/moqt_relay_track_publisher_test.cc", "quic/moqt/moqt_session_test.cc", "quic/moqt/moqt_stream_map_test.cc", + "quic/moqt/moqt_subscription_test.cc", "quic/moqt/moqt_track_test.cc", "quic/moqt/moqt_uni_stream_test.cc", "quic/moqt/relay_namespace_tree_test.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index 7529969..0135e7b 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1609,6 +1609,7 @@ "src/quiche/quic/moqt/moqt_session_callbacks.h", "src/quiche/quic/moqt/moqt_session_interface.h", "src/quiche/quic/moqt/moqt_stream_map.h", + "src/quiche/quic/moqt/moqt_subscription.h", "src/quiche/quic/moqt/moqt_trace_recorder.h", "src/quiche/quic/moqt/moqt_track.h", "src/quiche/quic/moqt/moqt_types.h", @@ -1642,6 +1643,7 @@ "src/quiche/quic/moqt/moqt_relay_track_publisher.cc", "src/quiche/quic/moqt/moqt_session.cc", "src/quiche/quic/moqt/moqt_stream_map.cc", + "src/quiche/quic/moqt/moqt_subscription.cc", "src/quiche/quic/moqt/moqt_trace_recorder.cc", "src/quiche/quic/moqt/moqt_track.cc", "src/quiche/quic/moqt/moqt_uni_stream.cc", @@ -1675,6 +1677,7 @@ "src/quiche/quic/moqt/moqt_relay_track_publisher_test.cc", "src/quiche/quic/moqt/moqt_session_test.cc", "src/quiche/quic/moqt/moqt_stream_map_test.cc", + "src/quiche/quic/moqt/moqt_subscription_test.cc", "src/quiche/quic/moqt/moqt_track_test.cc", "src/quiche/quic/moqt/moqt_uni_stream_test.cc", "src/quiche/quic/moqt/relay_namespace_tree_test.cc",
diff --git a/build/source_list.json b/build/source_list.json index 4a17ef1..009d4b2 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1608,6 +1608,7 @@ "quiche/quic/moqt/moqt_session_callbacks.h", "quiche/quic/moqt/moqt_session_interface.h", "quiche/quic/moqt/moqt_stream_map.h", + "quiche/quic/moqt/moqt_subscription.h", "quiche/quic/moqt/moqt_trace_recorder.h", "quiche/quic/moqt/moqt_track.h", "quiche/quic/moqt/moqt_types.h", @@ -1641,6 +1642,7 @@ "quiche/quic/moqt/moqt_relay_track_publisher.cc", "quiche/quic/moqt/moqt_session.cc", "quiche/quic/moqt/moqt_stream_map.cc", + "quiche/quic/moqt/moqt_subscription.cc", "quiche/quic/moqt/moqt_trace_recorder.cc", "quiche/quic/moqt/moqt_track.cc", "quiche/quic/moqt/moqt_uni_stream.cc", @@ -1674,6 +1676,7 @@ "quiche/quic/moqt/moqt_relay_track_publisher_test.cc", "quiche/quic/moqt/moqt_session_test.cc", "quiche/quic/moqt/moqt_stream_map_test.cc", + "quiche/quic/moqt/moqt_subscription_test.cc", "quiche/quic/moqt/moqt_track_test.cc", "quiche/quic/moqt/moqt_uni_stream_test.cc", "quiche/quic/moqt/relay_namespace_tree_test.cc",
diff --git a/quiche/quic/moqt/moqt_priority.h b/quiche/quic/moqt/moqt_priority.h index 04561d5..02bfc0b 100644 --- a/quiche/quic/moqt/moqt_priority.h +++ b/quiche/quic/moqt/moqt_priority.h
@@ -35,6 +35,14 @@ static constexpr uint64_t kMaxMoqtDeliveryOrder = static_cast<uint64_t>(MoqtDeliveryOrder::kDescending); +// The session weighs pending streams solely on the subscriber_priority and the +// highest of all pending publisher_priorities. +struct MoqtTrackPriority { + MoqtPriority subscriber_priority; + MoqtPriority publisher_priority; + auto operator<=>(const MoqtTrackPriority& other) const = default; +}; + // Computes WebTransport send order for an MoQT data stream with the specified // parameters. QUICHE_EXPORT webtransport::SendOrder SendOrderForStream(
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index b5e44ab..55b3398 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -4,8 +4,6 @@ #include "quiche/quic/moqt/moqt_session.h" -#include <algorithm> -#include <array> #include <cstdint> #include <memory> #include <optional> @@ -26,7 +24,6 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/core/quic_types.h" @@ -44,19 +41,16 @@ #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/moqt_session_interface.h" -#include "quiche/quic/moqt/moqt_stream_map.h" +#include "quiche/quic/moqt/moqt_subscription.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/quic/moqt/moqt_uni_stream.h" #include "quiche/quic/platform/api/quic_logging.h" #include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" -#include "quiche/common/platform/api/quiche_stack_trace.h" #include "quiche/common/quiche_buffer_allocator.h" -#include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_status_utils.h" #include "quiche/common/quiche_weak_ptr.h" -#include "quiche/web_transport/stream_helpers.h" #include "quiche/web_transport/web_transport.h" #define ENDPOINT \ @@ -655,31 +649,16 @@ "Peer did not close session after GOAWAY"); } -bool MoqtSession::PublishIsDone(uint64_t request_id, PublishDoneCode code, - absl::string_view error_reason) { +void MoqtSession::PublishIsDone(uint64_t request_id) { if (is_closing_) { - return false; + return; } auto it = published_subscriptions_.find(request_id); if (it == published_subscriptions_.end()) { - return false; + return; } - - PublishedSubscription& subscription = *it->second; - - MoqtPublishDone publish_done; - publish_done.request_id = request_id; - publish_done.status_code = code; - publish_done.stream_count = subscription.streams_opened(); - publish_done.error_reason = error_reason; - // TODO(martinduke): It is technically correct, but not good, to simply - // reset all the streams in order to send PUBLISH_DONE. It's better to wait - // until streams FIN naturally, where possible. - QUIC_DLOG(INFO) << ENDPOINT << "Sending PUBLISH_DONE message for " - << subscription.publisher().GetTrackName(); + subscribed_track_names_.erase(it->second->publisher().GetTrackName()); published_subscriptions_.erase(it); - SendControlMessage(framer_.SerializePublishDone(publish_done)); - return true; } void MoqtSession::MaybeDestroySubscription(SubscribeRemoteTrack* subscribe) { @@ -704,48 +683,20 @@ upstream_by_id_.erase(subscribe->request_id()); } -webtransport::Stream* MoqtSession::OpenOrQueueDataStream( - uint64_t subscription_id, const NewStreamParameters& parameters) { - auto it = published_subscriptions_.find(subscription_id); - if (it == published_subscriptions_.end()) { - // It is possible that the subscription has been discarded while the stream - // was in the queue; discard those streams. - return nullptr; +void MoqtSession::UpdateTrackPriority( + uint64_t request_id, std::optional<MoqtTrackPriority> old_priority, + MoqtTrackPriority new_priority) { + if (old_priority.has_value()) { + auto [start, end] = + subscriptions_with_queued_streams_.equal_range(*old_priority); + for (auto it = start; it != end; ++it) { + if (it->second == request_id) { + subscriptions_with_queued_streams_.erase(it); + break; + } + } } - PublishedSubscription& subscription = *it->second; - if (!session_->CanOpenNextOutgoingUnidirectionalStream()) { - subscription.AddQueuedOutgoingSubgroupStream(parameters); - // The subscription will notify the session about how to update the - // session's queue. - // TODO: limit the number of streams in the queue. - return nullptr; - } - return OpenDataStream(subscription, parameters); -} - -webtransport::Stream* MoqtSession::OpenDataStream( - PublishedSubscription& subscription, - const NewStreamParameters& parameters) { - webtransport::Stream* new_stream = - session_->OpenOutgoingUnidirectionalStream(); - if (new_stream == nullptr) { - QUICHE_BUG(MoqtSession_OpenDataStream_blocked) - << "OpenDataStream called when creation of new streams is blocked."; - return nullptr; - } - webtransport::StreamPriority priority{ - kMoqtSendGroupId, - subscription.GetSendOrder( - Location(parameters.index.group, parameters.first_object), - parameters.index.subgroup, - parameters.publisher_priority.value_or( - subscription.default_publisher_priority()))}; - new_stream->SetVisitor(std::make_unique<OutgoingSubgroupStream>( - framer_, new_stream, parameters.index, parameters.first_object, - subscription.GetWeakPtr(), subscription.publisher_shared_ptr(), priority, - subscription.track_alias(), &trace_recorder_)); - subscription.OnDataStreamCreated(new_stream->GetStreamId(), parameters.index); - return new_stream; + subscriptions_with_queued_streams_.emplace(new_priority, request_id); } bool MoqtSession::OpenDataStream(PublishedFetch* fetch, @@ -806,54 +757,27 @@ } void MoqtSession::OnCanCreateNewOutgoingUnidirectionalStream() { - while (!subscribes_with_queued_outgoing_data_streams_.empty() && + while (!subscriptions_with_queued_streams_.empty() && session_->CanOpenNextOutgoingUnidirectionalStream()) { - auto next = subscribes_with_queued_outgoing_data_streams_.rbegin(); - auto subscription = published_subscriptions_.find(next->subscription_id); + auto next = subscriptions_with_queued_streams_.begin(); + auto subscription = published_subscriptions_.find(next->second); if (subscription == published_subscriptions_.end()) { - auto fetch = incoming_fetches_.find(next->subscription_id); + auto fetch = incoming_fetches_.find(next->second); // Create the stream if the fetch still exists. if (fetch != incoming_fetches_.end() && - !OpenDataStream(fetch->second.get(), next->send_order)) { + !OpenDataStream(fetch->second.get(), + SendOrderForFetch(next->first.subscriber_priority))) { return; // A QUIC_BUG has fired because this shouldn't happen. } // FETCH needs only one stream, and can be deleted from the queue. Or, // there is no subscribe and no fetch; the entry in the queue is invalid. - subscribes_with_queued_outgoing_data_streams_.erase((++next).base()); + subscriptions_with_queued_streams_.erase(next); continue; } + subscriptions_with_queued_streams_.erase(next); // Pop the item from the subscription's queue, which might update - // subscribes_with_queued_outgoing_data_streams_. - NewStreamParameters next_queued_stream = - subscription->second->NextQueuedOutgoingSubgroupStream(); - // Check if Group is too old. - if (next_queued_stream.index.group < - subscription->second->first_active_group()) { - // The stream is too old to be sent. - continue; - } - // Open the stream. - webtransport::Stream* stream = - OpenDataStream(*subscription->second, next_queued_stream); - if (stream != nullptr) { - stream->visitor()->OnCanWrite(); - } - } -} - -void MoqtSession::UpdateQueuedSendOrder( - uint64_t request_id, std::optional<webtransport::SendOrder> old_send_order, - std::optional<webtransport::SendOrder> new_send_order) { - if (old_send_order == new_send_order) { - return; - } - if (old_send_order.has_value()) { - subscribes_with_queued_outgoing_data_streams_.erase( - SubscriptionWithQueuedStream{*old_send_order, request_id}); - } - if (new_send_order.has_value()) { - subscribes_with_queued_outgoing_data_streams_.emplace(*new_send_order, - request_id); + // subscriptions_with_queued_streams_ with a second pending stream. + subscription->second->OnCanCreateNewUniStream(); } } @@ -1032,15 +956,17 @@ } MoqtTrackPublisher* track_publisher_ptr = track_publisher.get(); - auto subscription = std::make_unique<PublishedSubscription>( - session_, track_publisher, message, session_->next_local_track_alias_++, - monitoring); - PublishedSubscription* subscription_ptr = subscription.get(); + auto subscription = std::make_unique<SubscriptionPublisher>( + session_->framer_, track_publisher, this, message.request_id, + session_->next_local_track_alias_++, message.parameters, session_, + monitoring, session_->callbacks_.clock, session_->trace_recorder_); + SubscriptionPublisher* subscription_ptr = subscription.get(); auto [it, success] = session_->published_subscriptions_.emplace( message.request_id, std::move(subscription)); if (!success) { QUICHE_NOTREACHED(); // ValidateRequestId() should have caught this. } + session_->subscribed_track_names_.insert(message.full_track_name); track_publisher_ptr->AddObjectListener(subscription_ptr); return absl::OkStatus(); } @@ -1218,7 +1144,7 @@ } QUIC_DLOG(INFO) << ENDPOINT << "Received an UNSUBSCRIBE for " << it->second->publisher().GetTrackName(); - session_->published_subscriptions_.erase(it); + session_->PublishIsDone(message.request_id); return absl::OkStatus(); } @@ -1551,23 +1477,20 @@ // created, the stream visitor will replace this callback. fetch_task->SetObjectAvailableCallback( [this, - send_order = - SendOrderForFetch(message.parameters.subscriber_priority.value_or( - kDefaultSubscriberPriority)), + subscriber_priority = message.parameters.subscriber_priority.value_or( + kDefaultSubscriberPriority), request_id = message.request_id]() { auto it = session_->incoming_fetches_.find(request_id); if (it == session_->incoming_fetches_.end()) { return; } if (!session_->session()->CanOpenNextOutgoingUnidirectionalStream() || - !session_->OpenDataStream(it->second.get(), send_order)) { - if (!session_->subscribes_with_queued_outgoing_data_streams_.contains( - SubscriptionWithQueuedStream(request_id, send_order))) { - // Put the FETCH in the queue for a new stream unless it has already - // done so. - session_->UpdateQueuedSendOrder(request_id, std::nullopt, - send_order); - } + !session_->OpenDataStream(it->second.get(), + SendOrderForFetch(subscriber_priority))) { + session_->UpdateTrackPriority( + request_id, std::nullopt, + MoqtTrackPriority(subscriber_priority, + kDefaultPublisherPriority)); } }); return absl::OkStatus(); @@ -1877,394 +1800,6 @@ session_->Error(error_code, absl::StrCat("Parse error: ", reason)); } -MoqtSession::PublishedSubscription::PublishedSubscription( - MoqtSession* session, std::shared_ptr<MoqtTrackPublisher> track_publisher, - const MoqtSubscribe& subscribe, uint64_t track_alias, - MoqtPublishingMonitorInterface* monitoring_interface) - : session_(session), - track_publisher_(track_publisher), - request_id_(subscribe.request_id), - can_have_joining_fetch_(subscribe.parameters.forward()), - track_alias_(track_alias), - parameters_(subscribe.parameters), - monitoring_interface_(monitoring_interface), - weak_ptr_factory_(this) { - if (monitoring_interface_ != nullptr) { - monitoring_interface_->OnObjectAckSupportKnown( - subscribe.parameters.oack_window_size); - } - QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for " - << subscribe.full_track_name; - session_->subscribed_track_names_.insert(subscribe.full_track_name); - // TODO(martinduke): Handle NEW_GROUP_REQUEST -} - -MoqtSession::PublishedSubscription::~PublishedSubscription() { - track_publisher_->RemoveObjectListener(this); - if (session_->is_closing_) { - return; - } - session_->subscribed_track_names_.erase(track_publisher_->GetTrackName()); - // Reset all streams. - for (const webtransport::StreamId stream_id : stream_map().GetAllStreams()) { - webtransport::Stream* stream = session_->session_->GetStreamById(stream_id); - if (stream != nullptr) { - stream->ResetWithUserCode(kResetCodeCancelled); - } - } -} - -SendStreamMap& MoqtSession::PublishedSubscription::stream_map() { - // The stream map is lazily initialized, since initializing it requires - // knowing the forwarding preference in advance, and it might not be known - // when the subscription is first created. - if (!lazily_initialized_stream_map_.has_value()) { - lazily_initialized_stream_map_.emplace(); - } - return *lazily_initialized_stream_map_; -} - -void MoqtSession::PublishedSubscription::Update( - const MessageParameters& parameters) { - // TODO(martinduke): If there are auth tokens, this probably has to go to the - // application. - MoqtPriority old_priority = - parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority); - parameters_.Update(parameters); - can_have_joining_fetch_ = parameters_.forward(); - if (parameters.subscriber_priority.has_value()) { // priority changed. - // Reprioritize all active streams. - for (const auto stream_id : stream_map().GetAllStreams()) { - webtransport::Stream* stream = - session_->session_->GetStreamById(stream_id); - if (stream == nullptr) { - continue; - } - OutgoingSubgroupStream* outgoing_stream = - absl::down_cast<OutgoingSubgroupStream*>(stream->visitor()); - outgoing_stream->UpdatePriority( - parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority)); - } - if (queued_outgoing_data_streams_.empty()) { - return; - } - webtransport::SendOrder old_send_order = - UpdateSendOrderForSubscriberPriority( - queued_outgoing_data_streams_.rbegin()->first, old_priority); - session_->UpdateQueuedSendOrder(request_id_, old_send_order, - FinalizeSendOrder(old_send_order)); - } -} - -void MoqtSession::PublishedSubscription::OnSubscribeAccepted() { - ControlStream* stream = session_->GetControlStream(); - QUICHE_DCHECK(!established_); - established_ = true; - parameters_.largest_object = track_publisher_->largest_location(); - if (parameters_.subscription_filter.has_value()) { - parameters_.subscription_filter->OnLargestObject( - parameters_.largest_object); - } - MoqtSubscribeOk subscribe_ok; - subscribe_ok.request_id = request_id_; - subscribe_ok.track_alias = track_alias_; - subscribe_ok.parameters.expires = track_publisher_->expiration(); - subscribe_ok.parameters.largest_object = parameters_.largest_object; - subscribe_ok.extensions = track_publisher_->extensions(); - if (!parameters_.group_order.has_value()) { - parameters_.group_order = - subscribe_ok.extensions.default_publisher_group_order(); - } - // TODO(martinduke): Support sending DELIVERY_TIMEOUT parameter as the - // publisher. - default_publisher_priority_ = - subscribe_ok.extensions.default_publisher_priority(); - stream->SendOrBufferMessageOrFatal( - session_->framer_.SerializeSubscribeOk(subscribe_ok)); - // TODO(martinduke): If we buffer objects that arrived previously, the arrival - // of the track alias disambiguates what subscription they belong to. Send - // them."; -} - -void MoqtSession::PublishedSubscription::OnSubscribeRejected( - MoqtRequestErrorInfo info) { - ControlStream* control_stream = session_->GetControlStream(); - control_stream->CheckStatus( - control_stream->SendRequestError(request_id_, info)); - session_->published_subscriptions_.erase(request_id_); - // No class access below this line! -} - -void MoqtSession::PublishedSubscription::OnNewObjectAvailable( - Location location, std::optional<uint64_t> subgroup, - MoqtPriority publisher_priority) { - if (!InWindow(location)) { - return; - } - - if (monitoring_interface_ != nullptr) { - // Notify the monitoring interface about all newly published normal objects. - // Objects with other statuses are not guaranteed to be acknowledged, thus - // passing them into the monitoring interface can lead to confusion. - std::optional<PublishedObject> object = track_publisher_->GetCachedObject( - location.group, subgroup, location.object); - QUICHE_DCHECK(object.has_value()) - << "Object " << absl::StrCat(location) << " on track " - << track_publisher_->GetTrackName().ToString() - << " does not exist, despite OnNewObjectAvailable being called"; - if (object.has_value() && object->metadata.location == location && - object->metadata.status == MoqtObjectStatus::kNormal) { - monitoring_interface_->OnNewObjectEnqueued(location); - } - } - - // TODO(vasilvv): This currently sends UINT64_MAX for datagram subgroups. - // Maybe do something more satisfactory? - session_->trace_recorder_.RecordNewObjectAvaliable( - track_alias_, *track_publisher_, location, subgroup.value_or(UINT64_MAX), - publisher_priority); - - std::optional<webtransport::StreamId> stream_id; - if (subgroup.has_value()) { - DataStreamIndex index(location.group, *subgroup); - if (reset_subgroups_.contains(index)) { - // This subgroup has already been reset, ignore. - return; - } - stream_id = stream_map().GetStreamFor(index); - } - if (session_->alternate_delivery_timeout_ && - !delivery_timeout().IsInfinite() && largest_sent_.has_value() && - location.group >= largest_sent_->group) { - // Start the delivery timeout timer on all previous groups. - for (uint64_t group = first_active_group_; group < location.group; - ++group) { - for (webtransport::StreamId stream_to_update : - stream_map().GetStreamsForGroup(group)) { - webtransport::Stream* raw_stream = - session_->session_->GetStreamById(stream_to_update); - if (raw_stream == nullptr) { - continue; - } - OutgoingSubgroupStream* stream = - absl::down_cast<OutgoingSubgroupStream*>(raw_stream->visitor()); - stream->CreateAndSetAlarm(session_->callbacks_.clock->ApproximateNow() + - delivery_timeout()); - } - } - } - QUICHE_DCHECK_GE(location.group, first_active_group_); - if (!subgroup.has_value()) { - SendDatagram(location); - return; - } - - webtransport::Stream* raw_stream = nullptr; - if (stream_id.has_value()) { - raw_stream = session_->session_->GetStreamById(*stream_id); - } else { - raw_stream = session_->OpenOrQueueDataStream( - request_id_, - NewStreamParameters(location.group, *subgroup, location.object, - (publisher_priority == default_publisher_priority_) - ? std::nullopt - : std::make_optional(publisher_priority))); - } - if (raw_stream == nullptr) { - return; - } - raw_stream->visitor()->OnCanWrite(); -} - -void MoqtSession::PublishedSubscription::OnTrackPublisherGone() { - session_->PublishIsDone(request_id_, PublishDoneCode::kGoingAway, - "Publisher is gone"); -} - -// TODO(martinduke): Revise to check if the last object has been delivered. -void MoqtSession::PublishedSubscription::OnNewFinAvailable(Location location, - uint64_t subgroup) { - if (!InWindow(location.group)) { - return; - } - DataStreamIndex index(location.group, subgroup); - if (reset_subgroups_.contains(index)) { - // This subgroup has already been reset, ignore. - return; - } - QUICHE_DCHECK_GE(location.group, first_active_group_); - std::optional<webtransport::StreamId> stream_id = - stream_map().GetStreamFor(index); - if (!stream_id.has_value()) { - return; - } - webtransport::Stream* raw_stream = - session_->session_->GetStreamById(*stream_id); - if (raw_stream == nullptr) { - return; - } - OutgoingSubgroupStream* stream = - absl::down_cast<OutgoingSubgroupStream*>(raw_stream->visitor()); - stream->Fin(location); -} - -void MoqtSession::PublishedSubscription::OnSubgroupAbandoned( - uint64_t group, uint64_t subgroup, - webtransport::StreamErrorCode error_code) { - if (session_->is_closing_) { - return; - } - if (!InWindow(group)) { - return; - } - DataStreamIndex index(group, subgroup); - if (reset_subgroups_.contains(index)) { - // This subgroup has already been reset, ignore. - return; - } - reset_subgroups_.insert(index); - QUICHE_DCHECK_GE(group, first_active_group_); - std::optional<webtransport::StreamId> stream_id = - stream_map().GetStreamFor(index); - if (!stream_id.has_value()) { - return; - } - webtransport::Stream* raw_stream = - session_->session_->GetStreamById(*stream_id); - if (raw_stream == nullptr) { - return; - } - raw_stream->ResetWithUserCode(error_code); -} - -void MoqtSession::PublishedSubscription::OnGroupAbandoned(uint64_t group_id) { - if (session_->is_closing_) { - return; - } - if (!InWindow(group_id)) { - // The group is not in the window, ignore. - return; - } - std::vector<webtransport::StreamId> streams = - stream_map().GetStreamsForGroup(group_id); - if (delivery_timeout().IsInfinite() && largest_sent_.has_value() && - largest_sent_->group <= group_id) { - session_->PublishIsDone(request_id_, PublishDoneCode::kTooFarBehind, ""); - // No class access below this line! - return; - } - for (webtransport::StreamId stream_id : streams) { - webtransport::Stream* raw_stream = - session_->session_->GetStreamById(stream_id); - if (raw_stream == nullptr) { - continue; - } - raw_stream->ResetWithUserCode(kResetCodeDeliveryTimeout); - // Sending the Reset will call the destructor for OutgoingSubgroupStream, - // which will erase it from the SendStreamMap. - } - first_active_group_ = std::max(first_active_group_, group_id + 1); - absl::erase_if(reset_subgroups_, [&](const DataStreamIndex& index) { - return index.group < first_active_group_; - }); -} - -std::vector<webtransport::StreamId> -MoqtSession::PublishedSubscription::GetAllStreams() const { - if (!lazily_initialized_stream_map_.has_value()) { - return {}; - } - return lazily_initialized_stream_map_->GetAllStreams(); -} - -webtransport::SendOrder MoqtSession::PublishedSubscription::GetSendOrder( - Location sequence, std::optional<uint64_t> subgroup, - MoqtPriority publisher_priority) const { - if (!parameters_.group_order.has_value()) { - QUICHE_BUG(GetSendOrder_no_delivery_order) - << "Can't compute send order without a group order."; - return 0; - } - if (!subgroup.has_value()) { - return SendOrderForDatagram( - parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority), - publisher_priority, sequence.group, sequence.object, - *parameters_.group_order); - } - return SendOrderForStream( - parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority), - publisher_priority, sequence.group, *subgroup, *parameters_.group_order); -} - -// Returns the highest send order in the subscription. -void MoqtSession::PublishedSubscription::AddQueuedOutgoingSubgroupStream( - const NewStreamParameters& parameters) { - std::optional<webtransport::SendOrder> start_send_order = - queued_outgoing_data_streams_.empty() - ? std::optional<webtransport::SendOrder>() - : queued_outgoing_data_streams_.rbegin()->first; - webtransport::SendOrder send_order = GetSendOrder( - Location(parameters.index.group, parameters.first_object), - parameters.index.subgroup, - parameters.publisher_priority.value_or(default_publisher_priority())); - // Zero out the subscriber priority bits, since these will be added when - // updating the session. - queued_outgoing_data_streams_.emplace( - UpdateSendOrderForSubscriberPriority(send_order, 0), parameters); - if (!start_send_order.has_value()) { - session_->UpdateQueuedSendOrder(request_id_, std::nullopt, send_order); - } else if (*start_send_order < send_order) { - session_->UpdateQueuedSendOrder( - request_id_, FinalizeSendOrder(*start_send_order), send_order); - } -} - -MoqtSession::NewStreamParameters -MoqtSession::PublishedSubscription::NextQueuedOutgoingSubgroupStream() { - QUICHE_DCHECK(!queued_outgoing_data_streams_.empty()); - if (queued_outgoing_data_streams_.empty()) { - QUICHE_BUG(NextQueuedOutgoingSubgroupStream_no_stream) - << "NextQueuedOutgoingSubgroupStream called when there are no streams " - "pending."; - return NewStreamParameters(0, 0, 0, 0); - } - auto it = queued_outgoing_data_streams_.rbegin(); - webtransport::SendOrder old_send_order = FinalizeSendOrder(it->first); - NewStreamParameters first_stream = it->second; - // converting a reverse iterator to an iterator involves incrementing it and - // then taking base(). - queued_outgoing_data_streams_.erase((++it).base()); - if (queued_outgoing_data_streams_.empty()) { - session_->UpdateQueuedSendOrder(request_id_, old_send_order, std::nullopt); - } else { - webtransport::SendOrder new_send_order = - FinalizeSendOrder(queued_outgoing_data_streams_.rbegin()->first); - if (old_send_order != new_send_order) { - session_->UpdateQueuedSendOrder(request_id_, old_send_order, - new_send_order); - } - } - return first_stream; -} - -void MoqtSession::PublishedSubscription::OnDataStreamCreated( - webtransport::StreamId id, DataStreamIndex start_sequence) { - ++streams_opened_; - stream_map().AddStream(start_sequence, id); -} -void MoqtSession::PublishedSubscription::OnDataStreamDestroyed( - DataStreamIndex end_sequence) { - stream_map().RemoveStream(end_sequence); -} - -void MoqtSession::PublishedSubscription::OnObjectSent(Location sequence) { - if (largest_sent_.has_value()) { - largest_sent_ = std::max(*largest_sent_, sequence); - } else { - largest_sent_ = sequence; - } - // TODO: send PUBLISH_DONE if the subscription is done. -} void MoqtSession::OnMalformedTrack(RemoteTrack* track) { if (!track->is_fetch()) { @@ -2332,44 +1867,6 @@ // hasn't opened yet. } -void MoqtSession::PublishedSubscription::SendDatagram(Location sequence) { - std::optional<PublishedObject> object = track_publisher_->GetCachedObject( - sequence.group, std::nullopt, sequence.object); - if (!object.has_value()) { - QUICHE_BUG(PublishedSubscription_SendDatagram_object_not_in_cache) - << "Got notification about an object that is not in the cache"; - return; - } - MoqtObject header; - header.track_alias = track_alias_; - header.group_id = object->metadata.location.group; - header.object_id = object->metadata.location.object; - header.publisher_priority = object->metadata.publisher_priority; - header.extension_headers = object->metadata.extensions; - header.object_status = object->metadata.status; - header.subgroup_id = std::nullopt; - header.payload_length = object->metadata.payload_length; - quiche::QuicheBuffer datagram = session_->framer_.SerializeObjectDatagram( - header, object->payload[0].AsStringView(), - default_publisher_priority_.value_or(kDefaultPublisherPriority)); - session_->session_->SendOrQueueDatagram(datagram.AsStringView()); - OnObjectSent(object->metadata.location); -} - -void MoqtSession::PublishedSubscription::ProcessObjectAck( - const MoqtObjectAck& message) { - session_->trace_recorder_.RecordObjectAck( - track_alias_, Location(message.group_id, message.object_id), - message.delta_from_deadline); - - if (monitoring_interface_ == nullptr) { - return; - } - monitoring_interface_->OnObjectAckReceived( - Location(message.group_id, message.object_id), - message.delta_from_deadline); -} - void MoqtSessionParameters::ToSetupParameters(SetupParameters& out) const { if (perspective == quic::Perspective::IS_CLIENT && !using_webtrans) { out.path = path;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 5c1998a..664004a 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -5,18 +5,15 @@ #ifndef QUICHE_QUIC_MOQT_MOQT_SESSION_H_ #define QUICHE_QUIC_MOQT_MOQT_SESSION_H_ -#include <algorithm> #include <cstdint> #include <memory> #include <optional> #include <string> #include <utility> -#include <vector> #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" #include "absl/container/btree_map.h" -#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -32,13 +29,12 @@ #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" -#include "quiche/quic/moqt/moqt_object.h" #include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/moqt_session_interface.h" -#include "quiche/quic/moqt/moqt_stream_map.h" +#include "quiche/quic/moqt/moqt_subscription.h" #include "quiche/quic/moqt/moqt_trace_recorder.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" @@ -49,7 +45,6 @@ #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_callbacks.h" #include "quiche/common/quiche_circular_deque.h" -#include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_weak_ptr.h" #include "quiche/web_transport/web_transport.h" @@ -62,27 +57,8 @@ inline constexpr quic::QuicTimeDelta kDefaultGoAwayTimeout = quic::QuicTime::Delta::FromSeconds(10); -struct SubscriptionWithQueuedStream { - webtransport::SendOrder send_order; - uint64_t subscription_id; - - auto operator<=>(const SubscriptionWithQueuedStream& other) const = default; -}; - -// MoqtPublishingMonitorInterface allows a publisher monitor the delivery -// progress for a single individual subscriber. -class MoqtPublishingMonitorInterface { - public: - virtual ~MoqtPublishingMonitorInterface() = default; - - virtual void OnObjectAckSupportKnown( - std::optional<quic::QuicTimeDelta> time_window) = 0; - virtual void OnNewObjectEnqueued(Location location) = 0; - virtual void OnObjectAckReceived(Location location, - quic::QuicTimeDelta delta_from_deadline) = 0; -}; - class QUICHE_EXPORT MoqtSession : public MoqtSessionInterface, + public SessionToPublisherInterface, public webtransport::SessionVisitor { public: MoqtSession(webtransport::Session* session, MoqtSessionParameters parameters, @@ -153,12 +129,28 @@ return weak_ptr_factory_.Create(); } + // SessionToPublisherInterface implementation. + bool alternate_delivery_timeout() const override { + return alternate_delivery_timeout_; + } + // If |old_priority| is nullopt, the subscription does not have any pending + // streams. If it has a value, |old_priority| is the old value to be replaced + // by |new_priority|. + void UpdateTrackPriority(uint64_t request_id, + std::optional<MoqtTrackPriority> old_priority, + MoqtTrackPriority new_priority) override; + quic::QuicAlarmFactory* alarm_factory() override { + return alarm_factory_.get(); + } + void PublishIsDone(uint64_t request_id) override; + webtransport::Session* session() override { + return is_closing_ ? nullptr : session_; + } + // Send a GOAWAY message to the peer. |new_session_uri| must be empty if // called by the client. void GoAway(absl::string_view new_session_uri); - webtransport::Session* session() { return session_; } - MoqtPublisher* publisher() { return publisher_; } void set_publisher(MoqtPublisher* publisher) { publisher_ = publisher; } bool support_object_acks() const { return parameters_.support_object_acks; } @@ -182,15 +174,6 @@ CleanUpState(); } - // Tells the session that the highest send order for pending streams in a - // subscription has changed. If |old_send_order| is nullopt, this is the - // first pending stream. If |new_send_order| is nullopt, the subscription - // has no pending streams anymore. - void UpdateQueuedSendOrder( - uint64_t request_id, - std::optional<webtransport::SendOrder> old_send_order, - std::optional<webtransport::SendOrder> new_send_order); - void GrantMoreRequests(uint64_t num_requests); void UseAlternateDeliveryTimeout() { alternate_delivery_timeout_ = true; } @@ -202,20 +185,6 @@ struct Empty {}; - struct NewStreamParameters { - DataStreamIndex index; - uint64_t first_object; - // nullopt if the default priority is used. - std::optional<MoqtPriority> publisher_priority; - - NewStreamParameters(uint64_t group, uint64_t subgroup, - uint64_t first_object, - std::optional<MoqtPriority> publisher_priority) - : index(group, subgroup), - first_object(first_object), - publisher_priority(publisher_priority) {} - }; - // A stream is open, but we don't know the type until we receive a message. class QUICHE_EXPORT UnknownBidiStream : public webtransport::StreamVisitor { public: @@ -367,184 +336,6 @@ std::string partial_object_; uint64_t bytes_received_this_object_ = 0; }; - // Represents a record for a single subscription to a local track that is - // being sent to the peer. - class PublishedSubscription : public MoqtObjectListener, - public SubscriptionPublisherInterface { - public: - PublishedSubscription(MoqtSession* session, - std::shared_ptr<MoqtTrackPublisher> track_publisher, - const MoqtSubscribe& subscribe, uint64_t track_alias, - MoqtPublishingMonitorInterface* monitoring_interface); - // TODO(martinduke): Immediately reset all the streams. - ~PublishedSubscription(); - - PublishedSubscription(const PublishedSubscription&) = delete; - PublishedSubscription(PublishedSubscription&&) = delete; - PublishedSubscription& operator=(const PublishedSubscription&) = delete; - PublishedSubscription& operator=(PublishedSubscription&&) = delete; - - uint64_t request_id() const { return request_id_; } - MoqtTrackPublisher& publisher() { return *track_publisher_; } - std::shared_ptr<MoqtTrackPublisher> publisher_shared_ptr() { - return track_publisher_; - } - uint64_t track_alias() const { return track_alias_; } - MessageParameters& parameters() { return parameters_; } - std::optional<Location> largest_sent() const { return largest_sent_; } - - // MoqtObjectListener implementation. - void OnSubscribeAccepted() override; - void OnSubscribeRejected(MoqtRequestErrorInfo info) override; - // This is only called for objects that have just arrived. - void OnNewObjectAvailable(Location location, - std::optional<uint64_t> subgroup, - MoqtPriority publisher_priority) override; - void OnTrackPublisherGone() override; - void OnNewFinAvailable(Location location, uint64_t subgroup) override; - // also a part of SubscriptionPublisherInterface. - void OnSubgroupAbandoned(uint64_t group, uint64_t subgroup, - webtransport::StreamErrorCode error_code) override; - void OnGroupAbandoned(uint64_t group_id) override; - void ProcessObjectAck(const MoqtObjectAck& message); - - // SubscriptionPublisherInterface implementation. - bool InWindow(Location location) override { - return parameters_.forward() && - (!parameters_.subscription_filter.has_value() || - (parameters_.subscription_filter->WindowKnown() && - parameters_.subscription_filter->InWindow(location))); - }; - bool alternate_delivery_timeout() override { - return session_->alternate_delivery_timeout_; - } - const quic::QuicClock* clock() override { - return session_->callbacks_.clock; - } - quic::QuicTimeDelta delivery_timeout() override { - return std::min( - parameters_.delivery_timeout.value_or(kDefaultDeliveryTimeout), - publisher_delivery_timeout_.value_or(kDefaultDeliveryTimeout)); - } - quic::QuicAlarmFactory* alarm_factory() override { - return session_->alarm_factory_.get(); - } - void OnObjectSent(Location sequence) override; - void OnStreamTimeout(DataStreamIndex index) override { - reset_subgroups_.insert(index); - if (session_->alternate_delivery_timeout_) { - first_active_group_ = std::max(first_active_group_, index.group + 1); - } - } - // OnSubgroupAbandoned() is declared above with MoqtObjectListener. - void OnDataStreamDestroyed(DataStreamIndex) override; - - // Updates the window and other properties of the subscription in question. - void Update(const MessageParameters& parameters); - // Checks if a given Location or Group should be forwarded to the - // subscriber. - bool InWindow(uint64_t group) { - return parameters_.forward() && - (!parameters_.subscription_filter.has_value() || - (parameters_.subscription_filter->WindowKnown() && - parameters_.subscription_filter->InWindow(group))); - } - - void OnDataStreamCreated(webtransport::StreamId id, - DataStreamIndex start_sequence); - - std::vector<webtransport::StreamId> GetAllStreams() const; - - // If subgroup is nullopt, returns the send order for a datagram. - webtransport::SendOrder GetSendOrder(Location sequence, - std::optional<uint64_t> subgroup, - MoqtPriority publisher_priority) const; - - void AddQueuedOutgoingSubgroupStream(const NewStreamParameters& parameters); - // Pops the pending outgoing data stream, with the highest send order. - // The session keeps track of which subscribes have pending streams. This - // function will trigger a QUICHE_DCHECK if called when there are no pending - // streams. - NewStreamParameters NextQueuedOutgoingSubgroupStream(); - - void set_subscriber_delivery_timeout(quic::QuicTimeDelta timeout) { - parameters_.delivery_timeout = timeout; - } - void set_publisher_delivery_timeout(quic::QuicTimeDelta timeout) { - publisher_delivery_timeout_ = timeout; - } - - uint64_t first_active_group() const { return first_active_group_; } - - absl::flat_hash_set<DataStreamIndex>& reset_subgroups() { - return reset_subgroups_; - } - - uint64_t streams_opened() const { return streams_opened_; } - - bool can_have_joining_fetch() const { return can_have_joining_fetch_; } - - MoqtPriority default_publisher_priority() const { - return default_publisher_priority_.value_or(kDefaultPublisherPriority); - } - - bool established() const { return established_; } - - quiche::QuicheWeakPtr<SubscriptionPublisherInterface> GetWeakPtr() { - return weak_ptr_factory_.Create(); - } - - private: - friend class test::MoqtSessionPeer; - SendStreamMap& stream_map(); - quic::Perspective perspective() const { - return session_->parameters_.perspective; - } - - void SendDatagram(Location sequence); - webtransport::SendOrder FinalizeSendOrder( - webtransport::SendOrder send_order) { - return UpdateSendOrderForSubscriberPriority( - send_order, - parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority)); - } - - MoqtSession* session_; - std::shared_ptr<MoqtTrackPublisher> track_publisher_; - uint64_t request_id_; - // Subscription is in the ESTABLISHED state. - bool established_ = false; - bool can_have_joining_fetch_ = false; - const uint64_t track_alias_; - // These are (mostly) the parameters from the SUBSCRIBE message. However, - // group_order and largest_object may be updated by SUBSCRIBE_OK because - // have no effect in a future REQUEST_UPDATE message. - MessageParameters parameters_; - std::optional<quic::QuicTimeDelta> publisher_delivery_timeout_; - std::optional<MoqtPriority> default_publisher_priority_; - uint64_t streams_opened_ = 0; - - // The subscription will ignore any groups with a lower ID, so it doesn't - // need to track reset subgroups. - uint64_t first_active_group_ = 0; - // If a stream has been reset due to delivery timeout, do not open a new - // stream if more object arrive for it. - absl::flat_hash_set<DataStreamIndex> reset_subgroups_; - - MoqtPublishingMonitorInterface* monitoring_interface_; - // Largest sequence number ever sent via this subscription. - std::optional<Location> largest_sent_; - // Should be almost always accessed via `stream_map()`. - std::optional<SendStreamMap> lazily_initialized_stream_map_; - // Store the send order of queued outgoing data streams. Use a - // subscriber_priority_ of zero to avoid having to update it, and call - // FinalizeSendOrder() whenever delivering it to the MoqtSession. - absl::btree_multimap<webtransport::SendOrder, NewStreamParameters> - queued_outgoing_data_streams_; - // Must be last. - quiche::QuicheWeakPtrFactory<SubscriptionPublisherInterface> - weak_ptr_factory_; - }; class QUICHE_EXPORT PublishedFetch { public: @@ -652,10 +443,6 @@ SubscribeRemoteTrack* subscribe_; }; - // Private members of MoqtSession. - // Returns true if PUBLISH_DONE was sent. - bool PublishIsDone(uint64_t request_id, PublishDoneCode code, - absl::string_view error_reason); void MaybeDestroySubscription(SubscribeRemoteTrack* subscribe); void DestroySubscription(SubscribeRemoteTrack* subscribe); @@ -665,14 +452,6 @@ // is present. void SendControlMessage(quiche::QuicheBuffer message); - // Opens a new data stream, or queues it if the session is flow control - // blocked. - webtransport::Stream* OpenOrQueueDataStream( - uint64_t subscription_id, const NewStreamParameters& parameters); - // Same as above, except the session is required to be not flow control - // blocked. - webtransport::Stream* OpenDataStream(PublishedSubscription& subscription, - const NewStreamParameters& parameters); // Returns false if creation failed. [[nodiscard]] bool OpenDataStream(PublishedFetch* fetch, webtransport::SendOrder send_order); @@ -769,11 +548,12 @@ // can be subscribed to via this connection. Must outlive this object. MoqtPublisher* publisher_; // Subscriptions for local tracks by the remote peer, indexed by subscribe ID. - absl::flat_hash_map<uint64_t, std::unique_ptr<PublishedSubscription>> + absl::flat_hash_map<uint64_t, std::unique_ptr<SubscriptionPublisher>> published_subscriptions_; - // Keeps track of all subscribe IDs that have queued outgoing data streams. - absl::btree_set<SubscriptionWithQueuedStream> - subscribes_with_queued_outgoing_data_streams_; + // Keeps track of all request IDs that have queued outgoing data streams. The + // first element is the highest priority (lowest integer). + absl::btree_multimap<MoqtTrackPriority, uint64_t> + subscriptions_with_queued_streams_; // This is only used to check for track_alias collisions. absl::flat_hash_set<uint64_t> used_track_aliases_; uint64_t next_local_track_alias_ = 0;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index ca4d6bd..8f1f0f4 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -1302,6 +1302,10 @@ control_stream->ReceiveMessage(subscribe_ok); } +// TODO(martinduke): Most of these test cases no longer need to be in +// MoqtSessionTest. Find any useful functionality and put it in +// SubscriptionPublisherTest or OutgoingSubgroupStreamTest. + TEST_F(MoqtSessionTest, CreateOutgoingSubgroupStreamAndSend) { FullTrackName ftn("foo", "bar"); auto track = @@ -1353,8 +1357,9 @@ subscription->OnNewObjectAvailable(Location(5, 0), 0, 127); EXPECT_TRUE(correct_message); EXPECT_FALSE(fin); - EXPECT_EQ(MoqtSessionPeer::LargestSentForSubscription(&session_, 0), - Location(5, 0)); + std::optional<Location> largest_sent = + MoqtSessionPeer::LargestSentForSubscription(&session_, 0); + EXPECT_TRUE(largest_sent.has_value() && *largest_sent == Location(5, 0)); } TEST_F(MoqtSessionTest, FinDataStreamFromCache) { @@ -1486,6 +1491,9 @@ TEST_F(MoqtSessionTest, GroupAbandonedNoDeliveryTimeout) { FullTrackName ftn("foo", "bar"); + webtransport::test::MockStream control_stream; + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &control_stream); auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, Location(4, 2)); MoqtObjectListener* subscription = @@ -1544,9 +1552,6 @@ /*error_reason=*/"", }; EXPECT_CALL(mock_stream_, ResetWithUserCode(kResetCodeCancelled)); - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(control_stream, Writev(SerializedControlMessage(expected_publish_done), _)); subscription->OnGroupAbandoned(5); @@ -1554,6 +1559,10 @@ TEST_F(MoqtSessionTest, GroupAbandonedDeliveryTimeout) { FullTrackName ftn("foo", "bar"); + webtransport::test::MockStream control_stream; + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &control_stream); + ; auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, Location(4, 2)); MoqtObjectListener* subscription = @@ -1612,9 +1621,6 @@ /*error_reason=*/"", }; EXPECT_CALL(mock_stream_, ResetWithUserCode(kResetCodeCancelled)); - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(control_stream, Writev(SerializedControlMessage(expected_publish_done), _)); subscription->OnGroupAbandoned(5); @@ -1904,7 +1910,7 @@ // Unblock the session, and cause the queued stream to be sent. EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) - .WillOnce(Return(true)); + .WillRepeatedly(Return(true)); bool fin = false; EXPECT_CALL(mock_stream_, CanWrite()).WillRepeatedly([&] { return !fin; }); EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) @@ -2436,6 +2442,7 @@ // Allow one stream to be opened. It will be group 0, subscription 0. EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) .WillOnce(Return(true)) + .WillOnce(Return(true)) .WillOnce(Return(false)); webtransport::test::MockStream mock_stream0; EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) @@ -2473,6 +2480,7 @@ MoqtSessionPeer::UpdateSubscriberPriority(&session_, 1, 0); EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) .WillOnce(Return(true)) + .WillOnce(Return(true)) .WillRepeatedly(Return(false)); webtransport::test::MockStream mock_stream1; EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) @@ -2515,7 +2523,8 @@ webtransport::test::MockStream& data_stream, std::unique_ptr<webtransport::StreamVisitor>& stream_visitor) { EXPECT_CALL(session, CanOpenNextOutgoingUnidirectionalStream) - .WillOnce(Return(true)); + .WillOnce(Return(true)) + .WillRepeatedly(Return(false)); EXPECT_CALL(session, OpenOutgoingUnidirectionalStream()) .WillOnce(Return(&data_stream)); EXPECT_CALL(data_stream, SetVisitor)
diff --git a/quiche/quic/moqt/moqt_subscription.cc b/quiche/quic/moqt/moqt_subscription.cc new file mode 100644 index 0000000..29b7e7f --- /dev/null +++ b/quiche/quic/moqt/moqt_subscription.cc
@@ -0,0 +1,438 @@ +// Copyright (c) 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_subscription.h" + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <optional> +#include <utility> +#include <vector> + +#include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_object.h" +#include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_publisher.h" +#include "quiche/quic/moqt/moqt_stream_map.h" +#include "quiche/quic/moqt/moqt_trace_recorder.h" +#include "quiche/quic/moqt/moqt_types.h" +#include "quiche/quic/moqt/moqt_uni_stream.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt { + +SubscriptionPublisher::SubscriptionPublisher( + MoqtFramer framer, std::shared_ptr<MoqtTrackPublisher> track_publisher, + MoqtBidiStreamBase* absl_nonnull bidi_stream, uint64_t request_id, + uint64_t track_alias, const MessageParameters& parameters, + SessionToPublisherInterface* absl_nonnull visitor, + MoqtPublishingMonitorInterface* monitoring_interface, + const quic::QuicClock* absl_nonnull clock, + MoqtTraceRecorder& trace_recorder) + : track_publisher_(track_publisher), + bidi_stream_(bidi_stream), + visitor_(visitor), + request_id_(request_id), + track_alias_(track_alias), + framer_(framer), + trace_recorder_(trace_recorder), + parameters_(parameters), + monitoring_interface_(monitoring_interface), + clock_(clock), + weak_ptr_factory_(this) { + if (monitoring_interface_ != nullptr) { + monitoring_interface_->OnObjectAckSupportKnown(parameters.oack_window_size); + } + QUIC_DLOG(INFO) << "Created subscription for " + << track_publisher_->GetTrackName(); + // TODO(martinduke): Handle NEW_GROUP_REQUEST +} + +SubscriptionPublisher::~SubscriptionPublisher() { + track_publisher_->RemoveObjectListener(this); + // Reset all streams. + for (const webtransport::StreamId stream_id : stream_map_.GetAllStreams()) { + webtransport::Stream* stream = GetStreamById(stream_id); + if (stream != nullptr) { + stream->ResetWithUserCode(kResetCodeCancelled); + } + } +} + +void SubscriptionPublisher::Update(const MessageParameters& parameters) { + // TODO(martinduke): If there are auth tokens, this probably has to go to the + // application. + MoqtPriority old_priority = + parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority); + parameters_.Update(parameters); + if (parameters.subscriber_priority.has_value()) { // priority changed. + MoqtPriority new_priority = *parameters.subscriber_priority; + // Reprioritize all active streams. + for (const webtransport::StreamId stream_id : stream_map_.GetAllStreams()) { + webtransport::Stream* stream = GetStreamById(stream_id); + if (stream == nullptr) { + continue; + } + OutgoingSubgroupStream* outgoing_stream = + absl::down_cast<OutgoingSubgroupStream*>(stream->visitor()); + outgoing_stream->UpdatePriority(new_priority); + } + if (pending_streams_.empty()) { + return; + } + // Tell the session that pending stream priority has changed. + MoqtPriority publisher_priority = + pending_streams_.rbegin()->second.publisher_priority.value_or( + track_publisher_->extensions().default_publisher_priority()); + MoqtTrackPriority old_track_priority = {old_priority, publisher_priority}; + visitor_->UpdateTrackPriority( + request_id_, old_track_priority, + MoqtTrackPriority{new_priority, publisher_priority}); + // Don't bother to update all the pending stream send orders. + } +} + +void SubscriptionPublisher::OnSubscribeAccepted() { + QUICHE_DCHECK(!established_); + established_ = true; + parameters_.largest_object = track_publisher_->largest_location(); + if (parameters_.subscription_filter.has_value()) { + parameters_.subscription_filter->OnLargestObject( + parameters_.largest_object); + } + MoqtSubscribeOk subscribe_ok; + subscribe_ok.request_id = request_id_; + subscribe_ok.track_alias = track_alias_; + subscribe_ok.parameters.expires = track_publisher_->expiration(); + subscribe_ok.parameters.largest_object = parameters_.largest_object; + subscribe_ok.extensions = track_publisher_->extensions(); + if (!parameters_.group_order.has_value()) { + parameters_.group_order = + subscribe_ok.extensions.default_publisher_group_order(); + } + // TODO(martinduke): Support sending DELIVERY_TIMEOUT parameter as the + // publisher. + default_publisher_priority_ = + subscribe_ok.extensions.default_publisher_priority(); + bidi_stream_->SendOrBufferMessageOrFatal( + framer_.SerializeSubscribeOk(subscribe_ok)); + // TODO(martinduke): If we buffer objects that arrived previously, the arrival + // of the track alias disambiguates what subscription they belong to. Send + // them. +} + +void SubscriptionPublisher::OnSubscribeRejected(MoqtRequestErrorInfo info) { + bidi_stream_->CheckStatus(bidi_stream_->SendRequestError(request_id_, info)); + visitor_->PublishIsDone(request_id_); + // No class access below this line! +} + +void SubscriptionPublisher::OnNewObjectAvailable( + Location location, std::optional<uint64_t> subgroup, + MoqtPriority publisher_priority) { + if (!InWindow(location)) { + return; + } + + if (monitoring_interface_ != nullptr) { + // Notify the monitoring interface about all newly published normal objects. + // Objects with other statuses are not guaranteed to be acknowledged, thus + // passing them into the monitoring interface can lead to confusion. + std::optional<PublishedObject> object = track_publisher_->GetCachedObject( + location.group, subgroup, location.object); + QUICHE_DCHECK(object.has_value()) + << "Object " << absl::StrCat(location) << " on track " + << track_publisher_->GetTrackName().ToString() + << " does not exist, despite OnNewObjectAvailable being called"; + if (object.has_value() && object->metadata.location == location && + object->metadata.status == MoqtObjectStatus::kNormal) { + monitoring_interface_->OnNewObjectEnqueued(location); + } + } + + // TODO(vasilvv): This currently sends UINT64_MAX for datagram subgroups. + // Maybe do something more satisfactory? + trace_recorder_.RecordNewObjectAvaliable( + track_alias_, *track_publisher_, location, subgroup.value_or(UINT64_MAX), + publisher_priority); + + std::optional<webtransport::StreamId> stream_id; + if (subgroup.has_value()) { + DataStreamIndex index(location.group, *subgroup); + if (reset_subgroups_.contains(index)) { + // This subgroup has already been reset, ignore. + return; + } + stream_id = stream_map_.GetStreamFor(index); + } + if (visitor_->alternate_delivery_timeout() && + !delivery_timeout().IsInfinite() && largest_sent_.has_value() && + location.group >= largest_sent_->group) { + // Start the delivery timeout timer on all previous groups. + for (uint64_t group = first_active_group_; group < location.group; + ++group) { + for (webtransport::StreamId stream_to_update : + stream_map_.GetStreamsForGroup(group)) { + webtransport::Stream* raw_stream = GetStreamById(stream_to_update); + if (raw_stream == nullptr) { + continue; + } + OutgoingSubgroupStream* stream = + absl::down_cast<OutgoingSubgroupStream*>(raw_stream->visitor()); + stream->CreateAndSetAlarm(clock_->ApproximateNow() + + delivery_timeout()); + } + } + } + QUICHE_DCHECK_GE(location.group, first_active_group_); + if (!subgroup.has_value()) { + SendDatagram(location); + return; + } + + webtransport::Stream* raw_stream = nullptr; + if (stream_id.has_value()) { + raw_stream = GetStreamById(*stream_id); + if (raw_stream != nullptr) { + raw_stream->visitor()->OnCanWrite(); + } + return; + } + NewDataStreamParameters parameters( + location.group, *subgroup, location.object, + publisher_priority == default_publisher_priority_ + ? std::nullopt + : std::make_optional(publisher_priority)); + raw_stream = OpenDataStream(parameters); + if (raw_stream == nullptr) { + StreamRank rank = StreamRankFor(parameters); + if (pending_streams_.empty() || rank > pending_streams_.rbegin()->first) { + visitor_->UpdateTrackPriority( + request_id_, + /*old_priority=*/pending_streams_.empty() + ? std::optional<MoqtTrackPriority>() + : std::make_optional( + MoqtTrackPriority{subscriber_priority(), + pending_streams_.rbegin() + ->second.publisher_priority.value_or( + default_publisher_priority())}), + MoqtTrackPriority{subscriber_priority(), publisher_priority}); + } + pending_streams_.emplace(rank, parameters); + } +} + +void SubscriptionPublisher::OnTrackPublisherGone() { + PublishIsDone(request_id_, PublishDoneCode::kGoingAway, "Publisher is gone"); +} + +// TODO(martinduke): Revise to check if the last object has been delivered. +void SubscriptionPublisher::OnNewFinAvailable(Location location, + uint64_t subgroup) { + if (!InWindow(location.group)) { + return; + } + DataStreamIndex index(location.group, subgroup); + std::optional<webtransport::StreamId> stream_id = + stream_map_.GetStreamFor(index); + if (!stream_id.has_value()) { + return; + } + webtransport::Stream* raw_stream = GetStreamById(*stream_id); + if (raw_stream == nullptr) { + return; + } + OutgoingSubgroupStream* stream = + absl::down_cast<OutgoingSubgroupStream*>(raw_stream->visitor()); + stream->Fin(location); +} + +void SubscriptionPublisher::OnSubgroupAbandoned( + uint64_t group, uint64_t subgroup, + webtransport::StreamErrorCode error_code) { + if (!InWindow(group)) { + return; + } + DataStreamIndex index(group, subgroup); + if (reset_subgroups_.contains(index)) { + // This subgroup has already been reset, ignore. + return; + } + reset_subgroups_.insert(index); + QUICHE_DCHECK_GE(group, first_active_group_); + std::optional<webtransport::StreamId> stream_id = + stream_map_.GetStreamFor(index); + if (!stream_id.has_value()) { + return; + } + webtransport::Stream* raw_stream = GetStreamById(*stream_id); + if (raw_stream == nullptr) { + return; + } + raw_stream->ResetWithUserCode(error_code); +} + +void SubscriptionPublisher::OnGroupAbandoned(uint64_t group_id) { + if (!InWindow(group_id)) { + // The group is not in the window, ignore. + return; + } + std::vector<webtransport::StreamId> streams = + stream_map_.GetStreamsForGroup(group_id); + if (delivery_timeout().IsInfinite() && largest_sent_.has_value() && + largest_sent_->group <= group_id) { + PublishIsDone(request_id_, PublishDoneCode::kTooFarBehind, ""); + // No class access below this line! + return; + } + for (webtransport::StreamId stream_id : streams) { + webtransport::Stream* raw_stream = GetStreamById(stream_id); + if (raw_stream == nullptr) { + continue; + } + raw_stream->ResetWithUserCode(kResetCodeDeliveryTimeout); + // Sending the Reset will call the destructor for OutgoingSubgroupStream, + // which will erase it from the SendStreamMap. + } + first_active_group_ = std::max(first_active_group_, group_id + 1); + absl::erase_if(reset_subgroups_, [&](const DataStreamIndex& index) { + return index.group < first_active_group_; + }); +} + +void SubscriptionPublisher::SendDatagram(Location sequence) { + std::optional<PublishedObject> object = track_publisher_->GetCachedObject( + sequence.group, std::nullopt, sequence.object); + if (!object.has_value()) { + QUICHE_BUG(PublishedSubscription_SendDatagram_object_not_in_cache) + << "Got notification about an object that is not in the cache"; + return; + } + MoqtObject header; + header.track_alias = track_alias_; + header.group_id = object->metadata.location.group; + header.object_id = object->metadata.location.object; + header.publisher_priority = object->metadata.publisher_priority; + header.extension_headers = object->metadata.extensions; + header.object_status = object->metadata.status; + header.subgroup_id = std::nullopt; + header.payload_length = object->metadata.payload_length; + QUICHE_BUG_IF(SubscriptionPublisher_SendDatagram_partial_payload, + object->payload.size() > 1) + << "Datagram is split into multiple slices"; + quiche::QuicheBuffer datagram = framer_.SerializeObjectDatagram( + header, object->payload[0].AsStringView(), + default_publisher_priority_.value_or(kDefaultPublisherPriority)); + if (visitor_->session() == nullptr) { + return; + } + visitor_->session()->SendOrQueueDatagram(datagram.AsStringView()); + OnObjectSent(object->metadata.location); +} + +void SubscriptionPublisher::ProcessObjectAck(const MoqtObjectAck& message) { + trace_recorder_.RecordObjectAck(track_alias_, + Location(message.group_id, message.object_id), + message.delta_from_deadline); + + if (monitoring_interface_ == nullptr) { + return; + } + monitoring_interface_->OnObjectAckReceived( + Location(message.group_id, message.object_id), + message.delta_from_deadline); +} + +webtransport::Stream* absl_nullable SubscriptionPublisher::OpenDataStream( + const NewDataStreamParameters& parameters) { + if (visitor_->session() == nullptr || + !visitor_->session()->CanOpenNextOutgoingUnidirectionalStream()) { + return nullptr; + } + webtransport::Stream* new_stream = + visitor_->session()->OpenOutgoingUnidirectionalStream(); + if (new_stream == nullptr) { + return nullptr; + } + stream_map_.AddStream(parameters.index, new_stream->GetStreamId()); + new_stream->SetVisitor(std::make_unique<OutgoingSubgroupStream>( + framer_, new_stream, parameters.index, parameters.first_object, + weak_ptr_factory_.Create(), track_publisher_, + StreamPriorityFor(parameters), track_alias_, &trace_recorder_)); + ++streams_opened_; + new_stream->visitor()->OnCanWrite(); + return new_stream; +} + +void SubscriptionPublisher::PublishIsDone(uint64_t request_id, + PublishDoneCode code, + absl::string_view error_reason) { + MoqtPublishDone publish_done; + publish_done.request_id = request_id; + publish_done.status_code = code; + publish_done.stream_count = streams_opened_; + publish_done.error_reason = error_reason; + // TODO(martinduke): It is technically correct, but not good, to simply + // reset all the streams in order to send PUBLISH_DONE. It's better to wait + // until streams FIN naturally, where possible. + QUICHE_DLOG(INFO) << "Sending PUBLISH_DONE message for " + << track_publisher_->GetTrackName(); + bidi_stream_->SendOrBufferMessageOrFatal( + framer_.SerializePublishDone(publish_done)); + visitor_->PublishIsDone(request_id_); + // No class access below this line! +} + +void SubscriptionPublisher::OnDataStreamDestroyed( + DataStreamIndex end_sequence) { + stream_map_.RemoveStream(end_sequence); +} + +void SubscriptionPublisher::OnCanCreateNewUniStream() { + auto it = pending_streams_.rbegin(); + while (it != pending_streams_.rend() && + (it->second.index.group < first_active_group_ || + reset_subgroups_.contains(it->second.index))) { + pending_streams_.erase(--(it.base())); + it = pending_streams_.rbegin(); + } + if (it == pending_streams_.rend()) { + return; + } + if (OpenDataStream(it->second) == nullptr) { + return; + } + pending_streams_.erase(--(it.base())); + if (!pending_streams_.empty()) { + visitor_->UpdateTrackPriority( + request_id_, std::nullopt, + MoqtTrackPriority{ + subscriber_priority(), + pending_streams_.rbegin()->second.publisher_priority.value_or( + default_publisher_priority())}); + } +} + +void SubscriptionPublisher::OnObjectSent(Location sequence) { + if (largest_sent_.has_value()) { + largest_sent_ = std::max(*largest_sent_, sequence); + } else { + largest_sent_ = sequence; + } + // TODO: send PUBLISH_DONE if the subscription is done. +} + +} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_subscription.h b/quiche/quic/moqt/moqt_subscription.h new file mode 100644 index 0000000..37dee22 --- /dev/null +++ b/quiche/quic/moqt/moqt_subscription.h
@@ -0,0 +1,266 @@ +// Copyright (c) 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_ +#define QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_ + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <optional> + +#include "absl/base/nullability.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_publisher.h" +#include "quiche/quic/moqt/moqt_stream_map.h" +#include "quiche/quic/moqt/moqt_trace_recorder.h" +#include "quiche/quic/moqt/moqt_types.h" +#include "quiche/quic/moqt/moqt_uni_stream.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_weak_ptr.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt { + +namespace test { +class SubscriptionPublisherPeer; +} + +// This is the part of the send order useful for ranking streams within the +// subscription. It sets the subscriber_priority to kDefaultSubscriberPriority +// to avoid constantly updating all pending streams. +using StreamRank = webtransport::SendOrder; + +struct NewDataStreamParameters { + DataStreamIndex index; + uint64_t first_object; + // nullopt if the default priority is used. + std::optional<MoqtPriority> publisher_priority; + + NewDataStreamParameters(uint64_t group, uint64_t subgroup, + uint64_t first_object, + std::optional<MoqtPriority> publisher_priority) + : index(group, subgroup), + first_object(first_object), + publisher_priority(publisher_priority) {} +}; + +// MoqtPublishingMonitorInterface allows a publisher monitor the delivery +// progress for a single individual subscriber. +class MoqtPublishingMonitorInterface { + public: + virtual ~MoqtPublishingMonitorInterface() = default; + + virtual void OnObjectAckSupportKnown( + std::optional<quic::QuicTimeDelta> time_window) = 0; + virtual void OnNewObjectEnqueued(Location location) = 0; + virtual void OnObjectAckReceived(Location location, + quic::QuicTimeDelta delta_from_deadline) = 0; +}; + +// Allows SubscriptionPublisher to get data from the session. +class QUICHE_EXPORT SessionToPublisherInterface { + public: + virtual ~SessionToPublisherInterface() = default; + virtual bool alternate_delivery_timeout() const = 0; + // If |old_priority| is nullopt, the subscription does not have any pending + // streams. If it has a value, |old_priority| is the old value to be replaced + // by |new_priority|. + virtual void UpdateTrackPriority( + uint64_t request_id, std::optional<MoqtTrackPriority> old_priority, + MoqtTrackPriority new_priority) = 0; + virtual quic::QuicAlarmFactory* alarm_factory() = 0; + // Destroy any state associated with the subscription. It is OK destroy + // SubscriptionPublisher in this method. + virtual void PublishIsDone(uint64_t request_id) = 0; + // Returns nullptr if MoqtSession is closing. + virtual webtransport::Session* session() = 0; +}; + +// State for delivery of objects via a subscription, whether initiated by a +// SUBSCRIBE or PUBLISH. +class SubscriptionPublisher : public MoqtObjectListener, + public SubscriptionPublisherInterface { + public: + SubscriptionPublisher(MoqtFramer framer, + std::shared_ptr<MoqtTrackPublisher> track_publisher, + MoqtBidiStreamBase* absl_nonnull bidi_stream, + uint64_t request_id, uint64_t track_alias, + const MessageParameters& parameters, + SessionToPublisherInterface* absl_nonnull visitor, + MoqtPublishingMonitorInterface* monitoring_interface, + const quic::QuicClock* absl_nonnull clock, + MoqtTraceRecorder& trace_recorder); + ~SubscriptionPublisher(); + + SubscriptionPublisher(const SubscriptionPublisher&) = delete; + SubscriptionPublisher(SubscriptionPublisher&&) = delete; + SubscriptionPublisher& operator=(const SubscriptionPublisher&) = delete; + SubscriptionPublisher& operator=(SubscriptionPublisher&&) = delete; + + uint64_t request_id() const { return request_id_; } + MoqtTrackPublisher& publisher() { return *track_publisher_; } + uint64_t track_alias() const { return track_alias_; } + MessageParameters& parameters() { return parameters_; } + + // MoqtObjectListener implementation. + void OnSubscribeAccepted() override; + void OnSubscribeRejected(MoqtRequestErrorInfo info) override; + // This is only called for objects that have just arrived. + void OnNewObjectAvailable(Location location, std::optional<uint64_t> subgroup, + MoqtPriority publisher_priority) override; + void OnTrackPublisherGone() override; + void OnNewFinAvailable(Location location, uint64_t subgroup) override; + // also a part of SubscriptionPublisherInterface. + void OnSubgroupAbandoned(uint64_t group, uint64_t subgroup, + webtransport::StreamErrorCode error_code) override; + void OnGroupAbandoned(uint64_t group_id) override; + void ProcessObjectAck(const MoqtObjectAck& message); + + // SubscriptionPublisherInterface implementation. + bool InWindow(Location location) override { + return parameters_.forward() && + (!parameters_.subscription_filter.has_value() || + (parameters_.subscription_filter->WindowKnown() && + parameters_.subscription_filter->InWindow(location))); + }; + bool alternate_delivery_timeout() override { + return visitor_->alternate_delivery_timeout(); + } + const quic::QuicClock* clock() override { return clock_; } + quic::QuicTimeDelta delivery_timeout() override { + return std::min( + parameters_.delivery_timeout.value_or(kDefaultDeliveryTimeout), + publisher_delivery_timeout_.value_or(kDefaultDeliveryTimeout)); + } + quic::QuicAlarmFactory* alarm_factory() override { + return visitor_->alarm_factory(); + } + void OnObjectSent(Location sequence) override; + void OnStreamTimeout(DataStreamIndex index) override { + reset_subgroups_.insert(index); + if (visitor_->alternate_delivery_timeout()) { + first_active_group_ = std::max(first_active_group_, index.group + 1); + } + } + // OnSubgroupAbandoned() is declared above with MoqtObjectListener. + void OnDataStreamDestroyed(DataStreamIndex) override; + + // Called by MoqtSession when this subscription can open a new stream. + void OnCanCreateNewUniStream(); + + // Called when the parameters_ needs an update. + void Update(const MessageParameters& parameters); + + bool can_have_joining_fetch() const { return parameters_.forward(); } + + bool established() const { return established_; } + + private: + friend class test::SubscriptionPublisherPeer; + + MoqtPriority default_publisher_priority() const { + return default_publisher_priority_.value_or(kDefaultPublisherPriority); + } + + // Checks if a given Location or Group should be forwarded to the + // subscriber. + bool InWindow(uint64_t group) { + return parameters_.forward() && + (!parameters_.subscription_filter.has_value() || + (parameters_.subscription_filter->WindowKnown() && + parameters_.subscription_filter->InWindow(group))); + } + + void SendDatagram(Location sequence); + + // Returns the rank of the stream with respect to other streams in the + // subscription. Higher numbers are higher priority. + StreamRank StreamRankFor(const NewDataStreamParameters& parameters) const { + return SendOrderForStream( + kDefaultSubscriberPriority, + parameters.publisher_priority.value_or(default_publisher_priority()), + parameters.index.group, parameters.index.subgroup, + *parameters_.group_order); + } + + // Returns the stream priority for use at the moment of stream creation. + webtransport::StreamPriority StreamPriorityFor( + const NewDataStreamParameters& parameters) const { + return webtransport::StreamPriority{ + kMoqtSendGroupId, + SendOrderForStream(subscriber_priority(), + parameters.publisher_priority.value_or( + default_publisher_priority()), + parameters.index.group, parameters.index.subgroup, + *parameters_.group_order)}; + } + + webtransport::Stream* absl_nullable OpenDataStream( + const NewDataStreamParameters& parameters); + + void PublishIsDone(uint64_t request_id, PublishDoneCode code, + absl::string_view reason); + + MoqtPriority subscriber_priority() const { + return parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority); + } + + webtransport::Stream* GetStreamById(webtransport::StreamId stream_id) { + return visitor_->session() == nullptr + ? nullptr + : visitor_->session()->GetStreamById(stream_id); + } + + std::shared_ptr<MoqtTrackPublisher> track_publisher_; + MoqtBidiStreamBase* absl_nonnull bidi_stream_; + SessionToPublisherInterface* absl_nonnull visitor_; + uint64_t request_id_; + // Subscription is in the ESTABLISHED state. + bool established_ = false; + const uint64_t track_alias_; + MoqtFramer framer_; + MoqtTraceRecorder& trace_recorder_; + // These are (mostly) the parameters from the SUBSCRIBE message. However, + // group_order and largest_object may be updated by SUBSCRIBE_OK because + // have no effect in a future REQUEST_UPDATE message. + MessageParameters parameters_; + std::optional<quic::QuicTimeDelta> publisher_delivery_timeout_; + std::optional<MoqtPriority> default_publisher_priority_; + uint64_t streams_opened_ = 0; + + // The subscription will ignore any groups with a lower ID, so it doesn't + // need to track reset subgroups. + uint64_t first_active_group_ = 0; + // If a stream has been reset due to delivery timeout, do not open a new + // stream if more object arrive for it. + absl::flat_hash_set<DataStreamIndex> reset_subgroups_; + + MoqtPublishingMonitorInterface* monitoring_interface_; + // Largest sequence number ever sent via this subscription. + std::optional<Location> largest_sent_; + SendStreamMap stream_map_; + // Store the StreamRank of queued outgoing data streams. High StreamRank is + // highest priority, so use rbegin() to get the highest priority pending + // stream. + absl::btree_multimap<StreamRank, NewDataStreamParameters> pending_streams_; + const quic::QuicClock* absl_nonnull clock_; + // Must be last. + quiche::QuicheWeakPtrFactory<SubscriptionPublisherInterface> + weak_ptr_factory_; +}; + +} // namespace moqt + +#endif // QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_
diff --git a/quiche/quic/moqt/moqt_subscription_test.cc b/quiche/quic/moqt/moqt_subscription_test.cc new file mode 100644 index 0000000..520e709 --- /dev/null +++ b/quiche/quic/moqt/moqt_subscription_test.cc
@@ -0,0 +1,546 @@ +// Copyright (c) 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_subscription.h" + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <optional> +#include <string> +#include <utility> + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_object.h" +#include "quiche/quic/moqt/moqt_parser.h" +#include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_session_callbacks.h" +#include "quiche/quic/moqt/moqt_session_interface.h" +#include "quiche/quic/moqt/moqt_trace_recorder.h" +#include "quiche/quic/moqt/moqt_types.h" +#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" +#include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h" +#include "quiche/quic/moqt/test_tools/moqt_session_peer.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/quiche_mem_slice.h" +#include "quiche/web_transport/test_tools/mock_web_transport.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt::test { + +namespace { + +using ::testing::_; +using ::testing::AtLeast; +using ::testing::Return; +using ::testing::ReturnRef; +using ::testing::StrictMock; +using ::webtransport::DatagramStatus; +using ::webtransport::DatagramStatusCode; + +class MockSessionToPublisherInterface : public SessionToPublisherInterface { + public: + ~MockSessionToPublisherInterface() override = default; + MOCK_METHOD(bool, alternate_delivery_timeout, (), (const, override)); + MOCK_METHOD(void, UpdateTrackPriority, + (uint64_t, std::optional<MoqtTrackPriority>, MoqtTrackPriority), + (override)); + MOCK_METHOD(quic::QuicAlarmFactory*, alarm_factory, (), (override)); + MOCK_METHOD(void, PublishIsDone, (uint64_t), (override)); + MOCK_METHOD(webtransport::Session*, session, (), (override)); +}; + +class TestMoqtBidiStream : public MoqtBidiStreamBase { + public: + TestMoqtBidiStream(MoqtFramer* absl_nonnull framer, + const MoqtControlMessageParser& message_parser, + BidiStreamDeletedCallback stream_deleted_callback, + SessionErrorCallback session_error_callback) + : MoqtBidiStreamBase(framer, message_parser, + std::move(stream_deleted_callback), + std::move(session_error_callback)) {} + ~TestMoqtBidiStream() override = default; + void OnStreamBound() override {}; + absl::Status OnRawControlMessage( + const MoqtRawControlMessage& message) override { + return absl::OkStatus(); + } +}; + +std::optional<PublishedObject> DefaultPublishedObject( + Location location, std::optional<uint64_t> subgroup, + MoqtPriority publisher_priority) { + PublishedObject object; + object.metadata.location = location; + object.metadata.subgroup = subgroup; + object.metadata.status = MoqtObjectStatus::kNormal; + object.metadata.publisher_priority = publisher_priority; + object.metadata.extensions = "extensions"; + object.metadata.payload_length = 8; + object.payload.push_back(quiche::QuicheMemSlice::Copy("deadbeef")); + return object; +} + +class SubscriptionPublisherTest : public quic::test::QuicTest { + public: + SubscriptionPublisherTest() + : track_publisher_( + std::make_shared<MockTrackPublisher>(FullTrackName("foo", "bar"))), + bidi_stream_( + &framer_, message_parser_, [] {}, + [](MoqtError, absl::string_view) {}), + trace_recorder_(nullptr) { + bidi_stream_.BindStream(&mock_bidi_stream_); + parameters_.set_forward(true); + parameters_.delivery_timeout = quic::QuicTimeDelta::FromSeconds(1); + parameters_.group_order = MoqtDeliveryOrder::kAscending; + EXPECT_CALL(monitoring_interface_, OnObjectAckSupportKnown) + .Times(AtLeast(0)); + ON_CALL(visitor_, session).WillByDefault(Return(&webtrans_)); + publisher_ = std::make_unique<SubscriptionPublisher>( + framer_, track_publisher_, &bidi_stream_, kRequestId, kTrackAlias, + parameters_, &visitor_, &monitoring_interface_, &mock_clock_, + trace_recorder_); + ON_CALL(visitor_, alternate_delivery_timeout).WillByDefault(Return(false)); + ON_CALL(webtrans_, GetStreamById(kStreamId)) + .WillByDefault(Return(&mock_uni_stream_)); + } + + ~SubscriptionPublisherTest() override { + EXPECT_CALL(*track_publisher_, RemoveObjectListener(publisher_.get())); + size_t num_open_streams = + SubscriptionPublisherPeer::num_open_streams(publisher_.get()); + EXPECT_CALL(mock_uni_stream_, ResetWithUserCode).Times(num_open_streams); + } + + MoqtPriority subscriber_priority() const { + return parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority); + } + + // Create a stream with the given parameters and send the first object. Will + // check that the first bytes written to the stream are equal to + // |opening_bytes|. + void CreateStream(Location location, uint64_t subgroup, + MoqtPriority publisher_priority, + std::string opening_bytes = "") { + EXPECT_CALL( + *track_publisher_, + GetCachedObject(location.group, std::make_optional<uint64_t>(subgroup), + location.object, 0)) + .WillOnce(Return( // Once for monitoring interface. + DefaultPublishedObject(location, subgroup, publisher_priority))) + .WillOnce(Return( // To actually deliver the object. + DefaultPublishedObject(location, subgroup, publisher_priority))); + EXPECT_CALL( + *track_publisher_, + GetCachedObject(location.group, std::make_optional<uint64_t>(subgroup), + location.object + 1, 0)) + .WillOnce(Return(std::nullopt)); + // Additional object retrievals will return nullopt. + EXPECT_CALL(monitoring_interface_, OnNewObjectEnqueued(location)); + EXPECT_CALL(mock_uni_stream_, GetStreamId()) + .WillRepeatedly(Return(kStreamId)); + EXPECT_CALL(webtrans_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(webtrans_, OpenOutgoingUnidirectionalStream) + .WillOnce(Return(&mock_uni_stream_)); + EXPECT_CALL(mock_uni_stream_, SetVisitor) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + uni_stream_ = std::move(visitor); + }); + EXPECT_CALL(mock_uni_stream_, SetPriority); + EXPECT_CALL(mock_uni_stream_, visitor()).WillRepeatedly([&]() { + return uni_stream_.get(); + }); + EXPECT_CALL(mock_uni_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(mock_uni_stream_, Writev) + .WillOnce([&](absl::Span<quiche::QuicheMemSlice> data, + const webtransport::StreamWriteOptions& options) { + EXPECT_TRUE(absl::StartsWith(data[0].AsStringView(), opening_bytes)); + EXPECT_FALSE(options.send_fin()); + return absl::OkStatus(); + }); + publisher_->OnNewObjectAvailable(location, subgroup, publisher_priority); + ++open_streams_; + } + + void CreatePendingStream(Location location, uint64_t subgroup, + MoqtPriority publisher_priority) { + EXPECT_CALL( + *track_publisher_, + GetCachedObject(location.group, std::make_optional<uint64_t>(subgroup), + location.object, 0)) + .WillOnce(Return( // Once for monitoring interface. + DefaultPublishedObject(location, subgroup, publisher_priority))); + // Additional object retrievals will return nullopt. + EXPECT_CALL(monitoring_interface_, OnNewObjectEnqueued(location)); + EXPECT_CALL(webtrans_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(false)); + EXPECT_CALL(visitor_, + UpdateTrackPriority(1, std::optional<MoqtTrackPriority>(), + MoqtTrackPriority{subscriber_priority(), + publisher_priority})); + publisher_->OnNewObjectAvailable(location, subgroup, publisher_priority); + } + + static constexpr webtransport::StreamId kStreamId = 100; + static constexpr uint64_t kTrackAlias = 10; + static constexpr uint64_t kRequestId = 1; + + MoqtFramer framer_{true}; + MoqtControlMessageParser message_parser_{kDefaultMoqtVersion, true}; + webtransport::test::MockSession webtrans_; + StrictMock<webtransport::test::MockStream> mock_bidi_stream_; + StrictMock<webtransport::test::MockStream> mock_uni_stream_; + std::shared_ptr<MockTrackPublisher> track_publisher_; + TestMoqtBidiStream bidi_stream_; + std::unique_ptr<webtransport::StreamVisitor> uni_stream_; + MessageParameters parameters_; + MockSessionToPublisherInterface visitor_; + StrictMock<MockPublishingMonitorInterface> monitoring_interface_; + MoqtTraceRecorder trace_recorder_; + std::unique_ptr<SubscriptionPublisher> publisher_; + const TrackExtensions extensions_; + quic::MockClock mock_clock_; + MoqtSessionCallbacks callbacks_; + quic::test::MockAlarmFactory alarm_factory_; + int open_streams_ = 0; +}; + +TEST_F(SubscriptionPublisherTest, OnSubscribeAcceptedNoFilter) { + EXPECT_CALL(mock_bidi_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(*track_publisher_, largest_location()) + .WillOnce(Return(Location(1, 2))); + EXPECT_CALL(*track_publisher_, expiration) + .WillOnce(Return(quic::QuicTimeDelta::FromSeconds(10))); + EXPECT_CALL(*track_publisher_, extensions) + .WillRepeatedly(ReturnRef(extensions_)); + EXPECT_CALL(mock_bidi_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)) + .WillOnce(Return(absl::OkStatus())); + publisher_->OnSubscribeAccepted(); + EXPECT_TRUE(publisher_->established()); + EXPECT_EQ(publisher_->parameters().largest_object, Location(1, 2)); + EXPECT_FALSE(publisher_->parameters().subscription_filter.has_value()); +} + +TEST_F(SubscriptionPublisherTest, OnSubscribeAcceptedWithFilter) { + publisher_->parameters().subscription_filter = + SubscriptionFilter(MoqtFilterType::kLargestObject); + const TrackExtensions extensions(std::nullopt, std::nullopt, + /*default_publisher_priority=*/64, + std::nullopt, std::nullopt, std::nullopt); + EXPECT_CALL(mock_bidi_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(*track_publisher_, largest_location()) + .WillOnce(Return(Location(1, 2))); + EXPECT_CALL(*track_publisher_, expiration) + .WillOnce(Return(quic::QuicTimeDelta::FromSeconds(10))); + EXPECT_CALL(*track_publisher_, extensions) + .WillRepeatedly(ReturnRef(extensions)); + EXPECT_CALL(mock_bidi_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)) + .WillOnce(Return(absl::OkStatus())); + publisher_->OnSubscribeAccepted(); + ASSERT_TRUE(publisher_->parameters().subscription_filter.has_value()); + EXPECT_EQ(publisher_->parameters().subscription_filter->start(), + Location(1, 3)); + // Check that default_publisher_priority is set. A datagram set at priority + // 64 should not explicitly encode that. + EXPECT_CALL(*track_publisher_, + GetCachedObject(1, std::optional<uint64_t>(), 3, 0)) + .WillOnce( + Return(DefaultPublishedObject(Location(1, 3), std::nullopt, 64))) + .WillOnce( + Return(DefaultPublishedObject(Location(1, 3), std::nullopt, 64))); + EXPECT_CALL(monitoring_interface_, OnNewObjectEnqueued(Location(1, 3))); + EXPECT_CALL(webtrans_, SendOrQueueDatagram) + .WillOnce([](absl::string_view datagram) { + EXPECT_FALSE(datagram.empty()); + std::optional<MoqtDatagramType> type = + MoqtDatagramType::FromValue(static_cast<uint64_t>(datagram[0])); + EXPECT_TRUE(type.has_value() && type->has_default_priority()); + return DatagramStatus(DatagramStatusCode::kSuccess, ""); + }); + publisher_->OnNewObjectAvailable(Location(1, 3), std::nullopt, 64); +} + +TEST_F(SubscriptionPublisherTest, OnSubscribeRejected) { + EXPECT_CALL(mock_bidi_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(mock_bidi_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(visitor_, PublishIsDone(1)); + publisher_->OnSubscribeRejected(MoqtRequestErrorInfo( + RequestErrorCode::kDoesNotExist, std::nullopt, "reason")); +} + +TEST_F(SubscriptionPublisherTest, Update) { + MessageParameters new_params; + new_params.delivery_timeout = quic::QuicTimeDelta::FromSeconds(5); + publisher_->Update(new_params); + + // Changing forward preference updates can_have_joining_fetch_ + new_params.set_forward(false); + publisher_->Update(new_params); + EXPECT_FALSE(publisher_->parameters().forward()); + EXPECT_FALSE(publisher_->can_have_joining_fetch()); +} + +TEST_F(SubscriptionPublisherTest, UpdatePriorityNoStreams) { + MessageParameters new_params; + new_params.subscriber_priority = 20; + publisher_->Update(new_params); + EXPECT_EQ(publisher_->parameters().subscriber_priority, 20); +} + +TEST_F(SubscriptionPublisherTest, UpdatePriorityWithPendingStreams) { + CreatePendingStream(Location(1, 0), 0, 64); + MessageParameters new_params; + new_params.subscriber_priority = 20; + EXPECT_CALL(*track_publisher_, extensions()) + .WillRepeatedly(ReturnRef(extensions_)); + EXPECT_CALL(visitor_, UpdateTrackPriority(1, + std::optional<MoqtTrackPriority>( + {subscriber_priority(), 64}), + MoqtTrackPriority{20, 64})); + publisher_->Update(new_params); +} + +TEST_F(SubscriptionPublisherTest, UpdatePriorityWithActiveStreams) { + CreateStream( + Location(1, 0), 0, 127, + {0x11, static_cast<uint8_t>(kTrackAlias), 0x01, 0x7f, 0x00, 0x0a}); + MessageParameters new_params; + new_params.subscriber_priority = 20; + EXPECT_CALL(mock_uni_stream_, SetPriority); + publisher_->Update(new_params); +} + +TEST_F(SubscriptionPublisherTest, OnNewObjectAvailableNotInWindow) { + MessageParameters params; + params.subscription_filter = SubscriptionFilter(Location(10, 0), 10); + publisher_->Update(params); + EXPECT_CALL(*track_publisher_, GetCachedObject).Times(0); + publisher_->OnNewObjectAvailable(Location(5, 0), 0, 128); +} + +TEST_F(SubscriptionPublisherTest, OnNewObjectAvailableDatagram) { + EXPECT_CALL(*track_publisher_, + GetCachedObject(1, std::optional<uint64_t>(), 0, 0)) + .WillOnce( + Return(DefaultPublishedObject(Location(1, 0), std::nullopt, 128))) + .WillOnce( + Return(DefaultPublishedObject(Location(1, 0), std::nullopt, 128))); + EXPECT_CALL(monitoring_interface_, OnNewObjectEnqueued(Location(1, 0))); + EXPECT_CALL(webtrans_, SendOrQueueDatagram) + .WillOnce(Return(DatagramStatus(DatagramStatusCode::kSuccess, ""))); + EXPECT_CALL(*track_publisher_, extensions()) + .WillRepeatedly(ReturnRef(extensions_)); + publisher_->OnNewObjectAvailable(Location(1, 0), std::nullopt, 128); +} + +TEST_F(SubscriptionPublisherTest, OnNewObjectAvailableStreamCreationBlocked) { + CreatePendingStream(Location(1, 0), 0, 128); +} + +TEST_F(SubscriptionPublisherTest, OnNewFinAvailableNoops) { + // Not in window + MessageParameters params; + params.subscription_filter = SubscriptionFilter(Location(10, 0), 10); + publisher_->Update(params); + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + EXPECT_CALL(mock_uni_stream_, Writev).Times(0); + publisher_->OnNewFinAvailable(Location(0, 0), 0); + + // In window but no stream + publisher_->Update(parameters_); + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + EXPECT_CALL(mock_uni_stream_, Writev).Times(0); + publisher_->OnNewFinAvailable(Location(10, 10), 0); + + EXPECT_CALL(webtrans_, GetStreamById) + .WillRepeatedly(Return(&mock_uni_stream_)); + // Stream hasn't gotten there yet. The cache will tell us when to send FIN. + CreateStream(Location(10, 0), 0, 128); + EXPECT_CALL(mock_uni_stream_, Writev).Times(0); + publisher_->OnNewFinAvailable(Location(10, 1), 0); +} + +TEST_F(SubscriptionPublisherTest, OnNewFinAvailableWithStream) { + CreateStream(Location(1, 0), 0, 128); + EXPECT_CALL(mock_uni_stream_, Writev) + .WillOnce([](absl::Span<quiche::QuicheMemSlice> data, + const webtransport::StreamWriteOptions& options) { + EXPECT_TRUE(data.empty()); + EXPECT_TRUE(options.send_fin()); + return absl::OkStatus(); + }); + quic::test::MockAlarmFactory alarm_factory; + EXPECT_CALL(visitor_, alarm_factory).WillOnce(Return(&alarm_factory)); + publisher_->OnNewFinAvailable(Location(1, 0), 0); +} + +TEST_F(SubscriptionPublisherTest, OnSubgroupAbandoned) { + // Not in window + MessageParameters params; + params.subscription_filter = SubscriptionFilter(Location(10, 0), 10); + publisher_->Update(params); + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + publisher_->OnSubgroupAbandoned(1, 0, 17); + + // In window but no stream + publisher_->Update(parameters_); + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + publisher_->OnSubgroupAbandoned(1, 0, 17); +} + +TEST_F(SubscriptionPublisherTest, OnGroupAbandoned) { + // Not in window + MessageParameters params; + params.subscription_filter = SubscriptionFilter(Location(10, 0), 10); + publisher_->Update(params); + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + publisher_->OnGroupAbandoned(1); + + // In window + publisher_->Update(parameters_); + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + publisher_->OnGroupAbandoned(1); + EXPECT_CALL(*track_publisher_, GetCachedObject).Times(0); + publisher_->OnNewObjectAvailable(Location(1, 0), 0, 128); +} + +TEST_F(SubscriptionPublisherTest, OnGroupAbandonedWithStreams) { + // Create a stream + CreateStream(Location(1, 0), 0, 128); + EXPECT_CALL(mock_uni_stream_, ResetWithUserCode); + publisher_->OnGroupAbandoned(1); +} + +TEST_F(SubscriptionPublisherTest, OnGroupAbandonedTooFarBehind) { + // Create a pending stream + parameters_.delivery_timeout = quic::QuicTimeDelta::Infinite(); + publisher_->Update(parameters_); + CreateStream(Location(5, 0), 0, 128); + struct MoqtPublishDone expected_publish_done = { + /*request_id=*/kRequestId, + PublishDoneCode::kTooFarBehind, + /*stream_count=*/1, + /*error_reason=*/"", + }; + EXPECT_CALL(mock_bidi_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(mock_bidi_stream_, + Writev(SerializedControlMessage(expected_publish_done), _)); + EXPECT_CALL(visitor_, PublishIsDone(kRequestId)); + publisher_->OnGroupAbandoned(5); +} + +TEST_F(SubscriptionPublisherTest, OnCanCreateNewUniStreamPendingCleanup) { + CreatePendingStream(Location(1, 0), 0, 128); + // Abandon the group. + publisher_->OnGroupAbandoned(1); + // OnCanCreateNewUniStream should clean it up; no attempt to create a stream. + EXPECT_CALL(webtrans_, CanOpenNextOutgoingUnidirectionalStream()).Times(0); + EXPECT_CALL(webtrans_, OpenOutgoingUnidirectionalStream).Times(0); + publisher_->OnCanCreateNewUniStream(); +} + +TEST_F(SubscriptionPublisherTest, AlternateDeliveryTimeoutSetAlarm) { + ON_CALL(visitor_, alternate_delivery_timeout).WillByDefault(Return(true)); + // Create a stream for group 1. + CreateStream(Location(1, 0), 0, 128); + // Create a pending stream for group 2, which should start the timer but does + // less work than an active stream. + EXPECT_CALL(visitor_, alarm_factory).WillOnce(Return(&alarm_factory_)); + CreatePendingStream(Location(2, 0), 0, 128); +} + +TEST_F(SubscriptionPublisherTest, OnTrackPublisherGone) { + EXPECT_CALL(mock_bidi_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(mock_bidi_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kPublishDone), _)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(visitor_, PublishIsDone(1)); + publisher_->OnTrackPublisherGone(); +} + +TEST_F(SubscriptionPublisherTest, ProcessObjectAck) { + MoqtObjectAck ack; + ack.group_id = 1; + ack.object_id = 2; + ack.delta_from_deadline = quic::QuicTimeDelta::FromMilliseconds(100); + EXPECT_CALL(monitoring_interface_, + OnObjectAckReceived(Location(1, 2), ack.delta_from_deadline)); + publisher_->ProcessObjectAck(ack); +} + +TEST_F(SubscriptionPublisherTest, OnSubgroupAbandonedWithStream) { + CreateStream(Location(1, 0), 0, 128); + EXPECT_CALL(mock_uni_stream_, ResetWithUserCode(17)); + publisher_->OnSubgroupAbandoned(1, 0, 17); +} + +TEST_F(SubscriptionPublisherTest, OnCanCreateNewUniStreamSuccess) { + CreatePendingStream(Location(1, 0), 0, 128); + // Call OnCanCreateNewUniStream and succeed. + EXPECT_CALL(mock_uni_stream_, GetStreamId()) + .WillRepeatedly(Return(kStreamId)); + EXPECT_CALL(webtrans_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(webtrans_, OpenOutgoingUnidirectionalStream) + .WillOnce(Return(&mock_uni_stream_)); + EXPECT_CALL(mock_uni_stream_, SetVisitor) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + uni_stream_ = std::move(visitor); + }); + EXPECT_CALL(mock_uni_stream_, SetPriority); + EXPECT_CALL(mock_uni_stream_, visitor()).WillRepeatedly([&]() { + return uni_stream_.get(); + }); + EXPECT_CALL(mock_uni_stream_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(*track_publisher_, + GetCachedObject(1, std::optional<uint64_t>(0), 0, 0)) + .WillOnce(Return(DefaultPublishedObject(Location(1, 0), 0, 128))); + EXPECT_CALL(*track_publisher_, + GetCachedObject(1, std::optional<uint64_t>(0), 1, 0)) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL(mock_uni_stream_, Writev).WillOnce(Return(absl::OkStatus())); + publisher_->OnCanCreateNewUniStream(); +} + +TEST_F(SubscriptionPublisherTest, OnDataStreamDestroyed) { + CreateStream(Location(1, 0), 0, 128); + DataStreamIndex index(1, 0); + publisher_->OnDataStreamDestroyed(index); + // No entries in the stream map. + EXPECT_CALL(webtrans_, GetStreamById).Times(0); + parameters_.subscriber_priority = 20; + publisher_->Update(parameters_); +} + +TEST_F(SubscriptionPublisherTest, OnObjectSentTwice) { + publisher_->OnObjectSent(Location(1, 0)); + EXPECT_TRUE( + SubscriptionPublisherPeer::largest_sent(publisher_.get()).has_value() && + *SubscriptionPublisherPeer::largest_sent(publisher_.get()) == + Location(1, 0)); +} + +} // namespace + +} // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index 2e6f66d..0994cee 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_ -#define QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_ +#ifndef QUICHE_QUIC_MOQT_MOQT_TRACK_H_ +#define QUICHE_QUIC_MOQT_MOQT_TRACK_H_ #include <cstdint> #include <memory> @@ -371,4 +371,4 @@ } // namespace moqt -#endif // QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_ +#endif // QUICHE_QUIC_MOQT_MOQT_TRACK_H_
diff --git a/quiche/quic/moqt/moqt_uni_stream_test.cc b/quiche/quic/moqt/moqt_uni_stream_test.cc index 4ebf4c5..7a9cc11 100644 --- a/quiche/quic/moqt/moqt_uni_stream_test.cc +++ b/quiche/quic/moqt/moqt_uni_stream_test.cc
@@ -227,7 +227,12 @@ EXPECT_CALL(*track_publisher_, extensions()) .WillRepeatedly(ReturnRef(track_extensions_)); - EXPECT_CALL(mock_stream_, Writev).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(mock_stream_, Writev) + .WillOnce([&](absl::Span<quiche::QuicheMemSlice> data, + const webtransport::StreamWriteOptions& options) { + EXPECT_TRUE(options.send_fin()); + return absl::OkStatus(); + }); EXPECT_CALL(visitor_, OnObjectSent(Location(0, 0))); ExpectAlarm(); stream_->OnCanWrite();
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index 5264255..b3c9e86 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -5,6 +5,7 @@ #ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_ #define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_ +#include <cstddef> #include <cstdint> #include <memory> #include <optional> @@ -13,6 +14,7 @@ #include "absl/base/casts.h" #include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -29,6 +31,7 @@ #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_session.h" #include "quiche/quic/moqt/moqt_session_interface.h" +#include "quiche/quic/moqt/moqt_subscription.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/quic/moqt/moqt_uni_stream.h" @@ -49,6 +52,25 @@ } }; +// TODO(martinduke): When subscription-specific tests are removed from, +// MoqtSessionTest, much of this file can be deleted (including +// SubscriptionPublisherPeer). + +class SubscriptionPublisherPeer { + public: + static size_t num_open_streams(SubscriptionPublisher* publisher) { + return publisher->stream_map_.GetAllStreams().size(); + } + static std::optional<Location> largest_sent( + const SubscriptionPublisher* publisher) { + return publisher->largest_sent_; + } + static const absl::flat_hash_set<DataStreamIndex>& reset_subgroups( + const SubscriptionPublisher* publisher) { + return publisher->reset_subgroups_; + } +}; + // Helper class to interact with MOQT bidi streams in tests. class MoqtBidiStreamTestWrapper { public: @@ -150,17 +172,21 @@ subscribe.parameters.subscriber_priority = 0x80; subscribe.parameters.group_order = MoqtDeliveryOrder::kAscending; session->published_subscriptions_.emplace( - subscribe_id, std::make_unique<MoqtSession::PublishedSubscription>( - session, std::move(publisher), subscribe, track_alias, - /*monitoring_interface=*/nullptr)); + subscribe_id, + std::make_unique<SubscriptionPublisher>( + session->framer_, std::move(publisher), session->GetControlStream(), + subscribe_id, track_alias, subscribe.parameters, session, + /*monitoring_interface=*/nullptr, session->callbacks_.clock, + session->trace_recorder_)); return session->published_subscriptions_[subscribe_id].get(); } static bool InSubscriptionWindow(MoqtObjectListener* subscription, Location sequence) { std::optional<SubscriptionFilter> filter = - absl::down_cast<MoqtSession::PublishedSubscription*>(subscription) - ->parameters_.subscription_filter; + absl::down_cast<SubscriptionPublisher*>(subscription) + ->parameters() + .subscription_filter; return (!filter.has_value() || filter->InWindow(sequence)); } @@ -211,9 +237,10 @@ session->ValidateRequestId(id); } - static Location LargestSentForSubscription(MoqtSession* session, - uint64_t subscribe_id) { - return *session->published_subscriptions_[subscribe_id]->largest_sent(); + static std::optional<Location> LargestSentForSubscription( + MoqtSession* session, uint64_t subscribe_id) { + return SubscriptionPublisherPeer::largest_sent( + session->published_subscriptions_[subscribe_id].get()); } // Adds an upstream fetch and a stream ready to receive data. @@ -274,19 +301,20 @@ static quic::QuicTimeDelta GetDeliveryTimeout( MoqtObjectListener* subscription) { - return absl::down_cast<MoqtSession::PublishedSubscription*>(subscription) + return absl::down_cast<SubscriptionPublisher*>(subscription) ->delivery_timeout(); } static void SetDeliveryTimeout(MoqtObjectListener* subscription, quic::QuicTimeDelta timeout) { - absl::down_cast<MoqtSession::PublishedSubscription*>(subscription) - ->parameters_.delivery_timeout = timeout; + absl::down_cast<SubscriptionPublisher*>(subscription) + ->parameters() + .delivery_timeout = timeout; } static bool SubgroupHasBeenReset(MoqtObjectListener* subscription, DataStreamIndex index) { - return absl::down_cast<MoqtSession::PublishedSubscription*>(subscription) - ->reset_subgroups() + return SubscriptionPublisherPeer::reset_subgroups( + absl::down_cast<SubscriptionPublisher*>(subscription)) .contains(index); }