Make improvements to MoqtNamespaceStream. Apologies for having another CL on this topic, but there are several fixes here: - Setting the ObjectsAvailableCallback is no longer a separate operation in the NamespaceTask. This is an anti-pattern from FETCH that is suboptimal and causes problems when retrieving namespaces from the relay tree. - IncomingSubscribeNamespaceCallback took std::optional<parameters> because it was also invoked on Unsubscribe. Now that is achieved by destroying the task, so Parameters are always present. - There was a mistake in the test where task_ was assigned twice. - There was no actual backpressure in the ObjectsAvailableCallback implementation. - Got rid of std::move() when invoking multiuse callbacks. PiperOrigin-RevId: 865435838
diff --git a/quiche/quic/moqt/moqt_bidi_stream.h b/quiche/quic/moqt/moqt_bidi_stream.h index 10fafcb..2568c52 100644 --- a/quiche/quic/moqt/moqt_bidi_stream.h +++ b/quiche/quic/moqt/moqt_bidi_stream.h
@@ -180,6 +180,10 @@ } } + bool QueueIsFull() const { + return pending_messages_.size() == kMaxPendingMessages; + } + void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false) { if (fin_queued_) { return;
diff --git a/quiche/quic/moqt/moqt_bidi_stream_test.cc b/quiche/quic/moqt/moqt_bidi_stream_test.cc index 1cf8beb..83c92ea 100644 --- a/quiche/quic/moqt/moqt_bidi_stream_test.cc +++ b/quiche/quic/moqt/moqt_bidi_stream_test.cc
@@ -257,4 +257,19 @@ stream_.reset(); } +TEST_F(MoqtBidiStreamTest, PendingQueueFull) { + stream_->set_stream(&mock_stream_); + EXPECT_CALL(mock_stream_, CanWrite).WillRepeatedly(Return(false)); + for (int i = 0; i < 100; ++i) { // kMaxPendingMessages = 100. + EXPECT_FALSE(stream_->QueueIsFull()); + stream_->SendOrBufferMessage( + framer_.SerializeSubscribeUpdate(MoqtSubscribeUpdate{})); + } + EXPECT_TRUE(stream_->QueueIsFull()); + EXPECT_CALL(error_callback_, Call(MoqtError::kInternalError, _)); + stream_->SendOrBufferMessage( + framer_.SerializeSubscribeUpdate(MoqtSubscribeUpdate{})); + EXPECT_TRUE(stream_->QueueIsFull()); +} + } // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_fetch_task.h b/quiche/quic/moqt/moqt_fetch_task.h index 56391ab..adc5615 100644 --- a/quiche/quic/moqt/moqt_fetch_task.h +++ b/quiche/quic/moqt/moqt_fetch_task.h
@@ -98,12 +98,6 @@ virtual GetNextResult GetNextSuffix(TrackNamespace& suffix, TransactionType& type) = 0; - // Sets the callback that is called when a NAMESPACE or NAMESPACE_DONE message - // is received. If a message is available immediately, the callback will be - // called immediately. - virtual void SetObjectAvailableCallback( - ObjectsAvailableCallback callback) = 0; - // Returns the error if request has completely failed, and nullopt otherwise. virtual std::optional<webtransport::StreamErrorCode> GetStatus() = 0;
diff --git a/quiche/quic/moqt/moqt_namespace_stream.cc b/quiche/quic/moqt/moqt_namespace_stream.cc index bbacc77..1fe45cf 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.cc +++ b/quiche/quic/moqt/moqt_namespace_stream.cc
@@ -126,8 +126,10 @@ } std::unique_ptr<MoqtNamespaceTask> MoqtNamespaceSubscriberStream::CreateTask( - const TrackNamespace& prefix) { - auto task = std::make_unique<NamespaceTask>(this, prefix); + const TrackNamespace& prefix, + ObjectsAvailableCallback absl_nonnull callback) { + auto task = + std::make_unique<NamespaceTask>(this, prefix, std::move(callback)); QUICHE_DCHECK(task != nullptr); task_ = task->GetWeakPtr(); QUICHE_DCHECK(task_.IsValid()); @@ -157,14 +159,6 @@ return kSuccess; } -void MoqtNamespaceSubscriberStream::NamespaceTask::SetObjectAvailableCallback( - ObjectsAvailableCallback callback) { - callback_ = std::move(callback); - if (!pending_suffixes_.empty() || eof_ || error_.has_value()) { - std::move(callback_)(); - } -} - void MoqtNamespaceSubscriberStream::NamespaceTask::AddPendingSuffix( TrackNamespace suffix, TransactionType type) { if (pending_suffixes_.size() == kMaxPendingSuffixes) { @@ -176,7 +170,7 @@ } pending_suffixes_.push_back(PendingSuffix{std::move(suffix), type}); if (callback_ != nullptr) { - std::move(callback_)(); + callback_(); } } @@ -187,7 +181,7 @@ eof_ = true; state_ = nullptr; if (callback_ != nullptr) { - std::move(callback_)(); + callback_(); } } @@ -220,64 +214,66 @@ return; } QUICHE_DCHECK(task_ == nullptr); - task_ = application_(message.track_namespace_prefix, message.parameters, - [this](std::optional<MoqtRequestErrorInfo> error) { - if (error.has_value()) { - SendRequestError(request_id_, *error, /*fin=*/true); - } else { - SendRequestOk(request_id_, MessageParameters()); - } - }); + task_ = application_( + message.track_namespace_prefix, message.parameters, + // Response callback + [this](std::optional<MoqtRequestErrorInfo> error) { + if (error.has_value()) { + SendRequestError(request_id_, *error, /*fin=*/true); + } else { + SendRequestOk(request_id_, MessageParameters()); + } + }, + // Objects available callback + [this]() { ProcessNamespaces(); }); +} + +void MoqtNamespacePublisherStream::ProcessNamespaces() { if (task_ == nullptr) { return; } - task_->SetObjectAvailableCallback([this]() { - if (task_ == nullptr) { - return; - } - TrackNamespace suffix; - TransactionType type; - for (;;) { - GetNextResult result = task_->GetNextSuffix(suffix, type); - switch (result) { - case kPending: - return; - case kEof: - if (!SendFinOnStream(*stream()).ok()) { - OnParsingError(MoqtError::kInternalError, "Failed to send FIN"); - }; - return; - case kError: - Reset(kResetCodeCanceled); - return; - case kSuccess: { - switch (type) { - case TransactionType::kAdd: { - auto [it, inserted] = published_suffixes_.insert(suffix); - if (!inserted) { - // This should never happen. Do not send something that would - // cause a protocol violation. - return; - } - SendOrBufferMessage( - framer_->SerializeNamespace(MoqtNamespace{suffix})); - break; + TrackNamespace suffix; + TransactionType type; + while (!QueueIsFull()) { + GetNextResult result = task_->GetNextSuffix(suffix, type); + switch (result) { + case kPending: + return; + case kEof: + if (!SendFinOnStream(*stream()).ok()) { + OnParsingError(MoqtError::kInternalError, "Failed to send FIN"); + }; + return; + case kError: + Reset(kResetCodeCanceled); + return; + case kSuccess: { + switch (type) { + case TransactionType::kAdd: { + auto [it, inserted] = published_suffixes_.insert(suffix); + if (!inserted) { + // This should never happen. Do not send something that would + // cause a protocol violation. + return; } - case TransactionType::kDelete: { - if (published_suffixes_.erase(suffix) == 0) { - // This should never happen. Do not send something that would - // cause a protocol violation. - return; - } - SendOrBufferMessage( - framer_->SerializeNamespaceDone(MoqtNamespaceDone{suffix})); - break; + SendOrBufferMessage( + framer_->SerializeNamespace(MoqtNamespace{suffix})); + break; + } + case TransactionType::kDelete: { + if (published_suffixes_.erase(suffix) == 0) { + // This should never happen. Do not send something that would + // cause a protocol violation. + return; } + SendOrBufferMessage( + framer_->SerializeNamespaceDone(MoqtNamespaceDone{suffix})); + break; } } } } - }); + } } } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_namespace_stream.h b/quiche/quic/moqt/moqt_namespace_stream.h index 903694c..919a4d9 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.h +++ b/quiche/quic/moqt/moqt_namespace_stream.h
@@ -49,7 +49,9 @@ void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) override; // Send the prefix now so it is only stored in one place (the task). - std::unique_ptr<MoqtNamespaceTask> CreateTask(const TrackNamespace& prefix); + std::unique_ptr<MoqtNamespaceTask> CreateTask(const TrackNamespace& prefix, + ObjectsAvailableCallback + absl_nonnull callback); private: // The class that will be passed to the application to consume namespace @@ -57,17 +59,18 @@ class NamespaceTask : public MoqtNamespaceTask { public: NamespaceTask(MoqtNamespaceSubscriberStream* absl_nonnull state, - const TrackNamespace& prefix) + const TrackNamespace& prefix, + ObjectsAvailableCallback absl_nonnull callback) : MoqtNamespaceTask(), prefix_(prefix), state_(state), + callback_(std::move(callback)), weak_ptr_factory_(this) {} ~NamespaceTask() override; // MoqtNamespaceTask methods. A return value of kEof implies // NAMESPACE_DONE for all outstanding namespaces. GetNextResult GetNextSuffix(TrackNamespace& suffix, TransactionType& type) override; - void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override; std::optional<webtransport::StreamErrorCode> GetStatus() override { return error_; } @@ -125,6 +128,8 @@ } private: + void ProcessNamespaces(); + uint64_t request_id_; SessionNamespaceTree& tree_; MoqtIncomingSubscribeNamespaceCallbackNew& application_;
diff --git a/quiche/quic/moqt/moqt_namespace_stream_test.cc b/quiche/quic/moqt/moqt_namespace_stream_test.cc index eb31193..a49b626 100644 --- a/quiche/quic/moqt/moqt_namespace_stream_test.cc +++ b/quiche/quic/moqt/moqt_namespace_stream_test.cc
@@ -44,8 +44,6 @@ MOCK_METHOD(GetNextResult, GetNextSuffix, (TrackNamespace & whole_namespace, TransactionType& type), (override)); - MOCK_METHOD(void, SetObjectAvailableCallback, - (ObjectsAvailableCallback callback), (override)); MOCK_METHOD(std::optional<webtransport::StreamErrorCode>, GetStatus, (), (override)); const TrackNamespace& prefix() override { return prefix_; } @@ -61,10 +59,14 @@ stream_(&framer_, kRequestId, deleted_callback_.AsStdFunction(), error_callback_.AsStdFunction(), response_callback_.AsStdFunction()), - task_(stream_.CreateTask(kPrefix)) { + task_(stream_.CreateTask(kPrefix, [this]() { ++objects_available_; })) { stream_.set_stream(&mock_stream_); } + void CheckNumberOfObjectsAvailable(int expected_count) { + EXPECT_EQ(objects_available_, expected_count); + } + MoqtFramer framer_; testing::MockFunction<void()> deleted_callback_; testing::MockFunction<void(MoqtError, absl::string_view)> error_callback_; @@ -72,7 +74,8 @@ response_callback_; webtransport::test::MockStream mock_stream_; MoqtNamespaceSubscriberStream stream_; - std::unique_ptr<MoqtNamespaceTask> task_ = stream_.CreateTask(kPrefix); + int objects_available_ = 0; + std::unique_ptr<MoqtNamespaceTask> task_; }; TEST_F(MoqtNamespaceSubscriberStreamTest, RequestOk) { @@ -119,6 +122,7 @@ EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); stream_.OnRequestOkMessage({kRequestId}); stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(1); TrackNamespace received_namespace; TransactionType type; EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); @@ -131,7 +135,9 @@ EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); stream_.OnRequestOkMessage({kRequestId}); stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(1); stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(2); TrackNamespace received_namespace; TransactionType type; EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); @@ -147,6 +153,7 @@ EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); stream_.OnRequestOkMessage({kRequestId}); stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(1); EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "Two NAMESPACE messages for the same track namespace")); @@ -166,21 +173,24 @@ stream_.OnRequestOkMessage({kRequestId}); EXPECT_CALL(error_callback_, Call).Times(0); stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(1); stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(2); stream_.OnNamespaceMessage({TrackNamespace({"buzz"})}); + CheckNumberOfObjectsAvailable(3); } TEST_F(MoqtNamespaceSubscriberStreamTest, TaskGetNextSuffix) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); stream_.OnRequestOkMessage({kRequestId}); stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(1); stream_.OnNamespaceMessage({TrackNamespace({"buzz"})}); + CheckNumberOfObjectsAvailable(2); stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(3); TrackNamespace received_namespace; TransactionType type; - bool object_available = false; - task_->SetObjectAvailableCallback([&]() { object_available = true; }); - EXPECT_TRUE(object_available); EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); EXPECT_EQ(type, TransactionType::kAdd); @@ -191,16 +201,34 @@ EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); EXPECT_EQ(type, TransactionType::kDelete); EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); - object_available = false; stream_.OnNamespaceMessage({TrackNamespace({"another"})}); - EXPECT_TRUE(object_available); - object_available = false; + CheckNumberOfObjectsAvailable(4); EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); EXPECT_EQ(received_namespace, TrackNamespace({"another"})); EXPECT_EQ(type, TransactionType::kAdd); EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); } +TEST_F(MoqtNamespaceSubscriberStreamTest, DeclareEof) { + auto stream = std::make_unique<MoqtNamespaceSubscriberStream>( + &framer_, kRequestId, deleted_callback_.AsStdFunction(), + error_callback_.AsStdFunction(), response_callback_.AsStdFunction()); + std::unique_ptr<MoqtNamespaceTask> task = + stream->CreateTask(kPrefix, [this]() { ++objects_available_; }); + EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); + stream->OnRequestOkMessage({kRequestId}); + stream->OnNamespaceMessage({TrackNamespace({"bar"})}); + CheckNumberOfObjectsAvailable(1); + stream.reset(); + CheckNumberOfObjectsAvailable(2); + TrackNamespace received_namespace; + TransactionType type; + EXPECT_EQ(task->GetNextSuffix(received_namespace, type), kSuccess); + EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); + EXPECT_EQ(type, TransactionType::kAdd); + EXPECT_EQ(task->GetNextSuffix(received_namespace, type), kEof); +} + class MoqtNamespacePublisherStreamTest : public quiche::test::QuicheTest { public: MoqtNamespacePublisherStreamTest() @@ -217,8 +245,8 @@ webtransport::test::MockStream mock_stream_; SessionNamespaceTree tree_; testing::MockFunction<std::unique_ptr<MoqtNamespaceTask>( - const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback)> + const TrackNamespace&, const MessageParameters&, MoqtResponseCallback, + ObjectsAvailableCallback)> mock_application_; MoqtIncomingSubscribeNamespaceCallbackNew application_callback_; MoqtNamespacePublisherStream stream_; @@ -234,15 +262,13 @@ ObjectsAvailableCallback callback; MockNamespaceTask* task_ptr; EXPECT_CALL(mock_application_, Call) - .WillOnce([&](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback response_callback) { + .WillOnce([&](const TrackNamespace&, const MessageParameters&, + MoqtResponseCallback response_callback, + ObjectsAvailableCallback available_callback) { std::move(response_callback)(std::nullopt); auto task = std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); - EXPECT_CALL(*task, SetObjectAvailableCallback) - .WillOnce([&](ObjectsAvailableCallback oa_callback) { - callback = std::move(oa_callback); - }); + callback = std::move(available_callback); task_ptr = task.get(); return task; }); @@ -303,14 +329,13 @@ MessageParameters(), }; EXPECT_CALL(mock_application_, Call) - .WillOnce([&](const TrackNamespace&, std::optional<MessageParameters>, - MoqtResponseCallback response_callback) { + .WillOnce([&](const TrackNamespace&, const MessageParameters&, + MoqtResponseCallback response_callback, + ObjectsAvailableCallback) { std::move(response_callback)(MoqtRequestErrorInfo{ RequestErrorCode::kInternalError, quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); - auto task = - std::make_unique<MockNamespaceTask>(message.track_namespace_prefix); - return task; + return nullptr; }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _));
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h index cc477f8..03c6f54 100644 --- a/quiche/quic/moqt/moqt_session_callbacks.h +++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -60,9 +60,9 @@ MoqtResponseCallback callback)>; using MoqtIncomingSubscribeNamespaceCallbackNew = quiche::MultiUseCallback<std::unique_ptr<MoqtNamespaceTask>( - const TrackNamespace& prefix, - std::optional<MessageParameters> parameters, - MoqtResponseCallback callback)>; + const TrackNamespace& prefix, const MessageParameters& parameters, + MoqtResponseCallback response_callback, + ObjectsAvailableCallback objects_available_callback)>; inline void DefaultIncomingPublishNamespaceCallback( const TrackNamespace&, const std::optional<VersionSpecificParameters>&, @@ -83,11 +83,11 @@ } inline std::unique_ptr<MoqtNamespaceTask> DefaultIncomingSubscribeNamespaceCallbackNew( - const TrackNamespace& track_namespace, std::optional<MessageParameters>, - MoqtResponseCallback callback) { - std::move(callback)(MoqtRequestErrorInfo{RequestErrorCode::kNotSupported, - std::nullopt, - "This endpoint cannot publish."}); + const TrackNamespace&, const MessageParameters&, + MoqtResponseCallback response_callback, ObjectsAvailableCallback) { + std::move(response_callback)( + MoqtRequestErrorInfo{RequestErrorCode::kNotSupported, std::nullopt, + "This endpoint cannot publish."}); return nullptr; }