Incoming MoQT SubscribeAnnounces lifecycle.

PiperOrigin-RevId: 705588295
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index 0f8c26f..4b189f9 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -158,6 +158,20 @@
   kAnnounceNotSupported = 1,
 };
 
+enum class QUICHE_EXPORT SubscribeErrorCode : uint64_t {
+  kInternalError = 0x0,
+  kInvalidRange = 0x1,
+  kRetryTrackAlias = 0x2,
+  kTrackDoesNotExist = 0x3,
+  kUnauthorized = 0x4,
+  kTimeout = 0x5,
+};
+
+struct MoqtSubscribeErrorReason {
+  SubscribeErrorCode error_code;
+  std::string reason_phrase;
+};
+
 struct MoqtAnnounceErrorReason {
   MoqtAnnounceErrorCode error_code;
   std::string reason_phrase;
@@ -397,15 +411,6 @@
   MoqtSubscribeParameters parameters;
 };
 
-enum class QUICHE_EXPORT SubscribeErrorCode : uint64_t {
-  kInternalError = 0x0,
-  kInvalidRange = 0x1,
-  kRetryTrackAlias = 0x2,
-  kTrackDoesNotExist = 0x3,
-  kUnauthorized = 0x4,
-  kTimeout = 0x5,
-};
-
 struct QUICHE_EXPORT MoqtSubscribeError {
   uint64_t subscribe_id;
   SubscribeErrorCode error_code;
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 538f081..9f01c32 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -239,9 +239,10 @@
   std::move(callbacks_.session_terminated_callback)(error);
 }
 
-bool MoqtSession::SubscribeAnnounces(FullTrackName track_namespace,
-                                     MoqtSubscribeAnnouncesCallback callback,
-                                     MoqtSubscribeParameters parameters) {
+bool MoqtSession::SubscribeAnnounces(
+    FullTrackName track_namespace,
+    MoqtOutgoingSubscribeAnnouncesCallback callback,
+    MoqtSubscribeParameters parameters) {
   if (peer_role_ == MoqtRole::kSubscriber) {
     std::move(callback)(track_namespace, SubscribeErrorCode::kInternalError,
                         "SUBSCRIBE_ANNOUNCES cannot be sent to subscriber");
@@ -1035,6 +1036,26 @@
   session_->outgoing_announces_.erase(it);
 }
 
+void MoqtSession::ControlStream::OnSubscribeAnnouncesMessage(
+    const MoqtSubscribeAnnounces& message) {
+  // TODO(martinduke): Handle authentication.
+  std::optional<MoqtSubscribeErrorReason> result =
+      session_->callbacks_.incoming_subscribe_announces_callback(
+          message.track_namespace, SubscribeType::kSubscribe);
+  if (result.has_value()) {
+    MoqtSubscribeAnnouncesError error;
+    error.track_namespace = message.track_namespace;
+    error.error_code = result->error_code;
+    error.reason_phrase = result->reason_phrase;
+    SendOrBufferMessage(
+        session_->framer_.SerializeSubscribeAnnouncesError(error));
+    return;
+  }
+  MoqtSubscribeAnnouncesOk ok;
+  ok.track_namespace = message.track_namespace;
+  SendOrBufferMessage(session_->framer_.SerializeSubscribeAnnouncesOk(ok));
+}
+
 void MoqtSession::ControlStream::OnSubscribeAnnouncesOkMessage(
     const MoqtSubscribeAnnouncesOk& message) {
   auto it =
@@ -1068,6 +1089,14 @@
   session_->outgoing_subscribe_announces_.erase(it);
 }
 
+void MoqtSession::ControlStream::OnUnsubscribeAnnouncesMessage(
+    const MoqtUnsubscribeAnnounces& message) {
+  // MoqtSession keeps no state here, so just tell the application.
+  std::optional<MoqtSubscribeErrorReason> result =
+      session_->callbacks_.incoming_subscribe_announces_callback(
+          message.track_namespace, SubscribeType::kUnsubscribe);
+}
+
 void MoqtSession::ControlStream::OnMaxSubscribeIdMessage(
     const MoqtMaxSubscribeId& message) {
   if (session_->peer_role_ == MoqtRole::kSubscriber) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 4bcba1c..98eb244 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -43,6 +43,8 @@
     quiche::SingleUseCallback<void(absl::string_view error_message)>;
 using MoqtSessionDeletedCallback = quiche::SingleUseCallback<void()>;
 
+enum class SubscribeType { kSubscribe, kUnsubscribe };
+
 // If |error_message| is nullopt, this is triggered by an ANNOUNCE_OK.
 // Otherwise, it is triggered by ANNOUNCE_ERROR or ANNOUNCE_CANCEL. For
 // ERROR or CANCEL, MoqtSession is deleting all ANNOUNCE state immediately
@@ -54,9 +56,17 @@
 using MoqtIncomingAnnounceCallback =
     quiche::MultiUseCallback<std::optional<MoqtAnnounceErrorReason>(
         FullTrackName track_namespace)>;
-using MoqtSubscribeAnnouncesCallback = quiche::SingleUseCallback<void(
+using MoqtOutgoingSubscribeAnnouncesCallback = quiche::SingleUseCallback<void(
     FullTrackName track_namespace, std::optional<SubscribeErrorCode> error,
     absl::string_view reason)>;
+// If the return value is nullopt, the Session will respond with
+// SUBSCRIBE_ANNOUNCES_OK. Otherwise, it will respond with
+// SUBSCRIBE_ANNOUNCES_ERROR.
+// If |subscribe_type| is kUnsubscribe, this is an UNSUBSCRIBE_ANNOUNCES message
+// and the return value will be ignored.
+using MoqtIncomingSubscribeAnnouncesCallback =
+    quiche::MultiUseCallback<std::optional<MoqtSubscribeErrorReason>(
+        const FullTrackName& track_namespace, SubscribeType subscribe_type)>;
 
 inline std::optional<MoqtAnnounceErrorReason> DefaultIncomingAnnounceCallback(
     FullTrackName /*track_namespace*/) {
@@ -65,6 +75,14 @@
       "This endpoint does not accept incoming ANNOUNCE messages"});
 };
 
+inline std::optional<MoqtSubscribeErrorReason>
+DefaultIncomingSubscribeAnnouncesCallback(const FullTrackName& track_namespace,
+                                          SubscribeType /*subscribe_type*/) {
+  return MoqtSubscribeErrorReason{
+      SubscribeErrorCode::kUnauthorized,
+      "This endpoint does not support incoming SUBSCRIBE_ANNOUNCES messages"};
+}
+
 // Callbacks for session-level events.
 struct MoqtSessionCallbacks {
   MoqtSessionEstablishedCallback session_established_callback = +[] {};
@@ -74,6 +92,8 @@
 
   MoqtIncomingAnnounceCallback incoming_announce_callback =
       DefaultIncomingAnnounceCallback;
+  MoqtIncomingSubscribeAnnouncesCallback incoming_subscribe_announces_callback =
+      DefaultIncomingSubscribeAnnouncesCallback;
 };
 
 struct SubscriptionWithQueuedStream {
@@ -116,7 +136,8 @@
 
   // Returns true if message was sent.
   bool SubscribeAnnounces(
-      FullTrackName track_namespace, MoqtSubscribeAnnouncesCallback callback,
+      FullTrackName track_namespace,
+      MoqtOutgoingSubscribeAnnouncesCallback callback,
       MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
   bool UnsubscribeAnnounces(FullTrackName track_namespace);
 
@@ -228,13 +249,13 @@
     void OnTrackStatusMessage(const MoqtTrackStatus& message) override {}
     void OnGoAwayMessage(const MoqtGoAway& /*message*/) override {}
     void OnSubscribeAnnouncesMessage(
-        const MoqtSubscribeAnnounces& message) override {}
+        const MoqtSubscribeAnnounces& message) override;
     void OnSubscribeAnnouncesOkMessage(
         const MoqtSubscribeAnnouncesOk& message) override;
     void OnSubscribeAnnouncesErrorMessage(
         const MoqtSubscribeAnnouncesError& message) override;
     void OnUnsubscribeAnnouncesMessage(
-        const MoqtUnsubscribeAnnounces& message) override {}
+        const MoqtUnsubscribeAnnounces& message) override;
     void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override;
     void OnFetchMessage(const MoqtFetch& message) override;
     void OnFetchCancelMessage(const MoqtFetchCancel& message) override {}
@@ -631,7 +652,7 @@
   // when sending UNSUBSCRIBE_ANNOUNCES, to make sure the application doesn't
   // unsubscribe from something that it isn't subscribed to. ANNOUNCEs that
   // result from this subscription use incoming_announce_callback.
-  absl::flat_hash_map<FullTrackName, MoqtSubscribeAnnouncesCallback>
+  absl::flat_hash_map<FullTrackName, MoqtOutgoingSubscribeAnnouncesCallback>
       outgoing_subscribe_announces_;
 
   // The role the peer advertised in its SETUP message. Initialize it to avoid
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 2d91180..44021ca 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -2103,6 +2103,51 @@
   EXPECT_EQ(objects_received, 2);
 }
 
+TEST_F(MoqtSessionTest, IncomingSubscribeAnnounces) {
+  FullTrackName track_namespace = FullTrackName{"foo"};
+  MoqtSubscribeAnnounces announces = {
+      track_namespace,
+      /*parameters=*/MoqtSubscribeParameters(),
+  };
+  webtransport::test::MockStream control_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
+  EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback,
+              Call(_, SubscribeType::kSubscribe))
+      .WillOnce(Return(std::nullopt));
+  EXPECT_CALL(
+      control_stream,
+      Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnouncesOk), _));
+  stream_input->OnSubscribeAnnouncesMessage(announces);
+  MoqtUnsubscribeAnnounces unsubscribe_announces = {
+      /*track_namespace=*/FullTrackName{"foo"},
+  };
+  EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback,
+              Call(track_namespace, SubscribeType::kUnsubscribe))
+      .WillOnce(Return(std::nullopt));
+  stream_input->OnUnsubscribeAnnouncesMessage(unsubscribe_announces);
+}
+
+TEST_F(MoqtSessionTest, IncomingSubscribeAnnouncesWithError) {
+  FullTrackName track_namespace = FullTrackName{"foo"};
+  MoqtSubscribeAnnounces announces = {
+      track_namespace,
+      /*parameters=*/MoqtSubscribeParameters(),
+  };
+  webtransport::test::MockStream control_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
+  EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback,
+              Call(_, SubscribeType::kSubscribe))
+      .WillOnce(Return(
+          MoqtSubscribeErrorReason{SubscribeErrorCode::kUnauthorized, "foo"}));
+  EXPECT_CALL(
+      control_stream,
+      Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnouncesError),
+             _));
+  stream_input->OnSubscribeAnnouncesMessage(announces);
+}
+
 // TODO: re-enable this test once this behavior is re-implemented.
 #if 0
 TEST_F(MoqtSessionTest, SubscribeUpdateClosesSubscription) {
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h
index 7031696..8eb22a7 100644
--- a/quiche/quic/moqt/tools/moqt_mock_visitor.h
+++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -30,17 +30,24 @@
   testing::MockFunction<void()> session_deleted_callback;
   testing::MockFunction<std::optional<MoqtAnnounceErrorReason>(FullTrackName)>
       incoming_announce_callback;
+  testing::MockFunction<std::optional<MoqtSubscribeErrorReason>(FullTrackName,
+                                                                SubscribeType)>
+      incoming_subscribe_announces_callback;
 
   MockSessionCallbacks() {
     ON_CALL(incoming_announce_callback, Call(testing::_))
         .WillByDefault(DefaultIncomingAnnounceCallback);
+    ON_CALL(incoming_subscribe_announces_callback, Call(testing::_, testing::_))
+        .WillByDefault(DefaultIncomingSubscribeAnnouncesCallback);
   }
 
   MoqtSessionCallbacks AsSessionCallbacks() {
-    return MoqtSessionCallbacks{session_established_callback.AsStdFunction(),
-                                session_terminated_callback.AsStdFunction(),
-                                session_deleted_callback.AsStdFunction(),
-                                incoming_announce_callback.AsStdFunction()};
+    return MoqtSessionCallbacks{
+        session_established_callback.AsStdFunction(),
+        session_terminated_callback.AsStdFunction(),
+        session_deleted_callback.AsStdFunction(),
+        incoming_announce_callback.AsStdFunction(),
+        incoming_subscribe_announces_callback.AsStdFunction()};
   }
 };