Handle MoQT ANNOUNCE_CANCEL messages.

PiperOrigin-RevId: 647675746
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 9843ab6..6832f25 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -9,13 +9,14 @@
 #include <cstdint>
 #include <memory>
 #include <optional>
-#include <set>
 #include <string>
 #include <utility>
 #include <vector>
 
 
 #include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/node_hash_map.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
@@ -192,6 +193,18 @@
   return (it != local_tracks_.end() && it->second.HasSubscriber());
 }
 
+void MoqtSession::CancelAnnounce(absl::string_view track_namespace) {
+  for (auto it = local_tracks_.begin(); it != local_tracks_.end(); ++it) {
+    if (it->first.track_namespace == track_namespace) {
+      it->second.set_announce_cancel();
+    }
+  }
+  absl::erase_if(local_tracks_, [&](const auto& it) {
+    return it.first.track_namespace == track_namespace &&
+           !it.second.HasSubscriber();
+  });
+}
+
 bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
                                     absl::string_view name,
                                     uint64_t start_group, uint64_t start_object,
@@ -322,6 +335,9 @@
   // Clean up the subscription
   track.DeleteWindow(subscribe_id);
   local_track_by_subscribe_id_.erase(name_it);
+  if (track.canceled() && !track.HasSubscriber()) {
+    local_tracks_.erase(track_it);
+  }
   return true;
 }
 
@@ -724,6 +740,14 @@
     return;
   }
   LocalTrack& track = it->second;
+  if (it->second.canceled()) {
+    // Note that if the track has already been deleted, there will not be a
+    // protocol violation, which the spec says there SHOULD be. It's not worth
+    // keeping state on deleted tracks.
+    session_->Error(MoqtError::kProtocolViolation,
+                    "Received SUBSCRIBE for canceled track");
+    return;
+  }
   if ((track.track_alias().has_value() &&
        message.track_alias != *track.track_alias()) ||
       session_->used_track_aliases_.contains(message.track_alias)) {
@@ -967,6 +991,11 @@
   session_->pending_outgoing_announces_.erase(it);
 }
 
+void MoqtSession::Stream::OnAnnounceCancelMessage(
+    const MoqtAnnounceCancel& message) {
+  session_->CancelAnnounce(message.track_namespace);
+}
+
 void MoqtSession::Stream::OnParsingError(MoqtError error_code,
                                          absl::string_view reason) {
   session_->Error(error_code, absl::StrCat("Parse error: ", reason));
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index c41bae5..9d88da9 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -98,6 +98,9 @@
   void Announce(absl::string_view track_namespace,
                 MoqtOutgoingAnnounceCallback announce_callback);
   bool HasSubscribers(const FullTrackName& full_track_name) const;
+  // Send an ANNOUNCE_CANCEL and delete local tracks in that namespace when all
+  // subscriptions are closed for that track.
+  void CancelAnnounce(absl::string_view track_namespace);
 
   // Returns true if SUBSCRIBE was sent. If there is already a subscription to
   // the track, the message will still be sent. However, the visitor will be
@@ -177,13 +180,14 @@
     void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override;
     void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override;
     void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override;
+    // There is no state to update for SUBSCRIBE_DONE.
     void OnSubscribeDoneMessage(const MoqtSubscribeDone& /*message*/) override {
     }
     void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) override;
     void OnAnnounceMessage(const MoqtAnnounce& message) override;
     void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override;
     void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override;
-    void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override {};
+    void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override;
     void OnTrackStatusRequestMessage(
         const MoqtTrackStatusRequest& message) override {};
     void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override {}
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index e943999..a002462 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -108,19 +108,23 @@
     session->active_subscribes_[subscribe_id] = {subscribe, visitor};
   }
 
-  static LocalTrack& local_track(MoqtSession* session, FullTrackName& name) {
-    return session->local_tracks_.find(name)->second;
+  static LocalTrack* local_track(MoqtSession* session, FullTrackName& name) {
+    auto it = session->local_tracks_.find(name);
+    if (it == session->local_tracks_.end()) {
+      return nullptr;
+    }
+    return &it->second;
   }
 
   static void AddSubscription(MoqtSession* session, FullTrackName& name,
                               uint64_t subscribe_id, uint64_t track_alias,
                               uint64_t start_group, uint64_t start_object) {
-    LocalTrack& track = local_track(session, name);
-    track.set_track_alias(track_alias);
-    track.AddWindow(subscribe_id, start_group, start_object);
+    LocalTrack* track = local_track(session, name);
+    track->set_track_alias(track_alias);
+    track->AddWindow(subscribe_id, start_group, start_object);
     session->used_track_aliases_.emplace(track_alias);
     session->local_track_by_subscribe_id_.emplace(subscribe_id,
-                                                  track.full_track_name());
+                                                  track->full_track_name());
   }
 
   static FullSequence next_sequence(MoqtSession* session, FullTrackName& name) {
@@ -1243,9 +1247,9 @@
   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor);
   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
   // Get the window, set the maximum delivered.
-  LocalTrack& track = MoqtSessionPeer::local_track(&session_, ftn);
-  track.GetWindow(0)->OnObjectSent(FullSequence(7, 3),
-                                   MoqtObjectStatus::kNormal);
+  LocalTrack* track = MoqtSessionPeer::local_track(&session_, ftn);
+  track->GetWindow(0)->OnObjectSent(FullSequence(7, 3),
+                                    MoqtObjectStatus::kNormal);
   // Update the end to fall at the last delivered object.
   MoqtSubscribeUpdate update = {
       /*subscribe_id=*/0,
@@ -1272,6 +1276,108 @@
   EXPECT_FALSE(session_.HasSubscribers(ftn));
 }
 
+TEST_F(MoqtSessionTest, ProcessAnnounceCancelNoSubscribes) {
+  MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber);
+  FullTrackName ftn("foo", "bar");
+  MockLocalTrackVisitor track_visitor;
+  session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor);
+  EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+  StrictMock<webtransport::test::MockStream> mock_stream;
+  MoqtAnnounceCancel cancel = {
+      /*track_namespace=*/"foo",
+  };
+  std::unique_ptr<MoqtParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  stream_input->OnAnnounceCancelMessage(cancel);
+  EXPECT_EQ(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+}
+
+TEST_F(MoqtSessionTest, ProcessAnnounceCancelActiveSubscribes) {
+  MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber);
+  FullTrackName ftn("foo", "bar");
+  MockLocalTrackVisitor track_visitor;
+  session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor);
+  EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+  MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
+  MoqtSessionPeer::AddSubscription(&session_, ftn, 1, 2, 7, 0);
+  StrictMock<webtransport::test::MockStream> mock_stream;
+  MoqtAnnounceCancel cancel = {
+      /*track_namespace=*/"foo",
+  };
+  std::unique_ptr<MoqtParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  stream_input->OnAnnounceCancelMessage(cancel);
+  // The track is still there because there is a subscribe.
+  EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+  // Unsubscribe from 0.
+  MoqtUnsubscribe unsubscribe = {
+      /*subscribe_id=*/0,
+  };
+  EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
+  bool correct_message = false;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]),
+                  MoqtMessageType::kSubscribeDone);
+        return absl::OkStatus();
+      });
+  stream_input->OnUnsubscribeMessage(unsubscribe);
+  EXPECT_TRUE(correct_message);
+  EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+  // Unsubscribe from 1.
+  unsubscribe.subscribe_id = 1;
+  EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
+  correct_message = false;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]),
+                  MoqtMessageType::kSubscribeDone);
+        return absl::OkStatus();
+      });
+  stream_input->OnUnsubscribeMessage(unsubscribe);
+  EXPECT_TRUE(correct_message);
+
+  EXPECT_EQ(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+}
+
+TEST_F(MoqtSessionTest, AnnounceCancelThenSubscribe) {
+  MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber);
+  FullTrackName ftn("foo", "bar");
+  MockLocalTrackVisitor track_visitor;
+  session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor);
+  EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+  MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
+  StrictMock<webtransport::test::MockStream> mock_stream;
+  MoqtAnnounceCancel cancel = {
+      /*track_namespace=*/"foo",
+  };
+  std::unique_ptr<MoqtParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  stream_input->OnAnnounceCancelMessage(cancel);
+  // The track is still there because there is a subscribe.
+  EXPECT_NE(MoqtSessionPeer::local_track(&session_, ftn), nullptr);
+  MoqtSubscribe subscribe = {
+      /*subscribe_id=*/1,
+      /*track_alias=*/2,
+      /*track_namespace=*/"foo",
+      /*track_name=*/"bar",
+      /*start_group=*/4,
+      /*start_object=*/0,
+      /*end_group=*/std::nullopt,
+      /*end_object=*/std::nullopt,
+      /*authorization_info=*/std::nullopt,
+  };
+  EXPECT_CALL(mock_session_,
+              CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+                           "Received SUBSCRIBE for canceled track"))
+      .Times(1);
+  stream_input->OnSubscribeMessage(subscribe);
+}
+
 // TODO: Cover more error cases in the above
 
 }  // namespace test
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h
index b813eb3..e61dc8b 100644
--- a/quiche/quic/moqt/moqt_track.h
+++ b/quiche/quic/moqt/moqt_track.h
@@ -14,6 +14,7 @@
 #include "absl/strings/string_view.h"
 #include "quiche/quic/moqt/moqt_messages.h"
 #include "quiche/quic/moqt/moqt_subscribe_windows.h"
+#include "quiche/quic/platform/api/quic_bug_tracker.h"
 #include "quiche/common/quiche_callbacks.h"
 
 namespace moqt {
@@ -69,11 +70,15 @@
 
   void AddWindow(uint64_t subscribe_id, uint64_t start_group,
                  uint64_t start_object) {
+    QUIC_BUG_IF(quic_bug_subscribe_to_canceled_track, announce_canceled_)
+        << "Canceled track got subscription";
     windows_.AddWindow(subscribe_id, next_sequence_, start_group, start_object);
   }
 
   void AddWindow(uint64_t subscribe_id, uint64_t start_group,
                  uint64_t start_object, uint64_t end_group) {
+    QUIC_BUG_IF(quic_bug_subscribe_to_canceled_track, announce_canceled_)
+        << "Canceled track got subscription";
     // The end object might be unknown.
     auto it = max_object_ids_.find(end_group);
     if (end_group >= next_sequence_.group) {
@@ -89,6 +94,8 @@
   void AddWindow(uint64_t subscribe_id, uint64_t start_group,
                  uint64_t start_object, uint64_t end_group,
                  uint64_t end_object) {
+    QUIC_BUG_IF(quic_bug_subscribe_to_canceled_track, announce_canceled_)
+        << "Canceled track got subscription";
     windows_.AddWindow(subscribe_id, next_sequence_, start_group, start_object,
                        end_group, end_object);
   }
@@ -142,6 +149,9 @@
     return forwarding_preference_;
   }
 
+  void set_announce_cancel() { announce_canceled_ = true; }
+  bool canceled() const { return announce_canceled_; }
+
  private:
   // This only needs to track subscriptions to current and future objects;
   // requests for objects in the past are forwarded to the application.
@@ -160,6 +170,11 @@
   // EndOfTrack has been received for that group.
   absl::flat_hash_map<uint64_t, uint64_t> max_object_ids_;
   Visitor* visitor_;
+
+  // If true, the session has received ANNOUNCE_CANCELED for this namespace.
+  // Additional subscribes will be a protocol error, and the track can be
+  // destroyed once all active subscribes end.
+  bool announce_canceled_ = false;
 };
 
 // A track on the peer to which the session has subscribed.