Connect MoqtIncomingSubscribeNamespaceCallback and the Relay namespace manager.

This will enable SUBSCRIBE_NAMESPACE support in the relay.

PiperOrigin-RevId: 820387075
diff --git a/quiche/quic/moqt/moqt_relay_publisher.h b/quiche/quic/moqt/moqt_relay_publisher.h
index 846eb6e..66cb094 100644
--- a/quiche/quic/moqt/moqt_relay_publisher.h
+++ b/quiche/quic/moqt/moqt_relay_publisher.h
@@ -33,6 +33,15 @@
   absl_nullable std::shared_ptr<MoqtTrackPublisher> GetTrack(
       const FullTrackName& track_name) override;
 
+  void AddNamespaceSubscriber(const TrackNamespace& track_namespace,
+                              MoqtSessionInterface* session) {
+    namespace_publishers_.AddSubscriber(track_namespace, session);
+  }
+  void RemoveNamespaceSubscriber(const TrackNamespace& track_namespace,
+                                 MoqtSessionInterface* session) {
+    namespace_publishers_.RemoveSubscriber(track_namespace, session);
+  }
+
   // There is a new default upstream session. When there is no other namespace
   // information, requests will route here.
   void SetDefaultUpstreamSession(
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 26026bd..cd378fc 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -1216,6 +1216,8 @@
         session_->framer_.SerializePublishNamespaceError(error));
     return;
   }
+  QUIC_DLOG(INFO) << ENDPOINT << "Received a PUBLISH_NAMESPACE for "
+                  << message.track_namespace;
   quiche::QuicheWeakPtr<MoqtSessionInterface> session_weakptr =
       session_->GetWeakPtr();
   session_->callbacks_.incoming_publish_namespace_callback(
diff --git a/quiche/quic/moqt/tools/moqt_relay.cc b/quiche/quic/moqt/tools/moqt_relay.cc
index 96e345d..43f365b 100644
--- a/quiche/quic/moqt/tools/moqt_relay.cc
+++ b/quiche/quic/moqt/tools/moqt_relay.cc
@@ -95,7 +95,7 @@
     MoqtSession* session = default_upstream_client_->session();
     session->set_publisher(&publisher_);
     publisher_.SetDefaultUpstreamSession(session);
-    SetPublishNamespaceCallback(session);
+    SetNamespaceCallbacks(session);
   };
   callbacks.goaway_received_callback = [](absl::string_view new_session_uri) {
     QUICHE_LOG(INFO) << "GoAway received, new session uri = "
@@ -106,7 +106,7 @@
   return callbacks;
 }
 
-void MoqtRelay::SetPublishNamespaceCallback(MoqtSessionInterface* session) {
+void MoqtRelay::SetNamespaceCallbacks(MoqtSessionInterface* session) {
   session->callbacks().incoming_publish_namespace_callback =
       [this, session](
           const TrackNamespace& track_namespace,
@@ -119,6 +119,18 @@
           return publisher_.OnPublishNamespaceDone(track_namespace, session);
         }
       };
+  session->callbacks().incoming_subscribe_namespace_callback =
+      [this, session](
+          const TrackNamespace& track_namespace,
+          const std::optional<VersionSpecificParameters>& parameters,
+          MoqtResponseCallback callback) {
+        if (parameters.has_value()) {
+          publisher_.AddNamespaceSubscriber(track_namespace, session);
+          std::move(callback)(std::nullopt);
+        } else {
+          publisher_.RemoveNamespaceSubscriber(track_namespace, session);
+        }
+      };
 }
 
 absl::StatusOr<MoqtConfigureSessionCallback> MoqtRelay::IncomingSessionHandler(
@@ -127,7 +139,7 @@
     session->callbacks().session_established_callback = [this, session]() {
       session->set_publisher(&publisher_);
     };
-    SetPublishNamespaceCallback(session);
+    SetNamespaceCallbacks(session);
   };
 }
 
diff --git a/quiche/quic/moqt/tools/moqt_relay.h b/quiche/quic/moqt/tools/moqt_relay.h
index f653b3d..ca6fe45 100644
--- a/quiche/quic/moqt/tools/moqt_relay.h
+++ b/quiche/quic/moqt/tools/moqt_relay.h
@@ -54,7 +54,7 @@
   MoqtClient* client() { return default_upstream_client_.get(); }
   MoqtRelayPublisher* publisher() { return &publisher_; }
 
-  virtual void SetPublishNamespaceCallback(MoqtSessionInterface* session);
+  virtual void SetNamespaceCallbacks(MoqtSessionInterface* session);
 
  private:
   std::unique_ptr<moqt::MoqtClient> CreateClient(
@@ -69,10 +69,11 @@
   const bool ignore_certificate_;
   quic::QuicEventLoop* client_event_loop_;
 
+  MoqtRelayPublisher publisher_;
+
   // Pointer to a client that has received GOAWAY.
   std::unique_ptr<MoqtClient> default_upstream_client_;
   std::unique_ptr<MoqtServer> server_;
-  MoqtRelayPublisher publisher_;
 };
 
 }  // namespace moqt
diff --git a/quiche/quic/moqt/tools/moqt_relay_test.cc b/quiche/quic/moqt/tools/moqt_relay_test.cc
index 45ab00e..720b63c 100644
--- a/quiche/quic/moqt/tools/moqt_relay_test.cc
+++ b/quiche/quic/moqt/tools/moqt_relay_test.cc
@@ -7,9 +7,11 @@
 #include <cstdint>
 #include <memory>
 #include <optional>
+#include <set>
 #include <string>
 #include <utility>
 
+
 #include "absl/strings/string_view.h"
 #include "quiche/quic/core/io/quic_event_loop.h"
 #include "quiche/quic/core/quic_time.h"
@@ -18,6 +20,7 @@
 #include "quiche/quic/moqt/moqt_relay_publisher.h"
 #include "quiche/quic/moqt/moqt_session.h"
 #include "quiche/quic/moqt/moqt_session_interface.h"
+#include "quiche/quic/moqt/test_tools/mock_moqt_session.h"
 #include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h"
 #include "quiche/quic/moqt/tools/moqt_client.h"
 #include "quiche/quic/moqt/tools/moqt_server.h"
@@ -28,6 +31,8 @@
 namespace moqt {
 namespace test {
 
+using testing::ElementsAre;
+
 constexpr quic::QuicTime::Delta kEventLoopDuration =
     quic::QuicTime::Delta::FromMilliseconds(50);
 
@@ -54,10 +59,9 @@
 
   MoqtRelayPublisher* publisher() { return MoqtRelay::publisher(); }
 
-  virtual void SetPublishNamespaceCallback(
-      MoqtSessionInterface* session) override {
+  virtual void SetNamespaceCallbacks(MoqtSessionInterface* session) override {
     last_server_session = session;
-    MoqtRelay::SetPublishNamespaceCallback(session);
+    MoqtRelay::SetNamespaceCallbacks(session);
   }
 
   MoqtSessionInterface* last_server_session;
@@ -129,7 +133,7 @@
   // relay_ publishes a namespace, so upstream_ will route to relay_.
   relay_.client_session()->PublishNamespace(
       TrackNamespace({"foo"}),
-      [](TrackNamespace, std::optional<MoqtPublishNamespaceErrorReason>) {},
+      [](TrackNamespace, std::optional<MoqtRequestError>) {},
       VersionSpecificParameters());
   upstream_.RunOneEvent();
   // There is now an upstream session for "Foo".
@@ -148,6 +152,88 @@
             nullptr);
 }
 
+TEST_F(MoqtRelayTest, SubscribeNamespace) {
+  TrackNamespace foo({"foo"}), foobar({"foo", "bar"}), foobaz({"foo", "baz"});
+  // These will be used to ascertain the namespace state.
+  MockMoqtSession relay_probe, upstream_probe;
+  std::set<TrackNamespace> relay_published_namespaces,
+      upstream_published_namespaces;
+  EXPECT_CALL(relay_probe, PublishNamespace)
+      .WillRepeatedly([&](TrackNamespace track_namespace,
+                          MoqtOutgoingPublishNamespaceCallback callback,
+                          VersionSpecificParameters) {
+        relay_published_namespaces.insert(track_namespace);
+        (callback)(track_namespace, std::nullopt);
+      });
+  EXPECT_CALL(relay_probe, PublishNamespaceDone)
+      .WillRepeatedly([&](TrackNamespace track_namespace) {
+        relay_published_namespaces.erase(track_namespace);
+        return true;
+      });
+  EXPECT_CALL(upstream_probe, PublishNamespace)
+      .WillRepeatedly([&](TrackNamespace track_namespace,
+                          MoqtOutgoingPublishNamespaceCallback callback,
+                          VersionSpecificParameters) {
+        upstream_published_namespaces.insert(track_namespace);
+        (callback)(track_namespace, std::nullopt);
+      });
+  EXPECT_CALL(upstream_probe, PublishNamespaceDone)
+      .WillRepeatedly([&](TrackNamespace track_namespace) {
+        upstream_published_namespaces.erase(track_namespace);
+        return true;
+      });
+  relay_.publisher()->AddNamespaceSubscriber(foo, &relay_probe);
+  upstream_.publisher()->AddNamespaceSubscriber(foo, &upstream_probe);
+  MoqtSession* upstream_session =
+      static_cast<MoqtSession*>(upstream_.last_server_session);
+  // Downstream publishes a namespace. It's stored in relay_ but upstream_
+  // hasn't been notified.
+  downstream_.client_session()->PublishNamespace(
+      foobar, [](TrackNamespace, std::optional<MoqtRequestError>) {},
+      VersionSpecificParameters());
+  relay_.RunOneEvent();
+  upstream_.RunOneEvent();
+  EXPECT_THAT(relay_published_namespaces, ElementsAre(foobar));
+  EXPECT_TRUE(upstream_published_namespaces.empty());
+
+  // Upstream subscribes. Now it's notified and forwards it to the probe.
+  upstream_session->SubscribeNamespace(
+      foo,
+      [](TrackNamespace, std::optional<RequestErrorCode>, absl::string_view) {},
+      VersionSpecificParameters());
+  upstream_.RunOneEvent();
+  upstream_.RunOneEvent();
+  EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobar));
+
+  // Downstream publishes another namespace. Everyone is notified.
+  downstream_.client_session()->PublishNamespace(
+      foobaz, [](TrackNamespace, std::optional<MoqtRequestError>) {},
+      VersionSpecificParameters());
+  relay_.RunOneEvent();
+  upstream_.RunOneEvent();
+  EXPECT_THAT(relay_published_namespaces, ElementsAre(foobar, foobaz));
+  EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobar, foobaz));
+
+  // Unpublish the namespace.
+  downstream_.client_session()->PublishNamespaceDone(foobar);
+  relay_.RunOneEvent();
+  upstream_.RunOneEvent();
+  EXPECT_THAT(relay_published_namespaces, ElementsAre(foobaz));
+  EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobaz));
+
+  // upstream_ unsubscribes. New PUBLISH_NAMESPACE_DONE doesn't arrive.
+  upstream_session->UnsubscribeNamespace(foo);
+  downstream_.client_session()->PublishNamespaceDone(foobaz);
+  upstream_.RunOneEvent();
+  relay_.RunOneEvent();
+  EXPECT_TRUE(relay_published_namespaces.empty());
+  EXPECT_THAT(upstream_published_namespaces, ElementsAre(foobaz));
+
+  // Remove the probes to avoid accessing an invalid WeakPtr on teardown.
+  relay_.publisher()->RemoveNamespaceSubscriber(foo, &relay_probe);
+  upstream_.publisher()->RemoveNamespaceSubscriber(foo, &upstream_probe);
+}
+
 #if 0  // TODO(martinduke): Re-enable these tests when GOAWAY support exists.
 TEST_F(MoqtRelayTest, GoAwayAtClient) {
   ASSERT_NE(relay_.client_session(), nullptr);