Implement a bidi stream version of SUBSCRIBE_NAMESPACE. Not yet integrated into MoqtSession. PiperOrigin-RevId: 865053873
diff --git a/build/source_list.bzl b/build/source_list.bzl index 7b95eec..3ce12e5 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1586,6 +1586,7 @@ "quic/moqt/moqt_known_track_publisher.h", "quic/moqt/moqt_messages.h", "quic/moqt/moqt_names.h", + "quic/moqt/moqt_namespace_stream.h", "quic/moqt/moqt_object.h", "quic/moqt/moqt_outgoing_queue.h", "quic/moqt/moqt_outstanding_objects.h", @@ -1618,6 +1619,7 @@ "quic/moqt/moqt_known_track_publisher.cc", "quic/moqt/moqt_messages.cc", "quic/moqt/moqt_names.cc", + "quic/moqt/moqt_namespace_stream.cc", "quic/moqt/moqt_object.cc", "quic/moqt/moqt_outgoing_queue.cc", "quic/moqt/moqt_outstanding_objects.cc", @@ -1647,6 +1649,7 @@ "quic/moqt/moqt_key_value_pair_test.cc", "quic/moqt/moqt_messages_test.cc", "quic/moqt/moqt_names_test.cc", + "quic/moqt/moqt_namespace_stream_test.cc", "quic/moqt/moqt_outgoing_queue_test.cc", "quic/moqt/moqt_outstanding_objects_test.cc", "quic/moqt/moqt_parser_fuzz_test.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index 4b683b3..182019a 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1590,6 +1590,7 @@ "src/quiche/quic/moqt/moqt_known_track_publisher.h", "src/quiche/quic/moqt/moqt_messages.h", "src/quiche/quic/moqt/moqt_names.h", + "src/quiche/quic/moqt/moqt_namespace_stream.h", "src/quiche/quic/moqt/moqt_object.h", "src/quiche/quic/moqt/moqt_outgoing_queue.h", "src/quiche/quic/moqt/moqt_outstanding_objects.h", @@ -1622,6 +1623,7 @@ "src/quiche/quic/moqt/moqt_known_track_publisher.cc", "src/quiche/quic/moqt/moqt_messages.cc", "src/quiche/quic/moqt/moqt_names.cc", + "src/quiche/quic/moqt/moqt_namespace_stream.cc", "src/quiche/quic/moqt/moqt_object.cc", "src/quiche/quic/moqt/moqt_outgoing_queue.cc", "src/quiche/quic/moqt/moqt_outstanding_objects.cc", @@ -1652,6 +1654,7 @@ "src/quiche/quic/moqt/moqt_key_value_pair_test.cc", "src/quiche/quic/moqt/moqt_messages_test.cc", "src/quiche/quic/moqt/moqt_names_test.cc", + "src/quiche/quic/moqt/moqt_namespace_stream_test.cc", "src/quiche/quic/moqt/moqt_outgoing_queue_test.cc", "src/quiche/quic/moqt/moqt_outstanding_objects_test.cc", "src/quiche/quic/moqt/moqt_parser_fuzz_test.cc",
diff --git a/build/source_list.json b/build/source_list.json index 802917f..c4a622d 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1589,6 +1589,7 @@ "quiche/quic/moqt/moqt_known_track_publisher.h", "quiche/quic/moqt/moqt_messages.h", "quiche/quic/moqt/moqt_names.h", + "quiche/quic/moqt/moqt_namespace_stream.h", "quiche/quic/moqt/moqt_object.h", "quiche/quic/moqt/moqt_outgoing_queue.h", "quiche/quic/moqt/moqt_outstanding_objects.h", @@ -1621,6 +1622,7 @@ "quiche/quic/moqt/moqt_known_track_publisher.cc", "quiche/quic/moqt/moqt_messages.cc", "quiche/quic/moqt/moqt_names.cc", + "quiche/quic/moqt/moqt_namespace_stream.cc", "quiche/quic/moqt/moqt_object.cc", "quiche/quic/moqt/moqt_outgoing_queue.cc", "quiche/quic/moqt/moqt_outstanding_objects.cc", @@ -1651,6 +1653,7 @@ "quiche/quic/moqt/moqt_key_value_pair_test.cc", "quiche/quic/moqt/moqt_messages_test.cc", "quiche/quic/moqt/moqt_names_test.cc", + "quiche/quic/moqt/moqt_namespace_stream_test.cc", "quiche/quic/moqt/moqt_outgoing_queue_test.cc", "quiche/quic/moqt/moqt_outstanding_objects_test.cc", "quiche/quic/moqt/moqt_parser_fuzz_test.cc",
diff --git a/quiche/quic/moqt/moqt_bidi_stream.h b/quiche/quic/moqt/moqt_bidi_stream.h index b1c3351..10fafcb 100644 --- a/quiche/quic/moqt/moqt_bidi_stream.h +++ b/quiche/quic/moqt/moqt_bidi_stream.h
@@ -214,7 +214,7 @@ void Fin() { fin_queued_ = true; if (pending_messages_.empty()) { - if (!stream_->SendFin()) { + if (stream_ != nullptr && !SendFinOnStream(*stream_).ok()) { std::move(session_error_callback_)(MoqtError::kInternalError, "Failed to send FIN"); }
diff --git a/quiche/quic/moqt/moqt_error.h b/quiche/quic/moqt/moqt_error.h index cb5eba7..a9d6d86 100644 --- a/quiche/quic/moqt/moqt_error.h +++ b/quiche/quic/moqt/moqt_error.h
@@ -50,6 +50,7 @@ // TODO(martinduke): This is not in the spec, but is needed. The number might // change. inline constexpr webtransport::StreamErrorCode kResetCodeMalformedTrack = 0x04; +inline constexpr webtransport::StreamErrorCode kResetCodeTooFarBehind = 0x05; // Used for SUBSCRIBE_ERROR, PUBLISH_NAMESPACE_ERROR, PUBLISH_NAMESPACE_CANCEL, // SUBSCRIBE_NAMESPACE_ERROR, and FETCH_ERROR. @@ -70,6 +71,7 @@ kMalformedTrack = 0x9, kMalformedAuthToken = 0x10, kExpiredAuthToken = 0x12, + kPrefixOverlap = 0x30, }; struct MoqtRequestErrorInfo {
diff --git a/quiche/quic/moqt/moqt_fetch_task.h b/quiche/quic/moqt/moqt_fetch_task.h index bfcf395..56391ab 100644 --- a/quiche/quic/moqt/moqt_fetch_task.h +++ b/quiche/quic/moqt/moqt_fetch_task.h
@@ -5,6 +5,7 @@ #ifndef QUICHE_QUIC_MOQT_MOQT_FETCH_TASK_H_ #define QUICHE_QUIC_MOQT_MOQT_FETCH_TASK_H_ +#include <cstdint> #include <optional> #include <string> #include <utility> @@ -13,22 +14,44 @@ #include "absl/status/status.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_object.h" #include "quiche/common/quiche_callbacks.h" +#include "quiche/web_transport/web_transport.h" namespace moqt { +// TODO(martinduke): There are will be multiple instances of flow-controlled +// "pull" data retrieval tasks. It might be worthwhile to extract some common +// features into a base class. + +using ObjectsAvailableCallback = quiche::MultiUseCallback<void()>; +// Potential results of a GetNextObject/GetNextMessage() call. +enum GetNextResult { + // The next object or message is available, and is placed into the reference + // specified by the caller. + kSuccess, + // The next object or message is not yet available (equivalent of EAGAIN). + kPending, + // The end of the response has been reached. + kEof, + // The request has failed; the error is available. + kError, +}; + +enum class TransactionType : uint8_t { kAdd, kDelete }; + // A handle representing a fetch in progress. The fetch in question can be // cancelled by deleting the object. class MoqtFetchTask { public: - using ObjectsAvailableCallback = quiche::MultiUseCallback<void()>; // The request_id field will be ignored. using FetchResponseCallback = quiche::SingleUseCallback<void( std::variant<MoqtFetchOk, MoqtRequestError>)>; virtual ~MoqtFetchTask() = default; + // TODO(martinduke): Replace with GetNextResult above. // Potential results of a GetNextObject() call. enum GetNextObjectResult { // The next object is available, and is placed into the reference specified @@ -65,6 +88,29 @@ virtual absl::Status GetStatus() = 0; }; +class MoqtNamespaceTask { + public: + virtual ~MoqtNamespaceTask() = default; + + // Returns the state of the message queue. If available, writes the suffix + // into |suffix|. If |type| is kAdd, it is from a NAMESPACE message. If |type| + // is kDelete, it is from a NAMESPACE_DONE message. + virtual GetNextResult GetNextSuffix(TrackNamespace& suffix, + TransactionType& type) = 0; + + // Sets the callback that is called when a NAMESPACE or NAMESPACE_DONE message + // is received. If a message is available immediately, the callback will be + // called immediately. + virtual void SetObjectAvailableCallback( + ObjectsAvailableCallback callback) = 0; + + // Returns the error if request has completely failed, and nullopt otherwise. + virtual std::optional<webtransport::StreamErrorCode> GetStatus() = 0; + + // Returns the prefix for this task. + virtual const TrackNamespace& prefix() = 0; +}; + // A fetch that starts out in the failed state. class MoqtFailedFetch : public MoqtFetchTask { public:
diff --git a/quiche/quic/moqt/moqt_namespace_stream.cc b/quiche/quic/moqt/moqt_namespace_stream.cc new file mode 100644 index 0000000..bbacc77 --- /dev/null +++ b/quiche/quic/moqt/moqt_namespace_stream.cc
@@ -0,0 +1,283 @@ +// Copyright (c) 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_namespace_stream.h" + +#include <memory> +#include <optional> +#include <utility> + +#include "absl/base/nullability.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_session_callbacks.h" +#include "quiche/quic/moqt/session_namespace_tree.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_stream.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt { + +MoqtNamespaceSubscriberStream::~MoqtNamespaceSubscriberStream() { + NamespaceTask* task = task_.GetIfAvailable(); + if (task != nullptr) { + task->DeclareEof(); + } +} + +void MoqtNamespaceSubscriberStream::set_stream( + webtransport::Stream* absl_nonnull stream) { + // TODO(martinduke): Set the priority for this stream. + MoqtBidiStreamBase::set_stream(stream); +} + +void MoqtNamespaceSubscriberStream::OnRequestOkMessage( + const MoqtRequestOk& message) { + if (message.request_id != request_id_) { + OnParsingError(MoqtError::kProtocolViolation, + "Unexpected request ID in response"); + return; + } + if (response_callback_ == nullptr) { + OnParsingError(MoqtError::kProtocolViolation, "Two responses"); + return; + } + std::move(response_callback_)(std::nullopt); + response_callback_ = nullptr; +} + +void MoqtNamespaceSubscriberStream::OnRequestErrorMessage( + const MoqtRequestError& message) { + if (message.request_id != request_id_) { + OnParsingError(MoqtError::kProtocolViolation, + "Unexpected request ID in error"); + return; + } + if (response_callback_ == nullptr) { + OnParsingError(MoqtError::kProtocolViolation, "Two responses"); + return; + } + std::move(response_callback_)(MoqtRequestErrorInfo{ + message.error_code, message.retry_interval, message.reason_phrase}); + response_callback_ = nullptr; +} + +void MoqtNamespaceSubscriberStream::OnNamespaceMessage( + const MoqtNamespace& message) { + if (response_callback_ != nullptr) { + OnParsingError(MoqtError::kProtocolViolation, + "First message must be REQUEST_OK or REQUEST_ERROR"); + return; + } + NamespaceTask* task = task_.GetIfAvailable(); + if (task == nullptr) { + // The application has already unsubscribed, and the stream has been reset. + // This is irrelevant. + return; + } + if (task->prefix().number_of_elements() + + message.track_namespace_suffix.number_of_elements() > + kMaxNamespaceElements) { + OnParsingError(MoqtError::kProtocolViolation, + "Too many namespace elements"); + return; + } + if (task->prefix().total_length() + + message.track_namespace_suffix.total_length() > + kMaxFullTrackNameSize) { + OnParsingError(MoqtError::kProtocolViolation, "Namespace too large"); + return; + } + auto [it, inserted] = + published_suffixes_.insert(message.track_namespace_suffix); + if (!inserted) { + OnParsingError(MoqtError::kProtocolViolation, + "Two NAMESPACE messages for the same track namespace"); + return; + } + task->AddPendingSuffix(message.track_namespace_suffix, TransactionType::kAdd); +} + +void MoqtNamespaceSubscriberStream::OnNamespaceDoneMessage( + const MoqtNamespaceDone& message) { + if (response_callback_ != nullptr) { + OnParsingError(MoqtError::kProtocolViolation, + "First message must be REQUEST_OK or REQUEST_ERROR"); + return; + } + NamespaceTask* task = task_.GetIfAvailable(); + if (task == nullptr) { + return; + } + if (published_suffixes_.erase(message.track_namespace_suffix) == 0) { + OnParsingError(MoqtError::kProtocolViolation, + "NAMESPACE_DONE with no active namespace"); + return; + } + task->AddPendingSuffix(message.track_namespace_suffix, + TransactionType::kDelete); +} + +std::unique_ptr<MoqtNamespaceTask> MoqtNamespaceSubscriberStream::CreateTask( + const TrackNamespace& prefix) { + auto task = std::make_unique<NamespaceTask>(this, prefix); + QUICHE_DCHECK(task != nullptr); + task_ = task->GetWeakPtr(); + QUICHE_DCHECK(task_.IsValid()); + return std::move(task); +} + +MoqtNamespaceSubscriberStream::NamespaceTask::~NamespaceTask() { + if (state_ != nullptr) { + state_->Reset(kResetCodeCanceled); + } +} + +GetNextResult MoqtNamespaceSubscriberStream::NamespaceTask::GetNextSuffix( + TrackNamespace& suffix, TransactionType& type) { + if (pending_suffixes_.empty()) { + if (error_.has_value()) { + return kError; + } + if (eof_) { + return kEof; + } + return kPending; + } + suffix = pending_suffixes_.front().suffix; + type = pending_suffixes_.front().type; + pending_suffixes_.pop_front(); + return kSuccess; +} + +void MoqtNamespaceSubscriberStream::NamespaceTask::SetObjectAvailableCallback( + ObjectsAvailableCallback callback) { + callback_ = std::move(callback); + if (!pending_suffixes_.empty() || eof_ || error_.has_value()) { + std::move(callback_)(); + } +} + +void MoqtNamespaceSubscriberStream::NamespaceTask::AddPendingSuffix( + TrackNamespace suffix, TransactionType type) { + if (pending_suffixes_.size() == kMaxPendingSuffixes) { + error_ = kResetCodeTooFarBehind; + if (state_ != nullptr) { + state_->Reset(kResetCodeTooFarBehind); + } + return; + } + pending_suffixes_.push_back(PendingSuffix{std::move(suffix), type}); + if (callback_ != nullptr) { + std::move(callback_)(); + } +} + +void MoqtNamespaceSubscriberStream::NamespaceTask::DeclareEof() { + if (eof_) { + return; + } + eof_ = true; + state_ = nullptr; + if (callback_ != nullptr) { + std::move(callback_)(); + } +} + +MoqtNamespacePublisherStream::MoqtNamespacePublisherStream( + MoqtFramer* framer, webtransport::Stream* stream, + SessionErrorCallback session_error_callback, SessionNamespaceTree& tree, + MoqtIncomingSubscribeNamespaceCallbackNew& application) + // No stream_deleted_callback because there's no state yet. + : MoqtBidiStreamBase( + framer, []() {}, std::move(session_error_callback)), + tree_(tree), + application_(application) { + // TODO(martinduke): Set the priority for this stream. + MoqtBidiStreamBase::set_stream(stream, MoqtMessageType::kSubscribeNamespace); +} + +MoqtNamespacePublisherStream::~MoqtNamespacePublisherStream() { + if (task_ != nullptr) { + // Could be null if the stream died early. + tree_.UnsubscribeNamespace(task_->prefix()); + } +} + +void MoqtNamespacePublisherStream::OnSubscribeNamespaceMessage( + const MoqtSubscribeNamespace& message) { + request_id_ = message.request_id; + if (!tree_.SubscribeNamespace(message.track_namespace_prefix)) { + SendRequestError(request_id_, RequestErrorCode::kPrefixOverlap, + std::nullopt, ""); + return; + } + QUICHE_DCHECK(task_ == nullptr); + task_ = application_(message.track_namespace_prefix, message.parameters, + [this](std::optional<MoqtRequestErrorInfo> error) { + if (error.has_value()) { + SendRequestError(request_id_, *error, /*fin=*/true); + } else { + SendRequestOk(request_id_, MessageParameters()); + } + }); + if (task_ == nullptr) { + return; + } + task_->SetObjectAvailableCallback([this]() { + if (task_ == nullptr) { + return; + } + TrackNamespace suffix; + TransactionType type; + for (;;) { + GetNextResult result = task_->GetNextSuffix(suffix, type); + switch (result) { + case kPending: + return; + case kEof: + if (!SendFinOnStream(*stream()).ok()) { + OnParsingError(MoqtError::kInternalError, "Failed to send FIN"); + }; + return; + case kError: + Reset(kResetCodeCanceled); + return; + case kSuccess: { + switch (type) { + case TransactionType::kAdd: { + auto [it, inserted] = published_suffixes_.insert(suffix); + if (!inserted) { + // This should never happen. Do not send something that would + // cause a protocol violation. + return; + } + SendOrBufferMessage( + framer_->SerializeNamespace(MoqtNamespace{suffix})); + break; + } + case TransactionType::kDelete: { + if (published_suffixes_.erase(suffix) == 0) { + // This should never happen. Do not send something that would + // cause a protocol violation. + return; + } + SendOrBufferMessage( + framer_->SerializeNamespaceDone(MoqtNamespaceDone{suffix})); + break; + } + } + } + } + } + }); +} + +} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_namespace_stream.h b/quiche/quic/moqt/moqt_namespace_stream.h new file mode 100644 index 0000000..903694c --- /dev/null +++ b/quiche/quic/moqt/moqt_namespace_stream.h
@@ -0,0 +1,137 @@ +// Copyright (c) 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MOQT_MOQT_NAMESPACE_STREAM_H_ +#define QUICHE_QUIC_MOQT_MOQT_NAMESPACE_STREAM_H_ + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <optional> +#include <utility> + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_session_callbacks.h" +#include "quiche/quic/moqt/session_namespace_tree.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/common/quiche_weak_ptr.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt { + +// This class will be owned by the webtransport stream. +class MoqtNamespaceSubscriberStream : public MoqtBidiStreamBase { + public: + // Assumes the caller will send or queue the SUBSCRIBE_NAMESPACE. + MoqtNamespaceSubscriberStream( + MoqtFramer* framer, uint64_t request_id, + BidiStreamDeletedCallback stream_deleted_callback, + SessionErrorCallback session_error_callback, + MoqtResponseCallback response_callback) + : MoqtBidiStreamBase(framer, std::move(stream_deleted_callback), + std::move(session_error_callback)), + request_id_(request_id), + response_callback_(std::move(response_callback)) {} + ~MoqtNamespaceSubscriberStream() override; + + // MoqtBidiStreamBase overrides. + void set_stream(webtransport::Stream* absl_nonnull stream) override; + void OnRequestOkMessage(const MoqtRequestOk& message) override; + void OnRequestErrorMessage(const MoqtRequestError& message) override; + void OnNamespaceMessage(const MoqtNamespace& message) override; + void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) override; + + // Send the prefix now so it is only stored in one place (the task). + std::unique_ptr<MoqtNamespaceTask> CreateTask(const TrackNamespace& prefix); + + private: + // The class that will be passed to the application to consume namespace + // information. Owned by the application. + class NamespaceTask : public MoqtNamespaceTask { + public: + NamespaceTask(MoqtNamespaceSubscriberStream* absl_nonnull state, + const TrackNamespace& prefix) + : MoqtNamespaceTask(), + prefix_(prefix), + state_(state), + weak_ptr_factory_(this) {} + ~NamespaceTask() override; + // MoqtNamespaceTask methods. A return value of kEof implies + // NAMESPACE_DONE for all outstanding namespaces. + GetNextResult GetNextSuffix(TrackNamespace& suffix, + TransactionType& type) override; + void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override; + std::optional<webtransport::StreamErrorCode> GetStatus() override { + return error_; + } + const TrackNamespace& prefix() override { return prefix_; } + + // Queues a suffix corresponding to a NAMESPACE (if |type| is kAdd) or a + // NAMESPACE_DONE (if |type| is kDelete). + void AddPendingSuffix(TrackNamespace suffix, TransactionType type); + // The stream is closed, so no more NAMESPACE messages are forthcoming. + // This is an implicit NAMESPACE_DONE for all published namespaces. + void DeclareEof(); + quiche::QuicheWeakPtr<NamespaceTask> GetWeakPtr() { + return weak_ptr_factory_.Create(); + } + + private: + struct PendingSuffix { + TrackNamespace suffix; + TransactionType type; + }; + + static constexpr size_t kMaxPendingSuffixes = 100; + const TrackNamespace prefix_; + // Must be nonnull initially, will be nullptr if the stream is closed. + MoqtNamespaceSubscriberStream* state_; + quiche::QuicheCircularDeque<PendingSuffix> pending_suffixes_; + ObjectsAvailableCallback callback_; + std::optional<webtransport::StreamErrorCode> error_; + bool eof_ = false; + // Must be last. + quiche::QuicheWeakPtrFactory<NamespaceTask> weak_ptr_factory_; + }; + + const uint64_t request_id_; + MoqtResponseCallback response_callback_; + absl::flat_hash_set<TrackNamespace> published_suffixes_; + quiche::QuicheWeakPtr<NamespaceTask> task_; +}; + +class MoqtNamespacePublisherStream : public MoqtBidiStreamBase { + public: + // Constructor for the publisher side. + MoqtNamespacePublisherStream( + MoqtFramer* framer, webtransport::Stream* stream, + SessionErrorCallback session_error_callback, SessionNamespaceTree& tree, + MoqtIncomingSubscribeNamespaceCallbackNew& application); + ~MoqtNamespacePublisherStream() override; + + // MoqtBidiStreamBase overrides. + void OnSubscribeNamespaceMessage( + const MoqtSubscribeNamespace& message) override; + // TODO(martinduke): Implement this. + void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) override { + QUICHE_DLOG(INFO) << "Got SUBSCRIBE_UPDATE on Namespace stream"; + } + + private: + uint64_t request_id_; + SessionNamespaceTree& tree_; + MoqtIncomingSubscribeNamespaceCallbackNew& application_; + std::unique_ptr<MoqtNamespaceTask> task_; + absl::flat_hash_set<TrackNamespace> published_suffixes_; +}; + +} // namespace moqt + +#endif // QUICHE_QUIC_MOQT_MOQT_NAMESPACE_STREAM_H_
diff --git a/quiche/quic/moqt/moqt_namespace_stream_test.cc b/quiche/quic/moqt/moqt_namespace_stream_test.cc new file mode 100644 index 0000000..eb31193 --- /dev/null +++ b/quiche/quic/moqt/moqt_namespace_stream_test.cc
@@ -0,0 +1,341 @@ +// Copyright (c) 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_namespace_stream.h" + +#include <cstdint> +#include <memory> +#include <optional> +#include <utility> + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_session_callbacks.h" +#include "quiche/quic/moqt/session_namespace_tree.h" +#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_mem_slice.h" +#include "quiche/common/quiche_stream.h" +#include "quiche/web_transport/test_tools/mock_web_transport.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt::test { +namespace { + +using ::testing::_; +using ::testing::Eq; +using ::testing::Return; + +constexpr uint64_t kRequestId = 3; +const TrackNamespace kPrefix({"foo"}); + +class MockNamespaceTask : public MoqtNamespaceTask { + public: + MockNamespaceTask(TrackNamespace& prefix) : prefix_(prefix) {} + MOCK_METHOD(GetNextResult, GetNextSuffix, + (TrackNamespace & whole_namespace, TransactionType& type), + (override)); + MOCK_METHOD(void, SetObjectAvailableCallback, + (ObjectsAvailableCallback callback), (override)); + MOCK_METHOD(std::optional<webtransport::StreamErrorCode>, GetStatus, (), + (override)); + const TrackNamespace& prefix() override { return prefix_; } + + private: + TrackNamespace prefix_; +}; + +class MoqtNamespaceSubscriberStreamTest : public quiche::test::QuicheTest { + public: + MoqtNamespaceSubscriberStreamTest() + : framer_(true), + stream_(&framer_, kRequestId, deleted_callback_.AsStdFunction(), + error_callback_.AsStdFunction(), + response_callback_.AsStdFunction()), + task_(stream_.CreateTask(kPrefix)) { + stream_.set_stream(&mock_stream_); + } + + MoqtFramer framer_; + testing::MockFunction<void()> deleted_callback_; + testing::MockFunction<void(MoqtError, absl::string_view)> error_callback_; + testing::MockFunction<void(std::optional<MoqtRequestErrorInfo>)> + response_callback_; + webtransport::test::MockStream mock_stream_; + MoqtNamespaceSubscriberStream stream_; + std::unique_ptr<MoqtNamespaceTask> task_ = stream_.CreateTask(kPrefix); +}; + +TEST_F(MoqtNamespaceSubscriberStreamTest, RequestOk) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, RequestOkWrongId) { + EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, + "Unexpected request ID in response")); + stream_.OnRequestOkMessage({kRequestId + 1}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, RequestError) { + EXPECT_CALL(response_callback_, Call); + stream_.OnRequestErrorMessage({kRequestId, RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), + "bar"}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, RequestErrorWrongId) { + EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, + "Unexpected request ID in error")); + stream_.OnRequestErrorMessage( + {kRequestId + 1, RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceBeforeResponse) { + EXPECT_CALL(error_callback_, + Call(MoqtError::kProtocolViolation, + "First message must be REQUEST_OK or REQUEST_ERROR")); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneBeforeResponse) { + EXPECT_CALL(error_callback_, + Call(MoqtError::kProtocolViolation, + "First message must be REQUEST_OK or REQUEST_ERROR")); + stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceAfterResponse) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + TrackNamespace received_namespace; + TransactionType type; + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); + EXPECT_EQ(type, TransactionType::kAdd); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneAfterResponse) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + TrackNamespace received_namespace; + TransactionType type; + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); + EXPECT_EQ(type, TransactionType::kAdd); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); + EXPECT_EQ(type, TransactionType::kDelete); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, DuplicateNamespace) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + EXPECT_CALL(error_callback_, + Call(MoqtError::kProtocolViolation, + "Two NAMESPACE messages for the same track namespace")); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneWithoutNamespace) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, + "NAMESPACE_DONE with no active namespace")); + stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneThenNamespace) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + EXPECT_CALL(error_callback_, Call).Times(0); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + stream_.OnNamespaceMessage({TrackNamespace({"buzz"})}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, TaskGetNextSuffix) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + stream_.OnNamespaceMessage({TrackNamespace({"buzz"})}); + stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + TrackNamespace received_namespace; + TransactionType type; + bool object_available = false; + task_->SetObjectAvailableCallback([&]() { object_available = true; }); + EXPECT_TRUE(object_available); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); + EXPECT_EQ(type, TransactionType::kAdd); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"buzz"})); + EXPECT_EQ(type, TransactionType::kAdd); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); + EXPECT_EQ(type, TransactionType::kDelete); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); + object_available = false; + stream_.OnNamespaceMessage({TrackNamespace({"another"})}); + EXPECT_TRUE(object_available); + object_available = false; + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"another"})); + EXPECT_EQ(type, TransactionType::kAdd); + EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); +} + +class MoqtNamespacePublisherStreamTest : public quiche::test::QuicheTest { + public: + MoqtNamespacePublisherStreamTest() + : framer_(false), + tree_(), + application_callback_(mock_application_.AsStdFunction()), + stream_(&framer_, &mock_stream_, error_callback_.AsStdFunction(), tree_, + application_callback_) { + EXPECT_CALL(mock_stream_, CanWrite()).WillRepeatedly(Return(true)); + } + + MoqtFramer framer_; + testing::MockFunction<void(MoqtError, absl::string_view)> error_callback_; + webtransport::test::MockStream mock_stream_; + SessionNamespaceTree tree_; + testing::MockFunction<std::unique_ptr<MoqtNamespaceTask>( + const TrackNamespace&, std::optional<MessageParameters>, + MoqtResponseCallback)> + mock_application_; + MoqtIncomingSubscribeNamespaceCallbackNew application_callback_; + MoqtNamespacePublisherStream stream_; +}; + +TEST_F(MoqtNamespacePublisherStreamTest, Subscribe) { + MoqtSubscribeNamespace message = { + kRequestId, + TrackNamespace({"foo"}), + SubscribeNamespaceOption::kNamespace, + MessageParameters(), + }; + ObjectsAvailableCallback callback; + MockNamespaceTask* task_ptr; + EXPECT_CALL(mock_application_, Call) + .WillOnce([&](const TrackNamespace&, std::optional<MessageParameters>, + MoqtResponseCallback response_callback) { + std::move(response_callback)(std::nullopt); + auto task = + std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); + EXPECT_CALL(*task, SetObjectAvailableCallback) + .WillOnce([&](ObjectsAvailableCallback oa_callback) { + callback = std::move(oa_callback); + }); + task_ptr = task.get(); + return task; + }); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); + stream_.OnSubscribeNamespaceMessage(message); + EXPECT_EQ(task_ptr->prefix(), message.track_namespace_prefix); + + // Deliver NAMESPACE. + EXPECT_CALL(*task_ptr, GetNextSuffix) + .WillOnce([](TrackNamespace& ns, TransactionType& type) { + ns = TrackNamespace({"bar"}); + type = TransactionType::kAdd; + return kSuccess; + }) + .WillOnce([](TrackNamespace& ns, TransactionType& type) { + ns = TrackNamespace({"beef"}); + type = TransactionType::kAdd; + return kSuccess; + }) + .WillOnce(Return(kPending)); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kNamespace), _)) + .Times(2); + callback(); + + // Deliver NAMESPACE_DONE and FIN. + EXPECT_CALL(*task_ptr, GetNextSuffix) + .WillOnce([](TrackNamespace& ns, TransactionType& type) { + ns = TrackNamespace({"bar"}); + type = TransactionType::kDelete; + return kSuccess; + }) + .WillOnce([](TrackNamespace& ns, TransactionType& type) { return kEof; }); + EXPECT_CALL(mock_stream_, Writev) + .WillOnce([&](absl::Span<quiche::QuicheMemSlice> slices, + const quiche::StreamWriteOptions& options) { + EXPECT_EQ(slices.size(), 1); + EXPECT_EQ(slices[0].data()[0], + static_cast<uint8_t>(MoqtMessageType::kNamespaceDone)); + EXPECT_FALSE(options.send_fin()); + return absl::OkStatus(); + }) + .WillOnce([&](absl::Span<quiche::QuicheMemSlice> slices, + const quiche::StreamWriteOptions& options) { + EXPECT_EQ(slices.size(), 0); + EXPECT_TRUE(options.send_fin()); + return absl::OkStatus(); + }); + callback(); +} + +TEST_F(MoqtNamespacePublisherStreamTest, RequestError) { + MoqtSubscribeNamespace message = { + kRequestId, + TrackNamespace({"foo"}), + SubscribeNamespaceOption::kNamespace, + MessageParameters(), + }; + EXPECT_CALL(mock_application_, Call) + .WillOnce([&](const TrackNamespace&, std::optional<MessageParameters>, + MoqtResponseCallback response_callback) { + std::move(response_callback)(MoqtRequestErrorInfo{ + RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); + auto task = + std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); + return task; + }); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); + stream_.OnSubscribeNamespaceMessage(message); +} + +TEST_F(MoqtNamespacePublisherStreamTest, SubscribePrefixOverlap) { + MoqtSubscribeNamespace message = { + kRequestId, + TrackNamespace({"foo", "bar", "baz"}), + SubscribeNamespaceOption::kNamespace, + MessageParameters(), + }; + // The namespace tree already has a subscriber for a prefix of "foo". + tree_.SubscribeNamespace(TrackNamespace({"foo", "bar"})); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); + stream_.OnSubscribeNamespaceMessage(message); + // Try to subscribe to the parent. Also not allowed. + message.track_namespace_prefix.PopElement(); + message.track_namespace_prefix.PopElement(); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); + stream_.OnSubscribeNamespaceMessage(message); +} + +} // namespace +} // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h index 7e1c610..cc477f8 100644 --- a/quiche/quic/moqt/moqt_session_callbacks.h +++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -5,6 +5,7 @@ #ifndef QUICHE_QUIC_MOQT_MOQT_SESSION_CALLBACKS_H_ #define QUICHE_QUIC_MOQT_MOQT_SESSION_CALLBACKS_H_ +#include <memory> #include <optional> #include <utility> @@ -12,6 +13,7 @@ #include "quiche/quic/core/quic_clock.h" #include "quiche/quic/core/quic_default_clock.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/common/quiche_callbacks.h" @@ -40,7 +42,8 @@ // Called whenever a PUBLISH_NAMESPACE or PUBLISH_NAMESPACE_DONE message is // received from the peer. PUBLISH_NAMESPACE sets a value for |parameters|, -// PUBLISH_NAMESPACE_DONE does not.. +// PUBLISH_NAMESPACE_DONE does not. This callback is not invoked by NAMESPACE or +// NAMESPACE_DONE messages that arrive on a SUBSCRIBE_NAMESPACE stream. using MoqtIncomingPublishNamespaceCallback = quiche::MultiUseCallback<void( const TrackNamespace& track_namespace, const std::optional<VersionSpecificParameters>& parameters, @@ -50,10 +53,16 @@ // the peer. SUBSCRIBE_NAMESPACE sets a value for |parameters|, // UNSUBSCRIBE_NAMESPACE does not. For UNSUBSCRIBE_NAMESPACE, |callback| is // null. +// TODO(martinduke): Remove this callback once the new one is in use. using MoqtIncomingSubscribeNamespaceCallback = quiche::MultiUseCallback<void(const TrackNamespace& track_namespace, std::optional<MessageParameters> parameters, MoqtResponseCallback callback)>; +using MoqtIncomingSubscribeNamespaceCallbackNew = + quiche::MultiUseCallback<std::unique_ptr<MoqtNamespaceTask>( + const TrackNamespace& prefix, + std::optional<MessageParameters> parameters, + MoqtResponseCallback callback)>; inline void DefaultIncomingPublishNamespaceCallback( const TrackNamespace&, const std::optional<VersionSpecificParameters>&, @@ -66,11 +75,21 @@ "This endpoint does not support incoming SUBSCRIBE_NAMESPACE messages"}); }; +// TODO(martinduke): Remove this callback once the new one is in use. inline void DefaultIncomingSubscribeNamespaceCallback( const TrackNamespace& track_namespace, std::optional<MessageParameters>, MoqtResponseCallback callback) { std::move(callback)(std::nullopt); } +inline std::unique_ptr<MoqtNamespaceTask> +DefaultIncomingSubscribeNamespaceCallbackNew( + const TrackNamespace& track_namespace, std::optional<MessageParameters>, + MoqtResponseCallback callback) { + std::move(callback)(MoqtRequestErrorInfo{RequestErrorCode::kNotSupported, + std::nullopt, + "This endpoint cannot publish."}); + return nullptr; +} // Callbacks for session-level events. struct MoqtSessionCallbacks {