Refactor MoQT Upstream SUBSCRIBE processing and data structures.
Streamline and simplify in preparation for also sending FETCH.
- Move subscribe-specific properties to SubscribeRemoteTrack, derived from RemoteTrack.
- A map of subscribes by track_alias_. Also, a separate map of pointers to those records indexed by subscribe_id_. pointers to FetchTask will also be stored in the latter.
- Eliminate "active subscribes" and just put the state into SubscribeRemoteTrack, since we are now indexing all subscribe by ID.
- Allow access to the stream type in the parser to verify it's correct for the subscription.
- Based on @vasilvv suggestion, deleted a check in quiche_weak_ptr.h that was misfiring.
- Added support for Upstream UNSUBSCRIBE to avoid a problem with ChatClient not clearing subscription state.
Not in production.
PiperOrigin-RevId: 700756286
diff --git a/quiche/common/quiche_weak_ptr.h b/quiche/common/quiche_weak_ptr.h
index c0d7b49..29dc6d6 100644
--- a/quiche/common/quiche_weak_ptr.h
+++ b/quiche/common/quiche_weak_ptr.h
@@ -93,17 +93,7 @@
class QUICHE_NO_EXPORT QuicheWeakPtrFactory final {
public:
explicit QuicheWeakPtrFactory(absl::Nonnull<T*> object)
- : control_block_(std::make_shared<ControlBlock>(object)) {
- // Chromium uses a Clang plugin to ensure that WeakPtrFactory objects are
- // always last; QUICHE does not have infrastructure for that, but we can do
- // some basic checks to prevent the API misuse.
- const uintptr_t factory_address = reinterpret_cast<uintptr_t>(this);
- const uintptr_t object_address_start = reinterpret_cast<uintptr_t>(object);
- const uintptr_t object_address_end = object_address_start + sizeof(object);
- QUICHE_DCHECK(factory_address >= object_address_start &&
- factory_address <= object_address_end)
- << "WeakPtrFactory<T> must be a member of T";
- }
+ : control_block_(std::make_shared<ControlBlock>(object)) {}
~QuicheWeakPtrFactory() { control_block_->Clear(); }
QuicheWeakPtrFactory(const QuicheWeakPtrFactory&) = delete;
diff --git a/quiche/common/quiche_weak_ptr_test.cc b/quiche/common/quiche_weak_ptr_test.cc
index 73e9c2d..f07d707 100644
--- a/quiche/common/quiche_weak_ptr_test.cc
+++ b/quiche/common/quiche_weak_ptr_test.cc
@@ -4,7 +4,6 @@
#include "quiche/common/quiche_weak_ptr.h"
-#include <memory>
#include <utility>
#include "quiche/common/platform/api/quiche_test.h"
@@ -63,12 +62,5 @@
EXPECT_FALSE(ptr.IsValid());
}
-TEST(QuicheWeakPtrTest, OutOfClassFactory) {
- bool object = false;
- EXPECT_QUICHE_DEBUG_DEATH(
- std::make_unique<QuicheWeakPtrFactory<bool>>(&object),
- "must be a member of T");
-}
-
} // namespace
} // namespace quiche
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc
index c7630a5..dc1820b 100644
--- a/quiche/quic/moqt/moqt_integration_test.cc
+++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -149,7 +149,7 @@
EXPECT_CALL(server_callbacks_.incoming_announce_callback,
Call(FullTrackName{"foo"}))
.WillOnce(Return(std::nullopt));
- MockRemoteTrackVisitor server_visitor;
+ MockSubscribeRemoteTrackVisitor server_visitor;
testing::MockFunction<void(
FullTrackName track_namespace,
std::optional<MoqtAnnounceErrorReason> error_message)>
@@ -166,7 +166,7 @@
EXPECT_FALSE(error.has_value());
server_->session()->SubscribeCurrentGroup(track_name, &server_visitor);
});
- EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() {
+ EXPECT_CALL(server_visitor, OnReply(_, _, _)).WillOnce([&]() {
matches = true;
});
bool success =
@@ -179,7 +179,7 @@
// Set up the server to subscribe to "data" track for the namespace announce
// it receives.
- MockRemoteTrackVisitor server_visitor;
+ MockSubscribeRemoteTrackVisitor server_visitor;
EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call(_))
.WillOnce([&](FullTrackName track_namespace) {
FullTrackName track_name = track_namespace;
@@ -196,7 +196,7 @@
client_->session()->set_publisher(&known_track_publisher);
queue->AddObject(MemSliceFromString("object data"), /*key=*/true);
bool received_subscribe_ok = false;
- EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() {
+ EXPECT_CALL(server_visitor, OnReply(_, _, _)).WillOnce([&]() {
received_subscribe_ok = true;
});
client_->session()->Announce(
@@ -232,7 +232,7 @@
{MoqtForwardingPreference::kSubgroup,
MoqtForwardingPreference::kDatagram}) {
SCOPED_TRACE(MoqtForwardingPreferenceToString(forwarding_preference));
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
std::string name =
absl::StrCat("pref_", static_cast<int>(forwarding_preference));
auto queue = std::make_shared<MoqtOutgoingQueue>(
@@ -295,7 +295,7 @@
{MoqtForwardingPreference::kSubgroup,
MoqtForwardingPreference::kDatagram}) {
SCOPED_TRACE(MoqtForwardingPreferenceToString(forwarding_preference));
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
std::string name =
absl::StrCat("pref_", static_cast<int>(forwarding_preference));
auto queue = std::make_shared<MoqtOutgoingQueue>(
@@ -379,10 +379,10 @@
auto track_publisher = std::make_shared<MockTrackPublisher>(full_track_name);
publisher.Add(track_publisher);
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
std::optional<absl::string_view> expected_reason = std::nullopt;
bool received_ok = false;
- EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
+ EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason))
.WillOnce([&]() { received_ok = true; });
client_->session()->SubscribeAbsolute(full_track_name, 0, 0, &client_visitor);
bool success =
@@ -399,10 +399,10 @@
auto track_publisher = std::make_shared<MockTrackPublisher>(full_track_name);
publisher.Add(track_publisher);
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
std::optional<absl::string_view> expected_reason = std::nullopt;
bool received_ok = false;
- EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
+ EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason))
.WillOnce([&]() { received_ok = true; });
client_->session()->SubscribeCurrentObject(full_track_name, &client_visitor);
bool success =
@@ -419,10 +419,10 @@
auto track_publisher = std::make_shared<MockTrackPublisher>(full_track_name);
publisher.Add(track_publisher);
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
std::optional<absl::string_view> expected_reason = std::nullopt;
bool received_ok = false;
- EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
+ EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason))
.WillOnce([&]() { received_ok = true; });
client_->session()->SubscribeCurrentGroup(full_track_name, &client_visitor);
bool success =
@@ -433,10 +433,10 @@
TEST_F(MoqtIntegrationTest, SubscribeError) {
EstablishSession();
FullTrackName full_track_name("foo", "bar");
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
std::optional<absl::string_view> expected_reason = "No tracks published";
bool received_ok = false;
- EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
+ EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason))
.WillOnce([&]() { received_ok = true; });
client_->session()->SubscribeCurrentObject(full_track_name, &client_visitor);
bool success =
@@ -452,7 +452,7 @@
ConnectEndpoints();
FullTrackName full_track_name("foo", "bar");
- MockRemoteTrackVisitor client_visitor;
+ MockSubscribeRemoteTrackVisitor client_visitor;
MoqtKnownTrackPublisher publisher;
server_->session()->set_publisher(&publisher);
@@ -468,8 +468,9 @@
.WillOnce([&](MoqtObjectAckFunction new_ack_function) {
ack_function = std::move(new_ack_function);
});
- EXPECT_CALL(client_visitor, OnReply(_, _))
- .WillOnce([&](const FullTrackName&, std::optional<absl::string_view>) {
+ EXPECT_CALL(client_visitor, OnReply(_, _, _))
+ .WillOnce([&](const FullTrackName&, std::optional<FullSequence>,
+ std::optional<absl::string_view>) {
ack_function(10, 20, quic::QuicTimeDelta::FromMicroseconds(-123));
ack_function(100, 200, quic::QuicTimeDelta::FromMicroseconds(456));
});
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h
index 58795c4..34953e4 100644
--- a/quiche/quic/moqt/moqt_parser.h
+++ b/quiche/quic/moqt/moqt_parser.h
@@ -20,6 +20,10 @@
namespace moqt {
+namespace test {
+class MoqtDataParserPeer;
+}
+
class QUICHE_EXPORT MoqtControlParserVisitor {
public:
virtual ~MoqtControlParserVisitor() = default;
@@ -195,7 +199,11 @@
// be used for testing.
void set_chunk_size(size_t size) { chunk_size_ = size; }
+ std::optional<MoqtDataStreamType> stream_type() const { return type_; }
+
private:
+ friend class test::MoqtDataParserPeer;
+
// If there is buffered data from the previous attempt at parsing it, new data
// will be added in `chunk_size_`-sized chunks.
constexpr static size_t kDefaultChunkSize = 64;
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 0bf32bf..10b5939 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -198,11 +198,27 @@
<< message.object_id << " priority "
<< message.publisher_priority << " length "
<< payload.size();
- auto [full_track_name, visitor] =
- TrackPropertiesFromAlias(message, MoqtForwardingPreference::kDatagram);
+ SubscribeRemoteTrack* track = RemoteTrackByAlias(message.track_alias);
+ if (track == nullptr) {
+ return;
+ }
+ if (!track->CheckDataStreamType(MoqtDataStreamType::kObjectDatagram)) {
+ Error(MoqtError::kProtocolViolation,
+ "Received DATAGRAM for non-datagram track");
+ return;
+ }
+ if (!track->InWindow(FullSequence(message.group_id, message.object_id))) {
+ // TODO(martinduke): a recent SUBSCRIBE_UPDATE could put us here, and it's
+ // not an error.
+ return;
+ }
+ QUICHE_CHECK(!track->is_fetch());
+ track->OnObjectOrOk();
+ SubscribeRemoteTrack::Visitor* visitor = track->visitor();
if (visitor != nullptr) {
visitor->OnObjectFragment(
- full_track_name, FullSequence{message.group_id, 0, message.object_id},
+ track->full_track_name(),
+ FullSequence{message.group_id, 0, message.object_id},
message.publisher_priority, message.object_status, payload, true);
}
}
@@ -248,7 +264,7 @@
bool MoqtSession::SubscribeAbsolute(const FullTrackName& name,
uint64_t start_group, uint64_t start_object,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters) {
MoqtSubscribe message;
message.full_track_name = name;
@@ -265,7 +281,7 @@
bool MoqtSession::SubscribeAbsolute(const FullTrackName& name,
uint64_t start_group, uint64_t start_object,
uint64_t end_group,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters) {
if (end_group < start_group) {
QUIC_DLOG(ERROR) << "Subscription end is before beginning";
@@ -286,7 +302,7 @@
bool MoqtSession::SubscribeAbsolute(const FullTrackName& name,
uint64_t start_group, uint64_t start_object,
uint64_t end_group, uint64_t end_object,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters) {
if (end_group < start_group) {
QUIC_DLOG(ERROR) << "Subscription end is before beginning";
@@ -309,7 +325,7 @@
}
bool MoqtSession::SubscribeCurrentObject(const FullTrackName& name,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters) {
MoqtSubscribe message;
message.full_track_name = name;
@@ -324,7 +340,7 @@
}
bool MoqtSession::SubscribeCurrentGroup(const FullTrackName& name,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters) {
MoqtSubscribe message;
message.full_track_name = name;
@@ -339,6 +355,21 @@
return Subscribe(message, visitor);
}
+void MoqtSession::Unsubscribe(const FullTrackName& name) {
+ RemoteTrack* track = RemoteTrackByName(name);
+ if (track == nullptr) {
+ return;
+ }
+ MoqtUnsubscribe message;
+ message.subscribe_id = track->subscribe_id();
+ SendControlMessage(framer_.SerializeUnsubscribe(message));
+ // Destroy state.
+ upstream_by_name_.erase(name);
+ upstream_by_id_.erase(track->subscribe_id());
+ subscribe_by_alias_.erase(
+ static_cast<SubscribeRemoteTrack*>(track)->track_alias());
+}
+
void MoqtSession::PublishedFetch::FetchStreamVisitor::OnCanWrite() {
std::shared_ptr<PublishedFetch> fetch = fetch_.lock();
if (fetch == nullptr) {
@@ -411,7 +442,8 @@
}
bool MoqtSession::Subscribe(MoqtSubscribe& message,
- RemoteTrack::Visitor* visitor) {
+ SubscribeRemoteTrack::Visitor* visitor,
+ std::optional<uint64_t> provided_track_alias) {
if (peer_role_ == MoqtRole::kSubscriber) {
QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE to subscriber peer";
return false;
@@ -424,16 +456,20 @@
<< peer_max_subscribe_id_;
return false;
}
- message.subscribe_id = next_subscribe_id_++;
- auto it = remote_track_aliases_.find(message.full_track_name);
- if (it != remote_track_aliases_.end()) {
- message.track_alias = it->second;
- if (message.track_alias >= next_remote_track_alias_) {
- next_remote_track_alias_ = message.track_alias + 1;
- }
- } else {
- message.track_alias = next_remote_track_alias_++;
+ if (upstream_by_name_.contains(message.full_track_name)) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE for track "
+ << message.full_track_name
+ << " which is already subscribed";
+ return false;
}
+ if (provided_track_alias.has_value() &&
+ subscribe_by_alias_.contains(*provided_track_alias)) {
+ Error(MoqtError::kProtocolViolation, "Provided track alias already in use");
+ return false;
+ }
+ message.subscribe_id = next_subscribe_id_++;
+ message.track_alias =
+ provided_track_alias.value_or(next_remote_track_alias_++);
if (SupportsObjectAck() && visitor != nullptr) {
// Since we do not expose subscribe IDs directly in the API, instead wrap
// the session and subscribe ID in a callback.
@@ -448,7 +484,10 @@
SendControlMessage(framer_.SerializeSubscribe(message));
QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
<< message.full_track_name;
- active_subscribes_.try_emplace(message.subscribe_id, message, visitor);
+ auto track = std::make_unique<SubscribeRemoteTrack>(message, visitor);
+ upstream_by_name_.emplace(message.full_track_name, track.get());
+ upstream_by_id_.emplace(message.subscribe_id, track.get());
+ subscribe_by_alias_.emplace(message.track_alias, std::move(track));
return true;
}
@@ -503,6 +542,30 @@
return true;
}
+SubscribeRemoteTrack* MoqtSession::RemoteTrackByAlias(uint64_t track_alias) {
+ auto it = subscribe_by_alias_.find(track_alias);
+ if (it == subscribe_by_alias_.end()) {
+ return nullptr;
+ }
+ return it->second.get();
+}
+
+RemoteTrack* MoqtSession::RemoteTrackById(uint64_t subscribe_id) {
+ auto it = upstream_by_id_.find(subscribe_id);
+ if (it == upstream_by_id_.end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+RemoteTrack* MoqtSession::RemoteTrackByName(const FullTrackName& name) {
+ auto it = upstream_by_name_.find(name);
+ if (it == upstream_by_name_.end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
void MoqtSession::OnCanCreateNewOutgoingUnidirectionalStream() {
while (!subscribes_with_queued_outgoing_data_streams_.empty() &&
session_->CanOpenNextOutgoingUnidirectionalStream()) {
@@ -555,58 +618,6 @@
SendControlMessage(framer_.SerializeMaxSubscribeId(message));
}
-std::pair<FullTrackName, RemoteTrack::Visitor*>
-MoqtSession::TrackPropertiesFromAlias(
- const MoqtObject& message,
- std::optional<MoqtForwardingPreference> forwarding_preference) {
- auto it = remote_tracks_.find(message.track_alias);
- if (it == remote_tracks_.end()) {
- ActiveSubscribe* subscribe = nullptr;
- // SUBSCRIBE_OK has not arrived yet, but deliver the object. Indexing
- // active_subscribes_ by track alias would make this faster if the
- // subscriber has tons of incomplete subscribes.
- for (auto& open_subscribe : active_subscribes_) {
- if (open_subscribe.second.message.track_alias == message.track_alias) {
- subscribe = &open_subscribe.second;
- break;
- }
- }
- if (subscribe == nullptr) {
- return std::pair<FullTrackName, RemoteTrack::Visitor*>(
- {FullTrackName{}, nullptr});
- }
- subscribe->received_object = true;
- if (forwarding_preference.has_value()) {
- if (subscribe->forwarding_preference.has_value()) {
- if (*forwarding_preference != *subscribe->forwarding_preference) {
- Error(MoqtError::kProtocolViolation,
- "Forwarding preference changes mid-track");
- return std::pair<FullTrackName, RemoteTrack::Visitor*>(
- {FullTrackName{}, nullptr});
- }
- } else {
- subscribe->forwarding_preference = *forwarding_preference;
- }
- } else {
- QUICHE_BUG(quic_subscribe_no_forwarding preference)
- << "Objects from a subscribe should know the forwarding preference";
- }
- return std::make_pair(subscribe->message.full_track_name,
- subscribe->visitor);
- }
- RemoteTrack& track = it->second;
- // Update the forwarding preference if it is present.
- if (forwarding_preference.has_value() &&
- !track.CheckForwardingPreference(*forwarding_preference)) {
- // Incorrect forwarding preference.
- Error(MoqtError::kProtocolViolation,
- "Forwarding preference changes mid-track");
- return std::pair<FullTrackName, RemoteTrack::Visitor*>(
- {FullTrackName{}, nullptr});
- }
- return std::make_pair(track.full_track_name(), track.visitor());
-}
-
bool MoqtSession::ValidateSubscribeId(uint64_t subscribe_id) {
if (peer_role_ == MoqtRole::kPublisher) {
QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE";
@@ -820,66 +831,75 @@
void MoqtSession::ControlStream::OnSubscribeOkMessage(
const MoqtSubscribeOk& message) {
- auto it = session_->active_subscribes_.find(message.subscribe_id);
- if (it == session_->active_subscribes_.end()) {
- session_->Error(MoqtError::kProtocolViolation,
- "Received SUBSCRIBE_OK for nonexistent subscribe");
+ RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id);
+ if (track == nullptr) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for "
+ << "subscribe_id = " << message.subscribe_id
+ << " but no track exists";
+ // Subscription state might have been destroyed for internal reasons.
return;
}
- MoqtSubscribe& subscribe = it->second.message;
+ if (track->is_fetch()) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Received SUBSCRIBE_OK for a FETCH");
+ return;
+ }
QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for "
<< "subscribe_id = " << message.subscribe_id << " "
- << subscribe.full_track_name;
- // Copy the Remote Track from session_->active_subscribes_ to
- // session_->remote_tracks_.
- RemoteTrack::Visitor* visitor = it->second.visitor;
- auto [track_iter, new_entry] = session_->remote_tracks_.try_emplace(
- subscribe.track_alias, subscribe.full_track_name, subscribe.track_alias,
- visitor);
- if (it->second.forwarding_preference.has_value()) {
- if (!track_iter->second.CheckForwardingPreference(
- *it->second.forwarding_preference)) {
- session_->Error(MoqtError::kProtocolViolation,
- "Forwarding preference different in early objects");
- return;
- }
+ << track->full_track_name();
+ SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track);
+ subscribe->OnObjectOrOk();
+ // TODO(martinduke): Handle expires field.
+ // TODO(martinduke): Resize the window based on largest_id.
+ if (subscribe->visitor() != nullptr) {
+ subscribe->visitor()->OnReply(track->full_track_name(), message.largest_id,
+ std::nullopt);
}
- // TODO: handle expires.
- if (visitor != nullptr) {
- visitor->OnReply(subscribe.full_track_name, std::nullopt);
- }
- session_->active_subscribes_.erase(it);
+ subscribe->OnObjectOrOk();
}
void MoqtSession::ControlStream::OnSubscribeErrorMessage(
const MoqtSubscribeError& message) {
- auto it = session_->active_subscribes_.find(message.subscribe_id);
- if (it == session_->active_subscribes_.end()) {
- session_->Error(MoqtError::kProtocolViolation,
- "Received SUBSCRIBE_ERROR for nonexistent subscribe");
+ RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id);
+ if (track == nullptr) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for "
+ << "subscribe_id = " << message.subscribe_id
+ << " but no track exists";
+ // Subscription state might have been destroyed for internal reasons.
return;
}
- if (it->second.received_object) {
+ if (track->is_fetch()) {
session_->Error(MoqtError::kProtocolViolation,
- "Received SUBSCRIBE_ERROR after object");
+ "Received SUBSCRIBE_ERROR for a FETCH");
return;
}
- MoqtSubscribe& subscribe = it->second.message;
+ if (!track->ErrorIsAllowed()) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Received SUBSCRIBE_ERROR after SUBSCRIBE_OK or objects");
+ return;
+ }
QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for "
<< "subscribe_id = " << message.subscribe_id << " ("
- << subscribe.full_track_name << ")"
+ << track->full_track_name() << ")"
<< ", error = " << static_cast<int>(message.error_code)
<< " (" << message.reason_phrase << ")";
- RemoteTrack::Visitor* visitor = it->second.visitor;
+ SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track);
+ // Delete secondary references to the track. Preserve the owner
+ // (subscribe_by_alias_) to get the original subscribe, if needed. Erasing the
+ // other references now prevents an error due to a duplicate subscription in
+ // Subscribe().
+ session_->upstream_by_id_.erase(subscribe->subscribe_id());
+ session_->upstream_by_name_.erase(subscribe->full_track_name());
if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) {
// Automatically resubscribe with new alias.
- session_->remote_track_aliases_[subscribe.full_track_name] =
- message.track_alias;
- session_->Subscribe(subscribe, visitor);
- } else if (visitor != nullptr) {
- visitor->OnReply(subscribe.full_track_name, message.reason_phrase);
+ MoqtSubscribe& subscribe_message = subscribe->GetSubscribe();
+ session_->Subscribe(subscribe_message, subscribe->visitor(),
+ message.track_alias);
+ } else if (subscribe->visitor() != nullptr) {
+ subscribe->visitor()->OnReply(subscribe->full_track_name(), std::nullopt,
+ message.reason_phrase);
}
- session_->active_subscribes_.erase(it);
+ session_->subscribe_by_alias_.erase(subscribe->track_alias());
}
void MoqtSession::ControlStream::OnUnsubscribeMessage(
@@ -1084,11 +1104,37 @@
payload = absl::string_view(partial_object_);
}
}
- auto [full_track_name, visitor] = session_->TrackPropertiesFromAlias(
- message, MoqtForwardingPreference::kSubgroup);
- if (visitor != nullptr) {
- visitor->OnObjectFragment(
- full_track_name,
+ QUICHE_BUG_IF(quic_bug_object_with_no_stream_type,
+ !parser_.stream_type().has_value())
+ << "Object delivered without a stream type";
+ // Get a pointer to the upstream state.
+ RemoteTrack* track = track_.GetIfAvaliable();
+ if (track == nullptr) {
+ track = (*parser_.stream_type() == MoqtDataStreamType::kStreamHeaderFetch)
+ // message.track_alias is actually a fetch ID for fetches.
+ ? session_->RemoteTrackById(message.track_alias)
+ : session_->RemoteTrackByAlias(message.track_alias);
+ if (track == nullptr) {
+ stream_->SendStopSending(kResetCodeSubscriptionGone);
+ // Received object for nonexistent track.
+ return;
+ }
+ track_ = track->weak_ptr();
+ }
+ if (!track->CheckDataStreamType(*parser_.stream_type())) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Received object for a track with a different stream type");
+ return;
+ }
+ if (!track->InWindow(FullSequence(message.group_id, message.object_id))) {
+ // This is not an error. It can be the result of a recent SUBSCRIBE_UPDATE.
+ return;
+ }
+ track->OnObjectOrOk();
+ SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track);
+ if (subscribe->visitor() != nullptr) {
+ subscribe->visitor()->OnObjectFragment(
+ track->full_track_name(),
FullSequence{message.group_id, message.subgroup_id.value_or(0),
message.object_id},
message.publisher_priority, message.object_status, payload,
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index c4ae794..628f188 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -29,6 +29,7 @@
#include "quiche/common/platform/api/quiche_export.h"
#include "quiche/common/quiche_buffer_allocator.h"
#include "quiche/common/quiche_callbacks.h"
+#include "quiche/common/quiche_weak_ptr.h"
#include "quiche/web_transport/web_transport.h"
namespace moqt {
@@ -117,24 +118,28 @@
// Subscribe from (start_group, start_object) to the end of the track.
bool SubscribeAbsolute(
const FullTrackName& name, uint64_t start_group, uint64_t start_object,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
// Subscribe from (start_group, start_object) to the end of end_group.
bool SubscribeAbsolute(
const FullTrackName& name, uint64_t start_group, uint64_t start_object,
- uint64_t end_group, RemoteTrack::Visitor* visitor,
+ uint64_t end_group, SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
// Subscribe from (start_group, start_object) to (end_group, end_object).
bool SubscribeAbsolute(
const FullTrackName& name, uint64_t start_group, uint64_t start_object,
- uint64_t end_group, uint64_t end_object, RemoteTrack::Visitor* visitor,
+ uint64_t end_group, uint64_t end_object,
+ SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
bool SubscribeCurrentObject(
- const FullTrackName& name, RemoteTrack::Visitor* visitor,
+ const FullTrackName& name, SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
bool SubscribeCurrentGroup(
- const FullTrackName& name, RemoteTrack::Visitor* visitor,
+ const FullTrackName& name, SubscribeRemoteTrack::Visitor* visitor,
MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
+ // Returns false if the subscription is not found. The session immediately
+ // destroys all subscription state.
+ void Unsubscribe(const FullTrackName& name);
webtransport::Session* session() { return session_; }
MoqtSessionCallbacks& callbacks() { return callbacks_; }
@@ -285,6 +290,9 @@
MoqtSession* session_;
webtransport::Stream* stream_;
+ // Once the subscribe ID is identified, set it here.
+ quiche::QuicheWeakPtr<RemoteTrack> track_;
+ // std::optional<uint64_t> subscribe_id_ = std::nullopt;
MoqtDataParser parser_;
std::string partial_object_;
};
@@ -494,8 +502,10 @@
// is present.
void SendControlMessage(quiche::QuicheBuffer message);
- // Returns false if the SUBSCRIBE isn't sent.
- bool Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor);
+ // Returns false if the SUBSCRIBE isn't sent. |provided_track_alias| has a
+ // value only if this call is due to a SUBSCRIBE_ERROR.
+ bool Subscribe(MoqtSubscribe& message, SubscribeRemoteTrack::Visitor* visitor,
+ std::optional<uint64_t> provided_track_alias = std::nullopt);
// Opens a new data stream, or queues it if the session is flow control
// blocked.
@@ -508,13 +518,9 @@
// Returns false if creation failed.
[[nodiscard]] bool OpenDataStream(std::shared_ptr<PublishedFetch> fetch);
- // Get FullTrackName and visitor for a subscribe_id and track_alias. Returns
- // an empty FullTrackName tuple and nullptr if not present. If the caller has
- // information about the track's forwarding preference, it can be passed via
- // |forwarding_preference| so that it can be stored in RemoteTrack.
- std::pair<FullTrackName, RemoteTrack::Visitor*> TrackPropertiesFromAlias(
- const MoqtObject& message,
- std::optional<MoqtForwardingPreference> forwarding_preference);
+ SubscribeRemoteTrack* RemoteTrackByAlias(uint64_t track_alias);
+ RemoteTrack* RemoteTrackById(uint64_t subscribe_id);
+ RemoteTrack* RemoteTrackByName(const FullTrackName& name);
// Checks that a subscribe ID from a SUBSCRIBE or FETCH is valid, and throws
// a session error if is not.
@@ -557,11 +563,22 @@
bool peer_supports_object_ack_ = false;
std::string error_;
+ // Upstream SUBSCRIBE state.
// All the tracks the session is subscribed to, indexed by track_alias.
- absl::flat_hash_map<uint64_t, RemoteTrack> remote_tracks_;
- // Look up aliases for remote tracks by name
- absl::flat_hash_map<FullTrackName, uint64_t> remote_track_aliases_;
+ absl::flat_hash_map<uint64_t, std::unique_ptr<SubscribeRemoteTrack>>
+ subscribe_by_alias_;
+ // Upstream SUBSCRIBEs indexed by subscribe_id.
+ // TODO(martinduke): Add fetches to this.
+ absl::flat_hash_map<uint64_t, RemoteTrack*> upstream_by_id_;
+ // The application only has track names, so this allows MoqtSession to
+ // quickly find what it's looking for. Also allows a quick check for duplicate
+ // subscriptions.
+ absl::flat_hash_map<FullTrackName, RemoteTrack*> upstream_by_name_;
uint64_t next_remote_track_alias_ = 0;
+ // The next subscribe ID that the local endpoint can send.
+ uint64_t next_subscribe_id_ = 0;
+ // The maximum subscribe ID that the local endpoint can send.
+ uint64_t peer_max_subscribe_id_ = 0;
// All open incoming subscriptions, indexed by track name, used to check for
// duplicates.
@@ -585,21 +602,6 @@
absl::flat_hash_map<uint64_t, std::shared_ptr<PublishedFetch>>
incoming_fetches_;
- // Indexed by subscribe_id.
- struct ActiveSubscribe {
- MoqtSubscribe message;
- RemoteTrack::Visitor* visitor;
- // The forwarding preference of the first received object, which all
- // subsequent objects must match.
- std::optional<MoqtForwardingPreference> forwarding_preference;
- // If true, an object has arrived for the subscription before SUBSCRIBE_OK
- // arrived.
- bool received_object = false;
- };
- // Outgoing SUBSCRIBEs that have not received SUBSCRIBE_OK or SUBSCRIBE_ERROR.
- absl::flat_hash_map<uint64_t, ActiveSubscribe> active_subscribes_;
- uint64_t next_subscribe_id_ = 0;
-
// Monitoring interfaces for expected incoming subscriptions.
absl::flat_hash_map<FullTrackName, MoqtPublishingMonitorInterface*>
monitoring_interfaces_for_published_tracks_;
@@ -613,12 +615,10 @@
// parameter, and other checks have changed/been disabled.
MoqtRole peer_role_ = MoqtRole::kPubSub;
- // The maximum subscribe ID that the local endpoint can send.
- uint64_t peer_max_subscribe_id_ = 0;
- // The maximum subscribe ID sent to the peer.
- uint64_t local_max_subscribe_id_ = 0;
// The minimum subscribe ID the peer can use that is monotonically increasing.
uint64_t next_incoming_subscribe_id_ = 0;
+ // The maximum subscribe ID sent to the peer.
+ uint64_t local_max_subscribe_id_ = 0;
// Must be last. Token used to make sure that the streams do not call into
// the session when the session has already been destroyed.
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 2e7aed0..3298cfc 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -24,7 +24,6 @@
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
-#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h"
#include "quiche/quic/moqt/test_tools/moqt_session_peer.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
@@ -50,6 +49,21 @@
constexpr webtransport::StreamId kIncomingUniStreamId = 15;
constexpr webtransport::StreamId kOutgoingUniStreamId = 14;
+MoqtSubscribe DefaultSubscribe() {
+ MoqtSubscribe subscribe = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*full_track_name=*/FullTrackName("foo", "bar"),
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start_group=*/0,
+ /*start_object=*/0,
+ /*end_group=*/std::nullopt,
+ /*end_object=*/std::nullopt,
+ };
+ return subscribe;
+}
+
static std::shared_ptr<MockTrackPublisher> SetupPublisher(
FullTrackName track_name, MoqtForwardingPreference forwarding_preference,
FullSequence largest_sequence) {
@@ -195,18 +209,7 @@
}
TEST_F(MoqtSessionTest, AddLocalTrack) {
- MoqtSubscribe request = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/FullTrackName({"foo", "bar"}),
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- /*parameters=*/MoqtSubscribeParameters(),
- };
+ MoqtSubscribe request = DefaultSubscribe();
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
@@ -298,26 +301,13 @@
.WillRepeatedly(Return(FullSequence(10, 20)));
publisher_.Add(track);
- // Peer subscribes to (0, 0)
- MoqtSubscribe request = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/FullTrackName({"foo", "bar"}),
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- /*parameters=*/MoqtSubscribeParameters(),
- };
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
EXPECT_CALL(
mock_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribeError), _));
- stream_input->OnSubscribeMessage(request);
+ stream_input->OnSubscribeMessage(DefaultSubscribe());
}
TEST_F(MoqtSessionTest, TwoSubscribesForTrack) {
@@ -439,19 +429,7 @@
}
TEST_F(MoqtSessionTest, SubscribeIdNotIncreasing) {
- // Peer subscribes to (0, 0)
- MoqtSubscribe request = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/FullTrackName({"foo", "bar"}),
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- /*parameters=*/MoqtSubscribeParameters(),
- };
+ MoqtSubscribe request = DefaultSubscribe();
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
@@ -475,7 +453,7 @@
TEST_F(MoqtSessionTest, TooManySubscribes) {
MoqtSessionPeer::set_next_subscribe_id(&session_,
kDefaultInitialMaxSubscribeId);
- MockRemoteTrackVisitor remote_track_visitor;
+ MockSubscribeRemoteTrackVisitor remote_track_visitor;
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
@@ -488,11 +466,26 @@
&remote_track_visitor));
}
+TEST_F(MoqtSessionTest, SubscribeDuplicateTrackName) {
+ MockSubscribeRemoteTrackVisitor remote_track_visitor;
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ EXPECT_CALL(mock_session_, GetStreamById(_))
+ .WillRepeatedly(Return(&mock_stream));
+ EXPECT_CALL(mock_stream,
+ Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _));
+ EXPECT_TRUE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"),
+ &remote_track_visitor));
+ EXPECT_FALSE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"),
+ &remote_track_visitor));
+}
+
TEST_F(MoqtSessionTest, SubscribeWithOk) {
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
- MockRemoteTrackVisitor remote_track_visitor;
+ MockSubscribeRemoteTrackVisitor remote_track_visitor;
EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
EXPECT_CALL(mock_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _));
@@ -503,8 +496,9 @@
/*subscribe_id=*/0,
/*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
};
- EXPECT_CALL(remote_track_visitor, OnReply(_, _))
+ EXPECT_CALL(remote_track_visitor, OnReply(_, _, _))
.WillOnce([&](const FullTrackName& ftn,
+ std::optional<FullSequence> /*largest_id*/,
std::optional<absl::string_view> error_message) {
EXPECT_EQ(ftn, FullTrackName("foo", "bar"));
EXPECT_FALSE(error_message.has_value());
@@ -515,7 +509,7 @@
TEST_F(MoqtSessionTest, MaxSubscribeIdChangesResponse) {
MoqtSessionPeer::set_next_subscribe_id(&session_,
kDefaultInitialMaxSubscribeId + 1);
- MockRemoteTrackVisitor remote_track_visitor;
+ MockSubscribeRemoteTrackVisitor remote_track_visitor;
EXPECT_FALSE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"),
&remote_track_visitor));
MoqtMaxSubscribeId max_subscribe_id = {
@@ -590,7 +584,7 @@
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
- MockRemoteTrackVisitor remote_track_visitor;
+ MockSubscribeRemoteTrackVisitor remote_track_visitor;
EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
EXPECT_CALL(mock_stream,
Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _));
@@ -603,8 +597,9 @@
/*reason_phrase=*/"deadbeef",
/*track_alias=*/2,
};
- EXPECT_CALL(remote_track_visitor, OnReply(_, _))
+ EXPECT_CALL(remote_track_visitor, OnReply(_, _, _))
.WillOnce([&](const FullTrackName& ftn,
+ std::optional<FullSequence> /*largest_id*/,
std::optional<absl::string_view> error_message) {
EXPECT_EQ(ftn, FullTrackName("foo", "bar"));
EXPECT_EQ(*error_message, "deadbeef");
@@ -612,6 +607,21 @@
stream_input->OnSubscribeErrorMessage(error);
}
+TEST_F(MoqtSessionTest, Unsubscribe) {
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ MockSubscribeRemoteTrackVisitor remote_track_visitor;
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(),
+ &remote_track_visitor);
+ EXPECT_CALL(mock_stream,
+ Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribe), _));
+ EXPECT_NE(MoqtSessionPeer::remote_track(&session_, 2), nullptr);
+ session_.Unsubscribe(FullTrackName("foo", "bar"));
+ // State is destroyed.
+ EXPECT_EQ(MoqtSessionPeer::remote_track(&session_, 2), nullptr);
+}
+
TEST_F(MoqtSessionTest, ReplyToAnnounce) {
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
@@ -630,10 +640,10 @@
}
TEST_F(MoqtSessionTest, IncomingObject) {
- MockRemoteTrackVisitor visitor_;
+ MockSubscribeRemoteTrackVisitor visitor_;
FullTrackName ftn("foo", "bar");
std::string payload = "deadbeef";
- MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -645,7 +655,8 @@
};
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(1);
EXPECT_CALL(mock_stream, GetStreamId())
@@ -654,10 +665,10 @@
}
TEST_F(MoqtSessionTest, IncomingPartialObject) {
- MockRemoteTrackVisitor visitor_;
+ MockSubscribeRemoteTrackVisitor visitor_;
FullTrackName ftn("foo", "bar");
std::string payload = "deadbeef";
- MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -669,7 +680,8 @@
};
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(1);
EXPECT_CALL(mock_stream, GetStreamId())
@@ -683,10 +695,10 @@
parameters.deliver_partial_objects = true;
MoqtSession session(&mock_session_, parameters,
session_callbacks_.AsSessionCallbacks());
- MockRemoteTrackVisitor visitor_;
+ MockSubscribeRemoteTrackVisitor visitor_;
FullTrackName ftn("foo", "bar");
std::string payload = "deadbeef";
- MoqtSessionPeer::CreateRemoteTrack(&session, ftn, &visitor_, 2);
+ MoqtSessionPeer::CreateRemoteTrack(&session, DefaultSubscribe(), &visitor_);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -698,7 +710,8 @@
};
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session, &mock_stream);
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(2);
EXPECT_CALL(mock_stream, GetStreamId())
@@ -708,21 +721,10 @@
}
TEST_F(MoqtSessionTest, ObjectBeforeSubscribeOk) {
- MockRemoteTrackVisitor visitor_;
+ MockSubscribeRemoteTrackVisitor visitor_;
FullTrackName ftn("foo", "bar");
std::string payload = "deadbeef";
- MoqtSubscribe subscribe = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/ftn,
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- };
- MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor_);
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -734,7 +736,8 @@
};
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _))
.WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence,
@@ -758,26 +761,15 @@
webtransport::test::MockStream mock_control_stream;
std::unique_ptr<MoqtControlParserVisitor> control_stream =
MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
- EXPECT_CALL(visitor_, OnReply(_, _)).Times(1);
+ EXPECT_CALL(visitor_, OnReply(_, _, _)).Times(1);
control_stream->OnSubscribeOkMessage(ok);
}
TEST_F(MoqtSessionTest, ObjectBeforeSubscribeError) {
- MockRemoteTrackVisitor visitor;
+ MockSubscribeRemoteTrackVisitor visitor;
FullTrackName ftn("foo", "bar");
std::string payload = "deadbeef";
- MoqtSubscribe subscribe = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/ftn,
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- };
- MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -789,7 +781,8 @@
};
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _))
.WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence,
@@ -813,123 +806,53 @@
webtransport::test::MockStream mock_control_stream;
std::unique_ptr<MoqtControlParserVisitor> control_stream =
MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
- EXPECT_CALL(mock_session_,
- CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
- "Received SUBSCRIBE_ERROR after object"))
+ EXPECT_CALL(
+ mock_session_,
+ CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+ "Received SUBSCRIBE_ERROR after SUBSCRIBE_OK or objects"))
.Times(1);
control_stream->OnSubscribeErrorMessage(subscribe_error);
}
-TEST_F(MoqtSessionTest, TwoEarlyObjectsDifferentForwarding) {
- MockRemoteTrackVisitor visitor;
- FullTrackName ftn("foo", "bar");
- std::string payload = "deadbeef";
- MoqtSubscribe subscribe = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/ftn,
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- };
- MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
- MoqtObject object = {
- /*track_alias=*/2,
- /*group_sequence=*/0,
- /*object_sequence=*/0,
- /*publisher_priority=*/0,
- /*object_status=*/MoqtObjectStatus::kNormal,
- /*subgroup_id=*/0,
- /*payload_length=*/8,
- };
- webtransport::test::MockStream mock_stream;
- std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
+TEST_F(MoqtSessionTest, SubscribeErrorWithTrackAlias) {
+ MockSubscribeRemoteTrackVisitor visitor;
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor);
- EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _))
- .WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence,
- MoqtPriority publisher_priority, MoqtObjectStatus status,
- absl::string_view payload, bool end_of_message) {
- EXPECT_EQ(full_track_name, ftn);
- EXPECT_EQ(sequence.group, object.group_id);
- EXPECT_EQ(sequence.object, object.object_id);
- });
- EXPECT_CALL(mock_stream, GetStreamId())
- .WillRepeatedly(Return(kIncomingUniStreamId));
- object_stream->OnObjectMessage(object, payload, true);
- char datagram[] = {0x01, 0x02, 0x00, 0x00, 0x00, 0x08, 0x64,
- 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66};
- EXPECT_CALL(mock_session_,
- CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
- "Forwarding preference changes mid-track"))
+ // SUBSCRIBE_ERROR arrives
+ MoqtSubscribeError subscribe_error = {
+ /*subscribe_id=*/1,
+ /*error_code=*/SubscribeErrorCode::kRetryTrackAlias,
+ /*reason_phrase=*/"foo",
+ /*track_alias =*/3,
+ };
+ webtransport::test::MockStream mock_control_stream;
+ std::unique_ptr<MoqtControlParserVisitor> control_stream =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
+ EXPECT_CALL(mock_control_stream,
+ Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _))
.Times(1);
- session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram)));
+ control_stream->OnSubscribeErrorMessage(subscribe_error);
}
-TEST_F(MoqtSessionTest, EarlyObjectForwardingDoesNotMatchTrack) {
- MockRemoteTrackVisitor visitor;
- FullTrackName ftn("foo", "bar");
- std::string payload = "deadbeef";
- MoqtSubscribe subscribe = {
+TEST_F(MoqtSessionTest, SubscribeErrorWithBadTrackAlias) {
+ MockSubscribeRemoteTrackVisitor visitor;
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor);
+
+ // SUBSCRIBE_ERROR arrives
+ MoqtSubscribeError subscribe_error = {
/*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/ftn,
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- };
- MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
- MoqtObject object = {
- /*track_alias=*/2,
- /*group_sequence=*/0,
- /*object_sequence=*/0,
- /*publisher_priority=*/0,
- /*object_status=*/MoqtObjectStatus::kNormal,
- /*subgroup_id=*/0,
- /*payload_length=*/8,
- };
- webtransport::test::MockStream mock_stream;
- std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
-
- EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _))
- .WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence,
- MoqtPriority publisher_priority, MoqtObjectStatus status,
- absl::string_view payload, bool end_of_message) {
- EXPECT_EQ(full_track_name, ftn);
- EXPECT_EQ(sequence.group, object.group_id);
- EXPECT_EQ(sequence.object, object.object_id);
- });
- EXPECT_CALL(mock_stream, GetStreamId())
- .WillRepeatedly(Return(kIncomingUniStreamId));
- object_stream->OnObjectMessage(object, payload, true);
-
- MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor, 2);
- // The track already exists, and has a different forwarding preference.
- MoqtSessionPeer::remote_track(&session_, 2)
- .CheckForwardingPreference(MoqtForwardingPreference::kDatagram);
-
- // SUBSCRIBE_OK arrives
- MoqtSubscribeOk ok = {
- /*subscribe_id=*/1,
- /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
- /*group_order=*/MoqtDeliveryOrder::kAscending,
- /*largest_id=*/std::nullopt,
+ /*error_code=*/SubscribeErrorCode::kRetryTrackAlias,
+ /*reason_phrase=*/"foo",
+ /*track_alias =*/2,
};
webtransport::test::MockStream mock_control_stream;
std::unique_ptr<MoqtControlParserVisitor> control_stream =
MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
EXPECT_CALL(mock_session_,
CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
- "Forwarding preference different in early objects"))
+ "Provided track alias already in use"))
.Times(1);
- control_stream->OnSubscribeOkMessage(ok);
+ control_stream->OnSubscribeErrorMessage(subscribe_error);
}
TEST_F(MoqtSessionTest, CreateOutgoingDataStreamAndSend) {
@@ -1428,10 +1351,10 @@
}
TEST_F(MoqtSessionTest, ReceiveDatagram) {
- MockRemoteTrackVisitor visitor_;
+ MockSubscribeRemoteTrackVisitor visitor_;
FullTrackName ftn("foo", "bar");
std::string payload = "deadbeef";
- MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -1452,11 +1375,10 @@
session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram)));
}
-TEST_F(MoqtSessionTest, ForwardingPreferenceMismatch) {
- MockRemoteTrackVisitor visitor_;
- FullTrackName ftn("foo", "bar");
+TEST_F(MoqtSessionTest, DataStreamTypeMismatch) {
+ MockSubscribeRemoteTrackVisitor visitor_;
std::string payload = "deadbeef";
- MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
+ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_);
MoqtObject object = {
/*track_alias=*/2,
/*group_sequence=*/0,
@@ -1468,7 +1390,8 @@
};
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtDataParserVisitor> object_stream =
- MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream);
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(1);
EXPECT_CALL(mock_stream, GetStreamId())
@@ -1478,11 +1401,46 @@
0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66};
EXPECT_CALL(mock_session_,
CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
- "Forwarding preference changes mid-track"))
+ "Received DATAGRAM for non-datagram track"))
.Times(1);
session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram)));
}
+TEST_F(MoqtSessionTest, StreamObjectOutOfWindow) {
+ MockSubscribeRemoteTrackVisitor visitor_;
+ std::string payload = "deadbeef";
+ MoqtSubscribe subscribe = DefaultSubscribe();
+ subscribe.start_group = 1;
+ MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe, &visitor_);
+ MoqtObject object = {
+ /*track_alias=*/2,
+ /*group_sequence=*/0,
+ /*object_sequence=*/0,
+ /*publisher_priority=*/0,
+ /*object_status=*/MoqtObjectStatus::kNormal,
+ /*subgroup_id=*/0,
+ /*payload_length=*/8,
+ };
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtDataParserVisitor> object_stream =
+ MoqtSessionPeer::CreateIncomingDataStream(
+ &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup);
+ EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(0);
+ object_stream->OnObjectMessage(object, payload, true);
+}
+
+TEST_F(MoqtSessionTest, DatagramOutOfWindow) {
+ MockSubscribeRemoteTrackVisitor visitor_;
+ std::string payload = "deadbeef";
+ MoqtSubscribe subscribe = DefaultSubscribe();
+ subscribe.start_group = 1;
+ MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe, &visitor_);
+ char datagram[] = {0x01, 0x02, 0x00, 0x00, 0x80, 0x08, 0x64,
+ 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66};
+ EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(0);
+ session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram)));
+}
+
TEST_F(MoqtSessionTest, AnnounceToPublisher) {
MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher);
testing::MockFunction<void(
@@ -1496,18 +1454,6 @@
TEST_F(MoqtSessionTest, SubscribeFromPublisher) {
MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher);
- MoqtSubscribe request = {
- /*subscribe_id=*/1,
- /*track_alias=*/2,
- /*full_track_name=*/FullTrackName({"foo", "bar"}),
- /*subscriber_priority=*/0x80,
- /*group_order=*/std::nullopt,
- /*start_group=*/0,
- /*start_object=*/0,
- /*end_group=*/std::nullopt,
- /*end_object=*/std::nullopt,
- /*parameters=*/MoqtSubscribeParameters(),
- };
webtransport::test::MockStream mock_stream;
std::unique_ptr<MoqtControlParserVisitor> stream_input =
MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
@@ -1517,7 +1463,7 @@
"Received SUBSCRIBE from publisher"))
.Times(1);
EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)).Times(1);
- stream_input->OnSubscribeMessage(request);
+ stream_input->OnSubscribeMessage(DefaultSubscribe());
}
TEST_F(MoqtSessionTest, AnnounceFromSubscriber) {
diff --git a/quiche/quic/moqt/moqt_track.cc b/quiche/quic/moqt/moqt_track.cc
index ee85664..07f44d7 100644
--- a/quiche/quic/moqt/moqt_track.cc
+++ b/quiche/quic/moqt/moqt_track.cc
@@ -4,18 +4,15 @@
#include "quiche/quic/moqt/moqt_track.h"
-#include <cstdint>
-
#include "quiche/quic/moqt/moqt_messages.h"
namespace moqt {
-bool RemoteTrack::CheckForwardingPreference(
- MoqtForwardingPreference preference) {
- if (forwarding_preference_.has_value()) {
- return forwarding_preference_.value() == preference;
+bool RemoteTrack::CheckDataStreamType(MoqtDataStreamType type) {
+ if (data_stream_type_.has_value()) {
+ return data_stream_type_.value() == type;
}
- forwarding_preference_ = preference;
+ data_stream_type_ = type;
return true;
}
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h
index 0da047a..e47464a 100644
--- a/quiche/quic/moqt/moqt_track.h
+++ b/quiche/quic/moqt/moqt_track.h
@@ -6,13 +6,16 @@
#define QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_
#include <cstdint>
+#include <memory>
#include <optional>
#include "absl/strings/string_view.h"
#include "quiche/quic/core/quic_time.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_priority.h"
+#include "quiche/quic/moqt/moqt_subscribe_windows.h"
#include "quiche/common/quiche_callbacks.h"
+#include "quiche/common/quiche_weak_ptr.h"
namespace moqt {
@@ -20,9 +23,65 @@
quiche::MultiUseCallback<void(uint64_t group_id, uint64_t object_id,
quic::QuicTimeDelta delta_from_deadline)>;
-// A track on the peer to which the session has subscribed.
+// State common to both SUBSCRIBE and FETCH upstream.
class RemoteTrack {
public:
+ RemoteTrack(const FullTrackName& full_track_name, uint64_t id,
+ SubscribeWindow window)
+ : full_track_name_(full_track_name),
+ subscribe_id_(id),
+ window_(window),
+ weak_ptr_factory_(this) {}
+ virtual ~RemoteTrack() = default;
+
+ FullTrackName full_track_name() const { return full_track_name_; }
+ // If FETCH_ERROR or SUBSCRIBE_ERROR arrives after OK or an object, it is a
+ // protocol violation.
+ virtual void OnObjectOrOk() { error_is_allowed_ = false; }
+ bool ErrorIsAllowed() const { return error_is_allowed_; }
+
+ // When called while processing the first object in the track, sets the
+ // data stream type to the value indicated by the incoming encoding.
+ // Otherwise, returns true if the incoming object does not violate the rule
+ // that the type is consistent.
+ bool CheckDataStreamType(MoqtDataStreamType type);
+
+ bool is_fetch() const {
+ return data_stream_type_.has_value() &&
+ *data_stream_type_ == MoqtDataStreamType::kStreamHeaderFetch;
+ }
+
+ uint64_t subscribe_id() const { return subscribe_id_; }
+
+ // Is the object one that was requested?
+ bool InWindow(FullSequence sequence) const {
+ return window_.InWindow(sequence);
+ }
+
+ void ChangeWindow(SubscribeWindow& window) { window_ = window; }
+
+ quiche::QuicheWeakPtr<RemoteTrack> weak_ptr() {
+ return weak_ptr_factory_.Create();
+ }
+
+ private:
+ const FullTrackName full_track_name_;
+ const uint64_t subscribe_id_;
+ SubscribeWindow window_;
+ std::optional<MoqtDataStreamType> data_stream_type_;
+ // If false, an object or OK message has been received, so any ERROR message
+ // is a protocol violation.
+ bool error_is_allowed_ = true;
+
+ // Must be last.
+ quiche::QuicheWeakPtrFactory<RemoteTrack> weak_ptr_factory_;
+};
+
+// A track on the peer to which the session has subscribed.
+class SubscribeRemoteTrack : public RemoteTrack {
+ public:
+ // TODO: Separate this out (as it's used by the application) and give it a
+ // name like MoqtTrackSubscriber,
class Visitor {
public:
virtual ~Visitor() = default;
@@ -31,6 +90,7 @@
// automatically retry.
virtual void OnReply(
const FullTrackName& full_track_name,
+ std::optional<FullSequence> largest_id,
std::optional<absl::string_view> error_reason_phrase) = 0;
// Called when the subscription process is far enough that it is possible to
// send OBJECT_ACK messages; provides a callback to do so. The callback is
@@ -43,35 +103,34 @@
absl::string_view object, bool end_of_message) = 0;
// TODO(martinduke): Add final sequence numbers
};
- RemoteTrack(const FullTrackName& full_track_name, uint64_t track_alias,
- Visitor* visitor)
- : full_track_name_(full_track_name),
- track_alias_(track_alias),
- visitor_(visitor) {}
+ SubscribeRemoteTrack(const MoqtSubscribe& subscribe, Visitor* visitor)
+ : RemoteTrack(subscribe.full_track_name, subscribe.subscribe_id,
+ SubscribeWindow(subscribe.start_group.value_or(0),
+ subscribe.start_object.value_or(0),
+ subscribe.end_group.value_or(UINT64_MAX),
+ subscribe.end_object.value_or(UINT64_MAX))),
+ track_alias_(subscribe.track_alias),
+ visitor_(visitor),
+ subscribe_(std::make_unique<MoqtSubscribe>(subscribe)) {}
- const FullTrackName& full_track_name() { return full_track_name_; }
-
+ void OnObjectOrOk() override {
+ subscribe_.reset(); // No SUBSCRIBE_ERROR, no need to store this anymore.
+ RemoteTrack::OnObjectOrOk();
+ }
uint64_t track_alias() const { return track_alias_; }
-
Visitor* visitor() { return visitor_; }
-
- // When called while processing the first object in the track, sets the
- // forwarding preference to the value indicated by the incoming encoding.
- // Otherwise, returns true if the incoming object does not violate the rule
- // that the preference is consistent.
- bool CheckForwardingPreference(MoqtForwardingPreference preference);
-
- std::optional<MoqtForwardingPreference> forwarding_preference() const {
- return forwarding_preference_;
+ MoqtSubscribe& GetSubscribe() {
+ return *subscribe_;
+ // This class will soon be destroyed, so there's no need to null the
+ // unique_ptr;
}
private:
- // TODO: There is no accounting for the number of outstanding subscribes,
- // because we can't match track names to individual subscribes.
- const FullTrackName full_track_name_;
const uint64_t track_alias_;
Visitor* visitor_;
- std::optional<MoqtForwardingPreference> forwarding_preference_;
+ // For convenience, store the subscribe message if it has to be re-sent with
+ // a new track alias.
+ std::unique_ptr<MoqtSubscribe> subscribe_;
};
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc
index cd62dd8..a5717ce 100644
--- a/quiche/quic/moqt/moqt_track_test.cc
+++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -4,6 +4,9 @@
#include "quiche/quic/moqt/moqt_track.h"
+#include <optional>
+
+#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
#include "quiche/quic/platform/api/quic_test.h"
@@ -11,27 +14,54 @@
namespace test {
-class RemoteTrackTest : public quic::test::QuicTest {
+class SubscribeRemoteTrackTest : public quic::test::QuicTest {
public:
- RemoteTrackTest()
- : track_(FullTrackName("foo", "bar"), /*track_alias=*/5, &visitor_) {}
- RemoteTrack track_;
- MockRemoteTrackVisitor visitor_;
+ SubscribeRemoteTrackTest() : track_(subscribe_, &visitor_) {}
+
+ MockSubscribeRemoteTrackVisitor visitor_;
+ MoqtSubscribe subscribe_ = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*full_track_name=*/FullTrackName("foo", "bar"),
+ /*subscriber_priority=*/128,
+ /*group_order=*/std::nullopt,
+ /*ranges=*/2,
+ 0,
+ std::nullopt,
+ std::nullopt,
+ MoqtSubscribeParameters(),
+ };
+ SubscribeRemoteTrack track_;
};
-TEST_F(RemoteTrackTest, Queries) {
+TEST_F(SubscribeRemoteTrackTest, Queries) {
EXPECT_EQ(track_.full_track_name(), FullTrackName("foo", "bar"));
- EXPECT_EQ(track_.track_alias(), 5);
+ EXPECT_EQ(track_.subscribe_id(), 1);
+ EXPECT_EQ(track_.track_alias(), 2);
EXPECT_EQ(track_.visitor(), &visitor_);
+ EXPECT_FALSE(track_.is_fetch());
}
-TEST_F(RemoteTrackTest, UpdateForwardingPreference) {
+TEST_F(SubscribeRemoteTrackTest, UpdateDataStreamType) {
EXPECT_TRUE(
- track_.CheckForwardingPreference(MoqtForwardingPreference::kSubgroup));
+ track_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderSubgroup));
EXPECT_TRUE(
- track_.CheckForwardingPreference(MoqtForwardingPreference::kSubgroup));
- EXPECT_FALSE(
- track_.CheckForwardingPreference(MoqtForwardingPreference::kDatagram));
+ track_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderSubgroup));
+ EXPECT_FALSE(track_.CheckDataStreamType(MoqtDataStreamType::kObjectDatagram));
+}
+
+TEST_F(SubscribeRemoteTrackTest, AllowError) {
+ EXPECT_TRUE(track_.ErrorIsAllowed());
+ EXPECT_EQ(track_.GetSubscribe().subscribe_id, subscribe_.subscribe_id);
+ track_.OnObjectOrOk();
+ EXPECT_FALSE(track_.ErrorIsAllowed());
+}
+
+TEST_F(SubscribeRemoteTrackTest, Windows) {
+ EXPECT_TRUE(track_.InWindow(FullSequence(2, 0)));
+ SubscribeWindow new_window(2, 1);
+ track_.ChangeWindow(new_window);
+ EXPECT_FALSE(track_.InWindow(FullSequence(2, 0)));
}
// TODO: Write test for GetStreamForSequence.
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h
index d41edd6..5426ef6 100644
--- a/quiche/quic/moqt/test_tools/moqt_session_peer.h
+++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -23,6 +23,13 @@
namespace moqt::test {
+class MoqtDataParserPeer {
+ public:
+ static void SetType(MoqtDataParser* parser, MoqtDataStreamType type) {
+ parser->type_ = type;
+ }
+};
+
class MoqtSessionPeer {
public:
static constexpr webtransport::StreamId kControlStreamId = 4;
@@ -43,9 +50,11 @@
}
static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream(
- MoqtSession* session, webtransport::Stream* stream) {
+ MoqtSession* session, webtransport::Stream* stream,
+ MoqtDataStreamType type) {
auto new_stream =
std::make_unique<MoqtSession::IncomingDataStream>(session, stream);
+ MoqtDataParserPeer::SetType(&new_stream->parser_, type);
return new_stream;
}
@@ -62,18 +71,15 @@
return static_cast<MoqtSession::ControlStream*>(visitor);
}
- static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
- RemoteTrack::Visitor* visitor,
- uint64_t track_alias) {
- session->remote_tracks_.try_emplace(track_alias, name, track_alias,
- visitor);
- session->remote_track_aliases_.try_emplace(name, track_alias);
- }
-
- static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
- MoqtSubscribe& subscribe,
- RemoteTrack::Visitor* visitor) {
- session->active_subscribes_[subscribe_id] = {subscribe, visitor};
+ static void CreateRemoteTrack(MoqtSession* session,
+ const MoqtSubscribe& subscribe,
+ SubscribeRemoteTrack::Visitor* visitor) {
+ auto track = std::make_unique<SubscribeRemoteTrack>(subscribe, visitor);
+ session->upstream_by_id_.try_emplace(subscribe.subscribe_id, track.get());
+ session->upstream_by_name_.try_emplace(subscribe.full_track_name,
+ track.get());
+ session->subscribe_by_alias_.try_emplace(subscribe.track_alias,
+ std::move(track));
}
static MoqtObjectListener* AddSubscription(
@@ -109,8 +115,9 @@
session->peer_role_ = role;
}
- static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
- return session->remote_tracks_.find(track_alias)->second;
+ static SubscribeRemoteTrack* remote_track(MoqtSession* session,
+ uint64_t track_alias) {
+ return session->RemoteTrackByAlias(track_alias);
}
static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
diff --git a/quiche/quic/moqt/tools/chat_client.cc b/quiche/quic/moqt/tools/chat_client.cc
index 28c93c1..e98cf75 100644
--- a/quiche/quic/moqt/tools/chat_client.cc
+++ b/quiche/quic/moqt/tools/chat_client.cc
@@ -109,6 +109,7 @@
void ChatClient::RemoteTrackVisitor::OnReply(
const FullTrackName& full_track_name,
+ std::optional<FullSequence> /*largest_id*/,
std::optional<absl::string_view> reason_phrase) {
client_->subscribes_to_make_--;
if (full_track_name == client_->chat_strings_->GetCatalogName()) {
@@ -200,7 +201,7 @@
}
void ChatClient::ProcessCatalog(absl::string_view object,
- RemoteTrack::Visitor* visitor,
+ SubscribeRemoteTrack::Visitor* visitor,
uint64_t group_sequence,
uint64_t object_sequence) {
std::string message(object);
@@ -245,7 +246,7 @@
continue;
}
if (!add) {
- // TODO: Unsubscribe from the user that's leaving
+ session_->Unsubscribe(chat_strings_->GetFullTrackNameFromUsername(user));
std::cout << user << "left the chat\n";
other_users_.erase(user);
continue;
diff --git a/quiche/quic/moqt/tools/chat_client.h b/quiche/quic/moqt/tools/chat_client.h
index 024b78d..e2c341f 100644
--- a/quiche/quic/moqt/tools/chat_client.h
+++ b/quiche/quic/moqt/tools/chat_client.h
@@ -88,11 +88,13 @@
quic::QuicEventLoop* event_loop() { return event_loop_; }
- class QUICHE_EXPORT RemoteTrackVisitor : public moqt::RemoteTrack::Visitor {
+ class QUICHE_EXPORT RemoteTrackVisitor
+ : public moqt::SubscribeRemoteTrack::Visitor {
public:
RemoteTrackVisitor(ChatClient* client) : client_(client) {}
void OnReply(const moqt::FullTrackName& full_track_name,
+ std::optional<FullSequence> largest_id,
std::optional<absl::string_view> reason_phrase) override;
void OnCanAckObjects(MoqtObjectAckFunction) override {}
@@ -127,7 +129,7 @@
// Objects from the same catalog group arrive on the same stream, and in
// object sequence order.
void ProcessCatalog(absl::string_view object,
- moqt::RemoteTrack::Visitor* visitor,
+ moqt::SubscribeRemoteTrack::Visitor* visitor,
uint64_t group_sequence, uint64_t object_sequence);
struct ChatUser {
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc
index c4f0327..3f31a47 100644
--- a/quiche/quic/moqt/tools/chat_server.cc
+++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -72,6 +72,7 @@
void ChatServer::RemoteTrackVisitor::OnReply(
const moqt::FullTrackName& full_track_name,
+ std::optional<FullSequence> /*largest_id*/,
std::optional<absl::string_view> reason_phrase) {
std::cout << "Subscription to user "
<< server_->strings().GetUsernameFromFullTrackName(full_track_name)
diff --git a/quiche/quic/moqt/tools/chat_server.h b/quiche/quic/moqt/tools/chat_server.h
index 1af6e5c..0c883ed 100644
--- a/quiche/quic/moqt/tools/chat_server.h
+++ b/quiche/quic/moqt/tools/chat_server.h
@@ -35,10 +35,11 @@
absl::string_view chat_id, absl::string_view output_file);
~ChatServer();
- class RemoteTrackVisitor : public RemoteTrack::Visitor {
+ class RemoteTrackVisitor : public SubscribeRemoteTrack::Visitor {
public:
explicit RemoteTrackVisitor(ChatServer* server);
void OnReply(const moqt::FullTrackName& full_track_name,
+ std::optional<FullSequence> largest_id,
std::optional<absl::string_view> reason_phrase) override;
void OnCanAckObjects(MoqtObjectAckFunction) override {}
void OnObjectFragment(
diff --git a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
index f31fff4..4b7aa2d 100644
--- a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
+++ b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
@@ -160,13 +160,14 @@
}
private:
- class NamespaceHandler : public RemoteTrack::Visitor {
+ class NamespaceHandler : public SubscribeRemoteTrack::Visitor {
public:
explicit NamespaceHandler(absl::string_view directory)
: directory_(directory) {}
void OnReply(
const FullTrackName& full_track_name,
+ std::optional<FullSequence> /*largest_id*/,
std::optional<absl::string_view> error_reason_phrase) override {
if (error_reason_phrase.has_value()) {
QUICHE_LOG(ERROR) << "Failed to subscribe to the peer track "
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h
index 972fc0c..7031696 100644
--- a/quiche/quic/moqt/tools/moqt_mock_visitor.h
+++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -77,10 +77,11 @@
FullTrackName track_name_;
};
-class MockRemoteTrackVisitor : public RemoteTrack::Visitor {
+class MockSubscribeRemoteTrackVisitor : public SubscribeRemoteTrack::Visitor {
public:
MOCK_METHOD(void, OnReply,
(const FullTrackName& full_track_name,
+ std::optional<FullSequence> largest_id,
std::optional<absl::string_view> error_reason_phrase),
(override));
MOCK_METHOD(void, OnCanAckObjects, (MoqtObjectAckFunction ack_function),
diff --git a/quiche/quic/moqt/tools/moqt_simulator_bin.cc b/quiche/quic/moqt/tools/moqt_simulator_bin.cc
index 39cf2c8..b2a30e6 100644
--- a/quiche/quic/moqt/tools/moqt_simulator_bin.cc
+++ b/quiche/quic/moqt/tools/moqt_simulator_bin.cc
@@ -186,12 +186,13 @@
std::vector<QuicBandwidth> bitrate_history_;
};
-class ObjectReceiver : public RemoteTrack::Visitor {
+class ObjectReceiver : public SubscribeRemoteTrack::Visitor {
public:
explicit ObjectReceiver(const QuicClock* clock, QuicTimeDelta deadline)
: clock_(clock), deadline_(deadline) {}
void OnReply(const FullTrackName& full_track_name,
+ std::optional<FullSequence> /*largest_id*/,
std::optional<absl::string_view> error_reason_phrase) override {
QUICHE_CHECK(full_track_name == TrackName());
QUICHE_CHECK(!error_reason_phrase.has_value()) << *error_reason_phrase;