Two Subscriptions to the same MoQT Track in a session is an error.

New to draft-07

PiperOrigin-RevId: 689433570
diff --git a/quiche/quic/moqt/moqt_live_relay_queue.h b/quiche/quic/moqt/moqt_live_relay_queue.h
index f6c0ca8..6b23e7a 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue.h
+++ b/quiche/quic/moqt/moqt_live_relay_queue.h
@@ -94,6 +94,15 @@
 
   bool HasSubscribers() const { return !listeners_.empty(); }
 
+  // Since MoqtTrackPublisher is generally held in a shared_ptr, an explicit
+  // call allows all the listeners to delete their reference and actually
+  // destroy the object.
+  void RemoveAllSubscriptions() {
+    for (MoqtObjectListener* listener : listeners_) {
+      listener->OnTrackPublisherGone();
+    }
+  }
+
  private:
   // The number of recent groups to keep around for newly joined subscribers.
   static constexpr size_t kMaxQueuedGroups = 3;
diff --git a/quiche/quic/moqt/moqt_live_relay_queue_test.cc b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
index 228cc81..4e0bcbf 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue_test.cc
+++ b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
@@ -77,6 +77,7 @@
   MOCK_METHOD(void, SkipObject, (uint64_t group_id, uint64_t object_id), ());
   MOCK_METHOD(void, SkipGroup, (uint64_t group_id), ());
   MOCK_METHOD(void, CloseTrack, (), ());
+  MOCK_METHOD(void, OnTrackPublisherGone, (), (override));
 };
 
 // Duplicates of MoqtOutgoingQueue test cases.
diff --git a/quiche/quic/moqt/moqt_outgoing_queue.h b/quiche/quic/moqt/moqt_outgoing_queue.h
index 57c0183..24d547c 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue.h
+++ b/quiche/quic/moqt/moqt_outgoing_queue.h
@@ -81,6 +81,15 @@
     delivery_order_ = order;
   }
 
+  // Since MoqtTrackPublisher is generally held in a shared_ptr, an explicit
+  // call allows all the listeners to delete their reference and actually
+  // destroy the object.
+  void RemoveAllSubscriptions() {
+    for (MoqtObjectListener* listener : listeners_) {
+      listener->OnTrackPublisherGone();
+    }
+  }
+
  private:
   // The number of recent groups to keep around for newly joined subscribers.
   static constexpr size_t kMaxQueuedGroups = 3;
diff --git a/quiche/quic/moqt/moqt_outgoing_queue_test.cc b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
index 3dc745e..059f549 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue_test.cc
+++ b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
@@ -71,6 +71,7 @@
               (uint64_t group_id, uint64_t object_id,
                absl::string_view payload),
               ());
+  MOCK_METHOD(void, OnTrackPublisherGone, (), (override));
 };
 
 absl::StatusOr<std::vector<std::string>> FetchToVector(
diff --git a/quiche/quic/moqt/moqt_publisher.h b/quiche/quic/moqt/moqt_publisher.h
index 1f125c1..1ed270d 100644
--- a/quiche/quic/moqt/moqt_publisher.h
+++ b/quiche/quic/moqt/moqt_publisher.h
@@ -37,6 +37,10 @@
   // available.  The object payload itself may be retrieved via GetCachedObject
   // method of the associated track publisher.
   virtual void OnNewObjectAvailable(FullSequence sequence) = 0;
+
+  // Notifies that the Publisher is being destroyed, so no more objects are
+  // coming.
+  virtual void OnTrackPublisherGone() = 0;
 };
 
 // A handle representing a fetch in progress.  The fetch in question can be
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 0da1fbc..63dbbfb 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -691,6 +691,11 @@
     session_->monitoring_interfaces_for_published_tracks_.erase(monitoring_it);
   }
 
+  if (session_->subscribed_track_names_.contains(track_name)) {
+    session_->Error(MoqtError::kProtocolViolation,
+                    "Duplicate subscribe for track");
+    return;
+  }
   auto subscription = std::make_unique<MoqtSession::PublishedSubscription>(
       session_, *std::move(track_publisher), message, monitoring);
   auto [it, success] = session_->published_subscriptions_.emplace(
@@ -698,6 +703,7 @@
   if (!success) {
     SendSubscribeError(message, SubscribeErrorCode::kInternalError,
                        "Duplicate subscribe ID", message.track_alias);
+    return;
   }
 
   MoqtSubscribeOk subscribe_ok;
@@ -966,10 +972,12 @@
   }
   QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
                   << subscribe.full_track_name;
+  session_->subscribed_track_names_.insert(subscribe.full_track_name);
 }
 
 MoqtSession::PublishedSubscription::~PublishedSubscription() {
   track_publisher_->RemoveObjectListener(this);
+  session_->subscribed_track_names_.erase(track_publisher_->GetTrackName());
 }
 
 SendStreamMap& MoqtSession::PublishedSubscription::stream_map() {
@@ -1043,6 +1051,11 @@
   stream->SendObjects(*this);
 }
 
+void MoqtSession::PublishedSubscription::OnTrackPublisherGone() {
+  session_->SubscribeIsDone(subscription_id_, SubscribeDoneCode::kGoingAway,
+                            "Publisher is gone");
+}
+
 void MoqtSession::PublishedSubscription::Backfill() {
   const FullSequence start = window_.start();
   const FullSequence end = track_publisher_->GetLargestSequence();
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 34e0aaf..2fad2ac 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -312,6 +312,7 @@
 
     // This is only called for objects that have just arrived.
     void OnNewObjectAvailable(FullSequence sequence) override;
+    void OnTrackPublisherGone() override;
     void ProcessObjectAck(const MoqtObjectAck& message) {
       if (monitoring_interface_ == nullptr) {
         return;
@@ -494,6 +495,9 @@
   absl::flat_hash_map<FullTrackName, uint64_t> remote_track_aliases_;
   uint64_t next_remote_track_alias_ = 0;
 
+  // All open incoming subscriptions, indexed by track name, used to check for
+  // duplicates.
+  absl::flat_hash_set<FullTrackName> subscribed_track_names_;
   // Application object representing the publisher for all of the tracks that
   // can be subscribed to via this connection.  Must outlive this object.
   MoqtPublisher* publisher_;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index b1b2db5..e4af115 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -12,6 +12,7 @@
 #include <utility>
 #include <vector>
 
+
 #include "absl/status/status.h"
 #include "absl/strings/match.h"
 #include "absl/strings/string_view.h"
@@ -81,9 +82,12 @@
     auto new_stream =
         std::make_unique<MoqtSession::ControlStream>(session, stream);
     session->control_stream_ = kControlStreamId;
-    EXPECT_CALL(*stream, visitor())
+    ON_CALL(*stream, visitor()).WillByDefault(Return(new_stream.get()));
+    webtransport::test::MockSession* mock_session =
+        static_cast<webtransport::test::MockSession*>(session->session());
+    EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId))
         .Times(AnyNumber())
-        .WillRepeatedly(Return(new_stream.get()));
+        .WillRepeatedly(Return(stream));
     return new_stream;
   }
 
@@ -475,6 +479,129 @@
   EXPECT_TRUE(correct_message);
 }
 
+TEST_F(MoqtSessionTest, TwoSubscribesForTrack) {
+  FullTrackName ftn("foo", "bar");
+  auto track = std::make_shared<MockTrackPublisher>(ftn);
+  EXPECT_CALL(*track, GetTrackStatus())
+      .WillRepeatedly(Return(MoqtTrackStatusCode::kInProgress));
+  EXPECT_CALL(*track, GetCachedObject(_)).WillRepeatedly([] {
+    return std::optional<PublishedObject>();
+  });
+  EXPECT_CALL(*track, GetCachedObjectsInRange(_, _))
+      .WillRepeatedly(Return(std::vector<FullSequence>()));
+  EXPECT_CALL(*track, GetLargestSequence())
+      .WillRepeatedly(Return(FullSequence(10, 20)));
+  publisher_.Add(track);
+
+  // Peer subscribes to (11, 0)
+  MoqtSubscribe request = {
+      /*subscribe_id=*/1,
+      /*track_alias=*/2,
+      /*full_track_name=*/FullTrackName({"foo", "bar"}),
+      /*subscriber_priority=*/0x80,
+      /*group_order=*/std::nullopt,
+      /*start_group=*/11,
+      /*start_object=*/0,
+      /*end_group=*/std::nullopt,
+      /*end_object=*/std::nullopt,
+      /*parameters=*/MoqtSubscribeParameters(),
+  };
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  bool correct_message = false;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+        return absl::OkStatus();
+      });
+  stream_input->OnSubscribeMessage(request);
+  EXPECT_TRUE(correct_message);
+
+  request.subscribe_id = 2;
+  request.start_group = 12;
+  EXPECT_CALL(mock_session_,
+              CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+                           "Duplicate subscribe for track"))
+      .Times(1);
+  stream_input->OnSubscribeMessage(request);
+  ;
+}
+
+TEST_F(MoqtSessionTest, UnsubscribeAllowsSecondSubscribe) {
+  FullTrackName ftn("foo", "bar");
+  auto track = std::make_shared<MockTrackPublisher>(ftn);
+  EXPECT_CALL(*track, GetTrackStatus())
+      .WillRepeatedly(Return(MoqtTrackStatusCode::kInProgress));
+  EXPECT_CALL(*track, GetCachedObject(_)).WillRepeatedly([] {
+    return std::optional<PublishedObject>();
+  });
+  EXPECT_CALL(*track, GetCachedObjectsInRange(_, _))
+      .WillRepeatedly(Return(std::vector<FullSequence>()));
+  EXPECT_CALL(*track, GetLargestSequence())
+      .WillRepeatedly(Return(FullSequence(10, 20)));
+  publisher_.Add(track);
+
+  // Peer subscribes to (11, 0)
+  MoqtSubscribe request = {
+      /*subscribe_id=*/1,
+      /*track_alias=*/2,
+      /*full_track_name=*/FullTrackName({"foo", "bar"}),
+      /*subscriber_priority=*/0x80,
+      /*group_order=*/std::nullopt,
+      /*start_group=*/11,
+      /*start_object=*/0,
+      /*end_group=*/std::nullopt,
+      /*end_object=*/std::nullopt,
+      /*parameters=*/MoqtSubscribeParameters(),
+  };
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  bool correct_message = false;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+        return absl::OkStatus();
+      });
+  stream_input->OnSubscribeMessage(request);
+  EXPECT_TRUE(correct_message);
+
+  // Peer unsubscribes.
+  MoqtUnsubscribe unsubscribe = {
+      /*subscribe_id=*/1,
+  };
+  correct_message = false;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]),
+                  MoqtMessageType::kSubscribeDone);
+        return absl::OkStatus();
+      });
+  stream_input->OnUnsubscribeMessage(unsubscribe);
+  EXPECT_TRUE(correct_message);
+
+  // Subscribe again, succeeds.
+  request.subscribe_id = 2;
+  request.start_group = 12;
+  correct_message = false;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+        return absl::OkStatus();
+      });
+  stream_input->OnSubscribeMessage(request);
+  EXPECT_TRUE(correct_message);
+}
+
 TEST_F(MoqtSessionTest, SubscribeIdTooHigh) {
   // Peer subscribes to (0, 0)
   MoqtSubscribe request = {
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc
index 5b32185..d098450 100644
--- a/quiche/quic/moqt/tools/chat_server.cc
+++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -4,7 +4,6 @@
 
 #include "quiche/quic/moqt/tools/chat_server.h"
 
-#include <cstdint>
 #include <iostream>
 #include <memory>
 #include <optional>
@@ -169,6 +168,7 @@
   catalog_->AddObject(quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(
                           quiche::SimpleBufferAllocator::Get(), catalog_data)),
                       /*key=*/false);
+  user_queues_[username]->RemoveAllSubscriptions();
   user_queues_.erase(username);
   publisher_.Delete(strings_.GetFullTrackNameFromUsername(username));
 }
diff --git a/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc b/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
index d68f742..3460aea 100644
--- a/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
+++ b/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
@@ -24,7 +24,7 @@
 
 using ::testing::_;
 
-constexpr std::string kChatHostname = "127.0.0.1";
+constexpr absl::string_view kChatHostname = "127.0.0.1";
 
 class MockChatUserInterface : public ChatUserInterface {
  public:
@@ -55,7 +55,8 @@
       : server_(quic::test::crypto_test_utils::ProofSourceForTesting(),
                 "test_chat", "") {
     quiche::QuicheIpAddress bind_address;
-    bind_address.FromString(kChatHostname);
+    std::string hostname(kChatHostname);
+    bind_address.FromString(hostname);
     EXPECT_TRUE(server_.moqt_server().quic_server().CreateUDPSocketAndListen(
         quic::QuicSocketAddress(bind_address, 0)));
     auto if1ptr = std::make_unique<MockChatUserInterface>();
@@ -64,10 +65,10 @@
     interface2_ = if2ptr.get();
     uint16_t port = server_.moqt_server().quic_server().port();
     client1_ = std::make_unique<ChatClient>(
-        quic::QuicServerId(kChatHostname, port), true, std::move(if1ptr),
+        quic::QuicServerId(hostname, port), true, std::move(if1ptr),
         server_.moqt_server().quic_server().event_loop());
     client2_ = std::make_unique<ChatClient>(
-        quic::QuicServerId(kChatHostname, port), true, std::move(if2ptr),
+        quic::QuicServerId(hostname, port), true, std::move(if2ptr),
         server_.moqt_server().quic_server().event_loop());
   }
 
@@ -129,12 +130,12 @@
   MockChatUserInterface* interface1b_ = if1bptr.get();
   uint16_t port = server_.moqt_server().quic_server().port();
   client1_ = std::make_unique<ChatClient>(
-      quic::QuicServerId(kChatHostname, port), true, std::move(if1bptr),
-      server_.moqt_server().quic_server().event_loop());
+      quic::QuicServerId(std::string(kChatHostname), port), true,
+      std::move(if1bptr), server_.moqt_server().quic_server().event_loop());
   EXPECT_TRUE(client1_->Connect("/moq-chat", "client1", "test_chat"));
   EXPECT_TRUE(client1_->AnnounceAndSubscribe());
-  SendAndWaitForOutput(interface1b_, interface2_, "client1", "Hello");
-  SendAndWaitForOutput(interface2_, interface1b_, "client2", "Hi");
+  SendAndWaitForOutput(interface1b_, interface2_, "client1", "Hello again");
+  SendAndWaitForOutput(interface2_, interface1b_, "client2", "Hi again");
 }
 
 }  // namespace test