blob: d9c2e1bca438e46ae20d6ec00272fc586e290300 [file] [log] [blame]
// Copyright 2025 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "quiche/quic/moqt/test_tools/mock_moqt_session.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/moqt/moqt_failed_fetch.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_subscribe_windows.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/common/platform/api/quiche_logging.h"
#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/web_transport/web_transport.h"
namespace moqt::test {
namespace {
using ::testing::_;
}
// Object listener that forwards all of the objects to the
// SubcribeRemoteTrack::Visitor provided.
class MockMoqtSession::LoopbackObjectListener : public MoqtObjectListener {
public:
LoopbackObjectListener(FullTrackName name,
SubscribeRemoteTrack::Visitor* visitor,
std::shared_ptr<MoqtTrackPublisher> publisher,
SubscribeWindow window)
: name_(name),
visitor_(visitor),
publisher_(std::move(publisher)),
window_(std::move(window)) {
publisher_->AddObjectListener(this);
}
~LoopbackObjectListener() { publisher_->RemoveObjectListener(this); }
LoopbackObjectListener(const LoopbackObjectListener&) = delete;
LoopbackObjectListener(LoopbackObjectListener&&) = delete;
LoopbackObjectListener& operator=(const LoopbackObjectListener&) = delete;
LoopbackObjectListener& operator=(LoopbackObjectListener&&) = delete;
void OnSubscribeAccepted() override {
visitor_->OnReply(name_,
HasObjects()
? std::make_optional(publisher_->GetLargestLocation())
: std::nullopt,
std::nullopt);
}
void OnSubscribeRejected(MoqtSubscribeErrorReason reason,
std::optional<uint64_t> track_alias) {
visitor_->OnReply(name_, std::nullopt, reason.reason_phrase);
}
void OnNewObjectAvailable(Location sequence) {
std::optional<PublishedObject> object =
publisher_->GetCachedObject(sequence);
if (!object.has_value()) {
QUICHE_LOG(FATAL)
<< "GetCachedObject() returned nullopt for a sequence passed into "
"OnNewObjectAvailable()";
return;
}
if (!window_.InWindow(object->sequence)) {
return;
}
visitor_->OnObjectFragment(name_, sequence, object->publisher_priority,
object->status, object->payload.AsStringView(),
/*end_of_message=*/true);
}
void OnNewFinAvailable(Location sequence) override {}
void OnSubgroupAbandoned(Location sequence,
webtransport::StreamErrorCode error_code) override {}
void OnGroupAbandoned(uint64_t group_id) override {}
void OnTrackPublisherGone() override { visitor_->OnSubscribeDone(name_); }
private:
bool HasObjects() {
absl::StatusOr<MoqtTrackStatusCode> status = publisher_->GetTrackStatus();
if (!status.ok()) {
return false;
}
return *status == MoqtTrackStatusCode::kInProgress ||
*status == MoqtTrackStatusCode::kFinished;
}
FullTrackName name_;
SubscribeRemoteTrack::Visitor* visitor_;
std::shared_ptr<MoqtTrackPublisher> publisher_;
SubscribeWindow window_;
};
bool MockMoqtSession::Subscribe(const FullTrackName& name,
SubscribeRemoteTrack::Visitor* visitor,
SubscribeWindow window) {
auto track_publisher = publisher_->GetTrack(name);
if (!track_publisher.ok()) {
visitor->OnReply(name, std::nullopt, track_publisher.status().ToString());
return false;
}
auto [it, inserted] = receiving_subscriptions_.insert(
{name,
std::make_unique<LoopbackObjectListener>(
name, visitor, *std::move(track_publisher), std::move(window))});
return inserted;
}
MockMoqtSession::MockMoqtSession(MoqtPublisher* publisher)
: publisher_(publisher) {
ON_CALL(*this, Error)
.WillByDefault([](MoqtError code, absl::string_view error) {
ADD_FAILURE() << "Unhandled MoQT fatal error, with code "
<< static_cast<int>(code) << " and message: " << error;
});
if (publisher_ != nullptr) {
ON_CALL(*this, SubscribeCurrentObject)
.WillByDefault([this](const FullTrackName& name,
SubscribeRemoteTrack::Visitor* visitor,
VersionSpecificParameters) {
return Subscribe(name, visitor, SubscribeWindow());
});
ON_CALL(*this, SubscribeAbsolute(_, _, _, _, _))
.WillByDefault([this](const FullTrackName& name, uint64_t start_group,
uint64_t start_object,
SubscribeRemoteTrack::Visitor* visitor,
VersionSpecificParameters) {
return Subscribe(
name, visitor,
SubscribeWindow(Location(start_group, start_object)));
});
ON_CALL(*this, SubscribeAbsolute(_, _, _, _, _, _))
.WillByDefault([this](const FullTrackName& name, uint64_t start_group,
uint64_t start_object, uint64_t end_group,
SubscribeRemoteTrack::Visitor* visitor,
VersionSpecificParameters) {
return Subscribe(
name, visitor,
SubscribeWindow(Location(start_group, start_object), end_group));
});
ON_CALL(*this, Unsubscribe)
.WillByDefault([this](const FullTrackName& name) {
receiving_subscriptions_.erase(name);
});
ON_CALL(*this, Fetch)
.WillByDefault(
[this](const FullTrackName& name, FetchResponseCallback callback,
Location start, uint64_t end_group,
std::optional<uint64_t> end_object, MoqtPriority priority,
std::optional<MoqtDeliveryOrder> delivery_order,
VersionSpecificParameters parameters) {
auto track_publisher = publisher_->GetTrack(name);
if (!track_publisher.ok()) {
std::move(callback)(std::make_unique<MoqtFailedFetch>(
track_publisher.status()));
return true;
}
std::move(callback)(track_publisher->get()->Fetch(
start, end_group, end_object,
delivery_order.value_or(MoqtDeliveryOrder::kAscending)));
return true;
});
ON_CALL(*this, JoiningFetch(_, _, _, _))
.WillByDefault([this](const FullTrackName& name,
SubscribeRemoteTrack::Visitor* visitor,
uint64_t num_previous_groups,
VersionSpecificParameters parameters) {
return JoiningFetch(
name, visitor,
[name, visitor](std::unique_ptr<MoqtFetchTask> fetch) {
PublishedObject object;
while (fetch->GetNextObject(object) ==
MoqtFetchTask::kSuccess) {
visitor->OnObjectFragment(
name, object.sequence, object.publisher_priority,
object.status, object.payload.AsStringView(), true);
}
},
num_previous_groups, 0x80, std::nullopt, parameters);
});
ON_CALL(*this, JoiningFetch(_, _, _, _, _, _, _))
.WillByDefault([this](const FullTrackName& name,
SubscribeRemoteTrack::Visitor* visitor,
FetchResponseCallback callback,
uint64_t num_previous_groups,
MoqtPriority priority,
std::optional<MoqtDeliveryOrder> delivery_order,
VersionSpecificParameters parameters) {
SubscribeCurrentObject(name, visitor, parameters);
auto track_publisher = publisher_->GetTrack(name);
if (!track_publisher.ok()) {
std::move(callback)(
std::make_unique<MoqtFailedFetch>(track_publisher.status()));
return true;
}
if (track_publisher->get()->GetTrackStatus().value_or(
MoqtTrackStatusCode::kStatusNotAvailable) ==
MoqtTrackStatusCode::kNotYetBegun) {
return Fetch(name, std::move(callback), Location(0, 0), 0, 0,
priority, delivery_order, std::move(parameters));
}
Location largest = track_publisher->get()->GetLargestLocation();
uint64_t start_group = largest.group >= num_previous_groups
? largest.group - num_previous_groups + 1
: 0;
return Fetch(name, std::move(callback), Location(start_group, 0),
largest.group, largest.object, priority, delivery_order,
std::move(parameters));
});
}
}
MockMoqtSession::~MockMoqtSession() = default;
} // namespace moqt::test