Support REQUEST_UPDATE for SUBSCRIBE_NAMESPACE doesn't actually do anything because only Auth matters. PiperOrigin-RevId: 866694952
diff --git a/quiche/quic/moqt/moqt_fetch_task.h b/quiche/quic/moqt/moqt_fetch_task.h index f18d21f..0b97b82 100644 --- a/quiche/quic/moqt/moqt_fetch_task.h +++ b/quiche/quic/moqt/moqt_fetch_task.h
@@ -14,6 +14,7 @@ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_object.h" @@ -22,6 +23,12 @@ namespace moqt { +// The callback we'll use for all request types going forward. Can only be used +// once; if the argument is nullopt, an OK response was received. Otherwise, an +// ERROR response was received. +using MoqtResponseCallback = + quiche::SingleUseCallback<void(std::optional<MoqtRequestErrorInfo>)>; + // TODO(martinduke): There are will be multiple instances of flow-controlled // "pull" data retrieval tasks. It might be worthwhile to extract some common // features into a base class. @@ -106,6 +113,10 @@ // Returns the error if request has completely failed, and nullopt otherwise. virtual std::optional<webtransport::StreamErrorCode> GetStatus() = 0; + // Handle a REQUEST_UPDATE message. + virtual void Update(const MessageParameters& parameters, + MoqtResponseCallback response_callback) = 0; + // Returns the prefix for this task. virtual const TrackNamespace& prefix() = 0; };
diff --git a/quiche/quic/moqt/moqt_namespace_stream.cc b/quiche/quic/moqt/moqt_namespace_stream.cc index ca59fd6..206d04c 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.cc +++ b/quiche/quic/moqt/moqt_namespace_stream.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/moqt/moqt_namespace_stream.h" +#include <cstdint> #include <memory> #include <optional> #include <utility> @@ -40,33 +41,57 @@ void MoqtNamespaceSubscriberStream::OnRequestOkMessage( const MoqtRequestOk& message) { - if (message.request_id != request_id_) { + if (message.request_id == request_id_) { + // Response to the initial SUBSCRIBE_NAMESPACE. + if (response_callback_ == nullptr) { + OnParsingError(MoqtError::kProtocolViolation, "Two responses"); + return; + } + std::move(response_callback_)(std::nullopt); + response_callback_ = nullptr; + return; + } + NamespaceTask* task = task_.GetIfAvailable(); + if (task == nullptr) { + // The application has already unsubscribed, and the stream has been reset. + // This is irrelevant. + return; + } + MoqtResponseCallback callback = task->GetResponseCallback(message.request_id); + if (callback == nullptr) { OnParsingError(MoqtError::kProtocolViolation, "Unexpected request ID in response"); return; } - if (response_callback_ == nullptr) { - OnParsingError(MoqtError::kProtocolViolation, "Two responses"); - return; - } - std::move(response_callback_)(std::nullopt); - response_callback_ = nullptr; + std::move(callback)(std::nullopt); } void MoqtNamespaceSubscriberStream::OnRequestErrorMessage( const MoqtRequestError& message) { - if (message.request_id != request_id_) { + if (message.request_id == request_id_) { + if (response_callback_ == nullptr) { + OnParsingError(MoqtError::kProtocolViolation, "Two responses"); + return; + } + std::move(response_callback_)(MoqtRequestErrorInfo{ + message.error_code, message.retry_interval, message.reason_phrase}); + response_callback_ = nullptr; + return; + } + NamespaceTask* task = task_.GetIfAvailable(); + if (task == nullptr) { + // The application has already unsubscribed, and the stream has been reset. + // This is irrelevant. + return; + } + MoqtResponseCallback callback = task->GetResponseCallback(message.request_id); + if (callback == nullptr) { OnParsingError(MoqtError::kProtocolViolation, - "Unexpected request ID in error"); + "Unexpected request ID in response"); return; } - if (response_callback_ == nullptr) { - OnParsingError(MoqtError::kProtocolViolation, "Two responses"); - return; - } - std::move(response_callback_)(MoqtRequestErrorInfo{ + std::move(callback)(MoqtRequestErrorInfo{ message.error_code, message.retry_interval, message.reason_phrase}); - response_callback_ = nullptr; } void MoqtNamespaceSubscriberStream::OnNamespaceMessage( @@ -152,6 +177,21 @@ } } +void MoqtNamespaceSubscriberStream::NamespaceTask::Update( + const MessageParameters& parameters, + MoqtResponseCallback response_callback) { + if (state_ == nullptr) { + std::move(response_callback)( + MoqtRequestErrorInfo{RequestErrorCode::kInternalError, std::nullopt, + "Stream has been reset"}); + return; + } + MoqtRequestUpdate message{next_request_id_, state_->request_id_, parameters}; + pending_updates_[message.request_id] = std::move(response_callback); + state_->SendOrBufferMessage(state_->framer_->SerializeRequestUpdate(message)); + next_request_id_ += 2; +} + GetNextResult MoqtNamespaceSubscriberStream::NamespaceTask::GetNextSuffix( TrackNamespace& suffix, TransactionType& type) { if (pending_suffixes_.empty()) { @@ -195,6 +235,18 @@ } } +MoqtResponseCallback +MoqtNamespaceSubscriberStream::NamespaceTask::GetResponseCallback( + uint64_t request_id) { + auto it = pending_updates_.find(request_id); + if (it == pending_updates_.end()) { + return nullptr; + } + MoqtResponseCallback callback = std::move(it->second); + pending_updates_.erase(it); + return callback; +} + MoqtNamespacePublisherStream::MoqtNamespacePublisherStream( MoqtFramer* framer, webtransport::Stream* stream, SessionErrorCallback session_error_callback, @@ -250,6 +302,23 @@ } } +void MoqtNamespacePublisherStream::OnRequestUpdateMessage( + const MoqtRequestUpdate& message) { + if (task_ == nullptr) { + // This stream is dying. + return; + } + task_->Update(message.parameters, + [this, request_id = message.request_id]( + std::optional<MoqtRequestErrorInfo> error) { + if (error.has_value()) { + SendRequestError(request_id, *error, /*fin=*/false); + } else { + SendRequestOk(request_id, MessageParameters()); + } + }); +} + void MoqtNamespacePublisherStream::ProcessNamespaces() { if (task_ == nullptr) { return;
diff --git a/quiche/quic/moqt/moqt_namespace_stream.h b/quiche/quic/moqt/moqt_namespace_stream.h index 3034cab..ddc63e2 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.h +++ b/quiche/quic/moqt/moqt_namespace_stream.h
@@ -12,10 +12,12 @@ #include <utility> #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "quiche/quic/moqt/moqt_bidi_stream.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" @@ -61,6 +63,7 @@ : MoqtNamespaceTask(), prefix_(prefix), state_(state), + next_request_id_(state->request_id_ + 2), weak_ptr_factory_(this) {} ~NamespaceTask() override; @@ -75,6 +78,8 @@ return error_; } const TrackNamespace& prefix() override { return prefix_; } + void Update(const MessageParameters& parameters, + MoqtResponseCallback response_callback) override; // Queues a suffix corresponding to a NAMESPACE (if |type| is kAdd) or a // NAMESPACE_DONE (if |type| is kDelete). @@ -82,6 +87,7 @@ // The stream is closed, so no more NAMESPACE messages are forthcoming. // This is an implicit NAMESPACE_DONE for all published namespaces. void DeclareEof(); + MoqtResponseCallback GetResponseCallback(uint64_t request_id); quiche::QuicheWeakPtr<NamespaceTask> GetWeakPtr() { return weak_ptr_factory_.Create(); } @@ -100,6 +106,8 @@ ObjectsAvailableCallback absl_nullable callback_ = nullptr; std::optional<webtransport::StreamErrorCode> error_; bool eof_ = false; + uint64_t next_request_id_; + absl::flat_hash_map<uint64_t, MoqtResponseCallback> pending_updates_; // Must be last. quiche::QuicheWeakPtrFactory<NamespaceTask> weak_ptr_factory_; }; @@ -123,10 +131,7 @@ // MoqtBidiStreamBase overrides. void OnSubscribeNamespaceMessage( const MoqtSubscribeNamespace& message) override; - // TODO(martinduke): Implement this. - void OnRequestUpdateMessage(const MoqtRequestUpdate&) override { - QUICHE_DLOG(INFO) << "Got REQUEST_UPDATE on Namespace stream"; - } + void OnRequestUpdateMessage(const MoqtRequestUpdate&) override; private: void ProcessNamespaces();
diff --git a/quiche/quic/moqt/moqt_namespace_stream_test.cc b/quiche/quic/moqt/moqt_namespace_stream_test.cc index 97a0cb1..e538c27 100644 --- a/quiche/quic/moqt/moqt_namespace_stream_test.cc +++ b/quiche/quic/moqt/moqt_namespace_stream_test.cc
@@ -48,6 +48,7 @@ task_(stream_.CreateTask(kPrefix)) { task_->SetObjectsAvailableCallback([this]() { ++objects_available_; }); stream_.set_stream(&mock_stream_); + ON_CALL(mock_stream_, CanWrite()).WillByDefault(Return(true)); } void CheckNumberOfObjectsAvailable(int expected_count) { @@ -85,7 +86,7 @@ TEST_F(MoqtNamespaceSubscriberStreamTest, RequestErrorWrongId) { EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, - "Unexpected request ID in error")); + "Unexpected request ID in response")); stream_.OnRequestErrorMessage( {kRequestId + 1, RequestErrorCode::kInternalError, quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); @@ -217,6 +218,36 @@ EXPECT_EQ(task->GetNextSuffix(received_namespace, type), kEof); } +TEST_F(MoqtNamespaceSubscriberStreamTest, UpdateAndRequestOk) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _)); + MessageParameters update_params; + update_params.subscriber_priority = 10; + testing::MockFunction<void(std::optional<MoqtRequestErrorInfo>)> + update_response_callback; + task_->Update(update_params, update_response_callback.AsStdFunction()); + EXPECT_CALL(update_response_callback, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId + 2}); +} + +TEST_F(MoqtNamespaceSubscriberStreamTest, UpdateAndRequestError) { + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream_.OnRequestOkMessage({kRequestId}); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _)); + MessageParameters update_params; + update_params.subscriber_priority = 10; + testing::MockFunction<void(std::optional<MoqtRequestErrorInfo>)> + update_response_callback; + task_->Update(update_params, update_response_callback.AsStdFunction()); + EXPECT_CALL(update_response_callback, Call(_)); + stream_.OnRequestErrorMessage( + {kRequestId + 2, RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); +} + class MoqtNamespacePublisherStreamTest : public quiche::test::QuicheTest { public: MoqtNamespacePublisherStreamTest() @@ -330,6 +361,88 @@ stream_.OnSubscribeNamespaceMessage(message); } +TEST_F(MoqtNamespacePublisherStreamTest, RequestUpdateOk) { + MoqtSubscribeNamespace message = { + kRequestId, + TrackNamespace({"foo"}), + SubscribeNamespaceOption::kNamespace, + MessageParameters(), + }; + MockNamespaceTask* task_ptr = nullptr; + EXPECT_CALL(mock_application_, Call) + .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { + std::move(response_callback)(std::nullopt); + auto task = + std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); + task_ptr = task.get(); + return task; + }); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); + stream_.OnSubscribeNamespaceMessage(message); + ASSERT_TRUE(task_ptr != nullptr); + + // Now send RequestUpdate + MoqtRequestUpdate update_message = { + kRequestId + 2, + kRequestId, + MessageParameters(), + }; + update_message.parameters.subscriber_priority = 10; + EXPECT_CALL(*task_ptr, Update(_, _)) + .WillOnce([&](const MessageParameters& params, MoqtResponseCallback cb) { + EXPECT_EQ(params.subscriber_priority, 10); + std::move(cb)(std::nullopt); + }); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); + stream_.OnRequestUpdateMessage(update_message); +} + +TEST_F(MoqtNamespacePublisherStreamTest, RequestUpdateError) { + MoqtSubscribeNamespace message = { + kRequestId, + TrackNamespace({"foo"}), + SubscribeNamespaceOption::kNamespace, + MessageParameters(), + }; + MockNamespaceTask* task_ptr = nullptr; + EXPECT_CALL(mock_application_, Call) + .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption, + const MessageParameters&, + MoqtResponseCallback response_callback) { + std::move(response_callback)(std::nullopt); + auto task = + std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); + task_ptr = task.get(); + return task; + }); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); + stream_.OnSubscribeNamespaceMessage(message); + ASSERT_TRUE(task_ptr != nullptr); + + // Now send RequestUpdate + MoqtRequestUpdate update_message = { + kRequestId + 2, + kRequestId, + MessageParameters(), + }; + update_message.parameters.subscriber_priority = 10; + EXPECT_CALL(*task_ptr, Update(_, _)) + .WillOnce([&](const MessageParameters& params, MoqtResponseCallback cb) { + EXPECT_EQ(params.subscriber_priority, 10); + std::move(cb)(MoqtRequestErrorInfo{ + RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); + }); + EXPECT_CALL(mock_stream_, + Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); + stream_.OnRequestUpdateMessage(update_message); +} + TEST_F(MoqtNamespacePublisherStreamTest, SubscribePrefixOverlap) { MoqtSubscribeNamespace message = { kRequestId,
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h index d43befb..2d070b7 100644 --- a/quiche/quic/moqt/moqt_session_callbacks.h +++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -21,12 +21,6 @@ namespace moqt { -// The callback we'll use for all request types going forward. Can only be used -// once; if the argument is nullopt, an OK response was received. Otherwise, an -// ERROR response was received. -using MoqtResponseCallback = - quiche::SingleUseCallback<void(std::optional<MoqtRequestErrorInfo>)>; - // Called when the SETUP message from the peer is received. using MoqtSessionEstablishedCallback = quiche::SingleUseCallback<void()>;
diff --git a/quiche/quic/moqt/relay_namespace_tree.cc b/quiche/quic/moqt/relay_namespace_tree.cc index 8c989ed..2ac4c07 100644 --- a/quiche/quic/moqt/relay_namespace_tree.cc +++ b/quiche/quic/moqt/relay_namespace_tree.cc
@@ -16,6 +16,7 @@ #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/common/platform/api/quiche_bug_tracker.h" @@ -44,6 +45,12 @@ } } +void RelayNamespaceTree::RelayNamespaceListener::Update( + const MessageParameters&, MoqtResponseCallback response_callback) { + // Don't do anything! + std::move(response_callback)(std::nullopt); +} + GetNextResult RelayNamespaceTree::RelayNamespaceListener::GetNextSuffix( TrackNamespace& suffix, TransactionType& type) { if (eof_) {
diff --git a/quiche/quic/moqt/relay_namespace_tree.h b/quiche/quic/moqt/relay_namespace_tree.h index 8add33a..2b58715 100644 --- a/quiche/quic/moqt/relay_namespace_tree.h +++ b/quiche/quic/moqt/relay_namespace_tree.h
@@ -10,13 +10,13 @@ #include <memory> #include <optional> #include <string> -#include <utility> #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_fetch_task.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/common/quiche_circular_deque.h" @@ -55,6 +55,8 @@ return error_; } const TrackNamespace& prefix() override { return prefix_; } + void Update(const MessageParameters& parameters, + MoqtResponseCallback response_callback) override; // Queues a suffix corresponding to a NAMESPACE (if |type| is kAdd) or a // NAMESPACE_DONE (if |type| is kDelete).
diff --git a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h index b568a78..74dcf57 100644 --- a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h +++ b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
@@ -11,6 +11,7 @@ #include <utility> #include <variant> +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -282,6 +283,10 @@ MOCK_METHOD(std::optional<webtransport::StreamErrorCode>, GetStatus, (), (override)); const TrackNamespace& prefix() override { return prefix_; } + MOCK_METHOD(void, Update, + (const MessageParameters& parameters, + MoqtResponseCallback response_callback), + (override)); void InvokeCallback() { if (callback_ != nullptr) {