blob: 7ac24126eeb77c45d769487440196cb17adb83e9 [file] [log] [blame]
// Copyright 2023 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_
#define QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <variant>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/core/quic_time.h"
#include "quiche/quic/moqt/moqt_fetch_task.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_object.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_session.h"
#include "quiche/quic/moqt/moqt_session_callbacks.h"
#include "quiche/quic/moqt/moqt_session_interface.h"
#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/common/quiche_mem_slice.h"
#include "quiche/web_transport/web_transport.h"
namespace moqt::test {
struct MockSessionCallbacks {
testing::MockFunction<void()> session_established_callback;
testing::MockFunction<void(absl::string_view)> goaway_received_callback;
testing::MockFunction<void(absl::string_view)> session_terminated_callback;
testing::MockFunction<void()> session_deleted_callback;
testing::MockFunction<void(const TrackNamespace&,
std::optional<VersionSpecificParameters>,
MoqtResponseCallback)>
incoming_publish_namespace_callback;
testing::MockFunction<void(const TrackNamespace&,
std::optional<VersionSpecificParameters>,
MoqtResponseCallback)>
incoming_subscribe_namespace_callback;
MockSessionCallbacks() {
ON_CALL(incoming_publish_namespace_callback, Call)
.WillByDefault(DefaultIncomingPublishNamespaceCallback);
ON_CALL(incoming_subscribe_namespace_callback, Call)
.WillByDefault(DefaultIncomingSubscribeNamespaceCallback);
}
MoqtSessionCallbacks AsSessionCallbacks() {
return MoqtSessionCallbacks{
session_established_callback.AsStdFunction(),
goaway_received_callback.AsStdFunction(),
session_terminated_callback.AsStdFunction(),
session_deleted_callback.AsStdFunction(),
incoming_publish_namespace_callback.AsStdFunction(),
incoming_subscribe_namespace_callback.AsStdFunction()};
}
};
class MockTrackPublisher : public MoqtTrackPublisher {
public:
explicit MockTrackPublisher(FullTrackName name)
: track_name_(std::move(name)) {
ON_CALL(*this, delivery_order())
.WillByDefault(testing::Return(MoqtDeliveryOrder::kAscending));
}
const FullTrackName& GetTrackName() const override { return track_name_; }
MOCK_METHOD(std::optional<PublishedObject>, GetCachedObject,
(uint64_t, uint64_t, uint64_t), (const, override));
MOCK_METHOD(void, AddObjectListener, (MoqtObjectListener * listener),
(override));
MOCK_METHOD(void, RemoveObjectListener, (MoqtObjectListener * listener),
(override));
MOCK_METHOD(std::optional<Location>, largest_location, (), (const, override));
MOCK_METHOD(std::optional<MoqtForwardingPreference>, forwarding_preference,
(), (const, override));
MOCK_METHOD(std::optional<MoqtDeliveryOrder>, delivery_order, (),
(const, override));
MOCK_METHOD(std::optional<quic::QuicTimeDelta>, expiration, (),
(const, override));
MOCK_METHOD(std::unique_ptr<MoqtFetchTask>, StandaloneFetch,
(Location, Location, std::optional<MoqtDeliveryOrder>),
(override));
MOCK_METHOD(std::unique_ptr<MoqtFetchTask>, RelativeFetch,
(uint64_t, std::optional<MoqtDeliveryOrder>), (override));
MOCK_METHOD(std::unique_ptr<MoqtFetchTask>, AbsoluteFetch,
(uint64_t, std::optional<MoqtDeliveryOrder>), (override));
private:
FullTrackName track_name_;
};
// A very simple MoqtTrackPublisher that allows tests to add arbitrary objects.
class TestTrackPublisher : public MoqtTrackPublisher {
public:
explicit TestTrackPublisher(FullTrackName name)
: track_name_(std::move(name)) {}
const FullTrackName& GetTrackName() const override { return track_name_; }
std::optional<PublishedObject> GetCachedObject(
uint64_t group, uint64_t subgroup, uint64_t object) const override {
Location location(group, object);
auto it = objects_.find(location);
if (it == objects_.end()) {
return std::nullopt;
}
return CachedObjectToPublishedObject(it->second);
}
void AddObjectListener(MoqtObjectListener* listener) override {
listeners_.insert(listener);
listener->OnSubscribeAccepted();
}
void RemoveObjectListener(MoqtObjectListener* listener) override {
listeners_.erase(listener);
}
std::optional<Location> largest_location() const override {
return largest_location_;
}
std::optional<MoqtForwardingPreference> forwarding_preference()
const override {
return MoqtForwardingPreference::kSubgroup;
}
std::optional<MoqtDeliveryOrder> delivery_order() const override {
return MoqtDeliveryOrder::kAscending;
}
std::optional<quic::QuicTimeDelta> expiration() const override {
return quic::QuicTimeDelta::Infinite();
}
// TODO(martinduke): Support Fetch
std::unique_ptr<MoqtFetchTask> StandaloneFetch(
Location start, Location end,
std::optional<MoqtDeliveryOrder> delivery_order) override {
return std::make_unique<MoqtFailedFetch>(
absl::UnimplementedError("Fetch not implemented"));
}
std::unique_ptr<MoqtFetchTask> RelativeFetch(
uint64_t offset,
std::optional<MoqtDeliveryOrder> delivery_order) override {
return std::make_unique<MoqtFailedFetch>(
absl::UnimplementedError("Fetch not implemented"));
}
std::unique_ptr<MoqtFetchTask> AbsoluteFetch(
uint64_t offset,
std::optional<MoqtDeliveryOrder> delivery_order) override {
return std::make_unique<MoqtFailedFetch>(
absl::UnimplementedError("Fetch not implemented"));
}
void AddObject(Location location, uint64_t subgroup,
absl::string_view payload, bool fin) {
CachedObject object;
object.metadata.location = location;
object.metadata.subgroup = subgroup;
object.metadata.extensions = "";
object.metadata.status = MoqtObjectStatus::kNormal;
object.metadata.publisher_priority = 128;
object.payload = std::make_shared<quiche::QuicheMemSlice>(
quiche::QuicheMemSlice::Copy(payload));
object.fin_after_this = fin;
objects_[location] = std::move(object);
if (!largest_location_.has_value() || *largest_location_ < location) {
largest_location_ = location;
}
for (MoqtObjectListener* listener : listeners_) {
listener->OnNewObjectAvailable(location, subgroup, 128,
MoqtForwardingPreference::kSubgroup);
}
}
void RemoveAllSubscriptions() {
while (!listeners_.empty()) {
(*listeners_.begin())->OnTrackPublisherGone();
}
}
private:
FullTrackName track_name_;
absl::flat_hash_set<MoqtObjectListener*> listeners_;
absl::flat_hash_map<Location, CachedObject> objects_;
std::optional<Location> largest_location_;
};
// TODO(martinduke): Rename to MockSubscribeVisitor.
class MockSubscribeRemoteTrackVisitor : public SubscribeVisitor {
public:
MOCK_METHOD(void, OnReply,
(const FullTrackName& full_track_name,
(std::variant<SubscribeOkData, MoqtRequestError> response)),
(override));
MOCK_METHOD(void, OnCanAckObjects, (MoqtObjectAckFunction ack_function),
(override));
MOCK_METHOD(void, OnObjectFragment,
(const FullTrackName& full_track_name,
const PublishedObjectMetadata& metadata,
absl::string_view object, bool end_of_message),
(override));
MOCK_METHOD(void, OnPublishDone, (FullTrackName full_track_name), (override));
MOCK_METHOD(void, OnMalformedTrack, (const FullTrackName& full_track_name),
(override));
MOCK_METHOD(void, OnStreamFin,
(const FullTrackName& full_track_name, DataStreamIndex stream),
(override));
MOCK_METHOD(void, OnStreamReset,
(const FullTrackName& full_track_name, DataStreamIndex stream),
(override));
};
class MockPublishingMonitorInterface : public MoqtPublishingMonitorInterface {
public:
MOCK_METHOD(void, OnObjectAckSupportKnown,
(std::optional<quic::QuicTimeDelta> time_window), (override));
MOCK_METHOD(void, OnNewObjectEnqueued, (Location location), (override));
MOCK_METHOD(void, OnObjectAckReceived,
(Location location, quic::QuicTimeDelta delta_from_deadline),
(override));
};
class MockFetchTask : public MoqtFetchTask {
public:
MockFetchTask() {}; // No synchronous callbacks.
MockFetchTask(std::optional<MoqtFetchOk> fetch_ok,
std::optional<MoqtFetchError> fetch_error,
bool synchronous_object_available)
: synchronous_fetch_ok_(fetch_ok),
synchronous_fetch_error_(fetch_error),
synchronous_object_available_(synchronous_object_available) {
QUICHE_DCHECK(!synchronous_fetch_ok_.has_value() ||
!synchronous_fetch_error_.has_value());
}
MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject,
(PublishedObject & output), (override));
MOCK_METHOD(absl::Status, GetStatus, (), (override));
void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override {
objects_available_callback_ = std::move(callback);
if (synchronous_object_available_) {
// The first call is installed by the session to trigger stream creation.
// An object might not exist yet.
objects_available_callback_();
}
// The second call is a result of the stream replacing the callback, which
// means there is an object available.
synchronous_object_available_ = true;
}
void SetFetchResponseCallback(FetchResponseCallback callback) override {
if (synchronous_fetch_ok_.has_value()) {
std::move(callback)(*synchronous_fetch_ok_);
return;
}
if (synchronous_fetch_error_.has_value()) {
std::move(callback)(*synchronous_fetch_error_);
return;
}
fetch_response_callback_ = std::move(callback);
}
void CallObjectsAvailableCallback() { objects_available_callback_(); };
void CallFetchResponseCallback(
std::variant<MoqtFetchOk, MoqtFetchError> response) {
std::move(fetch_response_callback_)(response);
}
private:
FetchResponseCallback fetch_response_callback_;
ObjectsAvailableCallback objects_available_callback_;
std::optional<MoqtFetchOk> synchronous_fetch_ok_;
std::optional<MoqtFetchError> synchronous_fetch_error_;
bool synchronous_object_available_ = false;
};
class MockMoqtObjectListener : public MoqtObjectListener {
public:
MOCK_METHOD(void, OnSubscribeAccepted, (), (override));
MOCK_METHOD(void, OnSubscribeRejected, (MoqtRequestError), (override));
MOCK_METHOD(void, OnNewObjectAvailable,
(Location, uint64_t, MoqtPriority, MoqtForwardingPreference),
(override));
MOCK_METHOD(void, OnNewFinAvailable, (Location, uint64_t), (override));
MOCK_METHOD(void, OnSubgroupAbandoned,
(uint64_t, uint64_t, webtransport::StreamErrorCode), (override));
MOCK_METHOD(void, OnGroupAbandoned, (uint64_t), (override));
MOCK_METHOD(void, OnTrackPublisherGone, (), (override));
};
} // namespace moqt::test
#endif // QUICHE_QUIC_MOQT_TOOLS_MOQT_MOCK_VISITOR_H_