Two Subscriptions to the same MoQT Track in a session is an error.
New to draft-07
PiperOrigin-RevId: 689433570
diff --git a/quiche/quic/moqt/moqt_live_relay_queue.h b/quiche/quic/moqt/moqt_live_relay_queue.h
index f6c0ca8..6b23e7a 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue.h
+++ b/quiche/quic/moqt/moqt_live_relay_queue.h
@@ -94,6 +94,15 @@
bool HasSubscribers() const { return !listeners_.empty(); }
+ // Since MoqtTrackPublisher is generally held in a shared_ptr, an explicit
+ // call allows all the listeners to delete their reference and actually
+ // destroy the object.
+ void RemoveAllSubscriptions() {
+ for (MoqtObjectListener* listener : listeners_) {
+ listener->OnTrackPublisherGone();
+ }
+ }
+
private:
// The number of recent groups to keep around for newly joined subscribers.
static constexpr size_t kMaxQueuedGroups = 3;
diff --git a/quiche/quic/moqt/moqt_live_relay_queue_test.cc b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
index 228cc81..4e0bcbf 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue_test.cc
+++ b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
@@ -77,6 +77,7 @@
MOCK_METHOD(void, SkipObject, (uint64_t group_id, uint64_t object_id), ());
MOCK_METHOD(void, SkipGroup, (uint64_t group_id), ());
MOCK_METHOD(void, CloseTrack, (), ());
+ MOCK_METHOD(void, OnTrackPublisherGone, (), (override));
};
// Duplicates of MoqtOutgoingQueue test cases.
diff --git a/quiche/quic/moqt/moqt_outgoing_queue.h b/quiche/quic/moqt/moqt_outgoing_queue.h
index 57c0183..24d547c 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue.h
+++ b/quiche/quic/moqt/moqt_outgoing_queue.h
@@ -81,6 +81,15 @@
delivery_order_ = order;
}
+ // Since MoqtTrackPublisher is generally held in a shared_ptr, an explicit
+ // call allows all the listeners to delete their reference and actually
+ // destroy the object.
+ void RemoveAllSubscriptions() {
+ for (MoqtObjectListener* listener : listeners_) {
+ listener->OnTrackPublisherGone();
+ }
+ }
+
private:
// The number of recent groups to keep around for newly joined subscribers.
static constexpr size_t kMaxQueuedGroups = 3;
diff --git a/quiche/quic/moqt/moqt_outgoing_queue_test.cc b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
index 3dc745e..059f549 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue_test.cc
+++ b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
@@ -71,6 +71,7 @@
(uint64_t group_id, uint64_t object_id,
absl::string_view payload),
());
+ MOCK_METHOD(void, OnTrackPublisherGone, (), (override));
};
absl::StatusOr<std::vector<std::string>> FetchToVector(
diff --git a/quiche/quic/moqt/moqt_publisher.h b/quiche/quic/moqt/moqt_publisher.h
index 1f125c1..1ed270d 100644
--- a/quiche/quic/moqt/moqt_publisher.h
+++ b/quiche/quic/moqt/moqt_publisher.h
@@ -37,6 +37,10 @@
// available. The object payload itself may be retrieved via GetCachedObject
// method of the associated track publisher.
virtual void OnNewObjectAvailable(FullSequence sequence) = 0;
+
+ // Notifies that the Publisher is being destroyed, so no more objects are
+ // coming.
+ virtual void OnTrackPublisherGone() = 0;
};
// A handle representing a fetch in progress. The fetch in question can be
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 0da1fbc..63dbbfb 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -691,6 +691,11 @@
session_->monitoring_interfaces_for_published_tracks_.erase(monitoring_it);
}
+ if (session_->subscribed_track_names_.contains(track_name)) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Duplicate subscribe for track");
+ return;
+ }
auto subscription = std::make_unique<MoqtSession::PublishedSubscription>(
session_, *std::move(track_publisher), message, monitoring);
auto [it, success] = session_->published_subscriptions_.emplace(
@@ -698,6 +703,7 @@
if (!success) {
SendSubscribeError(message, SubscribeErrorCode::kInternalError,
"Duplicate subscribe ID", message.track_alias);
+ return;
}
MoqtSubscribeOk subscribe_ok;
@@ -966,10 +972,12 @@
}
QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
<< subscribe.full_track_name;
+ session_->subscribed_track_names_.insert(subscribe.full_track_name);
}
MoqtSession::PublishedSubscription::~PublishedSubscription() {
track_publisher_->RemoveObjectListener(this);
+ session_->subscribed_track_names_.erase(track_publisher_->GetTrackName());
}
SendStreamMap& MoqtSession::PublishedSubscription::stream_map() {
@@ -1043,6 +1051,11 @@
stream->SendObjects(*this);
}
+void MoqtSession::PublishedSubscription::OnTrackPublisherGone() {
+ session_->SubscribeIsDone(subscription_id_, SubscribeDoneCode::kGoingAway,
+ "Publisher is gone");
+}
+
void MoqtSession::PublishedSubscription::Backfill() {
const FullSequence start = window_.start();
const FullSequence end = track_publisher_->GetLargestSequence();
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 34e0aaf..2fad2ac 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -312,6 +312,7 @@
// This is only called for objects that have just arrived.
void OnNewObjectAvailable(FullSequence sequence) override;
+ void OnTrackPublisherGone() override;
void ProcessObjectAck(const MoqtObjectAck& message) {
if (monitoring_interface_ == nullptr) {
return;
@@ -494,6 +495,9 @@
absl::flat_hash_map<FullTrackName, uint64_t> remote_track_aliases_;
uint64_t next_remote_track_alias_ = 0;
+ // All open incoming subscriptions, indexed by track name, used to check for
+ // duplicates.
+ absl::flat_hash_set<FullTrackName> subscribed_track_names_;
// Application object representing the publisher for all of the tracks that
// can be subscribed to via this connection. Must outlive this object.
MoqtPublisher* publisher_;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index b1b2db5..e4af115 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -12,6 +12,7 @@
#include <utility>
#include <vector>
+
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
@@ -81,9 +82,12 @@
auto new_stream =
std::make_unique<MoqtSession::ControlStream>(session, stream);
session->control_stream_ = kControlStreamId;
- EXPECT_CALL(*stream, visitor())
+ ON_CALL(*stream, visitor()).WillByDefault(Return(new_stream.get()));
+ webtransport::test::MockSession* mock_session =
+ static_cast<webtransport::test::MockSession*>(session->session());
+ EXPECT_CALL(*mock_session, GetStreamById(kControlStreamId))
.Times(AnyNumber())
- .WillRepeatedly(Return(new_stream.get()));
+ .WillRepeatedly(Return(stream));
return new_stream;
}
@@ -475,6 +479,129 @@
EXPECT_TRUE(correct_message);
}
+TEST_F(MoqtSessionTest, TwoSubscribesForTrack) {
+ FullTrackName ftn("foo", "bar");
+ auto track = std::make_shared<MockTrackPublisher>(ftn);
+ EXPECT_CALL(*track, GetTrackStatus())
+ .WillRepeatedly(Return(MoqtTrackStatusCode::kInProgress));
+ EXPECT_CALL(*track, GetCachedObject(_)).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ EXPECT_CALL(*track, GetCachedObjectsInRange(_, _))
+ .WillRepeatedly(Return(std::vector<FullSequence>()));
+ EXPECT_CALL(*track, GetLargestSequence())
+ .WillRepeatedly(Return(FullSequence(10, 20)));
+ publisher_.Add(track);
+
+ // Peer subscribes to (11, 0)
+ MoqtSubscribe request = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*full_track_name=*/FullTrackName({"foo", "bar"}),
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start_group=*/11,
+ /*start_object=*/0,
+ /*end_group=*/std::nullopt,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ bool correct_message = false;
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+ return absl::OkStatus();
+ });
+ stream_input->OnSubscribeMessage(request);
+ EXPECT_TRUE(correct_message);
+
+ request.subscribe_id = 2;
+ request.start_group = 12;
+ EXPECT_CALL(mock_session_,
+ CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+ "Duplicate subscribe for track"))
+ .Times(1);
+ stream_input->OnSubscribeMessage(request);
+ ;
+}
+
+TEST_F(MoqtSessionTest, UnsubscribeAllowsSecondSubscribe) {
+ FullTrackName ftn("foo", "bar");
+ auto track = std::make_shared<MockTrackPublisher>(ftn);
+ EXPECT_CALL(*track, GetTrackStatus())
+ .WillRepeatedly(Return(MoqtTrackStatusCode::kInProgress));
+ EXPECT_CALL(*track, GetCachedObject(_)).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ EXPECT_CALL(*track, GetCachedObjectsInRange(_, _))
+ .WillRepeatedly(Return(std::vector<FullSequence>()));
+ EXPECT_CALL(*track, GetLargestSequence())
+ .WillRepeatedly(Return(FullSequence(10, 20)));
+ publisher_.Add(track);
+
+ // Peer subscribes to (11, 0)
+ MoqtSubscribe request = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*full_track_name=*/FullTrackName({"foo", "bar"}),
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start_group=*/11,
+ /*start_object=*/0,
+ /*end_group=*/std::nullopt,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ bool correct_message = false;
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+ return absl::OkStatus();
+ });
+ stream_input->OnSubscribeMessage(request);
+ EXPECT_TRUE(correct_message);
+
+ // Peer unsubscribes.
+ MoqtUnsubscribe unsubscribe = {
+ /*subscribe_id=*/1,
+ };
+ correct_message = false;
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]),
+ MoqtMessageType::kSubscribeDone);
+ return absl::OkStatus();
+ });
+ stream_input->OnUnsubscribeMessage(unsubscribe);
+ EXPECT_TRUE(correct_message);
+
+ // Subscribe again, succeeds.
+ request.subscribe_id = 2;
+ request.start_group = 12;
+ correct_message = false;
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+ return absl::OkStatus();
+ });
+ stream_input->OnSubscribeMessage(request);
+ EXPECT_TRUE(correct_message);
+}
+
TEST_F(MoqtSessionTest, SubscribeIdTooHigh) {
// Peer subscribes to (0, 0)
MoqtSubscribe request = {
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc
index 5b32185..d098450 100644
--- a/quiche/quic/moqt/tools/chat_server.cc
+++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -4,7 +4,6 @@
#include "quiche/quic/moqt/tools/chat_server.h"
-#include <cstdint>
#include <iostream>
#include <memory>
#include <optional>
@@ -169,6 +168,7 @@
catalog_->AddObject(quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(
quiche::SimpleBufferAllocator::Get(), catalog_data)),
/*key=*/false);
+ user_queues_[username]->RemoveAllSubscriptions();
user_queues_.erase(username);
publisher_.Delete(strings_.GetFullTrackNameFromUsername(username));
}
diff --git a/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc b/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
index d68f742..3460aea 100644
--- a/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
+++ b/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc
@@ -24,7 +24,7 @@
using ::testing::_;
-constexpr std::string kChatHostname = "127.0.0.1";
+constexpr absl::string_view kChatHostname = "127.0.0.1";
class MockChatUserInterface : public ChatUserInterface {
public:
@@ -55,7 +55,8 @@
: server_(quic::test::crypto_test_utils::ProofSourceForTesting(),
"test_chat", "") {
quiche::QuicheIpAddress bind_address;
- bind_address.FromString(kChatHostname);
+ std::string hostname(kChatHostname);
+ bind_address.FromString(hostname);
EXPECT_TRUE(server_.moqt_server().quic_server().CreateUDPSocketAndListen(
quic::QuicSocketAddress(bind_address, 0)));
auto if1ptr = std::make_unique<MockChatUserInterface>();
@@ -64,10 +65,10 @@
interface2_ = if2ptr.get();
uint16_t port = server_.moqt_server().quic_server().port();
client1_ = std::make_unique<ChatClient>(
- quic::QuicServerId(kChatHostname, port), true, std::move(if1ptr),
+ quic::QuicServerId(hostname, port), true, std::move(if1ptr),
server_.moqt_server().quic_server().event_loop());
client2_ = std::make_unique<ChatClient>(
- quic::QuicServerId(kChatHostname, port), true, std::move(if2ptr),
+ quic::QuicServerId(hostname, port), true, std::move(if2ptr),
server_.moqt_server().quic_server().event_loop());
}
@@ -129,12 +130,12 @@
MockChatUserInterface* interface1b_ = if1bptr.get();
uint16_t port = server_.moqt_server().quic_server().port();
client1_ = std::make_unique<ChatClient>(
- quic::QuicServerId(kChatHostname, port), true, std::move(if1bptr),
- server_.moqt_server().quic_server().event_loop());
+ quic::QuicServerId(std::string(kChatHostname), port), true,
+ std::move(if1bptr), server_.moqt_server().quic_server().event_loop());
EXPECT_TRUE(client1_->Connect("/moq-chat", "client1", "test_chat"));
EXPECT_TRUE(client1_->AnnounceAndSubscribe());
- SendAndWaitForOutput(interface1b_, interface2_, "client1", "Hello");
- SendAndWaitForOutput(interface2_, interface1b_, "client2", "Hi");
+ SendAndWaitForOutput(interface1b_, interface2_, "client1", "Hello again");
+ SendAndWaitForOutput(interface2_, interface1b_, "client2", "Hi again");
}
} // namespace test