Outgoing MoQT SUBSCRIBE_ANNOUNCES life cycle.

Does not include any actual ANNOUNCE messages, which are actually not related to the state of the SUBSCRIBE_ANNOUNCE.

PiperOrigin-RevId: 704446656
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index 6239299..c7bb493 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -517,7 +517,7 @@
 
 struct QUICHE_EXPORT MoqtSubscribeAnnouncesError {
   FullTrackName track_namespace;
-  MoqtAnnounceErrorCode error_code;
+  SubscribeErrorCode error_code;
   std::string reason_phrase;
 };
 
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index 800f9eb..7de1149 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -734,7 +734,7 @@
     return 0;
   }
   subscribe_namespace_error.error_code =
-      static_cast<MoqtAnnounceErrorCode>(error_code);
+      static_cast<SubscribeErrorCode>(error_code);
   visitor_.OnSubscribeAnnouncesErrorMessage(subscribe_namespace_error);
   return reader.PreviouslyReadPayload().length();
 }
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 6a06e00..81fb1b7 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -239,6 +239,37 @@
   std::move(callbacks_.session_terminated_callback)(error);
 }
 
+bool MoqtSession::SubscribeAnnounces(FullTrackName track_namespace,
+                                     MoqtSubscribeAnnouncesCallback callback,
+                                     MoqtSubscribeParameters parameters) {
+  if (peer_role_ == MoqtRole::kSubscriber) {
+    std::move(callback)(track_namespace, SubscribeErrorCode::kInternalError,
+                        "SUBSCRIBE_ANNOUNCES cannot be sent to subscriber");
+    return false;
+  }
+  MoqtSubscribeAnnounces message;
+  message.track_namespace = track_namespace;
+  message.parameters = std::move(parameters);
+  SendControlMessage(framer_.SerializeSubscribeAnnounces(message));
+  QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_ANNOUNCES message for "
+                  << message.track_namespace;
+  outgoing_subscribe_announces_[track_namespace] = std::move(callback);
+  return true;
+}
+
+bool MoqtSession::UnsubscribeAnnounces(FullTrackName track_namespace) {
+  if (!outgoing_subscribe_announces_.contains(track_namespace)) {
+    return false;
+  }
+  MoqtUnsubscribeAnnounces message;
+  message.track_namespace = track_namespace;
+  SendControlMessage(framer_.SerializeUnsubscribeAnnounces(message));
+  QUIC_DLOG(INFO) << ENDPOINT << "Sent UNSUBSCRIBE_ANNOUNCES message for "
+                  << message.track_namespace;
+  outgoing_subscribe_announces_.erase(track_namespace);
+  return true;
+}
+
 // TODO: Create state that allows ANNOUNCE_OK/ERROR on spurious namespaces to
 // trigger session errors.
 void MoqtSession::Announce(FullTrackName track_namespace,
@@ -983,6 +1014,39 @@
   // TODO: notify the application about this.
 }
 
+void MoqtSession::ControlStream::OnSubscribeAnnouncesOkMessage(
+    const MoqtSubscribeAnnouncesOk& message) {
+  auto it =
+      session_->outgoing_subscribe_announces_.find(message.track_namespace);
+  if (it == session_->outgoing_subscribe_announces_.end()) {
+    return;  // UNSUBSCRIBE_ANNOUNCES may already have deleted the entry.
+  }
+  if (it->second == nullptr) {
+    session_->Error(MoqtError::kProtocolViolation,
+                    "Two responses to SUBSCRIBE_ANNOUNCES");
+    return;
+  }
+  std::move(it->second)(message.track_namespace, std::nullopt, "");
+  it->second = nullptr;
+}
+
+void MoqtSession::ControlStream::OnSubscribeAnnouncesErrorMessage(
+    const MoqtSubscribeAnnouncesError& message) {
+  auto it =
+      session_->outgoing_subscribe_announces_.find(message.track_namespace);
+  if (it == session_->outgoing_subscribe_announces_.end()) {
+    return;  // UNSUBSCRIBE_ANNOUNCES may already have deleted the entry.
+  }
+  if (it->second == nullptr) {
+    session_->Error(MoqtError::kProtocolViolation,
+                    "Two responses to SUBSCRIBE_ANNOUNCES");
+    return;
+  }
+  std::move(it->second)(message.track_namespace, message.error_code,
+                        absl::string_view(message.reason_phrase));
+  session_->outgoing_subscribe_announces_.erase(it);
+}
+
 void MoqtSession::ControlStream::OnMaxSubscribeIdMessage(
     const MoqtMaxSubscribeId& message) {
   if (session_->peer_role_ == MoqtRole::kSubscriber) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 7908c9e..97eb5e5 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -49,6 +49,9 @@
 using MoqtIncomingAnnounceCallback =
     quiche::MultiUseCallback<std::optional<MoqtAnnounceErrorReason>(
         FullTrackName track_namespace)>;
+using MoqtSubscribeAnnouncesCallback = quiche::SingleUseCallback<void(
+    FullTrackName track_namespace, std::optional<SubscribeErrorCode> error,
+    absl::string_view reason)>;
 
 inline std::optional<MoqtAnnounceErrorReason> DefaultIncomingAnnounceCallback(
     FullTrackName /*track_namespace*/) {
@@ -106,6 +109,12 @@
 
   quic::Perspective perspective() const { return parameters_.perspective; }
 
+  // Returns true if message was sent.
+  bool SubscribeAnnounces(
+      FullTrackName track_namespace, MoqtSubscribeAnnouncesCallback callback,
+      MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
+  bool UnsubscribeAnnounces(FullTrackName track_namespace);
+
   // Send an ANNOUNCE message for |track_namespace|, and call
   // |announce_callback| when the response arrives. Will fail immediately if
   // there is already an unresolved ANNOUNCE for that namespace.
@@ -214,9 +223,9 @@
     void OnSubscribeAnnouncesMessage(
         const MoqtSubscribeAnnounces& message) override {}
     void OnSubscribeAnnouncesOkMessage(
-        const MoqtSubscribeAnnouncesOk& message) override {}
+        const MoqtSubscribeAnnouncesOk& message) override;
     void OnSubscribeAnnouncesErrorMessage(
-        const MoqtSubscribeAnnouncesError& message) override {}
+        const MoqtSubscribeAnnouncesError& message) override;
     void OnUnsubscribeAnnouncesMessage(
         const MoqtUnsubscribeAnnounces& message) override {}
     void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override;
@@ -609,6 +618,12 @@
   // Indexed by track namespace.
   absl::flat_hash_map<FullTrackName, MoqtOutgoingAnnounceCallback>
       pending_outgoing_announces_;
+  // The value is nullptr after OK or ERROR is received. The entry is deleted
+  // when sending UNSUBSCRIBE_ANNOUNCES, to make sure the application doesn't
+  // unsubscribe from something that it isn't subscribed to. ANNOUNCEs that
+  // result from this subscription use incoming_announce_callback.
+  absl::flat_hash_map<FullTrackName, MoqtSubscribeAnnouncesCallback>
+      outgoing_subscribe_announces_;
 
   // The role the peer advertised in its SETUP message. Initialize it to avoid
   // an uninitialized value if no SETUP arrives or it arrives with no Role
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 349c147..c2fb556 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -639,6 +639,68 @@
   stream_input->OnAnnounceMessage(announce);
 }
 
+TEST_F(MoqtSessionTest, SubscribeAnnouncesLifeCycle) {
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  FullTrackName track_namespace("foo", "bar");
+  track_namespace.NameToNamespace();
+  bool got_callback = false;
+  EXPECT_CALL(
+      mock_stream,
+      Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnounces), _));
+  session_.SubscribeAnnounces(
+      track_namespace,
+      [&](const FullTrackName& ftn, std::optional<SubscribeErrorCode> error,
+          absl::string_view reason) {
+        got_callback = true;
+        EXPECT_EQ(track_namespace, ftn);
+        EXPECT_FALSE(error.has_value());
+        EXPECT_EQ(reason, "");
+      });
+  MoqtSubscribeAnnouncesOk ok = {
+      /*track_namespace=*/track_namespace,
+  };
+  stream_input->OnSubscribeAnnouncesOkMessage(ok);
+  EXPECT_TRUE(got_callback);
+  EXPECT_CALL(
+      mock_stream,
+      Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribeAnnounces), _));
+  EXPECT_TRUE(session_.UnsubscribeAnnounces(track_namespace));
+  EXPECT_FALSE(session_.UnsubscribeAnnounces(track_namespace));
+}
+
+TEST_F(MoqtSessionTest, SubscribeAnnouncesError) {
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  FullTrackName track_namespace("foo", "bar");
+  track_namespace.NameToNamespace();
+  bool got_callback = false;
+  EXPECT_CALL(
+      mock_stream,
+      Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnounces), _));
+  session_.SubscribeAnnounces(
+      track_namespace,
+      [&](const FullTrackName& ftn, std::optional<SubscribeErrorCode> error,
+          absl::string_view reason) {
+        got_callback = true;
+        EXPECT_EQ(track_namespace, ftn);
+        ASSERT_TRUE(error.has_value());
+        EXPECT_EQ(*error, SubscribeErrorCode::kInvalidRange);
+        EXPECT_EQ(reason, "deadbeef");
+      });
+  MoqtSubscribeAnnouncesError error = {
+      /*track_namespace=*/track_namespace,
+      /*error_code=*/SubscribeErrorCode::kInvalidRange,
+      /*reason_phrase=*/"deadbeef",
+  };
+  stream_input->OnSubscribeAnnouncesErrorMessage(error);
+  EXPECT_TRUE(got_callback);
+  // Entry is immediately gone.
+  EXPECT_FALSE(session_.UnsubscribeAnnounces(track_namespace));
+}
+
 TEST_F(MoqtSessionTest, IncomingObject) {
   MockSubscribeRemoteTrackVisitor visitor_;
   FullTrackName ftn("foo", "bar");
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h
index 9488f63..5015bf7 100644
--- a/quiche/quic/moqt/test_tools/moqt_test_message.h
+++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -1212,13 +1212,13 @@
  private:
   uint8_t raw_packet_[12] = {
       0x13, 0x0a, 0x01, 0x03, 0x66, 0x6f, 0x6f,  // track_namespace = "foo"
-      0x01,                                      // error_code = 1
+      0x04,                                      // error_code = 4
       0x03, 0x62, 0x61, 0x72,                    // reason_phrase = "bar"
   };
 
   MoqtSubscribeAnnouncesError subscribe_namespace_error_ = {
       /*track_namespace=*/FullTrackName{"foo"},
-      /*error_code=*/MoqtAnnounceErrorCode::kAnnounceNotSupported,
+      /*error_code=*/SubscribeErrorCode::kUnauthorized,
       /*reason_phrase=*/"bar",
   };
 };