Update MoqtSession and applications to use MoqtNamespaceStream. Required significant rework of RelayNamespaceTree. UNSUBSCRIBE_NAMESPACE now deleted; simply reset the namespace stream instead. I don't know what broke with fuzz test visibility, but this wouldn't build without the quic/BUILD change. PiperOrigin-RevId: 866645127
diff --git a/build/source_list.bzl b/build/source_list.bzl index 3ce12e5..c465704 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1633,6 +1633,7 @@ "quic/moqt/moqt_subscribe_windows.cc", "quic/moqt/moqt_trace_recorder.cc", "quic/moqt/moqt_track.cc", + "quic/moqt/relay_namespace_tree.cc", "quic/moqt/tools/chat_client.cc", "quic/moqt/tools/moq_chat.cc", "quic/moqt/tools/moqt_client.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index 182019a..55fd86b 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1637,6 +1637,7 @@ "src/quiche/quic/moqt/moqt_subscribe_windows.cc", "src/quiche/quic/moqt/moqt_trace_recorder.cc", "src/quiche/quic/moqt/moqt_track.cc", + "src/quiche/quic/moqt/relay_namespace_tree.cc", "src/quiche/quic/moqt/tools/chat_client.cc", "src/quiche/quic/moqt/tools/moq_chat.cc", "src/quiche/quic/moqt/tools/moqt_client.cc",
diff --git a/build/source_list.json b/build/source_list.json index c4a622d..cd25bcd 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1636,6 +1636,7 @@ "quiche/quic/moqt/moqt_subscribe_windows.cc", "quiche/quic/moqt/moqt_trace_recorder.cc", "quiche/quic/moqt/moqt_track.cc", + "quiche/quic/moqt/relay_namespace_tree.cc", "quiche/quic/moqt/tools/chat_client.cc", "quiche/quic/moqt/tools/moq_chat.cc", "quiche/quic/moqt/tools/moqt_client.cc",
diff --git a/quiche/quic/moqt/moqt_bidi_stream.h b/quiche/quic/moqt/moqt_bidi_stream.h index 2568c52..be5451d 100644 --- a/quiche/quic/moqt/moqt_bidi_stream.h +++ b/quiche/quic/moqt/moqt_bidi_stream.h
@@ -120,10 +120,6 @@ const MoqtSubscribeNamespace& message) override { OnParsingError(wrong_message_error_, wrong_message_reason_); } - virtual void OnUnsubscribeNamespaceMessage( - const MoqtUnsubscribeNamespace& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } virtual void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override { OnParsingError(wrong_message_error_, wrong_message_reason_); }
diff --git a/quiche/quic/moqt/moqt_bidi_stream_test.cc b/quiche/quic/moqt/moqt_bidi_stream_test.cc index 83c92ea..c47a231 100644 --- a/quiche/quic/moqt/moqt_bidi_stream_test.cc +++ b/quiche/quic/moqt/moqt_bidi_stream_test.cc
@@ -148,13 +148,6 @@ EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "Message not allowed for this stream type")); - stream_->OnUnsubscribeNamespaceMessage(MoqtUnsubscribeNamespace{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); stream_->OnMaxRequestIdMessage(MoqtMaxRequestId{}); stream_ = std::make_unique<MoqtBidiStreamBase>( &framer_, deleted_callback_.AsStdFunction(),
diff --git a/quiche/quic/moqt/moqt_fetch_task.h b/quiche/quic/moqt/moqt_fetch_task.h index adc5615..f18d21f 100644 --- a/quiche/quic/moqt/moqt_fetch_task.h +++ b/quiche/quic/moqt/moqt_fetch_task.h
@@ -11,6 +11,7 @@ #include <utility> #include <variant> +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_messages.h" @@ -92,6 +93,10 @@ public: virtual ~MoqtNamespaceTask() = default; + // The provided callback may be immediately invoked. + virtual void SetObjectsAvailableCallback(ObjectsAvailableCallback + absl_nullable callback) = 0; + // Returns the state of the message queue. If available, writes the suffix // into |suffix|. If |type| is kAdd, it is from a NAMESPACE message. If |type| // is kDelete, it is from a NAMESPACE_DONE message.
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index 95197a8..508e72d 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -642,12 +642,6 @@ WireKeyValuePairList(message.parameters.ToKeyValuePairList())); } -quiche::QuicheBuffer MoqtFramer::SerializeUnsubscribeNamespace( - const MoqtUnsubscribeNamespace& message) { - return SerializeControlMessage(MoqtMessageType::kUnsubscribeNamespace, - WireTrackNamespace(message.track_namespace)); -} - quiche::QuicheBuffer MoqtFramer::SerializeMaxRequestId( const MoqtMaxRequestId& message) { return SerializeControlMessage(MoqtMessageType::kMaxRequestId,
diff --git a/quiche/quic/moqt/moqt_framer.h b/quiche/quic/moqt/moqt_framer.h index b5ff0ad..9d3e797 100644 --- a/quiche/quic/moqt/moqt_framer.h +++ b/quiche/quic/moqt/moqt_framer.h
@@ -69,8 +69,6 @@ quiche::QuicheBuffer SerializeGoAway(const MoqtGoAway& message); quiche::QuicheBuffer SerializeSubscribeNamespace( const MoqtSubscribeNamespace& message); - quiche::QuicheBuffer SerializeUnsubscribeNamespace( - const MoqtUnsubscribeNamespace& message); quiche::QuicheBuffer SerializeMaxRequestId(const MoqtMaxRequestId& message); quiche::QuicheBuffer SerializeFetch(const MoqtFetch& message); quiche::QuicheBuffer SerializeFetchCancel(const MoqtFetchCancel& message);
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index acdc70b..b90cf2a 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -50,7 +50,6 @@ MoqtMessageType::kTrackStatus, MoqtMessageType::kGoAway, MoqtMessageType::kSubscribeNamespace, - MoqtMessageType::kUnsubscribeNamespace, MoqtMessageType::kMaxRequestId, MoqtMessageType::kFetch, MoqtMessageType::kFetchCancel, @@ -176,10 +175,6 @@ auto data = std::get<MoqtSubscribeNamespace>(structured_data); return framer_.SerializeSubscribeNamespace(data); } - case moqt::MoqtMessageType::kUnsubscribeNamespace: { - auto data = std::get<MoqtUnsubscribeNamespace>(structured_data); - return framer_.SerializeUnsubscribeNamespace(data); - } case moqt::MoqtMessageType::kMaxRequestId: { auto data = std::get<MoqtMaxRequestId>(structured_data); return framer_.SerializeMaxRequestId(data);
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc index a64ff21..ae97049 100644 --- a/quiche/quic/moqt/moqt_messages.cc +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -133,8 +133,6 @@ return "GOAWAY"; case MoqtMessageType::kSubscribeNamespace: return "SUBSCRIBE_NAMESPACE"; - case MoqtMessageType::kUnsubscribeNamespace: - return "UNSUBSCRIBE_NAMESPACE"; case MoqtMessageType::kMaxRequestId: return "MAX_REQUEST_ID"; case MoqtMessageType::kPublish:
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index 3f006d5..ece3be9 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -258,7 +258,6 @@ kNamespaceDone = 0x0e, kGoAway = 0x10, kSubscribeNamespace = 0x11, - kUnsubscribeNamespace = 0x14, kMaxRequestId = 0x15, kFetch = 0x16, kFetchCancel = 0x17, @@ -446,11 +445,6 @@ MessageParameters parameters; }; -// TODO(martinduke): Delete this -struct QUICHE_EXPORT MoqtUnsubscribeNamespace { - TrackNamespace track_namespace; -}; - struct QUICHE_EXPORT MoqtNamespace { TrackNamespace track_namespace_suffix; };
diff --git a/quiche/quic/moqt/moqt_names.cc b/quiche/quic/moqt/moqt_names.cc index 510e740..93b7eea 100644 --- a/quiche/quic/moqt/moqt_names.cc +++ b/quiche/quic/moqt/moqt_names.cc
@@ -36,6 +36,25 @@ } } +TrackNamespace::TrackNamespace(absl::Span<const std::string> elements) + : tuple_(elements.begin(), elements.end()) { + if (std::size(elements) > kMaxNamespaceElements) { + tuple_.clear(); + QUICHE_BUG(Moqt_namespace_too_large_01) + << "Constructing a namespace that is too large."; + return; + } + for (const auto& it : elements) { + length_ += it.size(); + if (length_ > kMaxFullTrackNameSize) { + tuple_.clear(); + QUICHE_BUG(Moqt_namespace_too_large_02) + << "Constructing a namespace that is too large."; + return; + } + } +} + bool TrackNamespace::InNamespace(const TrackNamespace& other) const { if (tuple_.size() < other.tuple_.size()) { return false;
diff --git a/quiche/quic/moqt/moqt_names.h b/quiche/quic/moqt/moqt_names.h index 0102eed..37a7fb2 100644 --- a/quiche/quic/moqt/moqt_names.h +++ b/quiche/quic/moqt/moqt_names.h
@@ -13,6 +13,8 @@ #include <string> #include <vector> +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -26,6 +28,7 @@ class TrackNamespace { public: explicit TrackNamespace(absl::Span<const absl::string_view> elements); + explicit TrackNamespace(absl::Span<const std::string> elements); explicit TrackNamespace( std::initializer_list<const absl::string_view> elements) : TrackNamespace(absl::Span<const absl::string_view>( @@ -46,13 +49,32 @@ } void AddElement(absl::string_view element); bool PopElement() { - if (tuple_.size() == 1) { + if (tuple_.size() == 0) { return false; } length_ -= tuple_.back().length(); tuple_.pop_back(); return true; } + absl::StatusOr<TrackNamespace> AddSuffix(const TrackNamespace& suffix) const { + TrackNamespace result = *this; + result.tuple_.reserve(tuple_.size() + suffix.tuple_.size()); + for (const auto& element : suffix.tuple()) { + if (!result.CanAddElement(element)) { + return absl::OutOfRangeError("Combined namespace is too large"); + } + result.AddElement(element); + } + return result; + } + absl::StatusOr<TrackNamespace> ExtractSuffix( + const TrackNamespace& prefix) const { + if (!InNamespace(prefix)) { + return absl::InvalidArgumentError("Prefix is not in namespace"); + } + return TrackNamespace( + absl::MakeSpan(tuple_).subspan(prefix.number_of_elements())); + } std::string ToString() const; // Returns the number of elements in the tuple. size_t number_of_elements() const { return tuple_.size(); }
diff --git a/quiche/quic/moqt/moqt_names_test.cc b/quiche/quic/moqt/moqt_names_test.cc index cffdaf0..2d843ee 100644 --- a/quiche/quic/moqt/moqt_names_test.cc +++ b/quiche/quic/moqt/moqt_names_test.cc
@@ -7,9 +7,11 @@ #include <vector> #include "absl/hash/hash.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/common/platform/api/quiche_expect_bug.h" #include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" namespace moqt::test { namespace { @@ -49,6 +51,8 @@ EXPECT_FALSE(original.InNamespace(name)); EXPECT_TRUE(name.PopElement()); EXPECT_EQ(name, original); + EXPECT_TRUE(name.PopElement()); + EXPECT_EQ(name.number_of_elements(), 0); EXPECT_FALSE(name.PopElement()); } @@ -65,6 +69,22 @@ EXPECT_EQ(name1.ToString(), R"({"a"::"b"}::c)"); } +TEST(MoqtNamesTest, TrackNamespaceSuffixes) { + using quiche::test::IsOkAndHolds; + TrackNamespace name1({"a", "b"}); + TrackNamespace name2({"c", "d"}); + TrackNamespace name3({"a", "b", "c", "d"}); + EXPECT_THAT(name1.AddSuffix(name2), IsOkAndHolds(name3)); + EXPECT_THAT(name3.ExtractSuffix(name1), IsOkAndHolds(name2)); + EXPECT_THAT(name1.AddSuffix(TrackNamespace()), IsOkAndHolds(name1)); + EXPECT_THAT(TrackNamespace().AddSuffix(name1), IsOkAndHolds(name1)); + EXPECT_THAT(name1.ExtractSuffix(TrackNamespace()), IsOkAndHolds(name1)); + EXPECT_THAT(name1.ExtractSuffix(name1), IsOkAndHolds(TrackNamespace())); + TrackNamespace name4({"c", "b"}); + EXPECT_EQ(name1.ExtractSuffix(name4).status(), + absl::InvalidArgumentError("Prefix is not in namespace")); +} + TEST(MoqtNamesTest, TooManyNamespaceElements) { // 32 elements work. TrackNamespace name1({"a", "b", "c", "d", "e", "f", "g", "h",
diff --git a/quiche/quic/moqt/moqt_namespace_stream.cc b/quiche/quic/moqt/moqt_namespace_stream.cc index 1fe45cf..ca59fd6 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.cc +++ b/quiche/quic/moqt/moqt_namespace_stream.cc
@@ -102,6 +102,8 @@ "Two NAMESPACE messages for the same track namespace"); return; } + QUIC_DLOG(INFO) << "Received NAMESPACE message for " + << message.track_namespace_suffix; task->AddPendingSuffix(message.track_namespace_suffix, TransactionType::kAdd); } @@ -121,15 +123,15 @@ "NAMESPACE_DONE with no active namespace"); return; } + QUIC_DLOG(INFO) << "Received NAMESPACE_DONE message for " + << message.track_namespace_suffix; task->AddPendingSuffix(message.track_namespace_suffix, TransactionType::kDelete); } std::unique_ptr<MoqtNamespaceTask> MoqtNamespaceSubscriberStream::CreateTask( - const TrackNamespace& prefix, - ObjectsAvailableCallback absl_nonnull callback) { - auto task = - std::make_unique<NamespaceTask>(this, prefix, std::move(callback)); + const TrackNamespace& prefix) { + auto task = std::make_unique<NamespaceTask>(this, prefix); QUICHE_DCHECK(task != nullptr); task_ = task->GetWeakPtr(); QUICHE_DCHECK(task_.IsValid()); @@ -142,6 +144,14 @@ } } +void MoqtNamespaceSubscriberStream::NamespaceTask::SetObjectsAvailableCallback( + ObjectsAvailableCallback absl_nullable callback) { + callback_ = std::move(callback); + if (!pending_suffixes_.empty() && callback_ != nullptr) { + callback_(); + } +} + GetNextResult MoqtNamespaceSubscriberStream::NamespaceTask::GetNextSuffix( TrackNamespace& suffix, TransactionType& type) { if (pending_suffixes_.empty()) { @@ -187,45 +197,57 @@ MoqtNamespacePublisherStream::MoqtNamespacePublisherStream( MoqtFramer* framer, webtransport::Stream* stream, - SessionErrorCallback session_error_callback, SessionNamespaceTree& tree, - MoqtIncomingSubscribeNamespaceCallbackNew& application) + SessionErrorCallback session_error_callback, + SessionNamespaceTree* absl_nonnull tree, + MoqtIncomingSubscribeNamespaceCallback& application) // No stream_deleted_callback because there's no state yet. : MoqtBidiStreamBase( framer, []() {}, std::move(session_error_callback)), - tree_(tree), + tree_(tree->GetWeakPtr()), application_(application) { // TODO(martinduke): Set the priority for this stream. MoqtBidiStreamBase::set_stream(stream, MoqtMessageType::kSubscribeNamespace); } MoqtNamespacePublisherStream::~MoqtNamespacePublisherStream() { - if (task_ != nullptr) { + if (task_ == nullptr) { + return; + } + SessionNamespaceTree* tree = tree_.GetIfAvailable(); + if (tree != nullptr) { // Could be null if the stream died early. - tree_.UnsubscribeNamespace(task_->prefix()); + tree->UnsubscribeNamespace(task_->prefix()); } } void MoqtNamespacePublisherStream::OnSubscribeNamespaceMessage( const MoqtSubscribeNamespace& message) { request_id_ = message.request_id; - if (!tree_.SubscribeNamespace(message.track_namespace_prefix)) { + SessionNamespaceTree* tree = tree_.GetIfAvailable(); + if (tree == nullptr) { + SendRequestError(request_id_, RequestErrorCode::kInternalError, + std::nullopt, "Session is gone", /*fin=*/true); + return; + } + if (!tree->SubscribeNamespace(message.track_namespace_prefix)) { SendRequestError(request_id_, RequestErrorCode::kPrefixOverlap, - std::nullopt, ""); + std::nullopt, "", /*fin=*/true); return; } QUICHE_DCHECK(task_ == nullptr); - task_ = application_( - message.track_namespace_prefix, message.parameters, - // Response callback - [this](std::optional<MoqtRequestErrorInfo> error) { - if (error.has_value()) { - SendRequestError(request_id_, *error, /*fin=*/true); - } else { - SendRequestOk(request_id_, MessageParameters()); - } - }, - // Objects available callback - [this]() { ProcessNamespaces(); }); + task_ = application_(message.track_namespace_prefix, + message.subscribe_options, message.parameters, + // Response callback + [this](std::optional<MoqtRequestErrorInfo> error) { + if (error.has_value()) { + SendRequestError(request_id_, *error, /*fin=*/true); + } else { + SendRequestOk(request_id_, MessageParameters()); + } + }); + if (task_ != nullptr) { + task_->SetObjectsAvailableCallback([this]() { ProcessNamespaces(); }); + } } void MoqtNamespacePublisherStream::ProcessNamespaces() {
diff --git a/quiche/quic/moqt/moqt_namespace_stream.h b/quiche/quic/moqt/moqt_namespace_stream.h index 919a4d9..94e9aea 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.h +++ b/quiche/quic/moqt/moqt_namespace_stream.h
@@ -49,9 +49,7 @@ void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) override; // Send the prefix now so it is only stored in one place (the task). - std::unique_ptr<MoqtNamespaceTask> CreateTask(const TrackNamespace& prefix, - ObjectsAvailableCallback - absl_nonnull callback); + std::unique_ptr<MoqtNamespaceTask> CreateTask(const TrackNamespace& prefix); private: // The class that will be passed to the application to consume namespace @@ -59,14 +57,16 @@ class NamespaceTask : public MoqtNamespaceTask { public: NamespaceTask(MoqtNamespaceSubscriberStream* absl_nonnull state, - const TrackNamespace& prefix, - ObjectsAvailableCallback absl_nonnull callback) + const TrackNamespace& prefix) : MoqtNamespaceTask(), prefix_(prefix), state_(state), - callback_(std::move(callback)), weak_ptr_factory_(this) {} ~NamespaceTask() override; + + void SetObjectsAvailableCallback(ObjectsAvailableCallback + absl_nullable callback) override; + // MoqtNamespaceTask methods. A return value of kEof implies // NAMESPACE_DONE for all outstanding namespaces. GetNextResult GetNextSuffix(TrackNamespace& suffix, @@ -97,7 +97,7 @@ // Must be nonnull initially, will be nullptr if the stream is closed. MoqtNamespaceSubscriberStream* state_; quiche::QuicheCircularDeque<PendingSuffix> pending_suffixes_; - ObjectsAvailableCallback callback_; + ObjectsAvailableCallback absl_nullable callback_ = nullptr; std::optional<webtransport::StreamErrorCode> error_; bool eof_ = false; // Must be last. @@ -115,15 +115,16 @@ // Constructor for the publisher side. MoqtNamespacePublisherStream( MoqtFramer* framer, webtransport::Stream* stream, - SessionErrorCallback session_error_callback, SessionNamespaceTree& tree, - MoqtIncomingSubscribeNamespaceCallbackNew& application); + SessionErrorCallback session_error_callback, + SessionNamespaceTree* absl_nonnull tree, + MoqtIncomingSubscribeNamespaceCallback& application); ~MoqtNamespacePublisherStream() override; // MoqtBidiStreamBase overrides. void OnSubscribeNamespaceMessage( const MoqtSubscribeNamespace& message) override; // TODO(martinduke): Implement this. - void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) override { + void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate&) override { QUICHE_DLOG(INFO) << "Got SUBSCRIBE_UPDATE on Namespace stream"; } @@ -131,8 +132,8 @@ void ProcessNamespaces(); uint64_t request_id_; - SessionNamespaceTree& tree_; - MoqtIncomingSubscribeNamespaceCallbackNew& application_; + quiche::QuicheWeakPtr<SessionNamespaceTree> tree_; + MoqtIncomingSubscribeNamespaceCallback& application_; std::unique_ptr<MoqtNamespaceTask> task_; absl::flat_hash_set<TrackNamespace> published_suffixes_; };
diff --git a/quiche/quic/moqt/moqt_namespace_stream_test.cc b/quiche/quic/moqt/moqt_namespace_stream_test.cc index a49b626..97a0cb1 100644 --- a/quiche/quic/moqt/moqt_namespace_stream_test.cc +++ b/quiche/quic/moqt/moqt_namespace_stream_test.cc
@@ -22,11 +22,11 @@ #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/session_namespace_tree.h" #include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" +#include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h" #include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_stream.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" -#include "quiche/web_transport/web_transport.h" namespace moqt::test { namespace { @@ -38,20 +38,6 @@ constexpr uint64_t kRequestId = 3; const TrackNamespace kPrefix({"foo"}); -class MockNamespaceTask : public MoqtNamespaceTask { - public: - MockNamespaceTask(TrackNamespace& prefix) : prefix_(prefix) {} - MOCK_METHOD(GetNextResult, GetNextSuffix, - (TrackNamespace & whole_namespace, TransactionType& type), - (override)); - MOCK_METHOD(std::optional<webtransport::StreamErrorCode>, GetStatus, (), - (override)); - const TrackNamespace& prefix() override { return prefix_; } - - private: - TrackNamespace prefix_; -}; - class MoqtNamespaceSubscriberStreamTest : public quiche::test::QuicheTest { public: MoqtNamespaceSubscriberStreamTest() @@ -59,7 +45,8 @@ stream_(&framer_, kRequestId, deleted_callback_.AsStdFunction(), error_callback_.AsStdFunction(), response_callback_.AsStdFunction()), - task_(stream_.CreateTask(kPrefix, [this]() { ++objects_available_; })) { + task_(stream_.CreateTask(kPrefix)) { + task_->SetObjectsAvailableCallback([this]() { ++objects_available_; }); stream_.set_stream(&mock_stream_); } @@ -213,8 +200,9 @@ auto stream = std::make_unique<MoqtNamespaceSubscriberStream>( &framer_, kRequestId, deleted_callback_.AsStdFunction(), error_callback_.AsStdFunction(), response_callback_.AsStdFunction()); - std::unique_ptr<MoqtNamespaceTask> task = - stream->CreateTask(kPrefix, [this]() { ++objects_available_; }); + std::unique_ptr<MoqtNamespaceTask> task = stream->CreateTask(kPrefix); + ASSERT_TRUE(task != nullptr); + task->SetObjectsAvailableCallback([this]() { ++objects_available_; }); EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); stream->OnRequestOkMessage({kRequestId}); stream->OnNamespaceMessage({TrackNamespace({"bar"})}); @@ -235,8 +223,8 @@ : framer_(false), tree_(), application_callback_(mock_application_.AsStdFunction()), - stream_(&framer_, &mock_stream_, error_callback_.AsStdFunction(), tree_, - application_callback_) { + stream_(&framer_, &mock_stream_, error_callback_.AsStdFunction(), + &tree_, application_callback_) { EXPECT_CALL(mock_stream_, CanWrite()).WillRepeatedly(Return(true)); } @@ -245,10 +233,10 @@ webtransport::test::MockStream mock_stream_; SessionNamespaceTree tree_; testing::MockFunction<std::unique_ptr<MoqtNamespaceTask>( - const TrackNamespace&, const MessageParameters&, MoqtResponseCallback, - ObjectsAvailableCallback)> + const TrackNamespace&, SubscribeNamespaceOption, const MessageParameters&, + MoqtResponseCallback)> mock_application_; - MoqtIncomingSubscribeNamespaceCallbackNew application_callback_; + MoqtIncomingSubscribeNamespaceCallback application_callback_; MoqtNamespacePublisherStream stream_; }; @@ -260,21 +248,21 @@ MessageParameters(), }; ObjectsAvailableCallback callback; - MockNamespaceTask* task_ptr; + MockNamespaceTask* task_ptr = nullptr; EXPECT_CALL(mock_application_, Call) - .WillOnce([&](const TrackNamespace&, const MessageParameters&, - MoqtResponseCallback response_callback, - ObjectsAvailableCallback available_callback) { + .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { std::move(response_callback)(std::nullopt); auto task = std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); - callback = std::move(available_callback); task_ptr = task.get(); return task; }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); stream_.OnSubscribeNamespaceMessage(message); + ASSERT_TRUE(task_ptr != nullptr); EXPECT_EQ(task_ptr->prefix(), message.track_namespace_prefix); // Deliver NAMESPACE. @@ -293,7 +281,7 @@ EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kNamespace), _)) .Times(2); - callback(); + task_ptr->InvokeCallback(); // Deliver NAMESPACE_DONE and FIN. EXPECT_CALL(*task_ptr, GetNextSuffix) @@ -318,7 +306,7 @@ EXPECT_TRUE(options.send_fin()); return absl::OkStatus(); }); - callback(); + task_ptr->InvokeCallback(); } TEST_F(MoqtNamespacePublisherStreamTest, RequestError) { @@ -329,9 +317,9 @@ MessageParameters(), }; EXPECT_CALL(mock_application_, Call) - .WillOnce([&](const TrackNamespace&, const MessageParameters&, - MoqtResponseCallback response_callback, - ObjectsAvailableCallback) { + .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { std::move(response_callback)(MoqtRequestErrorInfo{ RequestErrorCode::kInternalError, quic::QuicTimeDelta::FromMilliseconds(100), "bar"});
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index a55455d..d0d2258 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -617,9 +617,6 @@ case MoqtMessageType::kSubscribeNamespace: bytes_read = ProcessSubscribeNamespace(reader); break; - case MoqtMessageType::kUnsubscribeNamespace: - bytes_read = ProcessUnsubscribeNamespace(reader); - break; case MoqtMessageType::kMaxRequestId: bytes_read = ProcessMaxRequestId(reader); break; @@ -922,16 +919,6 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtControlParser::ProcessUnsubscribeNamespace( - quic::QuicDataReader& reader) { - MoqtUnsubscribeNamespace unsubscribe_namespace; - if (!ReadTrackNamespace(reader, unsubscribe_namespace.track_namespace)) { - return 0; - } - visitor_.OnUnsubscribeNamespaceMessage(unsubscribe_namespace); - return reader.PreviouslyReadPayload().length(); -} - size_t MoqtControlParser::ProcessMaxRequestId(quic::QuicDataReader& reader) { MoqtMaxRequestId max_request_id; if (!reader.ReadVarInt62(&max_request_id.max_request_id)) {
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index 34f6d41..a12bb5f 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -57,8 +57,6 @@ virtual void OnGoAwayMessage(const MoqtGoAway& message) = 0; virtual void OnSubscribeNamespaceMessage( const MoqtSubscribeNamespace& message) = 0; - virtual void OnUnsubscribeNamespaceMessage( - const MoqtUnsubscribeNamespace& message) = 0; virtual void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) = 0; virtual void OnFetchMessage(const MoqtFetch& message) = 0; virtual void OnFetchCancelMessage(const MoqtFetchCancel& message) = 0;
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 3a257ac..d422b86 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -51,7 +51,6 @@ MoqtMessageType::kServerSetup, MoqtMessageType::kGoAway, MoqtMessageType::kSubscribeNamespace, - MoqtMessageType::kUnsubscribeNamespace, MoqtMessageType::kMaxRequestId, MoqtMessageType::kFetch, MoqtMessageType::kFetchCancel,
diff --git a/quiche/quic/moqt/moqt_relay_publisher.cc b/quiche/quic/moqt/moqt_relay_publisher.cc index 2be7deb..478341c 100644 --- a/quiche/quic/moqt/moqt_relay_publisher.cc +++ b/quiche/quic/moqt/moqt_relay_publisher.cc
@@ -66,7 +66,8 @@ void MoqtRelayPublisher::OnPublishNamespace( const TrackNamespace& track_namespace, const VersionSpecificParameters& /*parameters*/, - MoqtSessionInterface* session, MoqtResponseCallback callback) { + MoqtSessionInterface* session, + MoqtResponseCallback absl_nullable callback) { if (session == nullptr) { return; } @@ -74,7 +75,9 @@ namespace_publishers_.AddPublisher(track_namespace, session); // TODO(martinduke): Notify subscribers listening for this namespace. // Send PUBLISH_NAMESPACE_OK. - std::move(callback)(std::nullopt); + if (callback != nullptr) { + std::move(callback)(std::nullopt); + } } void MoqtRelayPublisher::OnPublishNamespaceDone(
diff --git a/quiche/quic/moqt/moqt_relay_publisher.h b/quiche/quic/moqt/moqt_relay_publisher.h index aeb2109..37b1fcf 100644 --- a/quiche/quic/moqt/moqt_relay_publisher.h +++ b/quiche/quic/moqt/moqt_relay_publisher.h
@@ -6,10 +6,13 @@ #define QUICHE_QUIC_MOQT_MOQT_RELAY_PUBLISHER_H_ #include <memory> +#include <utility> #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" -#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_relay_track_publisher.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" @@ -33,13 +36,10 @@ absl_nullable std::shared_ptr<MoqtTrackPublisher> GetTrack( const FullTrackName& track_name) override; - void AddNamespaceSubscriber(const TrackNamespace& track_namespace, - MoqtSessionInterface* session) { - namespace_publishers_.AddSubscriber(track_namespace, session); - } - void RemoveNamespaceSubscriber(const TrackNamespace& track_namespace, - MoqtSessionInterface* session) { - namespace_publishers_.RemoveSubscriber(track_namespace, session); + std::unique_ptr<MoqtNamespaceTask> AddNamespaceSubscriber( + const TrackNamespace& track_namespace, + MoqtSessionInterface* absl_nullable session) { + return namespace_publishers_.AddSubscriber(track_namespace, session); } // There is a new default upstream session. When there is no other namespace @@ -55,7 +55,7 @@ void OnPublishNamespace(const TrackNamespace& track_namespace, const VersionSpecificParameters& parameters, MoqtSessionInterface* session, - MoqtResponseCallback callback); + MoqtResponseCallback absl_nullable callback); void OnPublishNamespaceDone(const TrackNamespace& track_namespace, MoqtSessionInterface* session);
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index d42810e..32a0c45 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -36,6 +36,7 @@ #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_namespace_stream.h" #include "quiche/quic/moqt/moqt_object.h" #include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_priority.h" @@ -257,15 +258,15 @@ CleanUpState(); } -bool MoqtSession::SubscribeNamespace( - TrackNamespace track_namespace, - MoqtOutgoingSubscribeNamespaceCallback callback, - MessageParameters parameters) { - QUICHE_DCHECK(track_namespace.IsValid()); +std::unique_ptr<MoqtNamespaceTask> MoqtSession::SubscribeNamespace( + TrackNamespace& prefix, SubscribeNamespaceOption option, + const MessageParameters& parameters, + MoqtResponseCallback response_callback) { + QUICHE_DCHECK(prefix.IsValid()); if (received_goaway_ || sent_goaway_) { QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE_NAMESPACE after GOAWAY"; - return false; + return nullptr; } if (next_request_id_ >= peer_max_request_id_) { if (!last_requests_blocked_sent_.has_value() || @@ -279,44 +280,55 @@ << next_request_id_ << " which is greater than the maximum ID " << peer_max_request_id_; - return false; + return nullptr; } - if (outgoing_subscribe_namespaces_.contains(track_namespace)) { - std::move(callback)( - track_namespace, - MoqtRequestErrorInfo{ - RequestErrorCode::kInternalError, std::nullopt, - "SUBSCRIBE_NAMESPACE already outstanding for namespace"}); - return false; + // Sanitize the option. + switch (option) { + case SubscribeNamespaceOption::kNamespace: + break; + case SubscribeNamespaceOption::kPublish: + // TODO(martinduke): Support PUBLISH. + return nullptr; + case SubscribeNamespaceOption::kBoth: + option = SubscribeNamespaceOption::kNamespace; + break; + } + QUICHE_DCHECK(option == SubscribeNamespaceOption::kNamespace); + if (!outgoing_subscribe_namespace_.SubscribeNamespace(prefix)) { + std::move(response_callback)(MoqtRequestErrorInfo{ + RequestErrorCode::kInternalError, std::nullopt, + "SUBSCRIBE_NAMESPACE already outstanding for namespace"}); + return nullptr; + } + std::unique_ptr<MoqtNamespaceSubscriberStream> state = + std::make_unique<MoqtNamespaceSubscriberStream>( + &framer_, next_request_id_, + [this, prefix]() { + if (!is_closing_) { + outgoing_subscribe_namespace_.UnsubscribeNamespace(prefix); + } + }, + [this](MoqtError error, absl::string_view reason) { + Error(error, reason); + }, + std::move(response_callback)); + MoqtNamespaceSubscriberStream* state_ptr = state.get(); + if (session_->CanOpenNextOutgoingBidirectionalStream()) { + webtransport::Stream* stream = session_->OpenOutgoingBidirectionalStream(); + state->set_stream(stream); + stream->SetVisitor(std::move(state)); + } else { + pending_bidi_streams_.push_back(std::move(state)); } MoqtSubscribeNamespace message; message.request_id = next_request_id_; - next_request_id_ += 2; - message.track_namespace_prefix = track_namespace; - // We don't support PUBLISH, so don't ask for it. + message.track_namespace_prefix = prefix; message.subscribe_options = SubscribeNamespaceOption::kNamespace; message.parameters = parameters; - SendControlMessage(framer_.SerializeSubscribeNamespace(message)); + state_ptr->SendOrBufferMessage(framer_.SerializeSubscribeNamespace(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_NAMESPACE message for " << message.track_namespace_prefix; - pending_outgoing_subscribe_namespaces_[message.request_id] = - PendingSubscribeNamespaceData{track_namespace, std::move(callback)}; - outgoing_subscribe_namespaces_.emplace(track_namespace); - return true; -} - -bool MoqtSession::UnsubscribeNamespace(TrackNamespace track_namespace) { - QUICHE_DCHECK(track_namespace.IsValid()); - if (!outgoing_subscribe_namespaces_.contains(track_namespace)) { - return false; - } - MoqtUnsubscribeNamespace message; - message.track_namespace = track_namespace; - SendControlMessage(framer_.SerializeUnsubscribeNamespace(message)); - QUIC_DLOG(INFO) << ENDPOINT << "Sent UNSUBSCRIBE_NAMESPACE message for " - << message.track_namespace; - outgoing_subscribe_namespaces_.erase(track_namespace); - return true; + return state_ptr->CreateTask(prefix); } void MoqtSession::PublishNamespace( @@ -929,6 +941,21 @@ temp_stream->OnCanRead(); break; } + case MoqtMessageType::kSubscribeNamespace: { + auto namespace_stream = std::make_unique<MoqtNamespacePublisherStream>( + &session_->framer_, stream_, + [session = session_](MoqtError code, absl::string_view reason) { + session->Error(code, reason); + }, + &session_->incoming_subscribe_namespace_, + session_->callbacks_.incoming_subscribe_namespace_callback); + MoqtNamespacePublisherStream* temp_stream = namespace_stream.get(); + stream_->SetVisitor(std::move(namespace_stream)); + // The UnknownBidiStream object is deleted; no class access after this + // point. + temp_stream->OnCanRead(); + break; + } default: session_->Error(MoqtError::kProtocolViolation, "Unexpected message type received to start bidi stream"); @@ -1117,15 +1144,7 @@ std::move(callback_it->second)(track_namespace, std::nullopt); return; } - // Response to SUBSCRIBE_NAMESPACE. - auto sn_it = - session_->pending_outgoing_subscribe_namespaces_.find(message.request_id); - if (sn_it != session_->pending_outgoing_subscribe_namespaces_.end()) { - std::move(sn_it->second.callback)(sn_it->second.track_namespace, - std::nullopt); - session_->pending_outgoing_subscribe_namespaces_.erase(sn_it); - return; - } + // Response to SUBSCRIBE_NAMESPACE is handled in the NamespaceStream. // TRACK_STATUS response would go here, but we don't support upstream // TRACK_STATUS. // If it doesn't match any state, it might be because the local application @@ -1188,17 +1207,7 @@ session_->outgoing_publish_namespaces_.erase(it2); return; } - // Response to SUBSCRIBE_NAMESPACE. - auto sn_it = - session_->pending_outgoing_subscribe_namespaces_.find(message.request_id); - if (sn_it != session_->pending_outgoing_subscribe_namespaces_.end()) { - std::move(sn_it->second.callback)(sn_it->second.track_namespace, - error_info); - session_->outgoing_subscribe_namespaces_.erase( - sn_it->second.track_namespace); - session_->pending_outgoing_subscribe_namespaces_.erase(sn_it); - return; - } + // Response to SUBSCRIBE_NAMESPACE is handled in the NamespaceStream. // TRACK_STATUS response would go here, but we don't support upstream // TRACK_STATUS. // If it doesn't match any state, it might be because the local application @@ -1346,51 +1355,6 @@ } } -void MoqtSession::ControlStream::OnSubscribeNamespaceMessage( - const MoqtSubscribeNamespace& message) { - if (!session_->ValidateRequestId(message.request_id)) { - return; - } - // TODO(martinduke): Handle authentication. - if (session_->sent_goaway_) { - QUIC_DLOG(INFO) << ENDPOINT - << "Received a SUBSCRIBE_NAMESPACE after GOAWAY"; - SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, - std::nullopt, "SUBSCRIBE_NAMESPACE after GOAWAY"); - return; - } - if (!session_->incoming_subscribe_namespace_.SubscribeNamespace( - message.track_namespace_prefix)) { - QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE_NAMESPACE for " - << message.track_namespace_prefix - << " that is already subscribed to"; - SendRequestError(message.request_id, - RequestErrorCode::kNamespacePrefixOverlap, std::nullopt, - "SUBSCRIBE_NAMESPACE for similar subscribed namespace"); - return; - } - (session_->callbacks_.incoming_subscribe_namespace_callback)( - message.track_namespace_prefix, message.parameters, - [&](std::optional<MoqtRequestErrorInfo> error) { - if (error.has_value()) { - SendRequestError(message.request_id, *error); - session_->incoming_subscribe_namespace_.UnsubscribeNamespace( - message.track_namespace_prefix); - } else { - SendRequestOk(message.request_id, MessageParameters()); - } - }); -} - -void MoqtSession::ControlStream::OnUnsubscribeNamespaceMessage( - const MoqtUnsubscribeNamespace& message) { - // MoqtSession keeps no state here, so just tell the application. - session_->incoming_subscribe_namespace_.UnsubscribeNamespace( - message.track_namespace); - session_->callbacks_.incoming_subscribe_namespace_callback( - message.track_namespace, std::nullopt, nullptr); -} - void MoqtSession::ControlStream::OnMaxRequestIdMessage( const MoqtMaxRequestId& message) { if (message.max_request_id < session_->peer_max_request_id_) { @@ -2413,14 +2377,9 @@ if (goaway_timeout_alarm_ != nullptr) { goaway_timeout_alarm_->PermanentCancel(); } - // When the session closes, report to the application implied receipt of - // UNSUBSCRIBE_NAMESPACE, PUBLISH_NAMESPACE_DONE, PUBLISH_NAMESPACE_CANCEL, - // PUBLISH_DONE, and UNSUBSCRIBE. - for (const TrackNamespace& track_namespace : - incoming_subscribe_namespace_.GetSubscribedNamespaces()) { - callbacks_.incoming_subscribe_namespace_callback(track_namespace, - std::nullopt, nullptr); - } + // Incoming SUBSCRIBE_NAMESPACE is automatically cleaned up; the destroyed + // session owns the webtransport stream, which owns the StreamVisitor, which + // owns the task. Destroying the task notifies the application. published_subscriptions_.clear(); for (const TrackNamespace& track_namespace : incoming_publish_namespaces_) { callbacks_.incoming_publish_namespace_callback(track_namespace,
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 22ef266..69a9a10 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -25,6 +25,7 @@ #include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_bidi_stream.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" @@ -98,12 +99,6 @@ quic::Perspective perspective() const { return parameters_.perspective; } - // Returns true if message was sent. - bool SubscribeNamespace(TrackNamespace track_namespace, - MoqtOutgoingSubscribeNamespaceCallback callback, - MessageParameters parameters); - bool UnsubscribeNamespace(TrackNamespace track_namespace); - // Allows the subscriber to declare it will not subscribe to |track_namespace| // anymore. void CancelPublishNamespace(TrackNamespace track_namespace, @@ -141,6 +136,15 @@ MoqtOutgoingPublishNamespaceCallback callback, VersionSpecificParameters parameters) override; bool PublishNamespaceDone(TrackNamespace track_namespace) override; + // TODO(martinduke): Support PUBLISH. For now, PUBLISH-only requests will be + // rejected with nullptr, and kBoth requests will change to kNamespace. + // After receiving MoqtNamespaceTask, call + // MoqtNamespaceTask::SetObjectsAvailableCallback() to actually retrieve + // namespaces. + std::unique_ptr<MoqtNamespaceTask> SubscribeNamespace( + TrackNamespace& prefix, SubscribeNamespaceOption option, + const MessageParameters& parameters, + MoqtResponseCallback response_callback) override; quiche::QuicheWeakPtr<MoqtSessionInterface> GetWeakPtr() override { return weak_ptr_factory_.Create(); } @@ -268,10 +272,6 @@ const MoqtPublishNamespaceCancel& message) override; void OnTrackStatusMessage(const MoqtTrackStatus& message) override; void OnGoAwayMessage(const MoqtGoAway& /*message*/) override; - void OnSubscribeNamespaceMessage( - const MoqtSubscribeNamespace& message) override; - void OnUnsubscribeNamespaceMessage( - const MoqtUnsubscribeNamespace& message) override; void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override; void OnFetchMessage(const MoqtFetch& message) override; void OnFetchCancelMessage(const MoqtFetchCancel& /*message*/) override {} @@ -853,19 +853,9 @@ outgoing_publish_namespaces_; absl::flat_hash_set<TrackNamespace> incoming_publish_namespaces_; - // The value is nullptr after OK or ERROR is received. The entry is deleted - // when sending UNSUBSCRIBE_NAMESPACE, to make sure the application doesn't - // unsubscribe from something that it isn't subscribed to. PUBLISH_NAMESPACEs - // that result from this subscription use incoming_publish_namespace_callback. - struct PendingSubscribeNamespaceData { - TrackNamespace track_namespace; - MoqtOutgoingSubscribeNamespaceCallback callback; - }; - absl::flat_hash_map<uint64_t, PendingSubscribeNamespaceData> - pending_outgoing_subscribe_namespaces_; - absl::flat_hash_set<TrackNamespace> outgoing_subscribe_namespaces_; // It's an error if the namespaces overlap, so keep track of them. SessionNamespaceTree incoming_subscribe_namespace_; + SessionNamespaceTree outgoing_subscribe_namespace_; // The minimum request ID the peer can use that is monotonically increasing. uint64_t next_incoming_request_id_ = 0;
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h index 03c6f54..d43befb 100644 --- a/quiche/quic/moqt/moqt_session_callbacks.h +++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -15,6 +15,7 @@ #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/common/quiche_callbacks.h" @@ -49,20 +50,15 @@ const std::optional<VersionSpecificParameters>& parameters, MoqtResponseCallback callback)>; -// Called whenever SUBSCRIBE_NAMESPACE or UNSUBSCRIBE_NAMESPACE is received from -// the peer. SUBSCRIBE_NAMESPACE sets a value for |parameters|, -// UNSUBSCRIBE_NAMESPACE does not. For UNSUBSCRIBE_NAMESPACE, |callback| is -// null. -// TODO(martinduke): Remove this callback once the new one is in use. +// Called whenever SUBSCRIBE_NAMESPACE is received from the peer. Unsubscribe +// is signalled by destroying MoqtNamespaceTask. +// Calling MoqtNamespaceTask::SetObjectsAvailableCallback() will get all the +// tracks and namespaces, as appropriate, that are already present. using MoqtIncomingSubscribeNamespaceCallback = - quiche::MultiUseCallback<void(const TrackNamespace& track_namespace, - std::optional<MessageParameters> parameters, - MoqtResponseCallback callback)>; -using MoqtIncomingSubscribeNamespaceCallbackNew = quiche::MultiUseCallback<std::unique_ptr<MoqtNamespaceTask>( - const TrackNamespace& prefix, const MessageParameters& parameters, - MoqtResponseCallback response_callback, - ObjectsAvailableCallback objects_available_callback)>; + const TrackNamespace& prefix, SubscribeNamespaceOption option, + const MessageParameters& parameters, + MoqtResponseCallback response_callback)>; inline void DefaultIncomingPublishNamespaceCallback( const TrackNamespace&, const std::optional<VersionSpecificParameters>&, @@ -75,16 +71,10 @@ "This endpoint does not support incoming SUBSCRIBE_NAMESPACE messages"}); }; -// TODO(martinduke): Remove this callback once the new one is in use. -inline void DefaultIncomingSubscribeNamespaceCallback( - const TrackNamespace& track_namespace, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(std::nullopt); -} inline std::unique_ptr<MoqtNamespaceTask> -DefaultIncomingSubscribeNamespaceCallbackNew( - const TrackNamespace&, const MessageParameters&, - MoqtResponseCallback response_callback, ObjectsAvailableCallback) { +DefaultIncomingSubscribeNamespaceCallback( + const TrackNamespace&, SubscribeNamespaceOption, const MessageParameters&, + MoqtResponseCallback response_callback) { std::move(response_callback)( MoqtRequestErrorInfo{RequestErrorCode::kNotSupported, std::nullopt, "This endpoint cannot publish."});
diff --git a/quiche/quic/moqt/moqt_session_interface.h b/quiche/quic/moqt/moqt_session_interface.h index 34006de..d683665 100644 --- a/quiche/quic/moqt/moqt_session_interface.h +++ b/quiche/quic/moqt/moqt_session_interface.h
@@ -71,8 +71,7 @@ using FetchResponseCallback = quiche::SingleUseCallback<void(std::unique_ptr<MoqtFetchTask> fetch_task)>; -// TODO(martinduke): MoqtOutgoingPublishNamespaceCallback and -// MoqtOutgoingSubscribeNamespaceCallback are deprecated. Remove. +// TODO(martinduke): MoqtOutgoingPublishNamespaceCallback is deprecated. Remove. // If |error| is nullopt, this is triggered by a PUBLISH_NAMESPACE_OK. // Otherwise, it is triggered by REQUEST_ERROR or PUBLISH_NAMESPACE_CANCEL. For @@ -84,9 +83,6 @@ quiche::MultiUseCallback<void(const TrackNamespace& track_namespace, std::optional<MoqtRequestErrorInfo> error)>; -using MoqtOutgoingSubscribeNamespaceCallback = quiche::SingleUseCallback<void( - TrackNamespace track_namespace, std::optional<MoqtRequestErrorInfo> info)>; - class MoqtSessionInterface { public: virtual ~MoqtSessionInterface() = default; @@ -154,13 +150,21 @@ // cancel. virtual bool PublishNamespaceDone(TrackNamespace track_namespace) = 0; + // Sends a SUBSCRIBE_NAMESPACE message for |prefix| and returns a + // MoqtNamespaceTask that can be used to process the response. + // Returns nullptr if the message cannot be sent. + // To unsubscribe, simply destroy the returned MoqtNamespaceTask. + virtual std::unique_ptr<MoqtNamespaceTask> SubscribeNamespace( + TrackNamespace& prefix, SubscribeNamespaceOption option, + const MessageParameters& parameters, + MoqtResponseCallback response_callback) = 0; + // TODO(martinduke): Add an API for absolute joining fetch. // TODO: Add SubscribeNamespace, UnsubscribeNamespace method. // TODO: Add PublishNamespaceCancel method. // TODO: Add TrackStatusRequest method. // TODO: Add SubscribeUpdate, PublishDone method. - virtual quiche::QuicheWeakPtr<MoqtSessionInterface> GetWeakPtr() = 0; };
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 17022d3..66b18fc 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -14,6 +14,7 @@ #include <variant> #include "absl/base/casts.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" @@ -28,6 +29,7 @@ #include "quiche/quic/moqt/moqt_known_track_publisher.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_namespace_stream.h" #include "quiche/quic/moqt/moqt_object.h" #include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_priority.h" @@ -35,6 +37,7 @@ #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/moqt_track.h" +#include "quiche/quic/moqt/session_namespace_tree.h" #include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" #include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h" #include "quiche/quic/moqt/test_tools/moqt_session_peer.h" @@ -43,6 +46,7 @@ #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_stream.h" +#include "quiche/common/quiche_weak_ptr.h" #include "quiche/web_transport/test_tools/in_memory_stream.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -1066,56 +1070,95 @@ } TEST_F(MoqtSessionTest, SubscribeNamespaceLifeCycle) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); - TrackNamespace track_namespace("foo"); + TrackNamespace prefix("foo"); bool got_callback = false; + EXPECT_CALL(mock_session_, CanOpenNextOutgoingBidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream()) + .WillOnce(Return(&mock_stream_)); + std::unique_ptr<MoqtNamespaceSubscriberStream> stream_input; + EXPECT_CALL(mock_stream_, SetVisitor) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_input = absl::WrapUnique( + absl::down_cast<MoqtNamespaceSubscriberStream*>(visitor.release())); + ASSERT_NE(stream_input, nullptr); + }); + EXPECT_CALL(mock_stream_, CanWrite).WillRepeatedly(Return(true)); EXPECT_CALL( mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespace), _)); - session_.SubscribeNamespace( - track_namespace, - [&](const TrackNamespace& ns, std::optional<MoqtRequestErrorInfo> error) { + std::unique_ptr<MoqtNamespaceTask> task = session_.SubscribeNamespace( + prefix, SubscribeNamespaceOption::kNamespace, MessageParameters(), + [&](std::optional<MoqtRequestErrorInfo> error) { got_callback = true; - EXPECT_EQ(track_namespace, ns); EXPECT_FALSE(error.has_value()); - }, - MessageParameters()); + }); MoqtRequestOk ok = {kDefaultLocalRequestId, MessageParameters()}; stream_input->OnRequestOkMessage(ok); EXPECT_TRUE(got_callback); - EXPECT_CALL( - mock_stream_, - Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribeNamespace), _)); - EXPECT_TRUE(session_.UnsubscribeNamespace(track_namespace)); - EXPECT_FALSE(session_.UnsubscribeNamespace(track_namespace)); + EXPECT_CALL(mock_stream_, ResetWithUserCode); } TEST_F(MoqtSessionTest, SubscribeNamespaceError) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); - TrackNamespace track_namespace("foo"); + TrackNamespace prefix("foo"); bool got_callback = false; + EXPECT_CALL(mock_session_, CanOpenNextOutgoingBidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream()) + .WillOnce(Return(&mock_stream_)); + std::unique_ptr<MoqtNamespaceSubscriberStream> stream_input; + EXPECT_CALL(mock_stream_, SetVisitor) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_input = std::unique_ptr<MoqtNamespaceSubscriberStream>( + absl::down_cast<MoqtNamespaceSubscriberStream*>(visitor.release())); + ASSERT_NE(stream_input, nullptr); + }); + EXPECT_CALL(mock_stream_, CanWrite).WillRepeatedly(Return(true)); EXPECT_CALL( mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespace), _)); - session_.SubscribeNamespace( - track_namespace, - [&](const TrackNamespace& ns, std::optional<MoqtRequestErrorInfo> error) { + std::unique_ptr<MoqtNamespaceTask> task = session_.SubscribeNamespace( + prefix, SubscribeNamespaceOption::kNamespace, MessageParameters(), + [&](std::optional<MoqtRequestErrorInfo> error) { got_callback = true; - EXPECT_EQ(track_namespace, ns); ASSERT_TRUE(error.has_value()); EXPECT_EQ(error->error_code, RequestErrorCode::kInvalidRange); EXPECT_EQ(error->reason_phrase, "deadbeef"); - }, - MessageParameters()); + }); MoqtRequestError error = {kDefaultLocalRequestId, RequestErrorCode::kInvalidRange, std::nullopt, "deadbeef"}; stream_input->OnRequestErrorMessage(error); EXPECT_TRUE(got_callback); - // Entry is immediately gone. - EXPECT_FALSE(session_.UnsubscribeNamespace(track_namespace)); +} + +TEST_F(MoqtSessionTest, SubscribeNamespacePublishOnly) { + TrackNamespace prefix("foo"); + // kPublish is not allowed. + EXPECT_EQ(session_.SubscribeNamespace( + prefix, SubscribeNamespaceOption::kPublish, MessageParameters(), + [&](std::optional<MoqtRequestErrorInfo>) {}), + nullptr); + // kBoth is treated as kNamespace. + EXPECT_CALL(mock_session_, CanOpenNextOutgoingBidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream()) + .WillOnce(Return(&mock_stream_)); + std::unique_ptr<webtransport::StreamVisitor> stream_visitor; + EXPECT_CALL(mock_stream_, SetVisitor) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_visitor = std::move(visitor); + }); + EXPECT_CALL(mock_stream_, CanWrite).WillRepeatedly(Return(true)); + EXPECT_CALL(mock_stream_, + Writev(SerializedControlMessage(MoqtSubscribeNamespace{ + 0, prefix, SubscribeNamespaceOption::kNamespace, + MessageParameters()}), + _)); + EXPECT_NE(session_.SubscribeNamespace( + prefix, SubscribeNamespaceOption::kBoth, MessageParameters(), + [&](std::optional<MoqtRequestErrorInfo>) {}), + nullptr); } TEST_F(MoqtSessionTest, IncomingObject) { @@ -2820,145 +2863,125 @@ } TEST_F(MoqtSessionTest, IncomingSubscribeNamespace) { - TrackNamespace track_namespace{"foo"}; - auto parameters = std::make_optional<MessageParameters>(); - parameters->authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, - "foo"); + TrackNamespace prefix{"foo"}; + MessageParameters parameters; + parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, + "foo"); + auto bidi_stream = std::make_unique<webtransport::test::InMemoryStream>(4); + MoqtFramer framer(true); MoqtSubscribeNamespace subscribe_namespace = { - /*request_id=*/1, - track_namespace, - SubscribeNamespaceOption::kBoth, - *parameters, - }; - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); + /*request_id=*/1, prefix, SubscribeNamespaceOption::kBoth, parameters}; + bidi_stream->Receive( + framer.SerializeSubscribeNamespace(subscribe_namespace).AsStringView(), + /*fin=*/false); + EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) + .WillOnce(Return(bidi_stream.get())) + .WillOnce(Return(nullptr)); + quiche::QuicheWeakPtr<MockNamespaceTask> task; EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(track_namespace, parameters, _)) - .WillOnce([](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(std::nullopt); + Call(prefix, SubscribeNamespaceOption::kBoth, parameters, _)) + .WillOnce([&](const TrackNamespace& prefix, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { + std::move(response_callback)(std::nullopt); + auto task_ptr = std::make_unique<MockNamespaceTask>(prefix); + task = task_ptr->GetWeakPtr(); + return task_ptr; }); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_input->OnSubscribeNamespaceMessage(subscribe_namespace); - MoqtUnsubscribeNamespace unsubscribe_namespace{track_namespace}; - EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(track_namespace, std::optional<MessageParameters>(), _)) - .WillOnce( - [](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); }); - stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace); + session_.OnIncomingBidirectionalStreamAvailable(); + EXPECT_EQ(static_cast<uint8_t>(bidi_stream->last_data_sent().data()[0]), + static_cast<uint8_t>(MoqtMessageType::kRequestOk)); + + // Deliver a NAMESPACE + ASSERT_TRUE(task.IsValid()); + EXPECT_CALL(*task.GetIfAvailable(), GetNextSuffix) + .WillOnce([](TrackNamespace& prefix, TransactionType& type) { + prefix = TrackNamespace({"bar"}); + type = TransactionType::kAdd; + return GetNextResult::kSuccess; + }) + .WillOnce(Return(GetNextResult::kPending)); + task.GetIfAvailable()->InvokeCallback(); + char expected_data[] = {0x08, 0x00, 0x05, 0x01, 0x03, 'b', 'a', 'r'}; + absl::string_view expected_data_view(expected_data, sizeof(expected_data)); + EXPECT_EQ(expected_data_view, bidi_stream->last_data_sent().substr( + 0, expected_data_view.length())); + + // Unsubscribe + bidi_stream.reset(); + EXPECT_FALSE(task.IsValid()); } -TEST_F(MoqtSessionTest, IncomingSubscribeNamespaceWithError) { - TrackNamespace track_namespace{"foo"}; - auto parameters = std::make_optional<MessageParameters>(); - parameters->authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, - "foo"); +TEST_F(MoqtSessionTest, IncomingSubscribeNamespaceWithSynchronousError) { + TrackNamespace prefix{"foo"}; + MessageParameters parameters; + parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, + "foo"); + webtransport::test::InMemoryStream bidi_stream(4); + MoqtFramer framer(true); MoqtSubscribeNamespace subscribe_namespace = { - /*request_id=*/1, - track_namespace, - SubscribeNamespaceOption::kBoth, - *parameters, - }; - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); + /*request_id=*/1, prefix, SubscribeNamespaceOption::kBoth, parameters}; + bidi_stream.Receive( + framer.SerializeSubscribeNamespace(subscribe_namespace).AsStringView(), + /*fin=*/false); + EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) + .WillOnce(Return(&bidi_stream)) + .WillOnce(Return(nullptr)); EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(track_namespace, parameters, _)) - .WillOnce([](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(MoqtRequestErrorInfo{ + Call(prefix, SubscribeNamespaceOption::kBoth, parameters, _)) + .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { + std::move(response_callback)(MoqtRequestErrorInfo{ RequestErrorCode::kUnauthorized, std::nullopt, "foo"}); + return nullptr; }); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnSubscribeNamespaceMessage(subscribe_namespace); - - // Try again, to verify that it was purged from the tree. - subscribe_namespace.request_id += 2; - EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(track_namespace, parameters, _)) - .WillOnce([](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(std::nullopt); - }); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_input->OnSubscribeNamespaceMessage(subscribe_namespace); - - // Cleanup. - MoqtUnsubscribeNamespace unsubscribe_namespace{track_namespace}; - EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(track_namespace, std::optional<MessageParameters>(), _)) - .WillOnce( - [](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); }); - stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace); + session_.OnIncomingBidirectionalStreamAvailable(); + EXPECT_EQ(static_cast<uint8_t>(bidi_stream.last_data_sent().data()[0]), + static_cast<uint8_t>(MoqtMessageType::kRequestError)); + EXPECT_TRUE(bidi_stream.fin_sent()); } TEST_F(MoqtSessionTest, IncomingSubscribeNamespaceWithPrefixOverlap) { TrackNamespace foo{"foo"}, foobar{"foo", "bar"}; - - auto parameters = std::make_optional<MessageParameters>(); - parameters->authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, - "foo"); + MessageParameters parameters; + parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, + "foo"); + webtransport::test::InMemoryStream bidi_stream1(4), bidi_stream2(8); + MoqtFramer framer(true); MoqtSubscribeNamespace subscribe_namespace = { - /*request_id=*/1, - foo, - SubscribeNamespaceOption::kBoth, - *parameters, - }; - webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = - MoqtSessionPeer::CreateControlStream(&session_, &control_stream); + /*request_id=*/1, foo, SubscribeNamespaceOption::kBoth, parameters}; + bidi_stream1.Receive( + framer.SerializeSubscribeNamespace(subscribe_namespace).AsStringView(), + /*fin=*/false); + EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) + .WillOnce(Return(&bidi_stream1)) + .WillOnce(Return(nullptr)); EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(foo, parameters, _)) - .WillOnce([](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(std::nullopt); + Call(foo, SubscribeNamespaceOption::kBoth, parameters, _)) + .WillOnce([&](const TrackNamespace& prefix, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { + std::move(response_callback)(std::nullopt); + auto task_ptr = std::make_unique<MockNamespaceTask>(prefix); + return task_ptr; }); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_input->OnSubscribeNamespaceMessage(subscribe_namespace); + session_.OnIncomingBidirectionalStreamAvailable(); + EXPECT_EQ(static_cast<uint8_t>(bidi_stream1.last_data_sent().data()[0]), + static_cast<uint8_t>(MoqtMessageType::kRequestOk)); - // Overlapping request is rejected. subscribe_namespace.request_id += 2; subscribe_namespace.track_namespace_prefix = foobar; - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnSubscribeNamespaceMessage(subscribe_namespace); - - // Remove the subscription. Now a later one will work. - MoqtUnsubscribeNamespace unsubscribe_namespace{foo}; - EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(foo, std::optional<MessageParameters>(), _)) - .WillOnce( - [](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); }); - stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace); - - // Try again, it will work. - subscribe_namespace.request_id += 2; - EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(foobar, parameters, _)) - .WillOnce([](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(std::nullopt); - }); - EXPECT_CALL(control_stream, - Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_input->OnSubscribeNamespaceMessage(subscribe_namespace); - - // Cleanup. - unsubscribe_namespace.track_namespace = foobar; - EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback, - Call(foobar, std::optional<MessageParameters>(), _)) - .WillOnce( - [](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); }); - stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace); + bidi_stream2.Receive( + framer.SerializeSubscribeNamespace(subscribe_namespace).AsStringView(), + /*fin=*/false); + EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) + .WillOnce(Return(&bidi_stream2)) + .WillOnce(Return(nullptr)); + session_.OnIncomingBidirectionalStreamAvailable(); + EXPECT_EQ(static_cast<uint8_t>(bidi_stream2.last_data_sent().data()[0]), + static_cast<uint8_t>(MoqtMessageType::kRequestError)); + EXPECT_TRUE(bidi_stream2.fin_sent()); } TEST_F(MoqtSessionTest, FetchThenOkThenCancel) { @@ -3517,10 +3540,12 @@ parameters.subscription_filter.emplace(MoqtFilterType::kLargestObject); EXPECT_FALSE(session_.Subscribe(FullTrackName("foo", "bar"), &remote_track_visitor_, parameters)); - EXPECT_FALSE(session_.SubscribeNamespace( - TrackNamespace{"foo"}, - +[](TrackNamespace, std::optional<MoqtRequestErrorInfo>) {}, - MessageParameters())); + TrackNamespace prefix({"foo"}); + EXPECT_EQ( + session_.SubscribeNamespace( + prefix, SubscribeNamespaceOption::kNamespace, MessageParameters(), + +[](std::optional<MoqtRequestErrorInfo>) {}), + nullptr); session_.PublishNamespace( TrackNamespace{"foo"}, +[](TrackNamespace, std::optional<MoqtRequestErrorInfo>) {}, @@ -3562,11 +3587,18 @@ MoqtFetch fetch = DefaultFetch(); fetch.request_id = 5; stream_input->OnFetchMessage(fetch); + + MoqtFramer framer(true); + SessionNamespaceTree tree; + MoqtIncomingSubscribeNamespaceCallback callback = + DefaultIncomingSubscribeNamespaceCallback; + MoqtNamespacePublisherStream namespace_stream(&framer, &mock_stream_, nullptr, + &tree, callback); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnSubscribeNamespaceMessage(MoqtSubscribeNamespace(7)); + namespace_stream.OnSubscribeNamespaceMessage(MoqtSubscribeNamespace(7)); MoqtTrackStatus track_status = DefaultSubscribe(); - track_status.request_id = 9; + track_status.request_id = 7; EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); stream_input->OnTrackStatusMessage(track_status); @@ -3576,10 +3608,12 @@ parameters.subscription_filter.emplace(MoqtFilterType::kLargestObject); EXPECT_FALSE(session_.Subscribe(FullTrackName(TrackNamespace("foo"), "bar"), &remote_track_visitor_, parameters)); - EXPECT_FALSE(session_.SubscribeNamespace( - TrackNamespace{"foo"}, - +[](TrackNamespace, std::optional<MoqtRequestErrorInfo>) {}, - MessageParameters())); + TrackNamespace prefix({"foo"}); + EXPECT_EQ( + session_.SubscribeNamespace( + prefix, SubscribeNamespaceOption::kNamespace, MessageParameters(), + +[](std::optional<MoqtRequestErrorInfo>) {}), + nullptr); session_.PublishNamespace( TrackNamespace{"foo"}, +[](TrackNamespace, std::optional<MoqtRequestErrorInfo>) {},
diff --git a/quiche/quic/moqt/relay_namespace_tree.cc b/quiche/quic/moqt/relay_namespace_tree.cc new file mode 100644 index 0000000..8c989ed --- /dev/null +++ b/quiche/quic/moqt/relay_namespace_tree.cc
@@ -0,0 +1,257 @@ +// Copyright 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/relay_namespace_tree.h" + +#include <memory> +#include <optional> +#include <string> +#include <utility> + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_session_interface.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/common/quiche_weak_ptr.h" + +namespace moqt { + +RelayNamespaceTree::RelayNamespaceListener::~RelayNamespaceListener() { + tree_.RemoveSubscriber(prefix_, this); +} + +void RelayNamespaceTree::RelayNamespaceListener::SetObjectsAvailableCallback( + ObjectsAvailableCallback absl_nullable callback) { + bool first_callback = callback_ == nullptr; + callback_ = std::move(callback); + if (first_callback) { + // If this is the first callback, we need to notify the listener of all + // published namespaces and tracks in this namespace. Even if |callback| is + // nullptr, run this to notify of published tracks. + // A track that published after the task was created and before this + // function is called will be actually cause two Publish() calls, but the + // session will ignore the second one. + TrackNamespace suffix; + tree_.NotifyOfAllChildren(tree_.FindNode(prefix_), suffix, this); + } +} + +GetNextResult RelayNamespaceTree::RelayNamespaceListener::GetNextSuffix( + TrackNamespace& suffix, TransactionType& type) { + if (eof_) { + return kEof; + } + if (pending_suffixes_.empty()) { + if (error_.has_value()) { + return kError; + } + return kPending; + } + suffix = pending_suffixes_.front().suffix; + type = pending_suffixes_.front().type; + pending_suffixes_.pop_front(); + return kSuccess; +} + +void RelayNamespaceTree::RelayNamespaceListener::AddPendingSuffix( + TrackNamespace suffix, TransactionType type) { + if (callback_ == nullptr) { + return; // Not interested in namespaces. + } + if (eof_) { + return; + } + if (pending_suffixes_.size() == kMaxPendingSuffixes) { + error_ = kResetCodeTooFarBehind; + return; + } + pending_suffixes_.push_back({std::move(suffix), type}); + callback_(); +} + +void RelayNamespaceTree::RelayNamespaceListener::Publish(TrackNamespace, + absl::string_view) { + if (session_ == nullptr) { + return; // Not interested in tracks. + } + // TODO(martinduke): Build a full track name from prefix_, suffix, and name, + // then call session_->Publish(). +} + +void RelayNamespaceTree::RelayNamespaceListener::DeclareEof() { + if (eof_ || error_.has_value()) { + return; + } + eof_ = true; + callback_(); +} + +void RelayNamespaceTree::AddPublisher( + TrackNamespace prefix, MoqtSessionInterface* absl_nonnull session) { + Node* node = FindOrCreateNode(prefix); + if (node->publishers.empty()) { + NotifyAllParents(prefix, TransactionType::kAdd); + } + node->publishers[session] = session->GetWeakPtr(); +} + +void RelayNamespaceTree::RemovePublisher( + const TrackNamespace& prefix, MoqtSessionInterface* absl_nonnull session) { + Node* node = FindNode(prefix); + if (node == nullptr) { + return; + } + node->publishers.erase(session); + // Tell all the namespace listeners. + if (node->publishers.empty()) { + TrackNamespace mutable_namespace = prefix; + NotifyAllParents(prefix, TransactionType::kDelete); + MaybePrune(node, mutable_namespace); + } +} + +std::unique_ptr<MoqtNamespaceTask> RelayNamespaceTree::AddSubscriber( + const TrackNamespace& prefix, + MoqtSessionInterface* absl_nullable track_listener) { + Node* node = FindOrCreateNode(prefix); + auto task = + std::make_unique<RelayNamespaceListener>(*this, prefix, track_listener); + node->listeners[task.get()] = task->GetWeakPtr(); + return std::move(task); +} + +void RelayNamespaceTree::RemoveSubscriber( + TrackNamespace prefix, MoqtNamespaceTask* absl_nonnull listener) { + Node* node = FindNode(prefix); + if (node == nullptr) { + return; + } + node->listeners.erase(listener); + MaybePrune(node, prefix); +} + +MoqtSessionInterface* absl_nullable RelayNamespaceTree::GetValidPublisher( + TrackNamespace track_namespace) { + Node* node; + do { + node = FindNode(track_namespace); + // Remove invalid publishers. + while (node != nullptr && !node->publishers.empty() && + !node->publishers.begin()->second.IsValid()) { + node->publishers.erase(node->publishers.begin()); + } + if (node != nullptr && !node->publishers.empty()) { + return node->publishers.begin()->second.GetIfAvailable(); + } + MaybePrune(node, track_namespace); + } while (track_namespace.PopElement()); + return nullptr; +} + +bool RelayNamespaceTree::Node::CanPrune() const { + return children.empty() && publishers.empty() && published_tracks.empty() && + listeners.empty(); +} + +RelayNamespaceTree::Node* RelayNamespaceTree::FindNode( + const TrackNamespace& track_namespace) const { + auto it = nodes_.find(track_namespace); + if (it == nodes_.end()) { + return nullptr; + } + return it->second.get(); +} + +RelayNamespaceTree::Node* RelayNamespaceTree::FindOrCreateNode( + TrackNamespace track_namespace) { + if (track_namespace.number_of_elements() == 0) { // Root node. + auto [it, inserted] = + nodes_.emplace(track_namespace, std::make_unique<Node>()); + return it->second.get(); + } + auto [it, inserted] = nodes_.emplace( + track_namespace, std::make_unique<Node>(track_namespace.tuple().back())); + if (!inserted) { + return it->second.get(); + } + Node* node = it->second.get(); // store it in case it moves. + if (track_namespace.PopElement()) { + Node* parent = FindOrCreateNode(track_namespace); + parent->children.insert(node); + } + return node; +} + +void RelayNamespaceTree::NotifyOfAllChildren( + Node* node, TrackNamespace& suffix, + RelayNamespaceListener* absl_nonnull listener) { + if (!node->publishers.empty()) { + listener->AddPendingSuffix(suffix, TransactionType::kAdd); + } + for (const std::string& track : node->published_tracks) { + listener->Publish(suffix, track); + } + for (auto child = node->children.begin(); child != node->children.end(); + ++child) { + if (std::optional<absl::string_view> element = (*child)->element) { + suffix.AddElement(*element); + NotifyOfAllChildren(*child, suffix, listener); + suffix.PopElement(); + } + } +} + +void RelayNamespaceTree::NotifyAllParents(const TrackNamespace& prefix, + TransactionType type) { + TrackNamespace mutable_namespace = prefix; + do { + Node* node = FindNode(mutable_namespace); + if (node == nullptr) { + continue; + } + for (const auto& it : node->listeners) { + RelayNamespaceListener* listener = it.second.GetIfAvailable(); + if (listener == nullptr) { + QUICHE_BUG(subscriber_is_invalid) + << "Subscriber WeakPtr is invalid but not removed from the set"; + continue; + } + absl::StatusOr<TrackNamespace> suffix = + prefix.ExtractSuffix(mutable_namespace); + if (!suffix.ok()) { + QUICHE_BUG(cannot_extract_suffix) << "Namespace tuple is mangled"; + continue; + } + listener->AddPendingSuffix(*suffix, type); + } + } while (mutable_namespace.PopElement()); +} + +void RelayNamespaceTree::MaybePrune(Node* node, + TrackNamespace track_namespace) { + if (node == nullptr || !node->CanPrune()) { + return; + } + Node* child = node; // Save the pointer before erasing. + nodes_.erase(track_namespace); + // child is now gone, do not dereference! + if (track_namespace.PopElement()) { + Node* parent = FindNode(track_namespace); + QUICHE_BUG_IF(quiche_bug_no_parent_namespace, parent == nullptr) + << "Parent namespace not found for " << track_namespace; + if (parent != nullptr) { + parent->children.erase(child); + MaybePrune(parent, track_namespace); + } + } +} + +} // namespace moqt
diff --git a/quiche/quic/moqt/relay_namespace_tree.h b/quiche/quic/moqt/relay_namespace_tree.h index 7c0c027..8add33a 100644 --- a/quiche/quic/moqt/relay_namespace_tree.h +++ b/quiche/quic/moqt/relay_namespace_tree.h
@@ -5,19 +5,23 @@ #ifndef QUICHE_QUIC_MOQT_RELAY_NAMESPACE_TREE_H_ #define QUICHE_QUIC_MOQT_RELAY_NAMESPACE_TREE_H_ +#include <cstddef> #include <cstdint> #include <memory> #include <optional> #include <string> +#include <utility> #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_session_interface.h" -#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/quiche_circular_deque.h" #include "quiche/common/quiche_weak_ptr.h" +#include "quiche/web_transport/web_transport.h" namespace moqt { @@ -29,195 +33,131 @@ // Therefore, this is a tree structure to easily and scalably move up and down // the hierarchy to find parents or children. class RelayNamespaceTree { + private: + class RelayNamespaceListener : public MoqtNamespaceTask { + public: + // If |tracks| is nullptr, the listener is not interested in PUBLISH + // messages. + RelayNamespaceListener(RelayNamespaceTree& tree, + const TrackNamespace& prefix, + MoqtSessionInterface* absl_nullable tracks) + : prefix_(prefix), + tree_(tree), + session_(tracks), + weak_ptr_factory_(this) {} + ~RelayNamespaceListener() override; + // MoqtNamespaceTask methods. + void SetObjectsAvailableCallback(ObjectsAvailableCallback + absl_nullable callback) override; + GetNextResult GetNextSuffix(TrackNamespace& suffix, + TransactionType& type) override; + std::optional<webtransport::StreamErrorCode> GetStatus() override { + return error_; + } + const TrackNamespace& prefix() override { return prefix_; } + + // Queues a suffix corresponding to a NAMESPACE (if |type| is kAdd) or a + // NAMESPACE_DONE (if |type| is kDelete). + void AddPendingSuffix(TrackNamespace suffix, TransactionType type); + // Publishes a track in this namespace. + void Publish(TrackNamespace suffix, absl::string_view name); + void DeclareEof(); + quiche::QuicheWeakPtr<RelayNamespaceListener> GetWeakPtr() { + return weak_ptr_factory_.Create(); + } + + private: + struct PendingSuffix { + TrackNamespace suffix; + TransactionType type; + }; + + static constexpr size_t kMaxPendingSuffixes = 100; + + const TrackNamespace prefix_; + RelayNamespaceTree& tree_; + std::optional<webtransport::StreamErrorCode> error_; + quiche::QuicheCircularDeque<PendingSuffix> pending_suffixes_; + MoqtSessionInterface* absl_nullable session_; + ObjectsAvailableCallback absl_nullable callback_ = nullptr; + bool eof_ = false; + bool got_first_pull_ = false; + // Must be last. + quiche::QuicheWeakPtrFactory<RelayNamespaceListener> weak_ptr_factory_; + }; + public: // Adds a publisher to the namespace tree. The caller is responsible to call // RemovePublisher if it goes away. |session| is stored as a WeakPtr. - void AddPublisher(const TrackNamespace& track_namespace, - MoqtSessionInterface* absl_nonnull session) { - Node* node = FindOrCreateNode(track_namespace); - if (node->publishers.empty()) { - NotifyAllParents(track_namespace, /*adding=*/true); - } - node->publishers.emplace(session->GetWeakPtr()); - } + void AddPublisher(TrackNamespace prefix, + MoqtSessionInterface* absl_nonnull session); - void RemovePublisher(const TrackNamespace& track_namespace, - MoqtSessionInterface* absl_nonnull session) { - Node* node = FindNode(track_namespace); - if (node == nullptr) { - return; - } - node->publishers.erase(session->GetWeakPtr()); - // Tell all the namespace listeners. - if (node->publishers.empty()) { - TrackNamespace mutable_namespace = track_namespace; - NotifyAllParents(track_namespace, /*adding=*/false); - MaybePrune(node, mutable_namespace); - } - } + void RemovePublisher(const TrackNamespace& prefix, + MoqtSessionInterface* absl_nonnull session); - // The caller is responsible to call RemoveNamespaceListener if it goes away. - // Thus, it is safe to store it as a raw pointer. - void AddSubscriber(const TrackNamespace& track_namespace, - MoqtSessionInterface* absl_nonnull subscriber) { - Node* node = FindOrCreateNode(track_namespace); - node->subscribers.insert(subscriber->GetWeakPtr()); - // Notify the listener of every published namespace and track in this - // namespace. - TrackNamespace mutable_namespace = track_namespace; - NotifyOfAllChildren(node, mutable_namespace, subscriber); - } - - void RemoveSubscriber(const TrackNamespace& track_namespace, - MoqtSessionInterface* absl_nonnull subscriber) { - Node* node = FindNode(track_namespace); - if (node == nullptr) { - return; - } - node->subscribers.erase(subscriber->GetWeakPtr()); - TrackNamespace mutable_namespace = track_namespace; - MaybePrune(node, mutable_namespace); - } + // Called on incoming SUBSCRIBE_NAMESPACE messages. If track_subscriber is + // nullptr, it is not interested in PUBLISH messages. If callback is nullptr, + // it is not interested in NAMESPACE messages. If not interested in, + // namespaces, will return nullptr. Otherwise, will return a task to flow + // control published namespaces. + std::unique_ptr<MoqtNamespaceTask> AddSubscriber( + const TrackNamespace& prefix, + MoqtSessionInterface* absl_nullable track_listener); // Returns a raw pointer to the session that publishes the smallest namespace // that contains |track_namespace|. If a WeakPtr is found to be invalid, - // deletes them from the tree. - MoqtSessionInterface* GetValidPublisher( - const TrackNamespace& track_namespace) const { - Node* node = FindNode(track_namespace); - TrackNamespace mutable_namespace = track_namespace; - while ((node == nullptr || node->publishers.empty()) && - mutable_namespace.PopElement()) { - node = FindNode(mutable_namespace); - } - if (node == nullptr || node->publishers.empty()) { - return nullptr; - } - MoqtSessionInterface* upstream = node->publishers.begin()->GetIfAvailable(); - if (!upstream) { - QUICHE_BUG(publisher_is_invalid) - << "Publisher WeakPtr is invalid but not removed from the set"; - return nullptr; - } - return upstream; - } + // deletes it from the tree. + MoqtSessionInterface* absl_nullable GetValidPublisher( + TrackNamespace track_namespace); protected: uint64_t NumNamespaces() const { return nodes_.size(); } private: struct Node { + Node() = default; // The root node has no element. explicit Node(absl::string_view element) : element(element) {} - - const std::string element; + std::optional<const std::string> element; absl::flat_hash_set<Node*> children; + // For all of the maps below, the key is a raw pointer to the type in the + // value. This is declared as void* because these raw pointers should NEVER + // be dereferenced. They are present so that the session or listener can + // delete itself from the tree quickly by passing a raw pointer to itself. + // Publishers of this namespace. - absl::flat_hash_set<quiche::QuicheWeakPtr<MoqtSessionInterface>> publishers; + absl::flat_hash_map<void*, quiche::QuicheWeakPtr<MoqtSessionInterface>> + publishers; + // The use of a QuicheWeakPtr is out of an abundance of caution. + // RelayNamespaceListeners should delete themselves from the tree when they + // go away. + absl::flat_hash_map<void*, quiche::QuicheWeakPtr<RelayNamespaceListener>> + listeners; // Just store the track name. Additional information will be in the // TrackPublisher. absl::flat_hash_set<std::string> published_tracks; - absl::flat_hash_set<quiche::QuicheWeakPtr<MoqtSessionInterface>> - subscribers; - bool CanPrune() const { - return children.empty() && publishers.empty() && - published_tracks.empty() && subscribers.empty(); - } + bool CanPrune() const; }; - Node* FindNode(const TrackNamespace& track_namespace) const { - auto it = nodes_.find(track_namespace); - if (it == nodes_.end()) { - return nullptr; - } - return it->second.get(); - } + Node* FindNode(const TrackNamespace& track_namespace) const; - Node* FindOrCreateNode(const TrackNamespace& track_namespace) { - auto [it, inserted] = - nodes_.emplace(track_namespace, - std::make_unique<Node>(track_namespace.tuple().back())); - if (!inserted) { - return it->second.get(); - } - Node* node = it->second.get(); // store it in case it moves. - TrackNamespace mutable_namespace = track_namespace; - if (mutable_namespace.PopElement()) { - Node* parent = FindOrCreateNode(mutable_namespace); - parent->children.insert(node); - } - return node; - } + Node* FindOrCreateNode(TrackNamespace track_namespace); // Recursive function to notify |listener| of all published namespaces and // tracks in and below |node|. - void NotifyOfAllChildren(Node* node, TrackNamespace& track_namespace, - MoqtSessionInterface* subscriber) { - // TODO(martinduke): Publish everything in node->published_tracks. - if (!node->publishers.empty()) { - subscriber->PublishNamespace( - track_namespace, - [](const TrackNamespace&, std::optional<MoqtRequestErrorInfo>) {}, - // TODO(martinduke): Add parameters. - VersionSpecificParameters()); - } - for (auto child = node->children.begin(); child != node->children.end(); - ++child) { - track_namespace.AddElement((*child)->element); - NotifyOfAllChildren(*child, track_namespace, subscriber); - track_namespace.PopElement(); - } - } + void NotifyOfAllChildren(Node* node, TrackNamespace& suffix, + RelayNamespaceListener* absl_nonnull listener); - // If |adding| is true, sends PUBLISH_NAMESPACE to all subscribers to a - // parent namespace. If |adding| is false, sends PUBLISH_NAMESPACE_DONE. - void NotifyAllParents(const TrackNamespace& track_namespace, bool adding) { - TrackNamespace mutable_namespace = track_namespace; - do { - Node* node = FindNode(mutable_namespace); - if (node == nullptr) { - continue; - } - for (const quiche::QuicheWeakPtr<MoqtSessionInterface>& subscriber_ptr : - node->subscribers) { - MoqtSessionInterface* subscriber = subscriber_ptr.GetIfAvailable(); - if (subscriber == nullptr) { - QUICHE_BUG(subscriber_is_invalid) - << "Subscriber WeakPtr is invalid but not removed from the set"; - continue; - } - if (adding) { - subscriber->PublishNamespace( - track_namespace, - [](const TrackNamespace&, std::optional<MoqtRequestErrorInfo>) {}, - // TODO(martinduke): Add parameters. - VersionSpecificParameters()); - } else { - subscriber->PublishNamespaceDone(track_namespace); - } - } - } while (mutable_namespace.PopElement()); - } + // If |adding| is true, sends NAMESPACE to all subscribers to a + // parent namespace. If |adding| is false, sends NAMESPACE_DONE. + void NotifyAllParents(const TrackNamespace& prefix, TransactionType type); // If a node has no children, publishers, or subscribers, remove it and see // if the same applies to its parent. - void MaybePrune(Node* node, TrackNamespace& track_namespace) { - if (node == nullptr || !node->CanPrune()) { - return; - } - Node* child = node; // Save the pointer before erasing. - nodes_.erase(track_namespace); - // child is now gone, do not dereference! - if (track_namespace.PopElement()) { - Node* parent = FindNode(track_namespace); - QUICHE_BUG_IF(quiche_bug_no_parent_namespace, parent == nullptr) - << "Parent namespace not found for " << track_namespace; - if (parent != nullptr) { - parent->children.erase(child); - MaybePrune(parent, track_namespace); - } - } - } + void MaybePrune(Node* node, TrackNamespace track_namespace); + + void RemoveSubscriber(TrackNamespace prefix, + MoqtNamespaceTask* absl_nonnull namespace_listener); // A map that allows quick access to any namespace without traversing the // tree. Use unique_ptr so that it's pointer stable.
diff --git a/quiche/quic/moqt/relay_namespace_tree_test.cc b/quiche/quic/moqt/relay_namespace_tree_test.cc index 7a45169..324e003 100644 --- a/quiche/quic/moqt/relay_namespace_tree_test.cc +++ b/quiche/quic/moqt/relay_namespace_tree_test.cc
@@ -5,8 +5,10 @@ #include "quiche/quic/moqt/relay_namespace_tree.h" #include <memory> +#include <utility> -#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/test_tools/mock_moqt_session.h" #include "quiche/common/platform/api/quiche_expect_bug.h" #include "quiche/common/platform/api/quiche_test.h" @@ -28,13 +30,24 @@ TestRelayNamespaceTree tree_; TrackNamespace a_{"a"}, ab_{"a", "b"}, abc_{"a", "b", "c"}; std::unique_ptr<MockMoqtSession> session_; + int objects_available_ = 0; + ObjectsAvailableCallback callback_ = [&]() { ++objects_available_; }; + + void CheckNextSuffix(MoqtNamespaceTask* task, TrackNamespace& full, + bool add = true) { + TrackNamespace suffix; + TransactionType type; + EXPECT_EQ(task->GetNextSuffix(suffix, type), GetNextResult::kSuccess); + EXPECT_EQ(*task->prefix().AddSuffix(suffix), full); + EXPECT_EQ(type, add ? TransactionType::kAdd : TransactionType::kDelete); + } }; TEST_F(RelayNamespaceTreeTest, AddGetRemovePublisher) { EXPECT_EQ(tree_.NumNamespaces(), 0u); EXPECT_EQ(tree_.GetValidPublisher(ab_), nullptr); tree_.AddPublisher(ab_, session_.get()); - EXPECT_EQ(tree_.NumNamespaces(), 2u); + EXPECT_EQ(tree_.NumNamespaces(), 3u); EXPECT_EQ(tree_.GetValidPublisher(a_), nullptr); EXPECT_EQ(tree_.GetValidPublisher(ab_), session_.get()); EXPECT_EQ(tree_.GetValidPublisher(abc_), session_.get()); @@ -46,60 +59,76 @@ TEST_F(RelayNamespaceTreeTest, AddGetRemoveListener) { // Add a listener to a namespace that has no publishers. EXPECT_EQ(tree_.NumNamespaces(), 0u); - tree_.AddSubscriber(ab_, session_.get()); - EXPECT_EQ(tree_.NumNamespaces(), 2u); - EXPECT_CALL(*session_, PublishNamespace).Times(0); - tree_.AddPublisher(a_, session_.get()); - EXPECT_CALL(*session_, PublishNamespace(ab_, _, _)); - tree_.AddPublisher(ab_, session_.get()); - EXPECT_CALL(*session_, PublishNamespace(abc_, _, _)); - tree_.AddPublisher(abc_, session_.get()); + std::unique_ptr<MoqtNamespaceTask> task = + tree_.AddSubscriber(ab_, session_.get()); + task->SetObjectsAvailableCallback(std::move(callback_)); EXPECT_EQ(tree_.NumNamespaces(), 3u); + tree_.AddPublisher(a_, session_.get()); + EXPECT_EQ(objects_available_, 0); + tree_.AddPublisher(ab_, session_.get()); + EXPECT_EQ(objects_available_, 1); + CheckNextSuffix(task.get(), ab_); + tree_.AddPublisher(abc_, session_.get()); + EXPECT_EQ(objects_available_, 2); + CheckNextSuffix(task.get(), abc_); + EXPECT_EQ(tree_.NumNamespaces(), 4u); // Second publisher creates no new notifications, and delays OnNamespaceDone. auto session2 = std::make_unique<MockMoqtSession>(); - EXPECT_CALL(*session_, PublishNamespace).Times(0); tree_.AddPublisher(ab_, session2.get()); - EXPECT_CALL(*session_, PublishNamespaceDone).Times(0); + EXPECT_EQ(objects_available_, 2); tree_.RemovePublisher(ab_, session_.get()); - EXPECT_CALL(*session_, PublishNamespaceDone(ab_)); + EXPECT_EQ(objects_available_, 2); tree_.RemovePublisher(ab_, session2.get()); + EXPECT_EQ(objects_available_, 3); + CheckNextSuffix(task.get(), ab_, /*add=*/false); // Removing the listener disables notifications. - tree_.RemoveSubscriber(ab_, session_.get()); - EXPECT_CALL(*session_, PublishNamespace).Times(0); + task.reset(); tree_.AddPublisher(ab_, session2.get()); + EXPECT_EQ(objects_available_, 3); } TEST_F(RelayNamespaceTreeTest, SessionDestroyed) { - tree_.AddSubscriber(ab_, session_.get()); - EXPECT_CALL(*session_, PublishNamespace(ab_, _, _)); + std::unique_ptr<MoqtNamespaceTask> task = + tree_.AddSubscriber(ab_, session_.get()); + task->SetObjectsAvailableCallback(std::move(callback_)); tree_.AddPublisher(ab_, session_.get()); + EXPECT_EQ(objects_available_, 1); + CheckNextSuffix(task.get(), ab_); EXPECT_NE(tree_.GetValidPublisher(ab_), nullptr); - // First session dies. It should have removed the namespace! + // First session dies. In real life, it would have destroyed the stream and + // therefore the task, removing the entry. But verify that the WeakPtr works. session_.reset(); - EXPECT_QUICHE_BUG( - tree_.GetValidPublisher(ab_), - "Publisher WeakPtr is invalid but not removed from the set"); + EXPECT_EQ(tree_.GetValidPublisher(ab_), nullptr); } TEST_F(RelayNamespaceTreeTest, AddListenerToExistingPublisher) { tree_.AddPublisher(a_, session_.get()); tree_.AddPublisher(ab_, session_.get()); tree_.AddPublisher(abc_, session_.get()); - EXPECT_CALL(*session_, PublishNamespace(ab_, _, _)); - EXPECT_CALL(*session_, PublishNamespace(abc_, _, _)); - tree_.AddSubscriber(ab_, session_.get()); + std::unique_ptr<MoqtNamespaceTask> task = + tree_.AddSubscriber(ab_, session_.get()); + task->SetObjectsAvailableCallback(std::move(callback_)); + EXPECT_EQ(objects_available_, 2); + CheckNextSuffix(task.get(), ab_); + CheckNextSuffix(task.get(), abc_); } TEST_F(RelayNamespaceTreeTest, MaxSizeNamespace) { - tree_.AddSubscriber(a_, session_.get()); + std::unique_ptr<MoqtNamespaceTask> task = + tree_.AddSubscriber(a_, session_.get()); + task->SetObjectsAvailableCallback(std::move(callback_)); TrackNamespace big_namespace{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "1", "2", "3", "4", "5", "6"}; - EXPECT_CALL(*session_, PublishNamespace(big_namespace, _, _)); tree_.AddPublisher(big_namespace, session_.get()); + EXPECT_EQ(objects_available_, 1); + TrackNamespace suffix; + TransactionType type; + CheckNextSuffix(task.get(), big_namespace); + EXPECT_EQ(task->GetNextSuffix(suffix, type), GetNextResult::kPending); } // TODO(martinduke): Add tests for published tracks.
diff --git a/quiche/quic/moqt/session_namespace_tree.h b/quiche/quic/moqt/session_namespace_tree.h index b68f382..ccbdd24 100644 --- a/quiche/quic/moqt/session_namespace_tree.h +++ b/quiche/quic/moqt/session_namespace_tree.h
@@ -9,7 +9,8 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" +#include "quiche/common/quiche_weak_ptr.h" namespace moqt { @@ -22,7 +23,7 @@ // a/b and a/b/c/d would not be. class SessionNamespaceTree { public: - SessionNamespaceTree() = default; + SessionNamespaceTree() : weak_ptr_factory_(this) {} ~SessionNamespaceTree() {} // Returns false if the namespace was not subscribed. @@ -69,6 +70,10 @@ return subscribed_namespaces_; } + quiche::QuicheWeakPtr<SessionNamespaceTree> GetWeakPtr() { + return weak_ptr_factory_.Create(); + } + protected: uint64_t NumSubscriptions() const { return subscribed_namespaces_.size(); } @@ -77,6 +82,7 @@ // Namespaces that cannot be subscribed to because they intersect with an // existing subscription. The value is a ref count. absl::flat_hash_map<TrackNamespace, int> prohibited_namespaces_; + quiche::QuicheWeakPtrFactory<SessionNamespaceTree> weak_ptr_factory_; }; } // namespace moqt
diff --git a/quiche/quic/moqt/test_tools/mock_moqt_session.h b/quiche/quic/moqt/test_tools/mock_moqt_session.h index bae22ee..3ba898a 100644 --- a/quiche/quic/moqt/test_tools/mock_moqt_session.h +++ b/quiche/quic/moqt/test_tools/mock_moqt_session.h
@@ -6,10 +6,14 @@ #define QUICHE_QUIC_MOQT_TOOLS_MOCK_MOQT_SESSION_H_ #include <cstdint> +#include <memory> #include <optional> #include "absl/strings/string_view.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" @@ -63,6 +67,10 @@ (override)); MOCK_METHOD(bool, PublishNamespaceDone, (TrackNamespace track_namespace), (override)); + MOCK_METHOD(std::unique_ptr<MoqtNamespaceTask>, SubscribeNamespace, + (TrackNamespace&, SubscribeNamespaceOption, + const MessageParameters&, MoqtResponseCallback), + (override)); quiche::QuicheWeakPtr<MoqtSessionInterface> GetWeakPtr() override { return weak_factory_.Create();
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc index 50f8ae2..a16e56b 100644 --- a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc +++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc
@@ -76,9 +76,6 @@ quiche::QuicheBuffer operator()(const MoqtSubscribeNamespace& message) { return framer.SerializeSubscribeNamespace(message); } - quiche::QuicheBuffer operator()(const MoqtUnsubscribeNamespace& message) { - return framer.SerializeUnsubscribeNamespace(message); - } quiche::QuicheBuffer operator()(const MoqtMaxRequestId& message) { return framer.SerializeMaxRequestId(message); } @@ -165,9 +162,6 @@ void OnSubscribeNamespaceMessage(const MoqtSubscribeNamespace& message) { frames_.push_back(message); } - void OnUnsubscribeNamespaceMessage(const MoqtUnsubscribeNamespace& message) { - frames_.push_back(message); - } void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) { frames_.push_back(message); }
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.h b/quiche/quic/moqt/test_tools/moqt_framer_utils.h index 77bb4c7..d2007a7 100644 --- a/quiche/quic/moqt/test_tools/moqt_framer_utils.h +++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.h
@@ -31,10 +31,9 @@ MoqtUnsubscribe, MoqtPublishDone, MoqtSubscribeUpdate, MoqtPublishNamespace, MoqtPublishNamespaceDone, MoqtNamespace, MoqtNamespaceDone, MoqtPublishNamespaceCancel, MoqtTrackStatus, - MoqtGoAway, MoqtSubscribeNamespace, MoqtUnsubscribeNamespace, - MoqtMaxRequestId, MoqtFetch, MoqtFetchCancel, MoqtFetchOk, - MoqtRequestsBlocked, MoqtPublish, MoqtPublishOk, - MoqtObjectAck>; + MoqtGoAway, MoqtSubscribeNamespace, MoqtMaxRequestId, + MoqtFetch, MoqtFetchCancel, MoqtFetchOk, MoqtRequestsBlocked, + MoqtPublish, MoqtPublishOk, MoqtObjectAck>; std::string SerializeGenericMessage(const MoqtGenericFrame& frame, bool use_webtrans = false);
diff --git a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h index 90807d3..b568a78 100644 --- a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h +++ b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
@@ -16,6 +16,7 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" @@ -28,6 +29,7 @@ #include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_mem_slice.h" +#include "quiche/common/quiche_weak_ptr.h" #include "quiche/web_transport/web_transport.h" namespace moqt::test { @@ -41,9 +43,9 @@ std::optional<VersionSpecificParameters>, MoqtResponseCallback)> incoming_publish_namespace_callback; - testing::MockFunction<void(const TrackNamespace&, - std::optional<MessageParameters>, - MoqtResponseCallback)> + testing::MockFunction<std::unique_ptr<MoqtNamespaceTask>( + const TrackNamespace&, SubscribeNamespaceOption, const MessageParameters&, + MoqtResponseCallback)> incoming_subscribe_namespace_callback; MockSessionCallbacks() { @@ -266,6 +268,36 @@ bool synchronous_object_available_ = false; }; +class MockNamespaceTask : public MoqtNamespaceTask { + public: + explicit MockNamespaceTask(const TrackNamespace& prefix) + : prefix_(prefix), weak_ptr_factory_(this) {} + void SetObjectsAvailableCallback(ObjectsAvailableCallback + absl_nullable callback) override { + callback_ = std::move(callback); + } + MOCK_METHOD(GetNextResult, GetNextSuffix, + (TrackNamespace & whole_namespace, TransactionType& type), + (override)); + MOCK_METHOD(std::optional<webtransport::StreamErrorCode>, GetStatus, (), + (override)); + const TrackNamespace& prefix() override { return prefix_; } + + void InvokeCallback() { + if (callback_ != nullptr) { + callback_(); + } + } + quiche::QuicheWeakPtr<MockNamespaceTask> GetWeakPtr() { + return weak_ptr_factory_.Create(); + } + + private: + ObjectsAvailableCallback callback_; + TrackNamespace prefix_; + quiche::QuicheWeakPtrFactory<MockNamespaceTask> weak_ptr_factory_; +}; + class MockMoqtObjectListener : public MoqtObjectListener { public: MOCK_METHOD(void, OnSubscribeAccepted, (), (override));
diff --git a/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h b/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h index 4c0072b..68f16b1 100644 --- a/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h +++ b/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h
@@ -97,10 +97,6 @@ const MoqtSubscribeNamespace& message) override { OnControlMessage(message); } - void OnUnsubscribeNamespaceMessage( - const MoqtUnsubscribeNamespace& message) override { - OnControlMessage(message); - } void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override { OnControlMessage(message); }
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 6d4eb73..051b99a 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -90,16 +90,14 @@ public: virtual ~TestMessageBase() = default; - using MessageStructuredData = - std::variant<MoqtClientSetup, MoqtServerSetup, MoqtObject, MoqtRequestOk, - MoqtRequestError, MoqtSubscribe, MoqtSubscribeOk, - MoqtUnsubscribe, MoqtPublishDone, MoqtSubscribeUpdate, - MoqtPublishNamespace, MoqtPublishNamespaceDone, - MoqtPublishNamespaceCancel, MoqtTrackStatus, MoqtGoAway, - MoqtSubscribeNamespace, MoqtUnsubscribeNamespace, - MoqtMaxRequestId, MoqtFetch, MoqtFetchCancel, MoqtFetchOk, - MoqtRequestsBlocked, MoqtPublish, MoqtPublishOk, - MoqtNamespace, MoqtNamespaceDone, MoqtObjectAck>; + using MessageStructuredData = std::variant< + MoqtClientSetup, MoqtServerSetup, MoqtObject, MoqtRequestOk, + MoqtRequestError, MoqtSubscribe, MoqtSubscribeOk, MoqtUnsubscribe, + MoqtPublishDone, MoqtSubscribeUpdate, MoqtPublishNamespace, + MoqtPublishNamespaceDone, MoqtPublishNamespaceCancel, MoqtTrackStatus, + MoqtGoAway, MoqtSubscribeNamespace, MoqtMaxRequestId, MoqtFetch, + MoqtFetchCancel, MoqtFetchOk, MoqtRequestsBlocked, MoqtPublish, + MoqtPublishOk, MoqtNamespace, MoqtNamespaceDone, MoqtObjectAck>; // The total actual size of the message. size_t total_message_size() const { return wire_image_size_; } @@ -1259,37 +1257,6 @@ }; }; -class QUICHE_NO_EXPORT UnsubscribeNamespaceMessage : public TestMessageBase { - public: - UnsubscribeNamespaceMessage() : TestMessageBase() { - SetWireImage(raw_packet_, sizeof(raw_packet_)); - } - - bool EqualFieldValues(MessageStructuredData& values) const override { - auto cast = std::get<MoqtUnsubscribeNamespace>(values); - if (cast.track_namespace != unsubscribe_namespace_.track_namespace) { - QUIC_LOG(INFO) << "UNSUBSCRIBE_NAMESPACE track namespace mismatch"; - return false; - } - return true; - } - - void ExpandVarints() override { ExpandVarintsImpl("vv---"); } - - MessageStructuredData structured_data() const override { - return TestMessageBase::MessageStructuredData(unsubscribe_namespace_); - } - - private: - uint8_t raw_packet_[8] = { - 0x14, 0x00, 0x05, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace - }; - - MoqtUnsubscribeNamespace unsubscribe_namespace_ = { - TrackNamespace("foo"), - }; -}; - class QUICHE_NO_EXPORT MaxRequestIdMessage : public TestMessageBase { public: MaxRequestIdMessage() : TestMessageBase() { @@ -1880,8 +1847,6 @@ return std::make_unique<GoAwayMessage>(); case MoqtMessageType::kSubscribeNamespace: return std::make_unique<SubscribeNamespaceMessage>(); - case MoqtMessageType::kUnsubscribeNamespace: - return std::make_unique<UnsubscribeNamespaceMessage>(); case MoqtMessageType::kMaxRequestId: return std::make_unique<MaxRequestIdMessage>(); case MoqtMessageType::kFetch:
diff --git a/quiche/quic/moqt/tools/chat_client.cc b/quiche/quic/moqt/tools/chat_client.cc index d5900b7..0baf539 100644 --- a/quiche/quic/moqt/tools/chat_client.cc +++ b/quiche/quic/moqt/tools/chat_client.cc
@@ -14,8 +14,10 @@ #include <utility> #include <variant> +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/bind_front.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/crypto/proof_verifier.h" #include "quiche/quic/core/io/quic_default_event_loop.h" @@ -23,8 +25,10 @@ #include "quiche/quic/core/quic_default_clock.h" #include "quiche/quic/core/quic_server_id.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_known_track_publisher.h" +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_object.h" #include "quiche/quic/moqt/moqt_outgoing_queue.h" @@ -46,13 +50,13 @@ void ChatClient::OnIncomingPublishNamespace( const moqt::TrackNamespace& track_namespace, std::optional<VersionSpecificParameters> parameters, - moqt::MoqtResponseCallback callback) { + moqt::MoqtResponseCallback absl_nullable callback) { if (!session_is_open_) { return; } if (track_namespace == GetUserNamespace(my_track_name_)) { // Ignore PUBLISH_NAMESPACE for my own track. - if (parameters.has_value()) { // callback exists. + if (parameters.has_value() && callback != nullptr) { // callback exists. std::move(callback)(std::nullopt); } return; @@ -70,20 +74,26 @@ std::cout << "PUBLISH_NAMESPACE for " << track_namespace.ToString() << "\n"; if (!track_name.has_value()) { std::cout << "PUBLISH_NAMESPACE rejected, invalid namespace\n"; - std::move(callback)(std::make_optional<MoqtRequestErrorInfo>( - RequestErrorCode::kTrackDoesNotExist, std::nullopt, - "Not a subscribed namespace")); + if (callback != nullptr) { + std::move(callback)(std::make_optional<MoqtRequestErrorInfo>( + RequestErrorCode::kTrackDoesNotExist, std::nullopt, + "Not a subscribed namespace")); + } return; } if (other_users_.contains(*track_name)) { std::cout << "Duplicate PUBLISH_NAMESPACE, send OK and ignore\n"; - std::move(callback)(std::nullopt); + if (callback != nullptr) { + std::move(callback)(std::nullopt); + } return; } if (GetUsername(my_track_name_) == GetUsername(*track_name)) { std::cout << "PUBLISH_NAMESPACE for a previous instance of my track, " "do not subscribe\n"; - std::move(callback)(std::nullopt); + if (callback != nullptr) { + std::move(callback)(std::nullopt); + } return; } MessageParameters subscribe_parameters(MoqtFilterType::kLargestObject); @@ -97,7 +107,9 @@ subscribe_parameters)) { ++subscribes_to_make_; } - std::move(callback)(std::nullopt); // Send PUBLISH_NAMESPACE_OK. + if (callback != nullptr) { + std::move(callback)(std::nullopt); // Send PUBLISH_NAMESPACE_OK. + } } ChatClient::ChatClient(const quic::QuicServerId& server_id, @@ -170,7 +182,7 @@ } if (input_message == "/exit") { // Clean teardown of SUBSCRIBE_NAMESPACE, PUBLISH_NAMESPACE, SUBSCRIBE. - session_->UnsubscribeNamespace(GetChatNamespace(my_track_name_)); + namespace_task_.reset(); // TODO(martinduke): Add a session API to send PUBLISH_DONE. session_->PublishNamespaceDone(GetUserNamespace(my_track_name_)); for (const auto& track_name : other_users_) { @@ -272,9 +284,9 @@ // Send SUBSCRIBE_NAMESPACE. Pop 3 levels of namespace to get to // {moq-chat, chat-id} bool subscribe_response_received = false; - MoqtOutgoingSubscribeNamespaceCallback subscribe_namespace_callback = - [&, this](TrackNamespace track_namespace, - std::optional<MoqtRequestErrorInfo> error) { + TrackNamespace prefix = GetChatNamespace(my_track_name_); + MoqtResponseCallback response_callback = + [&, this, prefix](std::optional<MoqtRequestErrorInfo> error) { subscribe_response_received = true; if (error.has_value()) { std::cout << "SUBSCRIBE_NAMESPACE rejected, " << error->reason_phrase @@ -283,16 +295,50 @@ "Local SUBSCRIBE_NAMESPACE rejected"); return; } - std::cout << "SUBSCRIBE_NAMESPACE for " << track_namespace.ToString() + std::cout << "SUBSCRIBE_NAMESPACE for " << prefix.ToString() << " accepted\n"; return; }; MessageParameters parameters; parameters.authorization_tokens.emplace_back( AuthTokenType::kOutOfBand, std::string(GetUsername(my_track_name_))); - session_->SubscribeNamespace(GetChatNamespace(my_track_name_), - std::move(subscribe_namespace_callback), - parameters); + namespace_task_ = + session_->SubscribeNamespace(prefix, SubscribeNamespaceOption::kNamespace, + parameters, std::move(response_callback)); + if (namespace_task_ != nullptr) { + namespace_task_->SetObjectsAvailableCallback( + [this]() { + TrackNamespace suffix; + TransactionType type; + for (;;) { + GetNextResult result = namespace_task_->GetNextSuffix(suffix, type); + switch (result) { + case GetNextResult::kPending: + return; + case GetNextResult::kEof: + return; + case GetNextResult::kError: + std::cerr << "Error: received error from namespace task\n"; + return; + case GetNextResult::kSuccess: + absl::StatusOr<TrackNamespace> track_namespace = + namespace_task_->prefix().AddSuffix(suffix); + if (!track_namespace.ok()) { + std::cerr + << "Error: received invalid suffix from namespace task\n"; + return; + } + OnIncomingPublishNamespace( + *track_namespace, + (type == TransactionType::kAdd) + ? std::make_optional(VersionSpecificParameters()) + : std::nullopt, + /*callback=*/nullptr); + break; + } + } + }); + } while (session_is_open_ && !subscribe_response_received) { RunEventLoop();
diff --git a/quiche/quic/moqt/tools/chat_client.h b/quiche/quic/moqt/tools/chat_client.h index 2c9ffd8..a042e8d 100644 --- a/quiche/quic/moqt/tools/chat_client.h +++ b/quiche/quic/moqt/tools/chat_client.h
@@ -9,13 +9,18 @@ #include <optional> #include <variant> +#include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/io/quic_event_loop.h" #include "quiche/quic/core/quic_server_id.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_known_track_publisher.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_object.h" #include "quiche/quic/moqt/moqt_outgoing_queue.h" #include "quiche/quic/moqt/moqt_session.h" @@ -131,11 +136,13 @@ private: void RunEventLoop() { event_loop_->RunEventLoopOnce(kChatEventLoopDuration); } - // Callback for incoming publish_namespaces. + // Callback for incoming publish_namespaces. If |parameters| is not nullopt, + // it's a PUBLISH_NAMESPACE or NAMESPACE. If |callback| is not nullptr, it's + // a PUBLISH_NAMESPACE. void OnIncomingPublishNamespace( const moqt::TrackNamespace& track_namespace, std::optional<VersionSpecificParameters> parameters, - moqt::MoqtResponseCallback callback); + moqt::MoqtResponseCallback absl_nullable callback); // Basic session information FullTrackName my_track_name_; @@ -155,6 +162,7 @@ moqt::MoqtKnownTrackPublisher publisher_; std::unique_ptr<moqt::MoqtClient> client_; moqt::MoqtSessionCallbacks session_callbacks_; + std::unique_ptr<moqt::MoqtNamespaceTask> namespace_task_; // Related to syncing. absl::flat_hash_set<FullTrackName> other_users_;
diff --git a/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc b/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc index 81ce8ac..18d915c 100644 --- a/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc +++ b/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
@@ -10,8 +10,8 @@ #include "absl/strings/string_view.h" #include "quiche/quic/core/io/quic_event_loop.h" #include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_names.h" -#include "quiche/quic/moqt/test_tools/mock_moqt_session.h" #include "quiche/quic/moqt/tools/chat_client.h" #include "quiche/quic/moqt/tools/moq_chat.h" #include "quiche/quic/moqt/tools/moqt_relay.h" @@ -146,27 +146,34 @@ SendAndWaitForOutput(interface2_, interface1_, "client2", "Hi"); // Add a probe to see how many publishers are tracked on the relay. - MockMoqtSession namespace_probe; - EXPECT_CALL(namespace_probe, PublishNamespace).Times(3); - relay_.publisher()->AddNamespaceSubscriber( - TrackNamespace(moq_chat::kBasePath), &namespace_probe); + int namespaces_announced = 0; + TrackNamespace last_suffix; + TransactionType last_type; + std::unique_ptr<MoqtNamespaceTask> namespace_probe = + relay_.publisher()->AddNamespaceSubscriber( + TrackNamespace(moq_chat::kBasePath), nullptr); + namespace_probe->SetObjectsAvailableCallback([&]() { + while (namespace_probe->GetNextSuffix(last_suffix, last_type) == kSuccess) { + if (last_type == TransactionType::kAdd) { + ++namespaces_announced; + } else { + --namespaces_announced; + } + } + }); + EXPECT_EQ(namespaces_announced, 2); - bool namespace_done = false; - EXPECT_CALL(namespace_probe, PublishNamespaceDone) - .WillOnce([&](const TrackNamespace&) { - namespace_done = true; - return true; - }); interface1_->SendMessage("/exit"); while (client1_->session_is_open()) { relay_.server()->WaitForEvents(); } client1_.reset(); - while (!namespace_done) { + while (last_type != TransactionType::kDelete) { // Wait for the relay's session cleanup to send PUBLISH_NAMESPACE_DONE // and PUBLISH_DONE. relay_.server()->WaitForEvents(); } + EXPECT_EQ(namespaces_announced, 1); // Create a new client with the same username and Reconnect. auto if1bptr = std::make_unique<MockChatUserInterface>(); MockChatUserInterface* interface1b_ = if1bptr.get(); @@ -182,9 +189,9 @@ } SendAndWaitForOutput(interface1b_, interface2_, "client1", "Hello again"); SendAndWaitForOutput(interface2_, interface1b_, "client2", "Hi again"); + EXPECT_EQ(namespaces_announced, 2); // Cleanup the probe. - relay_.publisher()->RemoveNamespaceSubscriber( - TrackNamespace(moq_chat::kBasePath), &namespace_probe); + namespace_probe.reset(); } } // namespace test
diff --git a/quiche/quic/moqt/tools/moqt_relay.cc b/quiche/quic/moqt/tools/moqt_relay.cc index 45e84d0..cc7dd54 100644 --- a/quiche/quic/moqt/tools/moqt_relay.cc +++ b/quiche/quic/moqt/tools/moqt_relay.cc
@@ -16,7 +16,9 @@ #include "quiche/quic/core/crypto/proof_verifier.h" #include "quiche/quic/core/io/quic_event_loop.h" #include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_session.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" @@ -123,19 +125,27 @@ } }; session->callbacks().incoming_subscribe_namespace_callback = - [this, session](const TrackNamespace& track_namespace, - const std::optional<MessageParameters>& parameters, - MoqtResponseCallback callback) { - if (is_closing_) { - return; - } - if (parameters.has_value()) { - publisher_.AddNamespaceSubscriber(track_namespace, session); - std::move(callback)(std::nullopt); - } else { - publisher_.RemoveNamespaceSubscriber(track_namespace, session); - } - }; + [this, session](const TrackNamespace& prefix, + SubscribeNamespaceOption option, + const MessageParameters& parameters, + MoqtResponseCallback response_callback) + -> std::unique_ptr<MoqtNamespaceTask> { + if (is_closing_) { + return nullptr; + } + std::unique_ptr<MoqtNamespaceTask> task; + switch (option) { + case SubscribeNamespaceOption::kNamespace: + task = publisher_.AddNamespaceSubscriber(prefix, nullptr); + break; + case SubscribeNamespaceOption::kBoth: + case SubscribeNamespaceOption::kPublish: + task = publisher_.AddNamespaceSubscriber(prefix, session); + break; + } + std::move(response_callback)(std::nullopt); + return task; + }; } absl::StatusOr<MoqtConfigureSessionCallback> MoqtRelay::IncomingSessionHandler(
diff --git a/quiche/quic/moqt/tools/moqt_relay_test.cc b/quiche/quic/moqt/tools/moqt_relay_test.cc index 83210cb..8aa2581 100644 --- a/quiche/quic/moqt/tools/moqt_relay_test.cc +++ b/quiche/quic/moqt/tools/moqt_relay_test.cc
@@ -16,7 +16,9 @@ #include "quiche/quic/core/io/quic_event_loop.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_relay_publisher.h" @@ -154,36 +156,37 @@ TEST_F(MoqtRelayTest, SubscribeNamespace) { TrackNamespace foo({"foo"}), foobar({"foo", "bar"}), foobaz({"foo", "baz"}); - // These will be used to ascertain the namespace state. - MockMoqtSession relay_probe, upstream_probe; std::set<TrackNamespace> relay_published_namespaces, upstream_published_namespaces; - EXPECT_CALL(relay_probe, PublishNamespace) - .WillRepeatedly([&](TrackNamespace track_namespace, - MoqtOutgoingPublishNamespaceCallback callback, - VersionSpecificParameters) { - relay_published_namespaces.insert(track_namespace); - (callback)(track_namespace, std::nullopt); - }); - EXPECT_CALL(relay_probe, PublishNamespaceDone) - .WillRepeatedly([&](TrackNamespace track_namespace) { - relay_published_namespaces.erase(track_namespace); - return true; - }); - EXPECT_CALL(upstream_probe, PublishNamespace) - .WillRepeatedly([&](TrackNamespace track_namespace, - MoqtOutgoingPublishNamespaceCallback callback, - VersionSpecificParameters) { - upstream_published_namespaces.insert(track_namespace); - (callback)(track_namespace, std::nullopt); - }); - EXPECT_CALL(upstream_probe, PublishNamespaceDone) - .WillRepeatedly([&](TrackNamespace track_namespace) { - upstream_published_namespaces.erase(track_namespace); - return true; - }); - relay_.publisher()->AddNamespaceSubscriber(foo, &relay_probe); - upstream_.publisher()->AddNamespaceSubscriber(foo, &upstream_probe); + // These will be used to ascertain the namespace state. + TrackNamespace suffix; + TransactionType type; + std::unique_ptr<MoqtNamespaceTask> relay_probe = + relay_.publisher()->AddNamespaceSubscriber(foo, nullptr); + relay_probe->SetObjectsAvailableCallback([&]() { + while (relay_probe->GetNextSuffix(suffix, type) == kSuccess) { + if (type == TransactionType::kAdd) { + relay_published_namespaces.insert( + *relay_probe->prefix().AddSuffix(suffix)); + } else { + relay_published_namespaces.erase( + *relay_probe->prefix().AddSuffix(suffix)); + } + } + }); + std::unique_ptr<MoqtNamespaceTask> upstream_probe = + upstream_.publisher()->AddNamespaceSubscriber(foo, nullptr); + upstream_probe->SetObjectsAvailableCallback([&]() { + while (upstream_probe->GetNextSuffix(suffix, type) == kSuccess) { + if (type == TransactionType::kAdd) { + upstream_published_namespaces.insert( + *upstream_probe->prefix().AddSuffix(suffix)); + } else { + upstream_published_namespaces.erase( + *upstream_probe->prefix().AddSuffix(suffix)); + } + } + }); MoqtSession* upstream_session = absl::down_cast<MoqtSession*>(upstream_.last_server_session); // Downstream publishes a namespace. It's stored in relay_ but upstream_ @@ -197,9 +200,23 @@ EXPECT_TRUE(upstream_published_namespaces.empty()); // Upstream subscribes. Now it's notified and forwards it to the probe. - upstream_session->SubscribeNamespace( - foo, [](TrackNamespace, std::optional<MoqtRequestErrorInfo>) {}, - MessageParameters()); + std::unique_ptr<MoqtNamespaceTask> task = + upstream_session->SubscribeNamespace( + foo, SubscribeNamespaceOption::kNamespace, MessageParameters(), + [](std::optional<MoqtRequestErrorInfo>) {}); + EXPECT_NE(task, nullptr); + task->SetObjectsAvailableCallback([&]() { + while (task->GetNextSuffix(suffix, type) == kSuccess) { + if (type == TransactionType::kAdd) { + upstream_.publisher()->OnPublishNamespace( + *task->prefix().AddSuffix(suffix), VersionSpecificParameters(), + upstream_session, nullptr); + } else { + upstream_.publisher()->OnPublishNamespaceDone( + *task->prefix().AddSuffix(suffix), upstream_session); + } + } + }); upstream_.RunOneEvent(); upstream_.RunOneEvent(); EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobar)); @@ -221,7 +238,9 @@ EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobaz)); // upstream_ unsubscribes. New PUBLISH_NAMESPACE_DONE doesn't arrive. - upstream_session->UnsubscribeNamespace(foo); + task.reset(); + upstream_.RunOneEvent(); + relay_.RunOneEvent(); downstream_.client_session()->PublishNamespaceDone(foobaz); upstream_.RunOneEvent(); relay_.RunOneEvent(); @@ -229,8 +248,8 @@ EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobaz)); // Remove the probes to avoid accessing an invalid WeakPtr on teardown. - relay_.publisher()->RemoveNamespaceSubscriber(foo, &relay_probe); - upstream_.publisher()->RemoveNamespaceSubscriber(foo, &upstream_probe); + relay_probe.reset(); + upstream_probe.reset(); } #if 0 // TODO(martinduke): Re-enable these tests when GOAWAY support exists.