Support REQUEST_UPDATE for SUBSCRIBE_NAMESPACE

doesn't actually do anything because only Auth matters.

PiperOrigin-RevId: 866694952
diff --git a/quiche/quic/moqt/moqt_fetch_task.h b/quiche/quic/moqt/moqt_fetch_task.h
index f18d21f..0b97b82 100644
--- a/quiche/quic/moqt/moqt_fetch_task.h
+++ b/quiche/quic/moqt/moqt_fetch_task.h
@@ -14,6 +14,7 @@
 #include "absl/base/nullability.h"
 #include "absl/status/status.h"
 #include "quiche/quic/moqt/moqt_error.h"
+#include "quiche/quic/moqt/moqt_key_value_pair.h"
 #include "quiche/quic/moqt/moqt_messages.h"
 #include "quiche/quic/moqt/moqt_names.h"
 #include "quiche/quic/moqt/moqt_object.h"
@@ -22,6 +23,12 @@
 
 namespace moqt {
 
+// The callback we'll use for all request types going forward. Can only be used
+// once; if the argument is nullopt, an OK response was received. Otherwise, an
+// ERROR response was received.
+using MoqtResponseCallback =
+    quiche::SingleUseCallback<void(std::optional<MoqtRequestErrorInfo>)>;
+
 // TODO(martinduke): There are will be multiple instances of flow-controlled
 // "pull" data retrieval tasks. It might be worthwhile to extract some common
 // features into a base class.
@@ -106,6 +113,10 @@
   // Returns the error if request has completely failed, and nullopt otherwise.
   virtual std::optional<webtransport::StreamErrorCode> GetStatus() = 0;
 
+  // Handle a REQUEST_UPDATE message.
+  virtual void Update(const MessageParameters& parameters,
+                      MoqtResponseCallback response_callback) = 0;
+
   // Returns the prefix for this task.
   virtual const TrackNamespace& prefix() = 0;
 };
diff --git a/quiche/quic/moqt/moqt_namespace_stream.cc b/quiche/quic/moqt/moqt_namespace_stream.cc
index ca59fd6..206d04c 100644
--- a/quiche/quic/moqt/moqt_namespace_stream.cc
+++ b/quiche/quic/moqt/moqt_namespace_stream.cc
@@ -4,6 +4,7 @@
 
 #include "quiche/quic/moqt/moqt_namespace_stream.h"
 
+#include <cstdint>
 #include <memory>
 #include <optional>
 #include <utility>
@@ -40,33 +41,57 @@
 
 void MoqtNamespaceSubscriberStream::OnRequestOkMessage(
     const MoqtRequestOk& message) {
-  if (message.request_id != request_id_) {
+  if (message.request_id == request_id_) {
+    // Response to the initial SUBSCRIBE_NAMESPACE.
+    if (response_callback_ == nullptr) {
+      OnParsingError(MoqtError::kProtocolViolation, "Two responses");
+      return;
+    }
+    std::move(response_callback_)(std::nullopt);
+    response_callback_ = nullptr;
+    return;
+  }
+  NamespaceTask* task = task_.GetIfAvailable();
+  if (task == nullptr) {
+    // The application has already unsubscribed, and the stream has been reset.
+    // This is irrelevant.
+    return;
+  }
+  MoqtResponseCallback callback = task->GetResponseCallback(message.request_id);
+  if (callback == nullptr) {
     OnParsingError(MoqtError::kProtocolViolation,
                    "Unexpected request ID in response");
     return;
   }
-  if (response_callback_ == nullptr) {
-    OnParsingError(MoqtError::kProtocolViolation, "Two responses");
-    return;
-  }
-  std::move(response_callback_)(std::nullopt);
-  response_callback_ = nullptr;
+  std::move(callback)(std::nullopt);
 }
 
 void MoqtNamespaceSubscriberStream::OnRequestErrorMessage(
     const MoqtRequestError& message) {
-  if (message.request_id != request_id_) {
+  if (message.request_id == request_id_) {
+    if (response_callback_ == nullptr) {
+      OnParsingError(MoqtError::kProtocolViolation, "Two responses");
+      return;
+    }
+    std::move(response_callback_)(MoqtRequestErrorInfo{
+        message.error_code, message.retry_interval, message.reason_phrase});
+    response_callback_ = nullptr;
+    return;
+  }
+  NamespaceTask* task = task_.GetIfAvailable();
+  if (task == nullptr) {
+    // The application has already unsubscribed, and the stream has been reset.
+    // This is irrelevant.
+    return;
+  }
+  MoqtResponseCallback callback = task->GetResponseCallback(message.request_id);
+  if (callback == nullptr) {
     OnParsingError(MoqtError::kProtocolViolation,
-                   "Unexpected request ID in error");
+                   "Unexpected request ID in response");
     return;
   }
-  if (response_callback_ == nullptr) {
-    OnParsingError(MoqtError::kProtocolViolation, "Two responses");
-    return;
-  }
-  std::move(response_callback_)(MoqtRequestErrorInfo{
+  std::move(callback)(MoqtRequestErrorInfo{
       message.error_code, message.retry_interval, message.reason_phrase});
-  response_callback_ = nullptr;
 }
 
 void MoqtNamespaceSubscriberStream::OnNamespaceMessage(
@@ -152,6 +177,21 @@
   }
 }
 
+void MoqtNamespaceSubscriberStream::NamespaceTask::Update(
+    const MessageParameters& parameters,
+    MoqtResponseCallback response_callback) {
+  if (state_ == nullptr) {
+    std::move(response_callback)(
+        MoqtRequestErrorInfo{RequestErrorCode::kInternalError, std::nullopt,
+                             "Stream has been reset"});
+    return;
+  }
+  MoqtRequestUpdate message{next_request_id_, state_->request_id_, parameters};
+  pending_updates_[message.request_id] = std::move(response_callback);
+  state_->SendOrBufferMessage(state_->framer_->SerializeRequestUpdate(message));
+  next_request_id_ += 2;
+}
+
 GetNextResult MoqtNamespaceSubscriberStream::NamespaceTask::GetNextSuffix(
     TrackNamespace& suffix, TransactionType& type) {
   if (pending_suffixes_.empty()) {
@@ -195,6 +235,18 @@
   }
 }
 
+MoqtResponseCallback
+MoqtNamespaceSubscriberStream::NamespaceTask::GetResponseCallback(
+    uint64_t request_id) {
+  auto it = pending_updates_.find(request_id);
+  if (it == pending_updates_.end()) {
+    return nullptr;
+  }
+  MoqtResponseCallback callback = std::move(it->second);
+  pending_updates_.erase(it);
+  return callback;
+}
+
 MoqtNamespacePublisherStream::MoqtNamespacePublisherStream(
     MoqtFramer* framer, webtransport::Stream* stream,
     SessionErrorCallback session_error_callback,
@@ -250,6 +302,23 @@
   }
 }
 
+void MoqtNamespacePublisherStream::OnRequestUpdateMessage(
+    const MoqtRequestUpdate& message) {
+  if (task_ == nullptr) {
+    // This stream is dying.
+    return;
+  }
+  task_->Update(message.parameters,
+                [this, request_id = message.request_id](
+                    std::optional<MoqtRequestErrorInfo> error) {
+                  if (error.has_value()) {
+                    SendRequestError(request_id, *error, /*fin=*/false);
+                  } else {
+                    SendRequestOk(request_id, MessageParameters());
+                  }
+                });
+}
+
 void MoqtNamespacePublisherStream::ProcessNamespaces() {
   if (task_ == nullptr) {
     return;
diff --git a/quiche/quic/moqt/moqt_namespace_stream.h b/quiche/quic/moqt/moqt_namespace_stream.h
index 3034cab..ddc63e2 100644
--- a/quiche/quic/moqt/moqt_namespace_stream.h
+++ b/quiche/quic/moqt/moqt_namespace_stream.h
@@ -12,10 +12,12 @@
 #include <utility>
 
 #include "absl/base/nullability.h"
+#include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "quiche/quic/moqt/moqt_bidi_stream.h"
 #include "quiche/quic/moqt/moqt_fetch_task.h"
 #include "quiche/quic/moqt/moqt_framer.h"
+#include "quiche/quic/moqt/moqt_key_value_pair.h"
 #include "quiche/quic/moqt/moqt_messages.h"
 #include "quiche/quic/moqt/moqt_names.h"
 #include "quiche/quic/moqt/moqt_session_callbacks.h"
@@ -61,6 +63,7 @@
         : MoqtNamespaceTask(),
           prefix_(prefix),
           state_(state),
+          next_request_id_(state->request_id_ + 2),
           weak_ptr_factory_(this) {}
     ~NamespaceTask() override;
 
@@ -75,6 +78,8 @@
       return error_;
     }
     const TrackNamespace& prefix() override { return prefix_; }
+    void Update(const MessageParameters& parameters,
+                MoqtResponseCallback response_callback) override;
 
     // Queues a suffix corresponding to a NAMESPACE (if |type| is kAdd) or a
     // NAMESPACE_DONE (if |type| is kDelete).
@@ -82,6 +87,7 @@
     // The stream is closed, so no more NAMESPACE messages are forthcoming.
     // This is an implicit NAMESPACE_DONE for all published namespaces.
     void DeclareEof();
+    MoqtResponseCallback GetResponseCallback(uint64_t request_id);
     quiche::QuicheWeakPtr<NamespaceTask> GetWeakPtr() {
       return weak_ptr_factory_.Create();
     }
@@ -100,6 +106,8 @@
     ObjectsAvailableCallback absl_nullable callback_ = nullptr;
     std::optional<webtransport::StreamErrorCode> error_;
     bool eof_ = false;
+    uint64_t next_request_id_;
+    absl::flat_hash_map<uint64_t, MoqtResponseCallback> pending_updates_;
     // Must be last.
     quiche::QuicheWeakPtrFactory<NamespaceTask> weak_ptr_factory_;
   };
@@ -123,10 +131,7 @@
   // MoqtBidiStreamBase overrides.
   void OnSubscribeNamespaceMessage(
       const MoqtSubscribeNamespace& message) override;
-  // TODO(martinduke): Implement this.
-  void OnRequestUpdateMessage(const MoqtRequestUpdate&) override {
-    QUICHE_DLOG(INFO) << "Got REQUEST_UPDATE on Namespace stream";
-  }
+  void OnRequestUpdateMessage(const MoqtRequestUpdate&) override;
 
  private:
   void ProcessNamespaces();
diff --git a/quiche/quic/moqt/moqt_namespace_stream_test.cc b/quiche/quic/moqt/moqt_namespace_stream_test.cc
index 97a0cb1..e538c27 100644
--- a/quiche/quic/moqt/moqt_namespace_stream_test.cc
+++ b/quiche/quic/moqt/moqt_namespace_stream_test.cc
@@ -48,6 +48,7 @@
         task_(stream_.CreateTask(kPrefix)) {
     task_->SetObjectsAvailableCallback([this]() { ++objects_available_; });
     stream_.set_stream(&mock_stream_);
+    ON_CALL(mock_stream_, CanWrite()).WillByDefault(Return(true));
   }
 
   void CheckNumberOfObjectsAvailable(int expected_count) {
@@ -85,7 +86,7 @@
 
 TEST_F(MoqtNamespaceSubscriberStreamTest, RequestErrorWrongId) {
   EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation,
-                                    "Unexpected request ID in error"));
+                                    "Unexpected request ID in response"));
   stream_.OnRequestErrorMessage(
       {kRequestId + 1, RequestErrorCode::kInternalError,
        quic::QuicTimeDelta::FromMilliseconds(100), "bar"});
@@ -217,6 +218,36 @@
   EXPECT_EQ(task->GetNextSuffix(received_namespace, type), kEof);
 }
 
+TEST_F(MoqtNamespaceSubscriberStreamTest, UpdateAndRequestOk) {
+  EXPECT_CALL(response_callback_, Call(Eq(std::nullopt)));
+  stream_.OnRequestOkMessage({kRequestId});
+  EXPECT_CALL(mock_stream_,
+              Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _));
+  MessageParameters update_params;
+  update_params.subscriber_priority = 10;
+  testing::MockFunction<void(std::optional<MoqtRequestErrorInfo>)>
+      update_response_callback;
+  task_->Update(update_params, update_response_callback.AsStdFunction());
+  EXPECT_CALL(update_response_callback, Call(Eq(std::nullopt)));
+  stream_.OnRequestOkMessage({kRequestId + 2});
+}
+
+TEST_F(MoqtNamespaceSubscriberStreamTest, UpdateAndRequestError) {
+  EXPECT_CALL(response_callback_, Call(Eq(std::nullopt)));
+  stream_.OnRequestOkMessage({kRequestId});
+  EXPECT_CALL(mock_stream_,
+              Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _));
+  MessageParameters update_params;
+  update_params.subscriber_priority = 10;
+  testing::MockFunction<void(std::optional<MoqtRequestErrorInfo>)>
+      update_response_callback;
+  task_->Update(update_params, update_response_callback.AsStdFunction());
+  EXPECT_CALL(update_response_callback, Call(_));
+  stream_.OnRequestErrorMessage(
+      {kRequestId + 2, RequestErrorCode::kInternalError,
+       quic::QuicTimeDelta::FromMilliseconds(100), "bar"});
+}
+
 class MoqtNamespacePublisherStreamTest : public quiche::test::QuicheTest {
  public:
   MoqtNamespacePublisherStreamTest()
@@ -330,6 +361,88 @@
   stream_.OnSubscribeNamespaceMessage(message);
 }
 
+TEST_F(MoqtNamespacePublisherStreamTest, RequestUpdateOk) {
+  MoqtSubscribeNamespace message = {
+      kRequestId,
+      TrackNamespace({"foo"}),
+      SubscribeNamespaceOption::kNamespace,
+      MessageParameters(),
+  };
+  MockNamespaceTask* task_ptr = nullptr;
+  EXPECT_CALL(mock_application_, Call)
+      .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption,
+                    const MessageParameters&,
+                    MoqtResponseCallback response_callback) {
+        std::move(response_callback)(std::nullopt);
+        auto task =
+            std::make_unique<MockNamespaceTask>(message.track_namespace_prefix);
+        task_ptr = task.get();
+        return task;
+      });
+  EXPECT_CALL(mock_stream_,
+              Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _));
+  stream_.OnSubscribeNamespaceMessage(message);
+  ASSERT_TRUE(task_ptr != nullptr);
+
+  // Now send RequestUpdate
+  MoqtRequestUpdate update_message = {
+      kRequestId + 2,
+      kRequestId,
+      MessageParameters(),
+  };
+  update_message.parameters.subscriber_priority = 10;
+  EXPECT_CALL(*task_ptr, Update(_, _))
+      .WillOnce([&](const MessageParameters& params, MoqtResponseCallback cb) {
+        EXPECT_EQ(params.subscriber_priority, 10);
+        std::move(cb)(std::nullopt);
+      });
+  EXPECT_CALL(mock_stream_,
+              Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _));
+  stream_.OnRequestUpdateMessage(update_message);
+}
+
+TEST_F(MoqtNamespacePublisherStreamTest, RequestUpdateError) {
+  MoqtSubscribeNamespace message = {
+      kRequestId,
+      TrackNamespace({"foo"}),
+      SubscribeNamespaceOption::kNamespace,
+      MessageParameters(),
+  };
+  MockNamespaceTask* task_ptr = nullptr;
+  EXPECT_CALL(mock_application_, Call)
+      .WillOnce([&](const TrackNamespace&, SubscribeNamespaceOption,
+                    const MessageParameters&,
+                    MoqtResponseCallback response_callback) {
+        std::move(response_callback)(std::nullopt);
+        auto task =
+            std::make_unique<MockNamespaceTask>(message.track_namespace_prefix);
+        task_ptr = task.get();
+        return task;
+      });
+  EXPECT_CALL(mock_stream_,
+              Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _));
+  stream_.OnSubscribeNamespaceMessage(message);
+  ASSERT_TRUE(task_ptr != nullptr);
+
+  // Now send RequestUpdate
+  MoqtRequestUpdate update_message = {
+      kRequestId + 2,
+      kRequestId,
+      MessageParameters(),
+  };
+  update_message.parameters.subscriber_priority = 10;
+  EXPECT_CALL(*task_ptr, Update(_, _))
+      .WillOnce([&](const MessageParameters& params, MoqtResponseCallback cb) {
+        EXPECT_EQ(params.subscriber_priority, 10);
+        std::move(cb)(MoqtRequestErrorInfo{
+            RequestErrorCode::kInternalError,
+            quic::QuicTimeDelta::FromMilliseconds(100), "bar"});
+      });
+  EXPECT_CALL(mock_stream_,
+              Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _));
+  stream_.OnRequestUpdateMessage(update_message);
+}
+
 TEST_F(MoqtNamespacePublisherStreamTest, SubscribePrefixOverlap) {
   MoqtSubscribeNamespace message = {
       kRequestId,
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h
index d43befb..2d070b7 100644
--- a/quiche/quic/moqt/moqt_session_callbacks.h
+++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -21,12 +21,6 @@
 
 namespace moqt {
 
-// The callback we'll use for all request types going forward. Can only be used
-// once; if the argument is nullopt, an OK response was received. Otherwise, an
-// ERROR response was received.
-using MoqtResponseCallback =
-    quiche::SingleUseCallback<void(std::optional<MoqtRequestErrorInfo>)>;
-
 // Called when the SETUP message from the peer is received.
 using MoqtSessionEstablishedCallback = quiche::SingleUseCallback<void()>;
 
diff --git a/quiche/quic/moqt/relay_namespace_tree.cc b/quiche/quic/moqt/relay_namespace_tree.cc
index 8c989ed..2ac4c07 100644
--- a/quiche/quic/moqt/relay_namespace_tree.cc
+++ b/quiche/quic/moqt/relay_namespace_tree.cc
@@ -16,6 +16,7 @@
 #include "absl/strings/string_view.h"
 #include "quiche/quic/moqt/moqt_error.h"
 #include "quiche/quic/moqt/moqt_fetch_task.h"
+#include "quiche/quic/moqt/moqt_key_value_pair.h"
 #include "quiche/quic/moqt/moqt_names.h"
 #include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/common/platform/api/quiche_bug_tracker.h"
@@ -44,6 +45,12 @@
   }
 }
 
+void RelayNamespaceTree::RelayNamespaceListener::Update(
+    const MessageParameters&, MoqtResponseCallback response_callback) {
+  // Don't do anything!
+  std::move(response_callback)(std::nullopt);
+}
+
 GetNextResult RelayNamespaceTree::RelayNamespaceListener::GetNextSuffix(
     TrackNamespace& suffix, TransactionType& type) {
   if (eof_) {
diff --git a/quiche/quic/moqt/relay_namespace_tree.h b/quiche/quic/moqt/relay_namespace_tree.h
index 8add33a..2b58715 100644
--- a/quiche/quic/moqt/relay_namespace_tree.h
+++ b/quiche/quic/moqt/relay_namespace_tree.h
@@ -10,13 +10,13 @@
 #include <memory>
 #include <optional>
 #include <string>
-#include <utility>
 
 #include "absl/base/nullability.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/strings/string_view.h"
 #include "quiche/quic/moqt/moqt_fetch_task.h"
+#include "quiche/quic/moqt/moqt_key_value_pair.h"
 #include "quiche/quic/moqt/moqt_names.h"
 #include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/common/quiche_circular_deque.h"
@@ -55,6 +55,8 @@
       return error_;
     }
     const TrackNamespace& prefix() override { return prefix_; }
+    void Update(const MessageParameters& parameters,
+                MoqtResponseCallback response_callback) override;
 
     // Queues a suffix corresponding to a NAMESPACE (if |type| is kAdd) or a
     // NAMESPACE_DONE (if |type| is kDelete).
diff --git a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
index b568a78..74dcf57 100644
--- a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
+++ b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
@@ -11,6 +11,7 @@
 #include <utility>
 #include <variant>
 
+#include "absl/base/nullability.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/status/status.h"
@@ -282,6 +283,10 @@
   MOCK_METHOD(std::optional<webtransport::StreamErrorCode>, GetStatus, (),
               (override));
   const TrackNamespace& prefix() override { return prefix_; }
+  MOCK_METHOD(void, Update,
+              (const MessageParameters& parameters,
+               MoqtResponseCallback response_callback),
+              (override));
 
   void InvokeCallback() {
     if (callback_ != nullptr) {