Handle `PublishNamespace` messages in `MoqtRelay`.
Make the response to PublishNamespace potentially asynchronous, though MoqtRelay will always respond synchronously.
Delete broadcast mode, as discussion with the editors indicated that this mode is not appropriate and unnecessary.
PiperOrigin-RevId: 808579176
diff --git a/build/source_list.bzl b/build/source_list.bzl
index 0117713..ae403e1 100644
--- a/build/source_list.bzl
+++ b/build/source_list.bzl
@@ -1572,6 +1572,7 @@
"quic/moqt/moqt_session_interface.h",
"quic/moqt/moqt_subscribe_windows.h",
"quic/moqt/moqt_track.h",
+ "quic/moqt/namespace_publisher_multimap.h",
"quic/moqt/namespace_tree.h",
"quic/moqt/tools/chat_client.h",
"quic/moqt/tools/chat_server.h",
@@ -1621,6 +1622,7 @@
"quic/moqt/moqt_session_test.cc",
"quic/moqt/moqt_subscribe_windows_test.cc",
"quic/moqt/moqt_track_test.cc",
+ "quic/moqt/namespace_publisher_multimap_test.cc",
"quic/moqt/namespace_tree_test.cc",
"quic/moqt/tools/moq_chat_end_to_end_test.cc",
"quic/moqt/tools/moq_chat_test.cc",
diff --git a/build/source_list.gni b/build/source_list.gni
index af40db2..dfd4993 100644
--- a/build/source_list.gni
+++ b/build/source_list.gni
@@ -1576,6 +1576,7 @@
"src/quiche/quic/moqt/moqt_session_interface.h",
"src/quiche/quic/moqt/moqt_subscribe_windows.h",
"src/quiche/quic/moqt/moqt_track.h",
+ "src/quiche/quic/moqt/namespace_publisher_multimap.h",
"src/quiche/quic/moqt/namespace_tree.h",
"src/quiche/quic/moqt/tools/chat_client.h",
"src/quiche/quic/moqt/tools/chat_server.h",
@@ -1626,6 +1627,7 @@
"src/quiche/quic/moqt/moqt_session_test.cc",
"src/quiche/quic/moqt/moqt_subscribe_windows_test.cc",
"src/quiche/quic/moqt/moqt_track_test.cc",
+ "src/quiche/quic/moqt/namespace_publisher_multimap_test.cc",
"src/quiche/quic/moqt/namespace_tree_test.cc",
"src/quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc",
"src/quiche/quic/moqt/tools/moq_chat_test.cc",
diff --git a/build/source_list.json b/build/source_list.json
index 0ebf99e..ee36c44 100644
--- a/build/source_list.json
+++ b/build/source_list.json
@@ -1575,6 +1575,7 @@
"quiche/quic/moqt/moqt_session_interface.h",
"quiche/quic/moqt/moqt_subscribe_windows.h",
"quiche/quic/moqt/moqt_track.h",
+ "quiche/quic/moqt/namespace_publisher_multimap.h",
"quiche/quic/moqt/namespace_tree.h",
"quiche/quic/moqt/tools/chat_client.h",
"quiche/quic/moqt/tools/chat_server.h",
@@ -1625,6 +1626,7 @@
"quiche/quic/moqt/moqt_session_test.cc",
"quiche/quic/moqt/moqt_subscribe_windows_test.cc",
"quiche/quic/moqt/moqt_track_test.cc",
+ "quiche/quic/moqt/namespace_publisher_multimap_test.cc",
"quiche/quic/moqt/namespace_tree_test.cc",
"quiche/quic/moqt/tools/moq_chat_end_to_end_test.cc",
"quiche/quic/moqt/tools/moq_chat_test.cc",
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc
index 3c3a07a..e5d3f8d 100644
--- a/quiche/quic/moqt/moqt_integration_test.cc
+++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -23,6 +23,7 @@
#include "quiche/quic/moqt/moqt_probe_manager.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h"
@@ -153,8 +154,12 @@
auto parameters = std::make_optional<VersionSpecificParameters>(
AuthTokenType::kOutOfBand, "foo");
EXPECT_CALL(server_callbacks_.incoming_publish_namespace_callback,
- Call(TrackNamespace{"foo"}, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(TrackNamespace{"foo"}, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
testing::MockFunction<void(
TrackNamespace track_namespace,
std::optional<MoqtPublishNamespaceErrorReason> error_message)>
@@ -174,13 +179,14 @@
test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; });
EXPECT_TRUE(success);
matches = false;
- EXPECT_CALL(server_callbacks_.incoming_publish_namespace_callback, Call(_, _))
- .WillOnce([&](TrackNamespace name,
- std::optional<VersionSpecificParameters> parameters) {
+ EXPECT_CALL(server_callbacks_.incoming_publish_namespace_callback,
+ Call(TrackNamespace{"foo"},
+ std::optional<VersionSpecificParameters>(), _))
+ .WillOnce([&](const TrackNamespace& name,
+ const std::optional<VersionSpecificParameters>& parameters,
+ MoqtResponseCallback callback) {
matches = true;
- EXPECT_EQ(name, TrackNamespace{"foo"});
- EXPECT_FALSE(parameters.has_value());
- return std::nullopt;
+ EXPECT_EQ(callback, nullptr);
});
client_->session()->PublishNamespaceDone(TrackNamespace{"foo"});
success = test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; });
@@ -192,8 +198,12 @@
auto parameters = std::make_optional<VersionSpecificParameters>(
AuthTokenType::kOutOfBand, "foo");
EXPECT_CALL(server_callbacks_.incoming_publish_namespace_callback,
- Call(TrackNamespace{"foo"}, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(TrackNamespace{"foo"}, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
testing::MockFunction<void(
TrackNamespace track_namespace,
std::optional<MoqtPublishNamespaceErrorReason> error_message)>
@@ -234,8 +244,12 @@
auto parameters = std::make_optional<VersionSpecificParameters>(
AuthTokenType::kOutOfBand, "foo");
EXPECT_CALL(server_callbacks_.incoming_publish_namespace_callback,
- Call(TrackNamespace{"foo"}, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(TrackNamespace{"foo"}, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
MockSubscribeRemoteTrackVisitor server_visitor;
testing::MockFunction<void(
TrackNamespace track_namespace,
@@ -269,14 +283,15 @@
AuthTokenType::kOutOfBand, "foo");
MockSubscribeRemoteTrackVisitor server_visitor;
EXPECT_CALL(server_callbacks_.incoming_publish_namespace_callback,
- Call(_, parameters))
+ Call(_, parameters, _))
.WillOnce([&](const TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> /*parameters*/) {
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
FullTrackName track_name(track_namespace, "data");
server_->session()->SubscribeAbsolute(
track_name, /*start_group=*/0, /*start_object=*/0, &server_visitor,
VersionSpecificParameters());
- return std::optional<MoqtPublishNamespaceErrorReason>();
+ std::move(callback)(std::nullopt);
});
auto queue = std::make_shared<MoqtOutgoingQueue>(
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index 5330c30..6e64394 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -419,6 +419,14 @@
length_ + element.length() <= kMaxFullTrackNameSize);
}
void AddElement(absl::string_view element);
+ bool PopElement() {
+ if (tuple_.size() == 1) {
+ return false;
+ }
+ length_ -= tuple_.back().length();
+ tuple_.pop_back();
+ return true;
+ }
std::string ToString() const;
// Returns the number of elements in the tuple.
size_t number_of_elements() const { return tuple_.size(); }
diff --git a/quiche/quic/moqt/moqt_messages_test.cc b/quiche/quic/moqt/moqt_messages_test.cc
index beb131d..4372ffd 100644
--- a/quiche/quic/moqt/moqt_messages_test.cc
+++ b/quiche/quic/moqt/moqt_messages_test.cc
@@ -42,6 +42,17 @@
EXPECT_FALSE(name2.InNamespace(name3));
}
+TEST(MoqtMessagesTest, TrackNamespacePushPop) {
+ TrackNamespace name({"a"});
+ TrackNamespace original = name;
+ name.AddElement("b");
+ EXPECT_TRUE(name.InNamespace(original));
+ EXPECT_FALSE(original.InNamespace(name));
+ EXPECT_TRUE(name.PopElement());
+ EXPECT_EQ(name, original);
+ EXPECT_FALSE(name.PopElement());
+}
+
TEST(MoqtMessagesTest, TrackNamespaceToString) {
TrackNamespace name1({"a", "b"});
EXPECT_EQ(name1.ToString(), R"({"a"::"b"})");
diff --git a/quiche/quic/moqt/moqt_publisher.h b/quiche/quic/moqt/moqt_publisher.h
index 0ad158b..3eca34e 100644
--- a/quiche/quic/moqt/moqt_publisher.h
+++ b/quiche/quic/moqt/moqt_publisher.h
@@ -138,7 +138,7 @@
public:
virtual ~MoqtPublisher() = default;
- // These are all called by MoqtSession based on messages arriving on the wire.
+ // Called by MoqtSession based on messages arriving on the wire.
virtual absl_nullable std::shared_ptr<MoqtTrackPublisher> GetTrack(
const FullTrackName& track_name) = 0;
virtual void AddNamespaceListener(NamespaceListener* listener) = 0;
diff --git a/quiche/quic/moqt/moqt_relay_publisher.cc b/quiche/quic/moqt/moqt_relay_publisher.cc
index bb8831e..ebe76ea 100644
--- a/quiche/quic/moqt/moqt_relay_publisher.cc
+++ b/quiche/quic/moqt/moqt_relay_publisher.cc
@@ -9,6 +9,7 @@
#include <utility>
#include "absl/base/nullability.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_publisher.h"
@@ -27,8 +28,9 @@
if (it != tracks_.end()) {
return it->second;
}
- QuicheWeakPtr<MoqtSessionInterface> upstream =
- GetUpstream(track_name.track_namespace());
+ // Make a copy, because this namespace might be truncated.
+ TrackNamespace track_namespace = track_name.track_namespace();
+ QuicheWeakPtr<MoqtSessionInterface> upstream = GetUpstream(track_namespace);
if (!upstream.IsValid()) {
return nullptr;
}
@@ -59,23 +61,48 @@
<< error_message;
default_upstream_session_ = QuicheWeakPtr<MoqtSessionInterface>();
};
- AddNamespaceCallbacks(default_upstream_session);
default_upstream_session_ = default_upstream_session->GetWeakPtr();
}
-void MoqtRelayPublisher::AddNamespaceCallbacks(
- MoqtSessionInterface* /*session*/) {
- // TODO(martinduke): Implement this.
+void MoqtRelayPublisher::OnPublishNamespace(
+ const TrackNamespace& track_namespace,
+ const VersionSpecificParameters& /*parameters*/,
+ MoqtSessionInterface* session, MoqtResponseCallback callback) {
+ if (session == nullptr) {
+ return;
+ }
+ // TODO(martinduke): Handle parameters.
+ namespace_publishers_.AddPublisher(track_namespace, session);
+ // TODO(martinduke): Notify subscribers listening for this namespace.
+ // Send PUBLISH_NAMESPACE_OK.
+ std::move(callback)(std::nullopt);
+}
+
+void MoqtRelayPublisher::OnPublishNamespaceDone(
+ const TrackNamespace& track_namespace, MoqtSessionInterface* session) {
+ if (session == nullptr) {
+ return;
+ }
+ namespace_publishers_.RemovePublisher(track_namespace, session);
+ // TODO(martinduke): Notify subscribers listening for this namespace.
}
QuicheWeakPtr<MoqtSessionInterface> MoqtRelayPublisher::GetUpstream(
- const TrackNamespace& /*track_namespace*/) {
- // TODO(martinduke): Find a published namespace that contains
- // |track_namespace|.
- if (default_upstream_session_.IsValid()) {
- return default_upstream_session_.GetIfAvailable()->GetWeakPtr();
+ TrackNamespace& track_namespace) {
+ quiche::QuicheWeakPtr<MoqtSessionInterface> upstream;
+ upstream = namespace_publishers_.GetValidPublisher(track_namespace);
+ if (upstream.IsValid()) {
+ return upstream;
}
- return QuicheWeakPtr<MoqtSessionInterface>();
+ if (!track_namespace.PopElement()) {
+ // This the last element; send the default upstream if valid.
+ if (default_upstream_session_.IsValid()) {
+ return default_upstream_session_.GetIfAvailable()->GetWeakPtr();
+ }
+ return QuicheWeakPtr<MoqtSessionInterface>();
+ }
+ // See if there's a subscriber for a parent namespace.
+ return GetUpstream(track_namespace);
}
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_relay_publisher.h b/quiche/quic/moqt/moqt_relay_publisher.h
index f841279..bf3f3da 100644
--- a/quiche/quic/moqt/moqt_relay_publisher.h
+++ b/quiche/quic/moqt/moqt_relay_publisher.h
@@ -12,7 +12,9 @@
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_relay_track_publisher.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
+#include "quiche/quic/moqt/namespace_publisher_multimap.h"
#include "quiche/common/quiche_weak_ptr.h"
namespace moqt {
@@ -21,8 +23,7 @@
// and namespaces with upstream sessions that can deliver those things.
class MoqtRelayPublisher : public MoqtPublisher {
public:
- explicit MoqtRelayPublisher(bool broadcast_mode)
- : broadcast_mode_(broadcast_mode) {}
+ MoqtRelayPublisher() = default;
MoqtRelayPublisher(const MoqtRelayPublisher&) = delete;
MoqtRelayPublisher(MoqtRelayPublisher&&) = delete;
MoqtRelayPublisher& operator=(const MoqtRelayPublisher&) = delete;
@@ -31,7 +32,7 @@
// MoqtPublisher implementation.
absl_nullable std::shared_ptr<MoqtTrackPublisher> GetTrack(
const FullTrackName& track_name) override;
- // TODO(martinduke): Implement namespace support.
+
void AddNamespaceListener(NamespaceListener* /*listener*/) override {}
void RemoveNamespaceListener(NamespaceListener* /*listener*/) override {}
@@ -39,28 +40,35 @@
// information, requests will route here.
void SetDefaultUpstreamSession(
MoqtSessionInterface* default_upstream_session);
- // There is a new incoming session. MoqtRelayPublisher will set the callbacks
- // for this session, but need not keep any state at this time.
- virtual void AddNamespaceCallbacks(MoqtSessionInterface* session);
// Returns the default upstream session.
quiche::QuicheWeakPtr<MoqtSessionInterface>& GetDefaultUpstreamSession() {
return default_upstream_session_;
}
+ void OnPublishNamespace(const TrackNamespace& track_namespace,
+ const VersionSpecificParameters& parameters,
+ MoqtSessionInterface* session,
+ MoqtResponseCallback callback);
+
+ void OnPublishNamespaceDone(const TrackNamespace& track_namespace,
+ MoqtSessionInterface* session);
+
private:
quiche::QuicheWeakPtr<MoqtSessionInterface> GetUpstream(
- const TrackNamespace& track_namespace);
+ TrackNamespace& track_namespace);
absl::flat_hash_map<FullTrackName, std::shared_ptr<MoqtRelayTrackPublisher>>
tracks_;
- // TODO(martinduke): Add a map of Namespaces to source sessions and
- // namespace listeners.
+
+ // An indexed map of namespace to a map of sessions. The key to the inner map
+ // is indexed by a raw pointer, to make it easier to find entries when
+ // deleting.
+ NamespacePublisherMultimap namespace_publishers_;
+
+ // TODO(martinduke): Add a map of Namespaces to namespace listeners.
quiche::QuicheWeakPtr<MoqtSessionInterface> default_upstream_session_;
- // If true, PUBLISH_NAMESPACE messages will be forwarded to all sessions,
- // whether or not they are subscribed.
- bool broadcast_mode_;
};
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_relay_publisher_test.cc b/quiche/quic/moqt/moqt_relay_publisher_test.cc
index 4188042..b844726 100644
--- a/quiche/quic/moqt/moqt_relay_publisher_test.cc
+++ b/quiche/quic/moqt/moqt_relay_publisher_test.cc
@@ -5,6 +5,7 @@
#include "quiche/quic/moqt/moqt_relay_publisher.h"
#include <memory>
+#include <optional>
#include <utility>
#include "absl/strings/string_view.h"
@@ -12,6 +13,7 @@
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/test_tools/mock_moqt_session.h"
+#include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h"
#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/common/quiche_weak_ptr.h"
@@ -20,11 +22,10 @@
class MoqtRelayPublisherTest : public quiche::test::QuicheTest {
public:
- MoqtRelayPublisherTest() : publisher_(false) {}
-
MoqtSessionCallbacks callbacks_;
MockMoqtSession session_;
MoqtRelayPublisher publisher_;
+ MockMoqtObjectListener object_listener_;
};
TEST_F(MoqtRelayPublisherTest, SetDefaultUpstreamSession) {
@@ -72,5 +73,24 @@
EXPECT_EQ(track->GetTrackName(), FullTrackName("foo", "bar"));
}
+TEST_F(MoqtRelayPublisherTest, PublishNamespaceLifecycle) {
+ EXPECT_EQ(publisher_.GetTrack(FullTrackName("foo", "bar")), nullptr);
+ std::optional<MoqtRequestError> response;
+ publisher_.OnPublishNamespace(
+ TrackNamespace({"foo"}), VersionSpecificParameters(), &session_,
+ [&](std::optional<MoqtRequestError> error_response) {
+ response = error_response;
+ });
+ EXPECT_EQ(response, std::nullopt);
+ std::shared_ptr<MoqtTrackPublisher> track =
+ publisher_.GetTrack(FullTrackName("foo", "bar"));
+ EXPECT_NE(track, nullptr);
+ EXPECT_CALL(session_, SubscribeCurrentObject);
+ track->AddObjectListener(&object_listener_);
+ track->RemoveObjectListener(&object_listener_);
+ publisher_.OnPublishNamespaceDone(TrackNamespace({"foo"}), &session_);
+ EXPECT_EQ(publisher_.GetTrack(FullTrackName("foo", "bar")), nullptr);
+}
+
} // namespace test
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 79632c8..3cc511f 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -47,6 +47,7 @@
#include "quiche/common/quiche_buffer_allocator.h"
#include "quiche/common/quiche_mem_slice.h"
#include "quiche/common/quiche_stream.h"
+#include "quiche/common/quiche_weak_ptr.h"
#include "quiche/common/simple_buffer_allocator.h"
#include "quiche/web_transport/web_transport.h"
@@ -1216,21 +1217,29 @@
session_->framer_.SerializePublishNamespaceError(error));
return;
}
- std::optional<MoqtPublishNamespaceErrorReason> error =
- session_->callbacks_.incoming_publish_namespace_callback(
- message.track_namespace, message.parameters);
- if (error.has_value()) {
- MoqtPublishNamespaceError reply;
- reply.request_id = message.request_id;
- reply.error_code = error->error_code;
- reply.error_reason = error->reason_phrase;
- SendOrBufferMessage(
- session_->framer_.SerializePublishNamespaceError(reply));
- return;
- }
- MoqtPublishNamespaceOk ok;
- ok.request_id = message.request_id;
- SendOrBufferMessage(session_->framer_.SerializePublishNamespaceOk(ok));
+ quiche::QuicheWeakPtr<MoqtSessionInterface> session_weakptr =
+ session_->GetWeakPtr();
+ session_->callbacks_.incoming_publish_namespace_callback(
+ message.track_namespace, message.parameters,
+ [&](std::optional<MoqtRequestError> error) {
+ MoqtSession* session =
+ static_cast<MoqtSession*>(session_weakptr.GetIfAvailable());
+ if (session == nullptr) {
+ return;
+ }
+ if (error.has_value()) {
+ MoqtPublishNamespaceError reply;
+ reply.request_id = message.request_id;
+ reply.error_code = error->error_code;
+ reply.error_reason = error->reason_phrase;
+ SendOrBufferMessage(
+ session->framer_.SerializePublishNamespaceError(reply));
+ } else {
+ MoqtPublishNamespaceOk ok;
+ ok.request_id = message.request_id;
+ SendOrBufferMessage(session->framer_.SerializePublishNamespaceOk(ok));
+ }
+ });
}
// Do not enforce that there is only one of OK or ERROR per PUBLISH_NAMESPACE.
@@ -1280,7 +1289,7 @@
void MoqtSession::ControlStream::OnPublishNamespaceDoneMessage(
const MoqtPublishNamespaceDone& message) {
session_->callbacks_.incoming_publish_namespace_callback(
- message.track_namespace, std::nullopt);
+ message.track_namespace, std::nullopt, nullptr);
}
void MoqtSession::ControlStream::OnPublishNamespaceCancelMessage(
diff --git a/quiche/quic/moqt/moqt_session_callbacks.h b/quiche/quic/moqt/moqt_session_callbacks.h
index f104de0..bb43103 100644
--- a/quiche/quic/moqt/moqt_session_callbacks.h
+++ b/quiche/quic/moqt/moqt_session_callbacks.h
@@ -6,6 +6,7 @@
#define QUICHE_QUIC_MOQT_MOQT_SESSION_CALLBACKS_H_
#include <optional>
+#include <utility>
#include "absl/strings/string_view.h"
#include "quiche/quic/core/quic_clock.h"
@@ -15,6 +16,12 @@
namespace moqt {
+// The callback we'll use for all request types going forward. Can only be used
+// once; if the argument is nullopt, an OK response was received. Otherwise, an
+// ERROR response was received.
+using MoqtResponseCallback =
+ quiche::SingleUseCallback<void(std::optional<MoqtRequestError>)>;
+
// Called when the SETUP message from the peer is received.
using MoqtSessionEstablishedCallback = quiche::SingleUseCallback<void()>;
@@ -31,11 +38,11 @@
// Called whenever a PUBLISH_NAMESPACE or PUBLISH_NAMESPACE_DONE message is
// received from the peer. PUBLISH_NAMESPACE sets a value for |parameters|,
-// PUBLISH_NAMESPACE_DONE does not.
-using MoqtIncomingPublishNamespaceCallback =
- quiche::MultiUseCallback<std::optional<MoqtPublishNamespaceErrorReason>(
- const TrackNamespace& track_namespace,
- const std::optional<VersionSpecificParameters>& parameters)>;
+// PUBLISH_NAMESPACE_DONE does not..
+using MoqtIncomingPublishNamespaceCallback = quiche::MultiUseCallback<void(
+ const TrackNamespace& track_namespace,
+ const std::optional<VersionSpecificParameters>& parameters,
+ MoqtResponseCallback callback)>;
// Called whenever SUBSCRIBE_NAMESPACE or UNSUBSCRIBE_NAMESPACE is received from
// the peer. For SUBSCRIBE_NAMESPACE, the return value indicates whether to
@@ -47,11 +54,13 @@
const TrackNamespace& track_namespace,
std::optional<VersionSpecificParameters> parameters)>;
-inline std::optional<MoqtPublishNamespaceErrorReason>
-DefaultIncomingPublishNamespaceCallback(
- const TrackNamespace& /*track_namespace*/,
- std::optional<VersionSpecificParameters> /*parameters*/) {
- return std::optional(MoqtPublishNamespaceErrorReason{
+inline void DefaultIncomingPublishNamespaceCallback(
+ const TrackNamespace&, const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
+ if (callback == nullptr) {
+ return;
+ }
+ return std::move(callback)(MoqtRequestError{
RequestErrorCode::kNotSupported,
"This endpoint does not accept incoming PUBLISH_NAMESPACE messages"});
};
diff --git a/quiche/quic/moqt/moqt_session_interface.h b/quiche/quic/moqt/moqt_session_interface.h
index 186d968..3ffe975 100644
--- a/quiche/quic/moqt/moqt_session_interface.h
+++ b/quiche/quic/moqt/moqt_session_interface.h
@@ -22,12 +22,6 @@
namespace moqt {
-// The callback we'll use for all request types going forward. Can only be used
-// once; if the argument is nullopt, an OK response was received. Otherwise, an
-// ERROR response was received.
-using MoqtResponseCallback =
- quiche::SingleUseCallback<void(std::optional<MoqtRequestError>)>;
-
using MoqtObjectAckFunction =
quiche::MultiUseCallback<void(uint64_t group_id, uint64_t object_id,
quic::QuicTimeDelta delta_from_deadline)>;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index ba966cd..b604dff 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -27,6 +27,7 @@
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h"
@@ -386,9 +387,8 @@
}
TEST_F(MoqtSessionTest, PublishNamespaceWithOkAndCancel) {
- testing::MockFunction<void(
- TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error_message)>
+ testing::MockFunction<void(TrackNamespace track_namespace,
+ std::optional<MoqtRequestError> error_message)>
publish_namespace_resolved_callback;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_);
@@ -405,7 +405,7 @@
};
EXPECT_CALL(publish_namespace_resolved_callback, Call(_, _))
.WillOnce([&](TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error) {
+ std::optional<MoqtRequestError> error) {
EXPECT_EQ(track_namespace, TrackNamespace("foo"));
EXPECT_FALSE(error.has_value());
});
@@ -418,7 +418,7 @@
};
EXPECT_CALL(publish_namespace_resolved_callback, Call(_, _))
.WillOnce([&](TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error) {
+ std::optional<MoqtRequestError> error) {
EXPECT_EQ(track_namespace, TrackNamespace("foo"));
ASSERT_TRUE(error.has_value());
EXPECT_EQ(error->error_code, RequestErrorCode::kInternalError);
@@ -430,9 +430,8 @@
}
TEST_F(MoqtSessionTest, PublishNamespaceWithOkAndPublishNamespaceDone) {
- testing::MockFunction<void(
- TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error_message)>
+ testing::MockFunction<void(TrackNamespace track_namespace,
+ std::optional<MoqtRequestError> error_message)>
publish_namespace_resolved_callback;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_);
@@ -449,7 +448,7 @@
};
EXPECT_CALL(publish_namespace_resolved_callback, Call(_, _))
.WillOnce([&](TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error) {
+ std::optional<MoqtRequestError> error) {
EXPECT_EQ(track_namespace, TrackNamespace{"foo"});
EXPECT_FALSE(error.has_value());
});
@@ -465,9 +464,8 @@
}
TEST_F(MoqtSessionTest, PublishNamespaceWithError) {
- testing::MockFunction<void(
- TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error_message)>
+ testing::MockFunction<void(TrackNamespace track_namespace,
+ std::optional<MoqtRequestError> error_message)>
publish_namespace_resolved_callback;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_);
@@ -486,7 +484,7 @@
};
EXPECT_CALL(publish_namespace_resolved_callback, Call(_, _))
.WillOnce([&](TrackNamespace track_namespace,
- std::optional<MoqtPublishNamespaceErrorReason> error) {
+ std::optional<MoqtRequestError> error) {
EXPECT_EQ(track_namespace, TrackNamespace{"foo"});
ASSERT_TRUE(error.has_value());
EXPECT_EQ(error->error_code, RequestErrorCode::kInternalError);
@@ -956,8 +954,12 @@
*parameters,
};
EXPECT_CALL(session_callbacks_.incoming_publish_namespace_callback,
- Call(track_namespace, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(track_namespace, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
EXPECT_CALL(mock_stream_,
Writev(SerializedControlMessage(
MoqtPublishNamespaceOk{kDefaultPeerRequestId}),
@@ -966,9 +968,13 @@
MoqtPublishNamespaceDone unpublish_namespace = {
track_namespace,
};
- EXPECT_CALL(session_callbacks_.incoming_publish_namespace_callback,
- Call(track_namespace, std::optional<VersionSpecificParameters>()))
- .WillOnce(Return(std::nullopt));
+ EXPECT_CALL(
+ session_callbacks_.incoming_publish_namespace_callback,
+ Call(track_namespace, std::optional<VersionSpecificParameters>(), _))
+ .WillOnce(
+ [](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); });
stream_input->OnPublishNamespaceDoneMessage(unpublish_namespace);
}
@@ -986,8 +992,12 @@
*parameters,
};
EXPECT_CALL(session_callbacks_.incoming_publish_namespace_callback,
- Call(track_namespace, parameters))
- .WillOnce(Return(std::nullopt));
+ Call(track_namespace, parameters, _))
+ .WillOnce([](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) {
+ std::move(callback)(std::nullopt);
+ });
EXPECT_CALL(mock_stream_,
Writev(SerializedControlMessage(
MoqtPublishNamespaceOk{kDefaultPeerRequestId}),
@@ -1014,13 +1024,16 @@
track_namespace,
*parameters,
};
- MoqtPublishNamespaceErrorReason error = {
+ MoqtRequestError error = {
RequestErrorCode::kNotSupported,
"deadbeef",
};
EXPECT_CALL(session_callbacks_.incoming_publish_namespace_callback,
- Call(track_namespace, parameters))
- .WillOnce(Return(error));
+ Call(track_namespace, parameters, _))
+ .WillOnce(
+ [&](const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback callback) { std::move(callback)(error); });
EXPECT_CALL(
mock_stream_,
Writev(SerializedControlMessage(MoqtPublishNamespaceError{
@@ -3419,7 +3432,7 @@
session_.PublishNamespace(
TrackNamespace{"foo"},
+[](TrackNamespace /*track_namespace*/,
- std::optional<MoqtPublishNamespaceErrorReason> /*error*/) {},
+ std::optional<MoqtRequestError> /*error*/) {},
VersionSpecificParameters());
EXPECT_FALSE(session_.Fetch(
FullTrackName{TrackNamespace("foo"), "bar"},
@@ -3486,7 +3499,7 @@
session_.PublishNamespace(
TrackNamespace{"foo"},
+[](TrackNamespace /*track_namespace*/,
- std::optional<MoqtPublishNamespaceErrorReason> /*error*/) {},
+ std::optional<MoqtRequestError> /*error*/) {},
VersionSpecificParameters());
EXPECT_FALSE(session_.Fetch(
FullTrackName(TrackNamespace("foo"), "bar"),
diff --git a/quiche/quic/moqt/namespace_publisher_multimap.h b/quiche/quic/moqt/namespace_publisher_multimap.h
new file mode 100644
index 0000000..d75092b
--- /dev/null
+++ b/quiche/quic/moqt/namespace_publisher_multimap.h
@@ -0,0 +1,63 @@
+// Copyright 2025 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_MOQT_NAMESPACE_PUBLISHER_MULTIMAP_H_
+#define QUICHE_QUIC_MOQT_NAMESPACE_PUBLISHER_MULTIMAP_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
+#include "quiche/common/quiche_weak_ptr.h"
+
+namespace moqt {
+
+class NamespacePublisherMultimap {
+ public:
+ void AddPublisher(const TrackNamespace& track_namespace,
+ MoqtSessionInterface* session) {
+ absl::flat_hash_map<MoqtSessionInterface*,
+ quiche::QuicheWeakPtr<MoqtSessionInterface>>&
+ publisher_map = namespace_map_[track_namespace];
+ publisher_map.emplace(session, session->GetWeakPtr());
+ }
+
+ void RemovePublisher(const TrackNamespace& track_namespace,
+ MoqtSessionInterface* session) {
+ auto it = namespace_map_.find(track_namespace);
+ if (it == namespace_map_.end()) {
+ return;
+ }
+ it->second.erase(session);
+ if (it->second.empty()) { // Last publisher for this namespace is gone.
+ namespace_map_.erase(it);
+ }
+ }
+
+ // Requires a precise match for |track_namespace|.
+ quiche::QuicheWeakPtr<MoqtSessionInterface> GetValidPublisher(
+ const TrackNamespace& track_namespace) {
+ auto it = namespace_map_.find(track_namespace);
+ if (it == namespace_map_.end()) {
+ return quiche::QuicheWeakPtr<MoqtSessionInterface>();
+ }
+ for (const auto& [session, publisher] : it->second) {
+ if (publisher.IsValid()) {
+ return publisher.GetIfAvailable()->GetWeakPtr();
+ }
+ }
+ return quiche::QuicheWeakPtr<MoqtSessionInterface>();
+ }
+
+ private:
+ absl::flat_hash_map<
+ TrackNamespace,
+ absl::flat_hash_map<MoqtSessionInterface*,
+ quiche::QuicheWeakPtr<MoqtSessionInterface>>>
+ namespace_map_;
+};
+
+} // namespace moqt
+
+#endif // QUICHE_QUIC_MOQT_NAMESPACE_PUBLISHER_MULTIMAP_H_
diff --git a/quiche/quic/moqt/namespace_publisher_multimap_test.cc b/quiche/quic/moqt/namespace_publisher_multimap_test.cc
new file mode 100644
index 0000000..7a961c6
--- /dev/null
+++ b/quiche/quic/moqt/namespace_publisher_multimap_test.cc
@@ -0,0 +1,45 @@
+// Copyright 2025 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quiche/quic/moqt/namespace_publisher_multimap.h"
+
+#include <memory>
+
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/test_tools/mock_moqt_session.h"
+#include "quiche/common/platform/api/quiche_test.h"
+
+namespace moqt {
+namespace test {
+
+class NamespacePublisherMultimapTest : public quiche::test::QuicheTest {
+ public:
+ NamespacePublisherMultimapTest()
+ : session_(std::make_unique<MockMoqtSession>()) {}
+
+ NamespacePublisherMultimap multimap_;
+ TrackNamespace ns1_{"foo", "bar"}, ns2_{"foo"}, ns3_{"foo", "bar", "baz"};
+ std::unique_ptr<MockMoqtSession> session_;
+};
+
+TEST_F(NamespacePublisherMultimapTest, AddGetRemovePublisher) {
+ EXPECT_EQ(multimap_.GetValidPublisher(ns1_).GetIfAvailable(), nullptr);
+ multimap_.AddPublisher(ns1_, session_.get());
+ EXPECT_EQ(multimap_.GetValidPublisher(ns1_).GetIfAvailable(), session_.get());
+ EXPECT_EQ(multimap_.GetValidPublisher(ns2_).GetIfAvailable(), nullptr);
+ EXPECT_EQ(multimap_.GetValidPublisher(ns3_).GetIfAvailable(), nullptr);
+ multimap_.RemovePublisher(ns1_, session_.get());
+ EXPECT_EQ(multimap_.GetValidPublisher(ns1_).GetIfAvailable(), nullptr);
+}
+
+TEST_F(NamespacePublisherMultimapTest, SessionDestroyed) {
+ EXPECT_EQ(multimap_.GetValidPublisher(ns1_).GetIfAvailable(), nullptr);
+ multimap_.AddPublisher(ns1_, session_.get());
+ EXPECT_EQ(multimap_.GetValidPublisher(ns1_).GetIfAvailable(), session_.get());
+ session_.reset();
+ EXPECT_EQ(multimap_.GetValidPublisher(ns1_).GetIfAvailable(), nullptr);
+}
+
+} // namespace test
+} // namespace moqt
diff --git a/quiche/quic/moqt/namespace_tree.h b/quiche/quic/moqt/namespace_tree.h
index 103e0f9..32296e3 100644
--- a/quiche/quic/moqt/namespace_tree.h
+++ b/quiche/quic/moqt/namespace_tree.h
@@ -5,9 +5,7 @@
#ifndef QUICHE_QUIC_MOQT_NAMESPACE_TREE_H_
#define QUICHE_QUIC_MOQT_NAMESPACE_TREE_H_
-#include <cstddef>
#include <string>
-#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
diff --git a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
index 10f10b1..27a2586 100644
--- a/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
+++ b/quiche/quic/moqt/test_tools/moqt_mock_visitor.h
@@ -22,6 +22,7 @@
#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/common/platform/api/quiche_test.h"
+#include "quiche/web_transport/web_transport.h"
namespace moqt::test {
@@ -30,15 +31,17 @@
testing::MockFunction<void(absl::string_view)> goaway_received_callback;
testing::MockFunction<void(absl::string_view)> session_terminated_callback;
testing::MockFunction<void()> session_deleted_callback;
- testing::MockFunction<std::optional<MoqtPublishNamespaceErrorReason>(
- const TrackNamespace&, std::optional<VersionSpecificParameters>)>
+ testing::MockFunction<void(const TrackNamespace&,
+ const std::optional<VersionSpecificParameters>&,
+ MoqtResponseCallback)>
incoming_publish_namespace_callback;
testing::MockFunction<std::optional<MoqtSubscribeErrorReason>(
TrackNamespace, std::optional<VersionSpecificParameters>)>
incoming_subscribe_namespace_callback;
MockSessionCallbacks() {
- ON_CALL(incoming_publish_namespace_callback, Call(testing::_, testing::_))
+ ON_CALL(incoming_publish_namespace_callback,
+ Call(testing::_, testing::_, testing::_))
.WillByDefault(DefaultIncomingPublishNamespaceCallback);
ON_CALL(incoming_subscribe_namespace_callback, Call(testing::_, testing::_))
.WillByDefault(DefaultIncomingSubscribeNamespaceCallback);
@@ -172,6 +175,19 @@
bool synchronous_object_available_ = false;
};
+class MockMoqtObjectListener : public MoqtObjectListener {
+ public:
+ MOCK_METHOD(void, OnSubscribeAccepted, (), (override));
+ MOCK_METHOD(void, OnSubscribeRejected, (MoqtRequestError), (override));
+ MOCK_METHOD(void, OnNewObjectAvailable, (Location, uint64_t, MoqtPriority),
+ (override));
+ MOCK_METHOD(void, OnNewFinAvailable, (Location, uint64_t), (override));
+ MOCK_METHOD(void, OnSubgroupAbandoned,
+ (uint64_t, uint64_t, webtransport::StreamErrorCode), (override));
+ MOCK_METHOD(void, OnGroupAbandoned, (uint64_t), (override));
+ MOCK_METHOD(void, OnTrackPublisherGone, (), (override));
+};
+
} // namespace moqt::test
#endif // QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_
diff --git a/quiche/quic/moqt/tools/chat_client.cc b/quiche/quic/moqt/tools/chat_client.cc
index 27d663b..111e83b 100644
--- a/quiche/quic/moqt/tools/chat_client.cc
+++ b/quiche/quic/moqt/tools/chat_client.cc
@@ -27,6 +27,7 @@
#include "quiche/quic/moqt/moqt_object.h"
#include "quiche/quic/moqt/moqt_outgoing_queue.h"
#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/tools/moq_chat.h"
#include "quiche/quic/moqt/tools/moqt_client.h"
@@ -40,13 +41,14 @@
namespace moqt::moq_chat {
-std::optional<MoqtPublishNamespaceErrorReason>
-ChatClient::OnIncomingPublishNamespace(
+void ChatClient::OnIncomingPublishNamespace(
const moqt::TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> parameters) {
+ std::optional<VersionSpecificParameters> parameters,
+ moqt::MoqtResponseCallback callback) {
if (track_namespace == GetUserNamespace(my_track_name_)) {
// Ignore PUBLISH_NAMESPACE for my own track.
- return std::optional<MoqtPublishNamespaceErrorReason>();
+ std::move(callback)(std::nullopt);
+ return;
}
std::optional<FullTrackName> track_name = ConstructTrackNameFromNamespace(
track_namespace, GetChatId(my_track_name_));
@@ -57,22 +59,25 @@
session_->Unsubscribe(*track_name);
other_users_.erase(*track_name);
}
- return std::nullopt;
+ return;
}
std::cout << "PUBLISH_NAMESPACE for " << track_namespace.ToString() << "\n";
if (!track_name.has_value()) {
std::cout << "PUBLISH_NAMESPACE rejected, invalid namespace\n";
- return std::make_optional<MoqtPublishNamespaceErrorReason>(
- RequestErrorCode::kTrackDoesNotExist, "Not a subscribed namespace");
+ std::move(callback)(std::make_optional<MoqtPublishNamespaceErrorReason>(
+ RequestErrorCode::kTrackDoesNotExist, "Not a subscribed namespace"));
+ return;
}
if (other_users_.contains(*track_name)) {
std::cout << "Duplicate PUBLISH_NAMESPACE, send OK and ignore\n";
- return std::nullopt;
+ std::move(callback)(std::nullopt);
+ return;
}
if (GetUsername(my_track_name_) == GetUsername(*track_name)) {
std::cout << "PUBLISH_NAMESPACE for a previous instance of my track, "
"do not subscribe\n";
- return std::nullopt;
+ std::move(callback)(std::nullopt);
+ return;
}
VersionSpecificParameters subscribe_parameters(
AuthTokenType::kOutOfBand, std::string(GetUsername(my_track_name_)));
@@ -81,7 +86,7 @@
++subscribes_to_make_;
other_users_.emplace(*track_name);
}
- return std::nullopt; // Send PUBLISH_NAMESPACE_OK.
+ std::move(callback)(std::nullopt); // Send PUBLISH_NAMESPACE_OK.
}
ChatClient::ChatClient(const quic::QuicServerId& server_id,
diff --git a/quiche/quic/moqt/tools/chat_client.h b/quiche/quic/moqt/tools/chat_client.h
index 21ff5b0..d5d8c95 100644
--- a/quiche/quic/moqt/tools/chat_client.h
+++ b/quiche/quic/moqt/tools/chat_client.h
@@ -128,9 +128,10 @@
private:
void RunEventLoop() { event_loop_->RunEventLoopOnce(kChatEventLoopDuration); }
// Callback for incoming publish_namespaces.
- std::optional<MoqtPublishNamespaceErrorReason> OnIncomingPublishNamespace(
+ void OnIncomingPublishNamespace(
const moqt::TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> parameters);
+ std::optional<VersionSpecificParameters> parameters,
+ moqt::MoqtResponseCallback callback);
// Basic session information
FullTrackName my_track_name_;
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc
index 9b21513..8cc05d9 100644
--- a/quiche/quic/moqt/tools/chat_server.cc
+++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -22,37 +22,40 @@
#include "quiche/quic/moqt/moqt_object.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/tools/moq_chat.h"
#include "quiche/quic/moqt/tools/moqt_server.h"
namespace moqt::moq_chat {
-std::optional<MoqtPublishNamespaceErrorReason>
-ChatServer::ChatServerSessionHandler::OnIncomingPublishNamespace(
+void ChatServer::ChatServerSessionHandler::OnIncomingPublishNamespace(
const moqt::TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> parameters) {
+ std::optional<VersionSpecificParameters> parameters,
+ MoqtResponseCallback callback) {
if (track_name_.has_value() &&
GetUserNamespace(*track_name_) != track_namespace) {
// ChatServer only supports one track per client session at a time. Return
// PUBLISH_NAMESPACE_OK and exit.
- return std::nullopt;
+ std::move(callback)(std::nullopt);
+ return;
}
// Accept the PUBLISH_NAMESPACE regardless of the chat_id.
track_name_ = ConstructTrackNameFromNamespace(track_namespace,
GetChatId(track_namespace));
if (!track_name_.has_value()) {
std::cout << "Malformed PUBLISH_NAMESPACE namespace\n";
- return MoqtPublishNamespaceErrorReason(
+ std::move(callback)(MoqtPublishNamespaceErrorReason(
RequestErrorCode::kTrackDoesNotExist,
- "Not a valid namespace for this chat.");
+ "Not a valid namespace for this chat."));
+ return;
}
if (!parameters.has_value()) {
std::cout << "Received PUBLISH_NAMESPACE_DONE for "
<< track_namespace.ToString() << "\n";
server_->DeleteUser(*track_name_);
track_name_.reset();
- return std::nullopt;
+ return;
}
std::cout << "Received PUBLISH_NAMESPACE for " << track_namespace.ToString()
<< "\n";
@@ -60,7 +63,7 @@
server_->remote_track_visitor(),
moqt::VersionSpecificParameters());
server_->AddUser(*track_name_);
- return std::nullopt;
+ std::move(callback)(std::nullopt);
}
void ChatServer::ChatServerSessionHandler::OnOutgoingPublishNamespaceReply(
diff --git a/quiche/quic/moqt/tools/chat_server.h b/quiche/quic/moqt/tools/chat_server.h
index a9d5e04..69b2e2c 100644
--- a/quiche/quic/moqt/tools/chat_server.h
+++ b/quiche/quic/moqt/tools/chat_server.h
@@ -24,6 +24,7 @@
#include "quiche/quic/moqt/moqt_object.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/tools/moqt_server.h"
@@ -92,9 +93,10 @@
private:
// Callback for incoming publish_namespaces.
- std::optional<MoqtPublishNamespaceErrorReason> OnIncomingPublishNamespace(
- const moqt::TrackNamespace& track_namespace,
- std::optional<VersionSpecificParameters> parameters);
+ void OnIncomingPublishNamespace(
+ const TrackNamespace& track_namespace,
+ std::optional<VersionSpecificParameters> parameters,
+ MoqtResponseCallback callback);
void OnOutgoingPublishNamespaceReply(
TrackNamespace track_namespace,
std::optional<MoqtPublishNamespaceErrorReason> error_message);
diff --git a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
index 42dc43f..2ef3773 100644
--- a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
+++ b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
@@ -32,6 +32,7 @@
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_object.h"
#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/tools/moqt_server.h"
#include "quiche/quic/platform/api/quic_socket_address.h"
@@ -118,9 +119,9 @@
// TODO(martinduke): Handle when |publish_namespace| is false
// (PUBLISH_NAMESPACE_DONE).
- std::optional<MoqtPublishNamespaceErrorReason> OnPublishNamespaceReceived(
- TrackNamespace track_namespace,
- std::optional<VersionSpecificParameters> /*parameters*/) {
+ void OnPublishNamespaceReceived(TrackNamespace track_namespace,
+ std::optional<VersionSpecificParameters>,
+ MoqtResponseCallback callback) {
if (!IsValidTrackNamespace(track_namespace) &&
!quiche::GetQuicheCommandLineFlag(
FLAGS_allow_invalid_track_namespaces)) {
@@ -128,9 +129,10 @@
<< "Rejected remote publish_namespace as it contained "
"disallowed characters; namespace: "
<< track_namespace;
- return MoqtPublishNamespaceErrorReason{
+ std::move(callback)(MoqtPublishNamespaceErrorReason{
RequestErrorCode::kInternalError,
- "Track namespace contains disallowed characters"};
+ "Track namespace contains disallowed characters"});
+ return;
}
std::string directory_name = absl::StrCat(
@@ -141,16 +143,18 @@
track_namespace, NamespaceHandler(directory_path));
if (!added) {
// Received before; should be handled by already existing subscriptions.
- return std::nullopt;
+ std::move(callback)(std::nullopt);
+ return;
}
if (absl::Status status = MakeDirectory(directory_path); !status.ok()) {
subscribed_namespaces_.erase(it);
QUICHE_LOG(ERROR) << "Failed to create directory " << directory_path
<< "; " << status;
- return MoqtPublishNamespaceErrorReason{
- RequestErrorCode::kInternalError,
- "Failed to create output directory"};
+ std::move(callback)(
+ MoqtPublishNamespaceErrorReason{RequestErrorCode::kInternalError,
+ "Failed to create output directory"});
+ return;
}
std::string track_list = quiche::GetQuicheCommandLineFlag(FLAGS_tracks);
@@ -161,8 +165,7 @@
session_->RelativeJoiningFetch(full_track_name, &it->second, 0,
VersionSpecificParameters());
}
-
- return std::nullopt;
+ std::move(callback)(std::nullopt);
}
private:
diff --git a/quiche/quic/moqt/tools/moqt_relay.cc b/quiche/quic/moqt/tools/moqt_relay.cc
index 993c46d..96e345d 100644
--- a/quiche/quic/moqt/tools/moqt_relay.cc
+++ b/quiche/quic/moqt/tools/moqt_relay.cc
@@ -6,6 +6,7 @@
#include <cstdint>
#include <memory>
+#include <optional>
#include <string>
#include <utility>
@@ -15,8 +16,10 @@
#include "quiche/quic/core/crypto/proof_verifier.h"
#include "quiche/quic/core/io/quic_event_loop.h"
#include "quiche/quic/core/quic_server_id.h"
+#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_session.h"
#include "quiche/quic/moqt/moqt_session_callbacks.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/tools/moqt_client.h"
#include "quiche/quic/moqt/tools/moqt_server.h"
#include "quiche/quic/platform/api/quic_default_proof_providers.h"
@@ -32,16 +35,15 @@
MoqtRelay::MoqtRelay(std::unique_ptr<quic::ProofSource> proof_source,
std::string bind_address, uint16_t bind_port,
absl::string_view default_upstream,
- bool ignore_certificate, bool broadcast_mode)
+ bool ignore_certificate)
: MoqtRelay(std::move(proof_source), bind_address, bind_port,
- default_upstream, ignore_certificate, broadcast_mode, nullptr) {
-}
+ default_upstream, ignore_certificate, nullptr) {}
// protected members.
MoqtRelay::MoqtRelay(std::unique_ptr<quic::ProofSource> proof_source,
std::string bind_address, uint16_t bind_port,
absl::string_view default_upstream,
- bool ignore_certificate, bool broadcast_mode,
+ bool ignore_certificate,
quic::QuicEventLoop* client_event_loop)
: ignore_certificate_(ignore_certificate),
client_event_loop_(client_event_loop),
@@ -51,8 +53,7 @@
[this](absl::string_view path) {
return IncomingSessionHandler(
path);
- })),
- publisher_(broadcast_mode) {
+ })) {
quiche::QuicheIpAddress bind_ip_address;
QUICHE_CHECK(bind_ip_address.FromString(bind_address));
// CreateUDPSocketAndListen() creates the event loop that we will pass to
@@ -91,8 +92,10 @@
MoqtSessionCallbacks MoqtRelay::CreateClientCallbacks() {
MoqtSessionCallbacks callbacks;
callbacks.session_established_callback = [this]() {
- default_upstream_client_->session()->set_publisher(&publisher_);
- publisher_.SetDefaultUpstreamSession(default_upstream_client_->session());
+ MoqtSession* session = default_upstream_client_->session();
+ session->set_publisher(&publisher_);
+ publisher_.SetDefaultUpstreamSession(session);
+ SetPublishNamespaceCallback(session);
};
callbacks.goaway_received_callback = [](absl::string_view new_session_uri) {
QUICHE_LOG(INFO) << "GoAway received, new session uri = "
@@ -103,13 +106,28 @@
return callbacks;
}
+void MoqtRelay::SetPublishNamespaceCallback(MoqtSessionInterface* session) {
+ session->callbacks().incoming_publish_namespace_callback =
+ [this, session](
+ const TrackNamespace& track_namespace,
+ const std::optional<VersionSpecificParameters>& parameters,
+ MoqtResponseCallback callback) {
+ if (parameters.has_value()) {
+ return publisher_.OnPublishNamespace(track_namespace, *parameters,
+ session, std::move(callback));
+ } else {
+ return publisher_.OnPublishNamespaceDone(track_namespace, session);
+ }
+ };
+}
+
absl::StatusOr<MoqtConfigureSessionCallback> MoqtRelay::IncomingSessionHandler(
absl::string_view /*path*/) {
return [this](MoqtSession* session) {
- session->set_publisher(&publisher_);
session->callbacks().session_established_callback = [this, session]() {
- publisher_.AddNamespaceCallbacks(session);
+ session->set_publisher(&publisher_);
};
+ SetPublishNamespaceCallback(session);
};
}
diff --git a/quiche/quic/moqt/tools/moqt_relay.h b/quiche/quic/moqt/tools/moqt_relay.h
index 3398685..f653b3d 100644
--- a/quiche/quic/moqt/tools/moqt_relay.h
+++ b/quiche/quic/moqt/tools/moqt_relay.h
@@ -15,6 +15,7 @@
#include "quiche/quic/core/io/quic_event_loop.h"
#include "quiche/quic/moqt/moqt_relay_publisher.h"
#include "quiche/quic/moqt/moqt_session_callbacks.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/quic/moqt/tools/moqt_client.h"
#include "quiche/quic/moqt/tools/moqt_server.h"
#include "quiche/quic/tools/quic_url.h"
@@ -34,8 +35,8 @@
// If |default_upstream| is empty, no default upstream session is created.
MoqtRelay(std::unique_ptr<quic::ProofSource> proof_source,
std::string bind_address, uint16_t bind_port,
- absl::string_view default_upstream, bool ignore_certificate,
- bool broadcast_mode);
+ absl::string_view default_upstream, bool ignore_certificate);
+ virtual ~MoqtRelay() = default;
void HandleEventsForever() { server_->quic_server().HandleEventsForever(); }
@@ -47,12 +48,14 @@
MoqtRelay(std::unique_ptr<quic::ProofSource> proof_source,
std::string bind_address, uint16_t bind_port,
absl::string_view default_upstream, bool ignore_certificate,
- bool broadcast_mode, quic::QuicEventLoop* client_event_loop);
+ quic::QuicEventLoop* client_event_loop);
// Other functions for MoqtTestRelay.
MoqtServer* server() { return server_.get(); }
MoqtClient* client() { return default_upstream_client_.get(); }
MoqtRelayPublisher* publisher() { return &publisher_; }
+ virtual void SetPublishNamespaceCallback(MoqtSessionInterface* session);
+
private:
std::unique_ptr<moqt::MoqtClient> CreateClient(
quic::QuicUrl url, bool ignore_certificate,
diff --git a/quiche/quic/moqt/tools/moqt_relay_bin.cc b/quiche/quic/moqt/tools/moqt_relay_bin.cc
index e281370..241a74d 100644
--- a/quiche/quic/moqt/tools/moqt_relay_bin.cc
+++ b/quiche/quic/moqt/tools/moqt_relay_bin.cc
@@ -29,11 +29,6 @@
"If set, connect to the upstream URL and forward all requests there if "
"there is no explicitly advertised source.");
-DEFINE_QUICHE_COMMAND_LINE_FLAG(
- bool, broadcast_mode, false,
- "If set, PUBLISH_NAMESPACE messages will be forwarded to all sessions, "
- "whether or not they are subscribed.");
-
// A pure MoQT relay. Accepts connections. Will try to route requests from a
// session to a different appropriate upstream session. If the namespace for the
// request has not been advertised, it will reject the request. If
@@ -52,8 +47,7 @@
quiche::GetQuicheCommandLineFlag(FLAGS_bind_address),
quiche::GetQuicheCommandLineFlag(FLAGS_port),
quiche::GetQuicheCommandLineFlag(FLAGS_default_upstream),
- quiche::GetQuicheCommandLineFlag(FLAGS_disable_certificate_verification),
- quiche::GetQuicheCommandLineFlag(FLAGS_broadcast_mode));
+ quiche::GetQuicheCommandLineFlag(FLAGS_disable_certificate_verification));
relay.HandleEventsForever();
return 0;
}
diff --git a/quiche/quic/moqt/tools/moqt_relay_test.cc b/quiche/quic/moqt/tools/moqt_relay_test.cc
index 8e23ebf..45ab00e 100644
--- a/quiche/quic/moqt/tools/moqt_relay_test.cc
+++ b/quiche/quic/moqt/tools/moqt_relay_test.cc
@@ -5,14 +5,20 @@
#include "quiche/quic/moqt/tools/moqt_relay.h"
#include <cstdint>
+#include <memory>
+#include <optional>
#include <string>
#include <utility>
#include "absl/strings/string_view.h"
#include "quiche/quic/core/io/quic_event_loop.h"
#include "quiche/quic/core/quic_time.h"
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_relay_publisher.h"
#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
+#include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h"
#include "quiche/quic/moqt/tools/moqt_client.h"
#include "quiche/quic/moqt/tools/moqt_server.h"
#include "quiche/quic/test_tools/crypto_test_utils.h"
@@ -29,10 +35,10 @@
public:
TestMoqtRelay(std::string bind_address, uint16_t bind_port,
absl::string_view default_upstream, bool ignore_certificate,
- bool promiscuous_mode, quic::QuicEventLoop* event_loop)
+ quic::QuicEventLoop* event_loop)
: MoqtRelay(quic::test::crypto_test_utils::ProofSourceForTesting(),
bind_address, bind_port, default_upstream, ignore_certificate,
- promiscuous_mode, event_loop) {}
+ event_loop) {}
quic::QuicEventLoop* server_event_loop() {
return server()->quic_server().event_loop();
@@ -47,15 +53,23 @@
}
MoqtRelayPublisher* publisher() { return MoqtRelay::publisher(); }
+
+ virtual void SetPublishNamespaceCallback(
+ MoqtSessionInterface* session) override {
+ last_server_session = session;
+ MoqtRelay::SetPublishNamespaceCallback(session);
+ }
+
+ MoqtSessionInterface* last_server_session;
};
class MoqtRelayTest : public quiche::test::QuicheTest {
public:
MoqtRelayTest()
- : upstream_("127.0.0.1", 9991, "", true, false, nullptr), // no client.
- relay_("127.0.0.1", 9992, "https://127.0.0.1:9991", true, false,
+ : upstream_("127.0.0.1", 9991, "", true, nullptr), // no client.
+ relay_("127.0.0.1", 9992, "https://127.0.0.1:9991", true,
upstream_.server_event_loop()),
- downstream_("127.0.0.1", 9993, "https://127.0.0.1:9992", true, false,
+ downstream_("127.0.0.1", 9993, "https://127.0.0.1:9992", true,
relay_.server_event_loop()) {
RunUntilConnected(relay_, upstream_);
RunUntilConnected(downstream_, relay_);
@@ -106,6 +120,34 @@
EXPECT_FALSE(relay_.publisher()->GetDefaultUpstreamSession().IsValid());
}
+TEST_F(MoqtRelayTest, PublishNamespace) {
+ MockMoqtObjectListener object_listener;
+ // No path to a subscribe. Test the upstream_ publisher because it doesn't
+ // have a default upstream.
+ EXPECT_EQ(upstream_.publisher()->GetTrack(FullTrackName("foo", "bar")),
+ nullptr);
+ // relay_ publishes a namespace, so upstream_ will route to relay_.
+ relay_.client_session()->PublishNamespace(
+ TrackNamespace({"foo"}),
+ [](TrackNamespace, std::optional<MoqtPublishNamespaceErrorReason>) {},
+ VersionSpecificParameters());
+ upstream_.RunOneEvent();
+ // There is now an upstream session for "Foo".
+ std::shared_ptr<MoqtTrackPublisher> track =
+ upstream_.publisher()->GetTrack(FullTrackName("foo", "bar"));
+ EXPECT_NE(track, nullptr);
+ track->AddObjectListener(&object_listener);
+ track->RemoveObjectListener(&object_listener);
+ // Track should have been destroyed.
+
+ // Send PUBLISH_NAMESPACE_DONE
+ relay_.client_session()->PublishNamespaceDone(TrackNamespace({"foo"}));
+ upstream_.RunOneEvent();
+ // Now there's nowhere to route for "foo".
+ EXPECT_EQ(upstream_.publisher()->GetTrack(FullTrackName("foo", "bar")),
+ nullptr);
+}
+
#if 0 // TODO(martinduke): Re-enable these tests when GOAWAY support exists.
TEST_F(MoqtRelayTest, GoAwayAtClient) {
ASSERT_NE(relay_.client_session(), nullptr);