Update SUBSCRIBE_NAMESPACE handling in the session to support Relays.
The response to SUBSCRIBE_NAMESPACE is now asynchronous in the session, in case the application needs to delay it.
This CL does not tie in the Session and the Relay; that will be a followon CL that installs the current IncomingSubscribeNamespaceCallback.
PiperOrigin-RevId: 819256838
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 1a22208..26026bd 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -312,13 +312,13 @@
void MoqtSession::PublishNamespace(
TrackNamespace track_namespace,
- MoqtOutgoingPublishNamespaceCallback publish_namespace_callback,
+ MoqtOutgoingPublishNamespaceCallback callback,
VersionSpecificParameters parameters) {
QUICHE_DCHECK(track_namespace.IsValid());
if (outgoing_publish_namespaces_.contains(track_namespace)) {
- std::move(publish_namespace_callback)(
+ std::move(callback)(
track_namespace,
- MoqtPublishNamespaceErrorReason{
+ MoqtRequestError{
RequestErrorCode::kInternalError,
"PUBLISH_NAMESPACE already outstanding for namespace"});
return;
@@ -351,8 +351,7 @@
QUIC_DLOG(INFO) << ENDPOINT << "Sent PUBLISH_NAMESPACE message for "
<< message.track_namespace;
pending_outgoing_publish_namespaces_[message.request_id] = track_namespace;
- outgoing_publish_namespaces_[track_namespace] =
- std::move(publish_namespace_callback);
+ outgoing_publish_namespaces_[track_namespace] = std::move(callback);
}
bool MoqtSession::PublishNamespaceDone(TrackNamespace track_namespace) {
@@ -1282,8 +1281,7 @@
}
std::move(it2->second)(
track_namespace,
- MoqtPublishNamespaceErrorReason{message.error_code,
- std::string(message.error_reason)});
+ MoqtRequestError{message.error_code, std::string(message.error_reason)});
session_->outgoing_publish_namespaces_.erase(it2);
}
@@ -1307,8 +1305,7 @@
}
std::move(it->second)(
message.track_namespace,
- MoqtPublishNamespaceErrorReason{message.error_code,
- std::string(message.error_reason)});
+ MoqtRequestError{message.error_code, std::string(message.error_reason)});
session_->outgoing_publish_namespaces_.erase(it);
}
@@ -1380,7 +1377,7 @@
session_->framer_.SerializeSubscribeNamespaceError(error));
return;
}
- if (!session_->incoming_subscribe_namespace_.AddNamespace(
+ if (!session_->incoming_subscribe_namespace_.SubscribeNamespace(
message.track_namespace)) {
QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE_NAMESPACE for "
<< message.track_namespace
@@ -1393,23 +1390,25 @@
session_->framer_.SerializeSubscribeNamespaceError(error));
return;
}
- std::optional<MoqtSubscribeErrorReason> result =
- session_->callbacks_.incoming_subscribe_namespace_callback(
- message.track_namespace, message.parameters);
- if (result.has_value()) {
- MoqtSubscribeNamespaceError error;
- error.request_id = message.request_id;
- error.error_code = result->error_code;
- error.error_reason = result->reason_phrase;
- SendOrBufferMessage(
- session_->framer_.SerializeSubscribeNamespaceError(error));
- session_->incoming_subscribe_namespace_.RemoveNamespace(
- message.track_namespace);
- return;
- }
- MoqtSubscribeNamespaceOk ok;
- ok.request_id = message.request_id;
- SendOrBufferMessage(session_->framer_.SerializeSubscribeNamespaceOk(ok));
+ (session_->callbacks_.incoming_subscribe_namespace_callback)(
+ message.track_namespace, message.parameters,
+ [&](std::optional<MoqtRequestError> error) {
+ if (error.has_value()) {
+ MoqtSubscribeNamespaceError reply;
+ reply.request_id = message.request_id;
+ reply.error_code = error->error_code;
+ reply.error_reason = error->reason_phrase;
+ SendOrBufferMessage(
+ session_->framer_.SerializeSubscribeNamespaceError(reply));
+ session_->incoming_subscribe_namespace_.UnsubscribeNamespace(
+ message.track_namespace);
+ } else {
+ MoqtSubscribeNamespaceOk ok;
+ ok.request_id = message.request_id;
+ SendOrBufferMessage(
+ session_->framer_.SerializeSubscribeNamespaceOk(ok));
+ }
+ });
}
void MoqtSession::ControlStream::OnSubscribeNamespaceOkMessage(
@@ -1444,11 +1443,10 @@
void MoqtSession::ControlStream::OnUnsubscribeNamespaceMessage(
const MoqtUnsubscribeNamespace& message) {
// MoqtSession keeps no state here, so just tell the application.
- std::optional<MoqtSubscribeErrorReason> result =
- session_->callbacks_.incoming_subscribe_namespace_callback(
- message.track_namespace, std::nullopt);
- session_->incoming_subscribe_namespace_.RemoveNamespace(
+ session_->incoming_subscribe_namespace_.UnsubscribeNamespace(
message.track_namespace);
+ session_->callbacks_.incoming_subscribe_namespace_callback(
+ message.track_namespace, std::nullopt, nullptr);
}
void MoqtSession::ControlStream::OnMaxRequestIdMessage(
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 87422a8..40b7a70 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -81,6 +81,11 @@
if (goaway_timeout_alarm_ != nullptr) {
goaway_timeout_alarm_->PermanentCancel();
}
+ for (const TrackNamespace& track_namespace :
+ incoming_subscribe_namespace_.GetSubscribedNamespaces()) {
+ callbacks_.incoming_subscribe_namespace_callback(track_namespace,
+ std::nullopt, nullptr);
+ }
for (const TrackNamespace& track_namespace : incoming_publish_namespaces_) {
callbacks_.incoming_publish_namespace_callback(track_namespace,
std::nullopt, nullptr);
@@ -148,10 +153,9 @@
uint64_t num_previous_groups, MoqtPriority priority,
std::optional<MoqtDeliveryOrder> delivery_order,
VersionSpecificParameters parameters) override;
- void PublishNamespace(
- TrackNamespace track_namespace,
- MoqtOutgoingPublishNamespaceCallback publish_namespace_callback,
- VersionSpecificParameters parameters) override;
+ void PublishNamespace(TrackNamespace track_namespace,
+ MoqtOutgoingPublishNamespaceCallback callback,
+ VersionSpecificParameters parameters) override;
bool PublishNamespaceDone(TrackNamespace track_namespace) override;
quiche::QuicheWeakPtr<MoqtSessionInterface> GetWeakPtr() override {
return weak_ptr_factory_.Create();
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h
index bb43103..e8a238d 100644
--- a/quiche/quic/moqt/moqt_session_callbacks.h
+++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -45,14 +45,13 @@
MoqtResponseCallback callback)>;
// Called whenever SUBSCRIBE_NAMESPACE or UNSUBSCRIBE_NAMESPACE is received from
-// the peer. For SUBSCRIBE_NAMESPACE, the return value indicates whether to
-// return an OK or an ERROR; for UNSUBSCRIBE_NAMESPACE, the return value is
-// ignored. SUBSCRIBE_NAMESPACE sets a value for |parameters|,
-// UNSUBSCRIBE_NAMESPACE does not.
-using MoqtIncomingSubscribeNamespaceCallback =
- quiche::MultiUseCallback<std::optional<MoqtSubscribeErrorReason>(
- const TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> parameters)>;
+// the peer. SUBSCRIBE_NAMESPACE sets a value for |parameters|,
+// UNSUBSCRIBE_NAMESPACE does not. For UNSUBSCRIBE_NAMESPACE, |callback| is
+// null.
+using MoqtIncomingSubscribeNamespaceCallback = quiche::MultiUseCallback<void(
+ const TrackNamespace& track_namespace,
+ std::optional<VersionSpecificParameters> parameters,
+ MoqtResponseCallback callback)>;
inline void DefaultIncomingPublishNamespaceCallback(
const TrackNamespace&, const std::optional<VersionSpecificParameters>&,
@@ -62,16 +61,13 @@
}
return std::move(callback)(MoqtRequestError{
RequestErrorCode::kNotSupported,
- "This endpoint does not accept incoming PUBLISH_NAMESPACE messages"});
+ "This endpoint does not support incoming SUBSCRIBE_NAMESPACE messages"});
};
-inline std::optional<MoqtSubscribeErrorReason>
-DefaultIncomingSubscribeNamespaceCallback(
+inline void DefaultIncomingSubscribeNamespaceCallback(
const TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> /*parameters*/) {
- return MoqtSubscribeErrorReason{
- RequestErrorCode::kNotSupported,
- "This endpoint does not support incoming SUBSCRIBE_NAMESPACE messages"};
+ std::optional<VersionSpecificParameters>, MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
}
// Callbacks for session-level events.
diff --git a/quiche/quic/moqt/moqt_session_interface.h b/quiche/quic/moqt/moqt_session_interface.h
index 00d290c..677ba61 100644
--- a/quiche/quic/moqt/moqt_session_interface.h
+++ b/quiche/quic/moqt/moqt_session_interface.h
@@ -73,7 +73,7 @@
// TODO(martinduke): MoqtOutgoingPublishNamespaceCallback and
// MoqtOutgoingSubscribeNamespaceCallback are deprecated. Remove.
-// If |error_message| is nullopt, this is triggered by a PUBLISH_NAMESPACE_OK.
+// If |error| is nullopt, this is triggered by a PUBLISH_NAMESPACE_OK.
// Otherwise, it is triggered by PUBLISH_NAMESPACE_ERROR or
// PUBLISH_NAMESPACE_CANCEL. For ERROR or CANCEL, MoqtSession is deleting all
// PUBLISH_NAMESPACE state immediately after calling this callback.
@@ -167,10 +167,9 @@
// |publish_namespace_callback| when the response arrives. Will fail
// immediately if there is already an unresolved PUBLISH_NAMESPACE for that
// namespace.
- virtual void PublishNamespace(
- TrackNamespace track_namespace,
- MoqtOutgoingPublishNamespaceCallback publish_namespace_callback,
- VersionSpecificParameters parameters) = 0;
+ virtual void PublishNamespace(TrackNamespace track_namespace,
+ MoqtOutgoingPublishNamespaceCallback callback,
+ VersionSpecificParameters parameters) = 0;
// Returns true if message was sent, false if there is no PUBLISH_NAMESPACE to
// cancel.
virtual bool PublishNamespaceDone(TrackNamespace track_namespace) = 0;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 0ecfbd8..6ef3121 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -2773,10 +2773,10 @@
}
TEST_F(MoqtSessionTest, IncomingSubscribeNamespace) {
- TrackNamespace track_namespace = TrackNamespace{"foo"};
+ TrackNamespace track_namespace{"foo"};
auto parameters = std::make_optional<VersionSpecificParameters>(
AuthTokenType::kOutOfBand, "foo");
- MoqtSubscribeNamespace publish_namespaces = {
+ MoqtSubscribeNamespace subscribe_namespace = {
/*request_id=*/1,
track_namespace,
*parameters,
@@ -2785,26 +2785,31 @@
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(_, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(track_namespace, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
EXPECT_CALL(
control_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespaceOk), _));
- stream_input->OnSubscribeNamespaceMessage(publish_namespaces);
- MoqtUnsubscribeNamespace ununsubscribe_namespaces = {
- TrackNamespace{"foo"},
- };
- EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(track_namespace, std::optional<VersionSpecificParameters>()))
- .WillOnce(Return(std::nullopt));
- stream_input->OnUnsubscribeNamespaceMessage(ununsubscribe_namespaces);
+ stream_input->OnSubscribeNamespaceMessage(subscribe_namespace);
+ MoqtUnsubscribeNamespace unsubscribe_namespace{track_namespace};
+ EXPECT_CALL(
+ session_callbacks_.incoming_subscribe_namespace_callback,
+ Call(track_namespace, std::optional<VersionSpecificParameters>(), _))
+ .WillOnce(
+ [](const TrackNamespace&, std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); });
+ stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace);
}
TEST_F(MoqtSessionTest, IncomingSubscribeNamespaceWithError) {
TrackNamespace track_namespace{"foo"};
auto parameters = std::make_optional<VersionSpecificParameters>(
AuthTokenType::kOutOfBand, "foo");
- MoqtSubscribeNamespace publish_namespaces = {
+ MoqtSubscribeNamespace subscribe_namespace = {
/*request_id=*/1,
track_namespace,
*parameters,
@@ -2813,73 +2818,109 @@
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(_, parameters))
- .WillOnce(Return(
- MoqtSubscribeErrorReason{RequestErrorCode::kUnauthorized, "foo"}));
+ Call(track_namespace, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) {
+ std::move(callback)(
+ MoqtSubscribeErrorReason{RequestErrorCode::kUnauthorized, "foo"});
+ });
EXPECT_CALL(
control_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespaceError),
_));
- stream_input->OnSubscribeNamespaceMessage(publish_namespaces);
+ stream_input->OnSubscribeNamespaceMessage(subscribe_namespace);
// Try again, to verify that it was purged from the tree.
- publish_namespaces.request_id += 2;
+ subscribe_namespace.request_id += 2;
EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(_, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(track_namespace, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
EXPECT_CALL(
control_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespaceOk), _));
- stream_input->OnSubscribeNamespaceMessage(publish_namespaces);
+ stream_input->OnSubscribeNamespaceMessage(subscribe_namespace);
+
+ // Cleanup.
+ MoqtUnsubscribeNamespace unsubscribe_namespace{track_namespace};
+ EXPECT_CALL(
+ session_callbacks_.incoming_subscribe_namespace_callback,
+ Call(track_namespace, std::optional<VersionSpecificParameters>(), _))
+ .WillOnce(
+ [](const TrackNamespace&, std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); });
+ stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace);
}
TEST_F(MoqtSessionTest, IncomingSubscribeNamespaceWithPrefixOverlap) {
- TrackNamespace track_namespace{"foo"};
+ TrackNamespace foo{"foo"}, foobar{"foo", "bar"};
+
auto parameters = std::make_optional<VersionSpecificParameters>(
AuthTokenType::kOutOfBand, "foo");
- MoqtSubscribeNamespace publish_namespaces = {
+ MoqtSubscribeNamespace subscribe_namespace = {
/*request_id=*/1,
- track_namespace,
+ foo,
*parameters,
};
webtransport::test::MockStream control_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(_, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(foo, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
EXPECT_CALL(
control_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespaceOk), _));
- stream_input->OnSubscribeNamespaceMessage(publish_namespaces);
+ stream_input->OnSubscribeNamespaceMessage(subscribe_namespace);
// Overlapping request is rejected.
- publish_namespaces.request_id += 2;
- publish_namespaces.track_namespace = TrackNamespace{"foo", "bar"};
+ subscribe_namespace.request_id += 2;
+ subscribe_namespace.track_namespace = foobar;
EXPECT_CALL(
control_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespaceError),
_));
- stream_input->OnSubscribeNamespaceMessage(publish_namespaces);
+ stream_input->OnSubscribeNamespaceMessage(subscribe_namespace);
// Remove the subscription. Now a later one will work.
- MoqtUnsubscribeNamespace ununsubscribe_namespaces = {
- TrackNamespace{"foo"},
- };
+ MoqtUnsubscribeNamespace unsubscribe_namespace{foo};
EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(track_namespace, std::optional<VersionSpecificParameters>()))
- .WillOnce(Return(std::nullopt));
- stream_input->OnUnsubscribeNamespaceMessage(ununsubscribe_namespaces);
+ Call(foo, std::optional<VersionSpecificParameters>(), _))
+ .WillOnce(
+ [](const TrackNamespace&, std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); });
+ stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace);
// Try again, it will work.
- publish_namespaces.request_id += 2;
+ subscribe_namespace.request_id += 2;
EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
- Call(_, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(foobar, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
EXPECT_CALL(
control_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeNamespaceOk), _));
- stream_input->OnSubscribeNamespaceMessage(publish_namespaces);
+ stream_input->OnSubscribeNamespaceMessage(subscribe_namespace);
+
+ // Cleanup.
+ unsubscribe_namespace.track_namespace = foobar;
+ EXPECT_CALL(session_callbacks_.incoming_subscribe_namespace_callback,
+ Call(foobar, std::optional<VersionSpecificParameters>(), _))
+ .WillOnce(
+ [](const TrackNamespace&, std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); });
+ stream_input->OnUnsubscribeNamespaceMessage(unsubscribe_namespace);
}
TEST_F(MoqtSessionTest, FetchThenOkThenCancel) {
diff --git a/quiche/quic/moqt/session_namespace_tree.h b/quiche/quic/moqt/session_namespace_tree.h
index 399f51b..b68f382 100644
--- a/quiche/quic/moqt/session_namespace_tree.h
+++ b/quiche/quic/moqt/session_namespace_tree.h
@@ -5,10 +5,10 @@
#ifndef QUICHE_QUIC_MOQT_SESSION_NAMESPACE_TREE_H_
#define QUICHE_QUIC_MOQT_SESSION_NAMESPACE_TREE_H_
-#include <string>
+#include <cstdint>
#include "absl/container/flat_hash_map.h"
-#include "absl/types/span.h"
+#include "absl/container/flat_hash_set.h"
#include "quiche/quic/moqt/moqt_messages.h"
namespace moqt {
@@ -23,87 +23,60 @@
class SessionNamespaceTree {
public:
SessionNamespaceTree() = default;
- ~SessionNamespaceTree() = default;
+ ~SessionNamespaceTree() {}
- // Returns false if the namespace can't be added because it intersects with an
- // existing namespace.
- bool AddNamespace(const TrackNamespace& track_namespace) {
- if (root_.children.empty()) {
- AddToTree(track_namespace.tuple(), root_);
- return true;
+ // Returns false if the namespace was not subscribed.
+ bool SubscribeNamespace(const TrackNamespace& track_namespace) {
+ if (prohibited_namespaces_.contains(track_namespace)) {
+ return false;
}
- return TraverseTree(track_namespace.tuple(), root_);
- }
- // Called when UNSUBSCRIBE_NAMESPACE is received.
- void RemoveNamespace(const TrackNamespace& track_namespace) {
- DeleteUniqueBranches(track_namespace.tuple(), root_);
+ TrackNamespace higher_namespace = track_namespace;
+ do {
+ if (subscribed_namespaces_.contains(higher_namespace)) {
+ return false;
+ }
+ } while (higher_namespace.PopElement());
+ subscribed_namespaces_.insert(track_namespace);
+ // Add a reference to every higher namespace to block future subscriptions.
+ higher_namespace = track_namespace;
+ while (higher_namespace.PopElement()) {
+ ++prohibited_namespaces_[higher_namespace];
+ }
+ return true;
}
- private:
- struct Node {
- absl::flat_hash_map<std::string, struct Node> children;
- };
- // Recursively add new elements of the tuple to the tree. |start_index| is the
- // element of |tuple| that is added directly to |parent_node|.
- void AddToTree(absl::Span<const std::string> tuple, Node& parent_node) {
- if (tuple.empty()) {
+ void UnsubscribeNamespace(const TrackNamespace& track_namespace) {
+ if (subscribed_namespaces_.erase(track_namespace) == 0) {
return;
}
- auto [it, success] = parent_node.children.emplace(tuple[0], Node());
- AddToTree(tuple.subspan(1), it->second);
+ // Delete one ref from prohibited_namespaces_.
+ TrackNamespace higher_namespace = track_namespace;
+ while (higher_namespace.PopElement()) {
+ auto it2 = prohibited_namespaces_.find(higher_namespace);
+ if (it2 == prohibited_namespaces_.end()) {
+ continue;
+ }
+ if (it2->second == 1) {
+ prohibited_namespaces_.erase(it2);
+ } else {
+ --it2->second;
+ }
+ }
}
- bool TraverseTree(absl::Span<const std::string> tuple, Node& node) {
- if (node.children.empty()) {
- // The new namespace would be a child of an existing namespace.
- return false;
- }
- if (tuple.empty()) {
- // The new namespace would be a parent of an existing namespace.
- return false;
- }
- auto it = node.children.find(tuple[0]);
- if (it == node.children.end()) {
- // The new namespace would be a cousin of an existing namespace. This is
- // allowed.
- AddToTree(tuple, node);
- return true;
- }
- return TraverseTree(tuple.subspan(1), it->second);
+ // Used only when the SessionNamespaceTree is being destroyed.
+ const absl::flat_hash_set<TrackNamespace>& GetSubscribedNamespaces() const {
+ return subscribed_namespaces_;
}
- // This recursive function finds the deepest leaf node for this namespace. It
- // then keeps deleting towards the root until it finds a parent node with
- // multiple children.
- // Returns false if there are other children of parent_node, so that it's not
- // safe to keep deleting.
- bool DeleteUniqueBranches(absl::Span<const std::string> tuple,
- Node& parent_node) {
- if (tuple.empty()) {
- // We've reached the end of the namespace, it's unique if there are no
- // children.
- return parent_node.children.empty();
- }
- if (parent_node.children.empty()) {
- // Ran out of leaves too early. The namespace is not present.
- return false;
- }
- auto it = parent_node.children.find(tuple[0]);
- if (it == parent_node.children.end()) {
- // The namespace was not present.
- return false;
- }
- // Go to the next leaf node.
- if (!DeleteUniqueBranches(tuple.subspan(1), it->second)) {
- // Do no more deletion.
- return false;
- }
- parent_node.children.erase(it);
- // If there other children at this level, stop deleting.
- return parent_node.children.empty();
- }
+ protected:
+ uint64_t NumSubscriptions() const { return subscribed_namespaces_.size(); }
- Node root_; // Not a legal namespace. It's the root of the tree.
+ private:
+ absl::flat_hash_set<TrackNamespace> subscribed_namespaces_;
+ // Namespaces that cannot be subscribed to because they intersect with an
+ // existing subscription. The value is a ref count.
+ absl::flat_hash_map<TrackNamespace, int> prohibited_namespaces_;
};
} // namespace moqt
diff --git a/quiche/quic/moqt/session_namespace_tree_test.cc b/quiche/quic/moqt/session_namespace_tree_test.cc
index 0bd9101..fbc0144 100644
--- a/quiche/quic/moqt/session_namespace_tree_test.cc
+++ b/quiche/quic/moqt/session_namespace_tree_test.cc
@@ -4,44 +4,91 @@
#include "quiche/quic/moqt/session_namespace_tree.h"
+#include <cstdint>
+
#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/test_tools/mock_moqt_session.h"
+#include "quiche/quic/platform/api/quic_test.h"
#include "quiche/common/platform/api/quiche_test.h"
namespace moqt {
namespace test {
-TEST(SessionNamespaceTreeTest, AddNamespaces) {
- SessionNamespaceTree tree;
- EXPECT_TRUE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
+class TestSessionNamespaceTree : public SessionNamespaceTree {
+ public:
+ TestSessionNamespaceTree() = default;
+ using SessionNamespaceTree::NumSubscriptions;
+};
+
+class SessionNamespaceTreeTest : public quic::test::QuicTest {
+ public:
+ void TestAddSucceeds(TrackNamespace track_namespace) {
+ uint64_t num_subscriptions_before = tree_.NumSubscriptions();
+ EXPECT_TRUE(tree_.SubscribeNamespace(track_namespace));
+ EXPECT_EQ(tree_.NumSubscriptions(), num_subscriptions_before + 1);
+ }
+
+ void TestAddFails(TrackNamespace track_namespace) {
+ uint64_t num_subscriptions_before = tree_.NumSubscriptions();
+ EXPECT_FALSE(tree_.SubscribeNamespace(track_namespace));
+ EXPECT_EQ(tree_.NumSubscriptions(), num_subscriptions_before);
+ }
+
+ void TestRemoveSucceeds(TrackNamespace track_namespace) {
+ uint64_t num_subscriptions_before = tree_.NumSubscriptions();
+ tree_.UnsubscribeNamespace(track_namespace);
+ EXPECT_EQ(tree_.NumSubscriptions(), num_subscriptions_before - 1);
+ }
+
+ void TestRemoveFails(TrackNamespace track_namespace) {
+ uint64_t num_subscriptions_before = tree_.NumSubscriptions();
+ tree_.UnsubscribeNamespace(track_namespace);
+ EXPECT_EQ(tree_.NumSubscriptions(), num_subscriptions_before);
+ }
+
+ MockMoqtSession session_;
+ TestSessionNamespaceTree tree_;
+ TrackNamespace ab_{"a", "b"}, abc_{"a", "b", "c"}, abcd_{"a", "b", "c", "d"};
+};
+
+TEST_F(SessionNamespaceTreeTest, AddNamespaces) {
+ TestAddSucceeds(abc_);
// No parents, children, or equivalents of what's already there.
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"a", "b", "c", "d"})));
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"a", "b"})));
+ TestAddFails(ab_);
+ TestAddFails(abc_);
+ TestAddFails(abcd_);
// Siblings are fine.
- EXPECT_TRUE(tree.AddNamespace(TrackNamespace({"a", "b", "d"})));
+ TestAddSucceeds(TrackNamespace{"a", "b", "d"});
+
// Totally different root is fine.
- EXPECT_TRUE(tree.AddNamespace(TrackNamespace({"b", "c"})));
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"b"})));
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"b", "c", "e"})));
+ TestAddSucceeds(TrackNamespace{"b", "c"});
}
-TEST(NamespaceTreeTest, RemoveNamespaces) {
- SessionNamespaceTree tree;
+TEST_F(SessionNamespaceTreeTest, RemoveNamespaces) {
// Removing from an empty tree doesn't do anything.
- tree.RemoveNamespace(TrackNamespace({"a", "b", "c"}));
- // RemoveNamespace doesn't do anything if the namespace isn't present.
- EXPECT_TRUE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
+ TestRemoveFails(abc_);
- tree.RemoveNamespace(TrackNamespace({"a", "b", "c"}));
- EXPECT_TRUE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
- tree.RemoveNamespace(TrackNamespace({"a", "b"}));
- // Inexact match didn't delete anything.
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
- tree.RemoveNamespace(TrackNamespace({"a", "b", "c", "d"}));
- // Inexact match didn't delete anything..
- EXPECT_FALSE(tree.AddNamespace(TrackNamespace({"a", "b", "c"})));
+ // UnsubscribeNamespace doesn't do anything if the namespace isn't present.
+ TestAddSucceeds(abc_);
+ TestRemoveFails(ab_);
+ TestRemoveFails(abcd_);
+
+ // Exact match works. Now re-adding can succeed.
+ TestRemoveSucceeds(abc_);
+ TestAddSucceeds(abc_);
+
+ // Add another ref count on ab_.
+ TrackNamespace abd{"a", "b", "d"};
+ TestAddSucceeds(abd);
+ // Higher namespace is blocked.
+ TestAddFails(ab_);
+ // Removing one doesn't remove the block.
+ TestRemoveSucceeds(abc_);
+ TestAddFails(ab_);
+ // Removing both allows add.
+ TestRemoveSucceeds(abd);
+ TestAddSucceeds(ab_);
}
} // namespace test
diff --git a/quiche/quic/moqt/test_tools/mock_moqt_session.h b/quiche/quic/moqt/test_tools/mock_moqt_session.h
index eeabb32..b52940f 100644
--- a/quiche/quic/moqt/test_tools/mock_moqt_session.h
+++ b/quiche/quic/moqt/test_tools/mock_moqt_session.h
@@ -71,7 +71,7 @@
(override));
MOCK_METHOD(void, PublishNamespace,
(TrackNamespace track_namespace,
- MoqtOutgoingPublishNamespaceCallback publish_namespace_callback,
+ MoqtOutgoingPublishNamespaceCallback callback,
VersionSpecificParameters parameters),
(override));
MOCK_METHOD(bool, PublishNamespaceDone, (TrackNamespace track_namespace),
diff --git a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
index f1bd0d4..afddd23 100644
--- a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
+++ b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
@@ -32,18 +32,18 @@
testing::MockFunction<void(absl::string_view)> session_terminated_callback;
testing::MockFunction<void()> session_deleted_callback;
testing::MockFunction<void(const TrackNamespace&,
- const std::optional<VersionSpecificParameters>&,
+ std::optional<VersionSpecificParameters>,
MoqtResponseCallback)>
incoming_publish_namespace_callback;
- testing::MockFunction<std::optional<MoqtSubscribeErrorReason>(
- TrackNamespace, std::optional<VersionSpecificParameters>)>
+ testing::MockFunction<void(const TrackNamespace&,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback)>
incoming_subscribe_namespace_callback;
MockSessionCallbacks() {
- ON_CALL(incoming_publish_namespace_callback,
- Call(testing::_, testing::_, testing::_))
+ ON_CALL(incoming_publish_namespace_callback, Call)
.WillByDefault(DefaultIncomingPublishNamespaceCallback);
- ON_CALL(incoming_subscribe_namespace_callback, Call(testing::_, testing::_))
+ ON_CALL(incoming_subscribe_namespace_callback, Call)
.WillByDefault(DefaultIncomingSubscribeNamespaceCallback);
}
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc
index 8cc05d9..fea8b36 100644
--- a/quiche/quic/moqt/tools/chat_server.cc
+++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -93,7 +93,8 @@
};
session_->callbacks().incoming_subscribe_namespace_callback =
[this](const moqt::TrackNamespace& chat_namespace,
- std::optional<VersionSpecificParameters> parameters) {
+ std::optional<VersionSpecificParameters> parameters,
+ MoqtResponseCallback callback) {
if (parameters.has_value()) {
subscribed_namespaces_.insert(chat_namespace);
std::cout << "Received SUBSCRIBE_NAMESPACE for ";
@@ -104,12 +105,13 @@
std::cout << chat_namespace.ToString() << "\n";
if (!IsValidChatNamespace(chat_namespace)) {
std::cout << "Not a valid moq-chat namespace.\n";
- return std::make_optional(
- MoqtSubscribeErrorReason{RequestErrorCode::kTrackDoesNotExist,
- "Not a valid namespace for this chat."});
+ std::move(callback)(
+ MoqtRequestError{RequestErrorCode::kTrackDoesNotExist,
+ "Not a valid namespace for this chat."});
+ return;
}
- if (!parameters.has_value()) {
- return std::optional<MoqtSubscribeErrorReason>();
+ if (!parameters.has_value()) { // UNSUBSCRIBE_NAMESPACE
+ return;
}
// Send all PUBLISH_NAMESPACE.
for (auto& [track_name, queue] : server_->user_queues_) {
@@ -127,7 +129,7 @@
this),
moqt::VersionSpecificParameters());
}
- return std::optional<MoqtSubscribeErrorReason>();
+ std::move(callback)(std::nullopt);
};
session_->set_publisher(server_->publisher());
}