Publisher-side MoQT FETCH handling.
Also:
a. Enforce that subscribe_id is monotonically increasing.
b. Use the cached object publisher_priority instead of the track-level pub priority.
PiperOrigin-RevId: 698869614
diff --git a/quiche/quic/moqt/moqt_failed_fetch.h b/quiche/quic/moqt/moqt_failed_fetch.h
index a1e93bf..40e8d67 100644
--- a/quiche/quic/moqt/moqt_failed_fetch.h
+++ b/quiche/quic/moqt/moqt_failed_fetch.h
@@ -8,6 +8,7 @@
#include <utility>
#include "absl/status/status.h"
+#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_publisher.h"
namespace moqt {
@@ -23,6 +24,7 @@
absl::Status GetStatus() override { return status_; }
void SetObjectAvailableCallback(
ObjectsAvailableCallback /*callback*/) override {}
+ FullSequence GetLargestId() const override { return FullSequence(); }
private:
absl::Status status_;
diff --git a/quiche/quic/moqt/moqt_outgoing_queue.h b/quiche/quic/moqt/moqt_outgoing_queue.h
index 2bc46ee..83d7773 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue.h
+++ b/quiche/quic/moqt/moqt_outgoing_queue.h
@@ -102,6 +102,7 @@
GetNextObjectResult GetNextObject(PublishedObject&) override;
absl::Status GetStatus() override { return status_; }
+ FullSequence GetLargestId() const override { return objects_.back(); }
void SetObjectAvailableCallback(
ObjectsAvailableCallback /*callback*/) override {
diff --git a/quiche/quic/moqt/moqt_publisher.h b/quiche/quic/moqt/moqt_publisher.h
index a0897f4..010db7f 100644
--- a/quiche/quic/moqt/moqt_publisher.h
+++ b/quiche/quic/moqt/moqt_publisher.h
@@ -84,8 +84,9 @@
// Returns the error if fetch has completely failed, and OK otherwise.
virtual absl::Status GetStatus() = 0;
- // TODO: expose the largest sequence and the end of track bit returned in
- // the FETCH_OK.
+ // Returns the highest sequence number that will be delivered by the fetch.
+ // It is the minimum of the end of the fetch range and the live edge.
+ virtual FullSequence GetLargestId() const = 0;
};
// MoqtTrackPublisher is an application-side API for an MoQT publisher
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 0f8fad7..ec7a867 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -339,6 +339,46 @@
return Subscribe(message, visitor);
}
+void MoqtSession::PublishedFetch::FetchStreamVisitor::OnCanWrite() {
+ std::shared_ptr<PublishedFetch> fetch = fetch_.lock();
+ if (fetch == nullptr) {
+ return;
+ }
+ PublishedObject object;
+ while (stream_->CanWrite()) {
+ MoqtFetchTask::GetNextObjectResult result =
+ fetch->fetch_task()->GetNextObject(object);
+ switch (result) {
+ case MoqtFetchTask::GetNextObjectResult::kSuccess:
+ // Skip ObjectDoesNotExist in FETCH.
+ if (object.status == MoqtObjectStatus::kObjectDoesNotExist) {
+ continue;
+ }
+ if (fetch->session_->WriteObjectToStream(
+ stream_, fetch->fetch_id_, object,
+ MoqtDataStreamType::kStreamHeaderFetch, !stream_header_written_,
+ /*fin=*/false)) {
+ stream_header_written_ = true;
+ }
+ break;
+ case MoqtFetchTask::GetNextObjectResult::kPending:
+ return;
+ case MoqtFetchTask::GetNextObjectResult::kEof:
+ // TODO(martinduke): Either prefetch the next object, or alter the API
+ // so that we're not sending FIN in a separate frame.
+ if (!quiche::SendFinOnStream(*stream_).ok()) {
+ QUIC_DVLOG(1) << "Sending FIN onStream " << stream_->GetStreamId()
+ << " failed";
+ }
+ return;
+ case MoqtFetchTask::GetNextObjectResult::kError:
+ stream_->ResetWithUserCode(static_cast<webtransport::StreamErrorCode>(
+ fetch->fetch_task()->GetStatus().code()));
+ return;
+ }
+ }
+}
+
bool MoqtSession::SubscribeIsDone(uint64_t subscribe_id, SubscribeDoneCode code,
absl::string_view reason_phrase) {
auto it = published_subscriptions_.find(subscribe_id);
@@ -446,13 +486,36 @@
return new_stream;
}
+bool MoqtSession::OpenDataStream(std::shared_ptr<PublishedFetch> fetch) {
+ webtransport::Stream* new_stream =
+ session_->OpenOutgoingUnidirectionalStream();
+ if (new_stream == nullptr) {
+ QUICHE_BUG(MoqtSession_OpenDataStream_blocked)
+ << "OpenDataStream called when creation of new streams is blocked.";
+ return false;
+ }
+ fetch->SetStreamId(new_stream->GetStreamId());
+ new_stream->SetVisitor(
+ std::make_unique<PublishedFetch::FetchStreamVisitor>(fetch, new_stream));
+ if (new_stream->CanWrite()) {
+ new_stream->visitor()->OnCanWrite();
+ }
+ return true;
+}
+
void MoqtSession::OnCanCreateNewOutgoingUnidirectionalStream() {
while (!subscribes_with_queued_outgoing_data_streams_.empty() &&
session_->CanOpenNextOutgoingUnidirectionalStream()) {
auto next = subscribes_with_queued_outgoing_data_streams_.rbegin();
auto subscription = published_subscriptions_.find(next->subscription_id);
if (subscription == published_subscriptions_.end()) {
- // Subscription no longer exists; delete the entry.
+ auto fetch = incoming_fetches_.find(next->subscription_id);
+ // Create the stream if the fetch still exists.
+ if (fetch != incoming_fetches_.end() && !OpenDataStream(fetch->second)) {
+ return; // A QUIC_BUG has fired because this shouldn't happen.
+ }
+ // FETCH needs only one stream, and can be deleted from the queue. Or,
+ // there is no subscribe and no fetch; the entry in the queue is invalid.
subscribes_with_queued_outgoing_data_streams_.erase((++next).base());
continue;
}
@@ -535,6 +598,28 @@
return std::make_pair(track.full_track_name(), track.visitor());
}
+bool MoqtSession::ValidateSubscribeId(uint64_t subscribe_id) {
+ if (peer_role_ == MoqtRole::kPublisher) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE";
+ Error(MoqtError::kProtocolViolation, "Received SUBSCRIBE from publisher");
+ return false;
+ }
+ if (subscribe_id > local_max_subscribe_id_) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Received SUBSCRIBE with too large ID";
+ Error(MoqtError::kTooManySubscribes,
+ "Received SUBSCRIBE with too large ID");
+ return false;
+ }
+ if (subscribe_id < next_incoming_subscribe_id_) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Subscribe ID not monotonically increasing";
+ Error(MoqtError::kProtocolViolation,
+ "Subscribe ID not monotonically increasing");
+ return false;
+ }
+ next_incoming_subscribe_id_ = subscribe_id + 1;
+ return true;
+}
+
template <class Parser>
static void ForwardStreamDataToParser(webtransport::Stream& stream,
Parser& parser) {
@@ -647,18 +732,19 @@
session_->framer_.SerializeSubscribeError(subscribe_error));
}
+void MoqtSession::ControlStream::SendFetchError(
+ uint64_t subscribe_id, SubscribeErrorCode error_code,
+ absl::string_view reason_phrase) {
+ MoqtFetchError fetch_error;
+ fetch_error.subscribe_id = subscribe_id;
+ fetch_error.error_code = error_code;
+ fetch_error.reason_phrase = reason_phrase;
+ SendOrBufferMessage(session_->framer_.SerializeFetchError(fetch_error));
+}
+
void MoqtSession::ControlStream::OnSubscribeMessage(
const MoqtSubscribe& message) {
- if (session_->peer_role_ == MoqtRole::kPublisher) {
- QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE";
- session_->Error(MoqtError::kProtocolViolation,
- "Received SUBSCRIBE from publisher");
- return;
- }
- if (message.subscribe_id > session_->local_max_subscribe_id_) {
- QUIC_DLOG(INFO) << ENDPOINT << "Received SUBSCRIBE with too large ID";
- session_->Error(MoqtError::kTooManySubscribes,
- "Received SUBSCRIBE with too large ID");
+ if (!session_->ValidateSubscribeId(message.subscribe_id)) {
return;
}
QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for "
@@ -883,6 +969,66 @@
session_->peer_max_subscribe_id_ = message.max_subscribe_id;
}
+void MoqtSession::ControlStream::OnFetchMessage(const MoqtFetch& message) {
+ if (!session_->ValidateSubscribeId(message.subscribe_id)) {
+ return;
+ }
+ QUIC_DLOG(INFO) << ENDPOINT << "Received a FETCH for "
+ << message.full_track_name;
+
+ const FullTrackName& track_name = message.full_track_name;
+ absl::StatusOr<std::shared_ptr<MoqtTrackPublisher>> track_publisher =
+ session_->publisher_->GetTrack(track_name);
+ if (!track_publisher.ok()) {
+ QUIC_DLOG(INFO) << ENDPOINT << "FETCH for " << track_name
+ << " rejected by the application: "
+ << track_publisher.status();
+ SendFetchError(message.subscribe_id, SubscribeErrorCode::kTrackDoesNotExist,
+ track_publisher.status().message());
+ return;
+ }
+ std::unique_ptr<MoqtFetchTask> fetch =
+ (*track_publisher)
+ ->Fetch(message.start_object, message.end_group, message.end_object,
+ message.group_order.value_or(
+ (*track_publisher)->GetDeliveryOrder()));
+ if (!fetch->GetStatus().ok()) {
+ QUIC_DLOG(INFO) << ENDPOINT << "FETCH for " << track_name
+ << " could not initialize the task";
+ SendFetchError(message.subscribe_id, SubscribeErrorCode::kInvalidRange,
+ fetch->GetStatus().message());
+ return;
+ }
+ auto published_fetch = std::make_unique<PublishedFetch>(
+ message.subscribe_id, session_, std::move(fetch));
+ auto result = session_->incoming_fetches_.emplace(message.subscribe_id,
+ std::move(published_fetch));
+ if (!result.second) { // Emplace failed.
+ QUIC_DLOG(INFO) << ENDPOINT << "FETCH for " << track_name
+ << " could not be added to the session";
+ SendFetchError(message.subscribe_id, SubscribeErrorCode::kInternalError,
+ "Could not initialize FETCH state");
+ return;
+ }
+ MoqtFetchOk fetch_ok;
+ fetch_ok.subscribe_id = message.subscribe_id;
+ fetch_ok.group_order =
+ message.group_order.value_or((*track_publisher)->GetDeliveryOrder());
+ fetch_ok.largest_id = result.first->second->fetch_task()->GetLargestId();
+ SendOrBufferMessage(session_->framer_.SerializeFetchOk(fetch_ok));
+ if (!session_->session()->CanOpenNextOutgoingUnidirectionalStream() ||
+ !session_->OpenDataStream(result.first->second)) {
+ // Put the FETCH in the queue for a new stream.
+ session_->UpdateQueuedSendOrder(
+ message.subscribe_id, std::nullopt,
+ SendOrderForStream(message.subscriber_priority,
+ (*track_publisher)->GetPublisherPriority(),
+ /*group_id=*/0,
+ message.group_order.value_or(
+ (*track_publisher)->GetDeliveryOrder())));
+ }
+}
+
void MoqtSession::ControlStream::OnParsingError(MoqtError error_code,
absl::string_view reason) {
session_->Error(error_code, absl::StrCat("Parse error: ", reason));
@@ -1290,7 +1436,38 @@
<< "Writing FIN failed despite CanWrite() being true.";
return;
}
- SendNextObject(subscription, *std::move(object));
+ QUICHE_DCHECK(next_object_ <= object->sequence);
+ MoqtTrackPublisher& publisher = subscription.publisher();
+ QUICHE_DCHECK(DoesTrackStatusImplyHavingData(*publisher.GetTrackStatus()));
+ MoqtForwardingPreference forwarding_preference =
+ publisher.GetForwardingPreference();
+ UpdateSendOrder(subscription);
+ switch (forwarding_preference) {
+ case MoqtForwardingPreference::kTrack:
+ if (object->status == MoqtObjectStatus::kEndOfGroup ||
+ object->status == MoqtObjectStatus::kGroupDoesNotExist) {
+ ++next_object_.group;
+ next_object_.object = 0;
+ } else {
+ next_object_.object = object->sequence.object + 1;
+ }
+ break;
+
+ case MoqtForwardingPreference::kSubgroup:
+ next_object_.object = object->sequence.object + 1;
+ break;
+
+ case MoqtForwardingPreference::kDatagram:
+ QUICHE_NOTREACHED();
+ break;
+ }
+ if (session_->WriteObjectToStream(
+ stream_, subscription.track_alias(), *object,
+ GetMessageTypeForForwardingPreference(forwarding_preference),
+ !stream_header_written_, object->fin_after_this)) {
+ stream_header_written_ = true;
+ subscription.OnObjectSent(object->sequence);
+ }
}
}
@@ -1305,78 +1482,41 @@
<< "Writing pure FIN failed.";
}
-void MoqtSession::OutgoingDataStream::SendNextObject(
- PublishedSubscription& subscription, PublishedObject object) {
- QUICHE_DCHECK(next_object_ <= object.sequence);
- QUICHE_DCHECK(stream_->CanWrite());
-
- MoqtTrackPublisher& publisher = subscription.publisher();
- QUICHE_DCHECK(DoesTrackStatusImplyHavingData(*publisher.GetTrackStatus()));
- MoqtForwardingPreference forwarding_preference =
- publisher.GetForwardingPreference();
-
- UpdateSendOrder(subscription);
-
+bool MoqtSession::WriteObjectToStream(webtransport::Stream* stream, uint64_t id,
+ const PublishedObject& object,
+ MoqtDataStreamType type,
+ bool is_first_on_stream, bool fin) {
+ QUICHE_DCHECK(stream->CanWrite());
MoqtObject header;
- header.track_alias = subscription.track_alias();
+ header.track_alias = id;
header.group_id = object.sequence.group;
+ header.subgroup_id = object.sequence.subgroup;
header.object_id = object.sequence.object;
- header.publisher_priority = publisher.GetPublisherPriority();
+ header.publisher_priority = object.publisher_priority;
header.object_status = object.status;
- header.forwarding_preference = forwarding_preference;
- // TODO(martinduke): send values other than 0.
- header.subgroup_id =
- (forwarding_preference == MoqtForwardingPreference::kSubgroup)
- ? 0
- : std::optional<uint64_t>();
header.payload_length = object.payload.length();
quiche::QuicheBuffer serialized_header =
- session_->framer_.SerializeObjectHeader(
- header, GetMessageTypeForForwardingPreference(forwarding_preference),
- !stream_header_written_);
- switch (forwarding_preference) {
- case MoqtForwardingPreference::kTrack:
- if (object.status == MoqtObjectStatus::kEndOfGroup ||
- object.status == MoqtObjectStatus::kGroupDoesNotExist) {
- ++next_object_.group;
- next_object_.object = 0;
- } else {
- next_object_.object = header.object_id + 1;
- }
- break;
-
- case MoqtForwardingPreference::kSubgroup:
- next_object_.object = header.object_id + 1;
- break;
-
- case MoqtForwardingPreference::kDatagram:
- QUICHE_NOTREACHED();
- break;
- }
-
+ framer_.SerializeObjectHeader(header, type, is_first_on_stream);
// TODO(vasilvv): add a version of WebTransport write API that accepts
// memslices so that we can avoid a copy here.
std::array<absl::string_view, 2> write_vector = {
serialized_header.AsStringView(), object.payload.AsStringView()};
quiche::StreamWriteOptions options;
- options.set_send_fin(object.fin_after_this);
- absl::Status write_status = stream_->Writev(write_vector, options);
+ options.set_send_fin(fin);
+ absl::Status write_status = stream->Writev(write_vector, options);
if (!write_status.ok()) {
- QUICHE_BUG(MoqtSession_SendNextObject_write_failed)
+ QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed)
<< "Writing into MoQT stream failed despite CanWrite() being true "
"before; status: "
<< write_status;
- session_->Error(MoqtError::kInternalError, "Data stream write error");
- return;
+ Error(MoqtError::kInternalError, "Data stream write error");
+ return false;
}
- QUIC_DVLOG(1) << "Stream " << stream_->GetStreamId() << " successfully wrote "
- << object.sequence << ", fin = " << object.fin_after_this
- << ", next: " << next_object_;
-
- stream_header_written_ = true;
- subscription.OnObjectSent(object.sequence);
+ QUIC_DVLOG(1) << "Stream " << stream->GetStreamId() << " successfully wrote "
+ << object.sequence << ", fin = " << fin;
+ return true;
}
void MoqtSession::PublishedSubscription::SendDatagram(FullSequence sequence) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 04972b1..ffae8d7 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -172,6 +172,8 @@
private:
friend class test::MoqtSessionPeer;
+ struct Empty {};
+
class QUICHE_EXPORT ControlStream : public webtransport::StreamVisitor,
public MoqtControlParserVisitor {
public:
@@ -213,7 +215,7 @@
void OnUnsubscribeAnnouncesMessage(
const MoqtUnsubscribeAnnounces& message) override {}
void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override;
- void OnFetchMessage(const MoqtFetch& message) override {}
+ void OnFetchMessage(const MoqtFetch& message) override;
void OnFetchCancelMessage(const MoqtFetchCancel& message) override {}
void OnFetchOkMessage(const MoqtFetchOk& message) override {}
void OnFetchErrorMessage(const MoqtFetchError& message) override {}
@@ -244,6 +246,8 @@
SubscribeErrorCode error_code,
absl::string_view reason_phrase,
uint64_t track_alias);
+ void SendFetchError(uint64_t subscribe_id, SubscribeErrorCode error_code,
+ absl::string_view reason_phrase);
MoqtSession* session_;
webtransport::Stream* stream_;
@@ -418,11 +422,6 @@
// the stream and returns nullptr.
PublishedSubscription* GetSubscriptionIfValid();
- // Actually sends an object on the stream; the object MUST have object ID
- // of at least `next_object_`.
- void SendNextObject(PublishedSubscription& subscription,
- PublishedObject object);
-
MoqtSession* session_;
webtransport::Stream* stream_;
uint64_t subscription_id_;
@@ -436,6 +435,54 @@
std::weak_ptr<void> session_liveness_;
};
+ class QUICHE_EXPORT PublishedFetch {
+ public:
+ PublishedFetch(uint64_t fetch_id, MoqtSession* session,
+ std::unique_ptr<MoqtFetchTask> fetch)
+ : session_(session), fetch_(std::move(fetch)), fetch_id_(fetch_id) {}
+
+ class FetchStreamVisitor : public webtransport::StreamVisitor {
+ public:
+ FetchStreamVisitor(std::shared_ptr<PublishedFetch> fetch,
+ webtransport::Stream* stream)
+ : fetch_(fetch), stream_(stream) {
+ fetch->fetch_task()->SetObjectAvailableCallback(
+ [this]() { this->OnCanWrite(); });
+ }
+ ~FetchStreamVisitor() {
+ std::shared_ptr<PublishedFetch> fetch = fetch_.lock();
+ if (fetch != nullptr) {
+ fetch->session()->incoming_fetches_.erase(fetch->fetch_id_);
+ }
+ }
+ // webtransport::StreamVisitor implementation.
+ void OnCanRead() override {} // Write-only stream.
+ void OnCanWrite() override;
+ void OnResetStreamReceived(webtransport::StreamErrorCode error) override {
+ } // Write-only stream
+ void OnStopSendingReceived(webtransport::StreamErrorCode error) override {
+ }
+ void OnWriteSideInDataRecvdState() override {}
+
+ private:
+ std::weak_ptr<PublishedFetch> fetch_;
+ webtransport::Stream* stream_;
+ bool stream_header_written_ = false;
+ };
+
+ MoqtFetchTask* fetch_task() { return fetch_.get(); }
+ MoqtSession* session() { return session_; }
+ uint64_t fetch_id() const { return fetch_id_; }
+ void SetStreamId(webtransport::StreamId id) { stream_id_ = id; }
+
+ private:
+ MoqtSession* session_;
+ std::unique_ptr<MoqtFetchTask> fetch_;
+ uint64_t fetch_id_;
+ // Store the stream ID in case a FETCH_CANCEL requires a reset.
+ std::optional<webtransport::StreamId> stream_id_;
+ };
+
// Private members of MoqtSession.
// Returns true if SUBSCRIBE_DONE was sent.
@@ -458,12 +505,26 @@
// blocked.
webtransport::Stream* OpenDataStream(PublishedSubscription& subscription,
FullSequence first_object);
+ // Returns false if creation failed.
+ [[nodiscard]] bool OpenDataStream(std::shared_ptr<PublishedFetch> fetch);
// Get FullTrackName and visitor for a subscribe_id and track_alias. Returns
// an empty FullTrackName tuple and nullptr if not present.
std::pair<FullTrackName, RemoteTrack::Visitor*> TrackPropertiesFromAlias(
const MoqtObject& message);
+ // Checks that a subscribe ID from a SUBSCRIBE or FETCH is valid, and throws
+ // a session error if is not.
+ bool ValidateSubscribeId(uint64_t subscribe_id);
+
+ // Actually sends an object on |stream| with track alias or fetch ID |id|
+ // and metadata in |object|. Not for use with datagrams. Returns |true| if
+ // the write was successful.
+ bool WriteObjectToStream(webtransport::Stream* stream, uint64_t id,
+ const PublishedObject& object,
+ MoqtDataStreamType type, bool is_first_on_stream,
+ bool fin);
+
// Sends an OBJECT_ACK message for a specific subscribe ID.
void SendObjectAck(uint64_t subscribe_id, uint64_t group_id,
uint64_t object_id,
@@ -515,6 +576,12 @@
absl::flat_hash_set<uint64_t> used_track_aliases_;
uint64_t next_local_track_alias_ = 0;
+ // Incoming FETCHes, indexed by fetch ID. There will be other pointers to
+ // PublishedFetch, so storing a shared_ptr in the map provides pointer
+ // stability for the value.
+ absl::flat_hash_map<uint64_t, std::shared_ptr<PublishedFetch>>
+ incoming_fetches_;
+
// Indexed by subscribe_id.
struct ActiveSubscribe {
MoqtSubscribe message;
@@ -547,10 +614,12 @@
uint64_t peer_max_subscribe_id_ = 0;
// The maximum subscribe ID sent to the peer.
uint64_t local_max_subscribe_id_ = 0;
+ // The minimum subscribe ID the peer can use that is monotonically increasing.
+ uint64_t next_incoming_subscribe_id_ = 0;
// Must be last. Token used to make sure that the streams do not call into
// the session when the session has already been destroyed.
- struct Empty {};
+
std::shared_ptr<Empty> liveness_token_;
};
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index bc31674..5e40775 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -42,6 +42,7 @@
using ::quic::test::MemSliceFromString;
using ::testing::_;
using ::testing::AnyNumber;
+using ::testing::Invoke;
using ::testing::Return;
using ::testing::StrictMock;
@@ -75,6 +76,20 @@
} // namespace
+class MockFetchTask : public MoqtFetchTask {
+ public:
+ MOCK_METHOD(MoqtFetchTask::GetNextObjectResult, GetNextObject,
+ (PublishedObject & output), (override));
+ MOCK_METHOD(absl::Status, GetStatus, (), (override));
+ MOCK_METHOD(FullSequence, GetLargestId, (), (const, override));
+
+ void SetObjectAvailableCallback(ObjectsAvailableCallback callback) override {
+ callback_ = std::move(callback);
+ }
+
+ ObjectsAvailableCallback callback_;
+};
+
class MoqtSessionPeer {
public:
static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream(
@@ -169,6 +184,35 @@
static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) {
session->peer_max_subscribe_id_ = id;
}
+
+ static MockFetchTask* AddFetch(MoqtSession* session, uint64_t fetch_id) {
+ auto fetch_task = std::make_unique<MockFetchTask>();
+ MockFetchTask* return_ptr = fetch_task.get();
+ auto published_fetch = std::make_unique<MoqtSession::PublishedFetch>(
+ fetch_id, session, std::move(fetch_task));
+ session->incoming_fetches_.emplace(fetch_id, std::move(published_fetch));
+ // Add the fetch to the pending stream queue.
+ session->UpdateQueuedSendOrder(fetch_id, std::nullopt, 0);
+ return return_ptr;
+ }
+
+ static MoqtSession::PublishedFetch* GetFetch(MoqtSession* session,
+ uint64_t fetch_id) {
+ auto it = session->incoming_fetches_.find(fetch_id);
+ if (it == session->incoming_fetches_.end()) {
+ return nullptr;
+ }
+ return it->second.get();
+ }
+
+ static void ValidateSubscribeId(MoqtSession* session, uint64_t id) {
+ session->ValidateSubscribeId(id);
+ }
+
+ static FullSequence LargestSentForSubscription(MoqtSession* session,
+ uint64_t subscribe_id) {
+ return *session->published_subscriptions_[subscribe_id]->largest_sent();
+ }
};
class MoqtSessionTest : public quic::test::QuicTest {
@@ -356,6 +400,7 @@
EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
return absl::OkStatus();
});
+ request.subscribe_id = 2;
stream_input->OnSubscribeMessage(request);
EXPECT_TRUE(correct_message);
}
@@ -626,6 +671,47 @@
stream_input->OnSubscribeMessage(request);
}
+TEST_F(MoqtSessionTest, SubscribeIdNotIncreasing) {
+ // Peer subscribes to (0, 0)
+ MoqtSubscribe request = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*full_track_name=*/FullTrackName({"foo", "bar"}),
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start_group=*/0,
+ /*start_object=*/0,
+ /*end_group=*/std::nullopt,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ // Request for track returns SUBSCRIBE_ERROR.
+ bool correct_message = false;
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]),
+ MoqtMessageType::kSubscribeError);
+ return absl::OkStatus();
+ });
+ stream_input->OnSubscribeMessage(request);
+ EXPECT_TRUE(correct_message);
+
+ // Second request is a protocol violation.
+ request.subscribe_id = 0;
+ request.track_alias = 3;
+ request.full_track_name = FullTrackName({"dead", "beef"});
+ EXPECT_CALL(mock_session_,
+ CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+ "Subscribe ID not monotonically increasing"))
+ .Times(1);
+ stream_input->OnSubscribeMessage(request);
+}
+
TEST_F(MoqtSessionTest, TooManySubscribes) {
MoqtSessionPeer::set_next_subscribe_id(&session_,
kDefaultInitialMaxSubscribeId);
@@ -1148,7 +1234,7 @@
control_stream->OnSubscribeOkMessage(ok);
}
-TEST_F(MoqtSessionTest, CreateIncomingDataStreamAndSend) {
+TEST_F(MoqtSessionTest, CreateOutgoingDataStreamAndSend) {
FullTrackName ftn("foo", "bar");
auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup,
FullSequence(4, 2));
@@ -1177,7 +1263,7 @@
// Verify first six message fields are sent correctly
bool correct_message = false;
- const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x7f};
EXPECT_CALL(mock_stream, Writev(_, _))
.WillOnce([&](absl::Span<const absl::string_view> data,
const quiche::StreamWriteOptions& options) {
@@ -1195,6 +1281,8 @@
subscription->OnNewObjectAvailable(FullSequence(5, 0));
EXPECT_TRUE(correct_message);
EXPECT_FALSE(fin);
+ EXPECT_EQ(MoqtSessionPeer::LargestSentForSubscription(&session_, 0),
+ FullSequence(5, 0));
}
TEST_F(MoqtSessionTest, FinDataStreamFromCache) {
@@ -1224,9 +1312,9 @@
EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
.WillRepeatedly(Return(&mock_stream));
- // Verify first six message fields are sent correctly
+ // Verify first five message fields are sent correctly
bool correct_message = false;
- const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x7f};
EXPECT_CALL(mock_stream, Writev(_, _))
.WillOnce([&](absl::Span<const absl::string_view> data,
const quiche::StreamWriteOptions& options) {
@@ -1275,7 +1363,7 @@
// Verify first six message fields are sent correctly
bool correct_message = false;
- const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x7f};
EXPECT_CALL(mock_stream, Writev(_, _))
.WillOnce([&](absl::Span<const absl::string_view> data,
const quiche::StreamWriteOptions& options) {
@@ -1327,7 +1415,7 @@
// Verify first six message fields are sent correctly
bool correct_message = false;
- const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x7f};
EXPECT_CALL(mock_stream, Writev(_, _))
.WillOnce([&](absl::Span<const absl::string_view> data,
const quiche::StreamWriteOptions& options) {
@@ -1994,6 +2082,316 @@
session_.OnCanCreateNewOutgoingUnidirectionalStream();
}
+TEST_F(MoqtSessionTest, FetchReturnsOk) {
+ webtransport::test::MockStream control_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
+ FullTrackName ftn("foo", "bar");
+ MoqtFetch request = {
+ /*subscribe_id=*/0,
+ /*full_track_name=*/ftn,
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start=*/FullSequence(0, 0),
+ /*end_group=*/1,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ bool correct_message = false;
+ auto track = std::make_shared<MockTrackPublisher>(ftn);
+ publisher_.Add(track);
+
+ auto fetch_task_ptr = std::make_unique<MockFetchTask>();
+ MockFetchTask* fetch_task = fetch_task_ptr.get();
+ EXPECT_CALL(*track, Fetch(_, _, _, _))
+ .WillOnce(Return(std::move(fetch_task_ptr)));
+ // Compose and send the FETCH_OK.
+ EXPECT_CALL(*track, GetDeliveryOrder())
+ .WillRepeatedly(Return(MoqtDeliveryOrder::kAscending));
+ EXPECT_CALL(*fetch_task, GetLargestId()).WillOnce(Return(FullSequence(0, 0)));
+ EXPECT_CALL(control_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kFetchOk);
+ return absl::OkStatus();
+ });
+ // Stream can't open yet.
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream)
+ .WillOnce(Return(false));
+ stream_input->OnFetchMessage(request);
+ EXPECT_TRUE(correct_message);
+}
+
+TEST_F(MoqtSessionTest, FetchReturnsOkImmediateOpen) {
+ webtransport::test::MockStream control_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
+ FullTrackName ftn("foo", "bar");
+ MoqtFetch request = {
+ /*subscribe_id=*/0,
+ /*full_track_name=*/ftn,
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start=*/FullSequence(0, 0),
+ /*end_group=*/1,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ bool correct_message = false;
+ auto track = std::make_shared<MockTrackPublisher>(ftn);
+ publisher_.Add(track);
+
+ auto fetch_task_ptr = std::make_unique<MockFetchTask>();
+ MockFetchTask* fetch_task = fetch_task_ptr.get();
+ EXPECT_CALL(*track, Fetch(_, _, _, _))
+ .WillOnce(Return(std::move(fetch_task_ptr)));
+ // Compose and send the FETCH_OK.
+ EXPECT_CALL(*track, GetDeliveryOrder())
+ .WillRepeatedly(Return(MoqtDeliveryOrder::kAscending));
+ EXPECT_CALL(*fetch_task, GetLargestId()).WillOnce(Return(FullSequence(0, 0)));
+ EXPECT_CALL(control_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kFetchOk);
+ return absl::OkStatus();
+ });
+ // Open stream immediately.
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream)
+ .WillOnce(Return(true));
+ webtransport::test::MockStream data_stream;
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&data_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(data_stream, SetVisitor(_))
+ .WillOnce(
+ Invoke([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ }));
+ EXPECT_CALL(data_stream, CanWrite()).WillRepeatedly(Return(true));
+ EXPECT_CALL(data_stream, visitor()).WillOnce(Invoke([&]() {
+ return stream_visitor.get();
+ }));
+ EXPECT_CALL(*fetch_task, GetNextObject(_))
+ .WillOnce(Return(MoqtFetchTask::GetNextObjectResult::kPending));
+ stream_input->OnFetchMessage(request);
+ EXPECT_TRUE(correct_message);
+
+ // Signal the stream that pending object is now available.
+ correct_message = false;
+ EXPECT_CALL(data_stream, CanWrite()).WillRepeatedly(Return(true));
+ EXPECT_CALL(*fetch_task, GetNextObject(_))
+ .WillOnce(Invoke([](PublishedObject& output) {
+ output.sequence = FullSequence(0, 0, 0);
+ output.status = MoqtObjectStatus::kNormal;
+ output.publisher_priority = 128;
+ output.payload = MemSliceFromString("foo");
+ output.fin_after_this = true;
+ return MoqtFetchTask::GetNextObjectResult::kSuccess;
+ }))
+ .WillOnce(Invoke([](PublishedObject& /*output*/) {
+ return MoqtFetchTask::GetNextObjectResult::kPending;
+ }));
+ EXPECT_CALL(data_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ quic::QuicDataReader reader(data[0]);
+ uint64_t type;
+ EXPECT_TRUE(reader.ReadVarInt62(&type));
+ EXPECT_EQ(type, static_cast<uint64_t>(
+ MoqtDataStreamType::kStreamHeaderFetch));
+ return absl::OkStatus();
+ });
+ fetch_task->callback_();
+ EXPECT_TRUE(correct_message);
+}
+
+TEST_F(MoqtSessionTest, InvalidFetch) {
+ // Update the state so that it expects ID > 0 next time.
+ MoqtSessionPeer::ValidateSubscribeId(&session_, 0);
+ webtransport::test::MockStream control_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
+ FullTrackName ftn("foo", "bar");
+ MoqtFetch request = {
+ /*subscribe_id=*/0, // Subscribe ID is too low.
+ /*full_track_name=*/ftn,
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start=*/FullSequence(0, 0),
+ /*end_group=*/1,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ EXPECT_CALL(mock_session_,
+ CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+ "Subscribe ID not monotonically increasing"))
+ .Times(1);
+ stream_input->OnFetchMessage(request);
+}
+
+TEST_F(MoqtSessionTest, FetchFails) {
+ webtransport::test::MockStream control_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &control_stream);
+ FullTrackName ftn("foo", "bar");
+ MoqtFetch request = {
+ /*subscribe_id=*/0,
+ /*full_track_name=*/ftn,
+ /*subscriber_priority=*/0x80,
+ /*group_order=*/std::nullopt,
+ /*start=*/FullSequence(0, 0),
+ /*end_group=*/1,
+ /*end_object=*/std::nullopt,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ bool correct_message = false;
+ auto track = std::make_shared<MockTrackPublisher>(ftn);
+ publisher_.Add(track);
+
+ auto fetch_task_ptr = std::make_unique<MockFetchTask>();
+ MockFetchTask* fetch_task = fetch_task_ptr.get();
+ EXPECT_CALL(*track, Fetch(_, _, _, _))
+ .WillOnce(Return(std::move(fetch_task_ptr)));
+ EXPECT_CALL(*fetch_task, GetStatus())
+ .WillRepeatedly(Return(absl::Status(absl::StatusCode::kInternal, "foo")));
+ EXPECT_CALL(control_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kFetchError);
+ return absl::OkStatus();
+ });
+ stream_input->OnFetchMessage(request);
+ EXPECT_TRUE(correct_message);
+}
+
+TEST_F(MoqtSessionTest, FetchDelivery) {
+ constexpr uint64_t kFetchId = 0;
+ MockFetchTask* fetch = MoqtSessionPeer::AddFetch(&session_, kFetchId);
+ // Stream creation started out blocked. Allow its creation, but data is
+ // blocked.
+ webtransport::test::MockStream data_stream;
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&data_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(data_stream, GetStreamId()).WillOnce(Return(4));
+ EXPECT_CALL(data_stream, SetVisitor(_))
+ .WillOnce(
+ Invoke([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ }));
+ EXPECT_CALL(data_stream, CanWrite()).WillOnce(Return(false));
+ session_.OnCanCreateNewOutgoingUnidirectionalStream();
+ // Unblock the stream. Provide one object and an EOF.
+ EXPECT_CALL(data_stream, CanWrite()).WillRepeatedly(Return(true));
+ EXPECT_CALL(*fetch, GetNextObject(_))
+ .WillOnce(Invoke([](PublishedObject& output) {
+ output.sequence = FullSequence(0, 0, 0);
+ output.status = MoqtObjectStatus::kNormal;
+ output.publisher_priority = 128;
+ output.payload = MemSliceFromString("foo");
+ output.fin_after_this = true;
+ return MoqtFetchTask::GetNextObjectResult::kSuccess;
+ }))
+ .WillOnce(Invoke([](PublishedObject& /*output*/) {
+ return MoqtFetchTask::GetNextObjectResult::kEof;
+ }));
+
+ int objects_received = 0;
+ EXPECT_CALL(data_stream, Writev(_, _))
+ .WillOnce(Invoke([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ ++objects_received;
+ quic::QuicDataReader reader(data[0]);
+ uint64_t type;
+ EXPECT_TRUE(reader.ReadVarInt62(&type));
+ EXPECT_EQ(type, static_cast<uint64_t>(
+ MoqtDataStreamType::kStreamHeaderFetch));
+ EXPECT_FALSE(options.send_fin()); // fin_after_this is ignored.
+ return absl::OkStatus();
+ }))
+ .WillOnce(Invoke([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ ++objects_received;
+ EXPECT_TRUE(data.empty());
+ EXPECT_TRUE(options.send_fin());
+ return absl::OkStatus();
+ }));
+ stream_visitor->OnCanWrite();
+ EXPECT_EQ(objects_received, 2);
+}
+
+TEST_F(MoqtSessionTest, FetchNonNormalObjects) {
+ constexpr uint64_t kFetchId = 0;
+ MockFetchTask* fetch = MoqtSessionPeer::AddFetch(&session_, kFetchId);
+ // Stream creation started out blocked. Allow its creation, but data is
+ // blocked.
+ webtransport::test::MockStream data_stream;
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&data_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(data_stream, SetVisitor(_))
+ .WillOnce(
+ Invoke([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ }));
+ EXPECT_CALL(data_stream, CanWrite()).WillOnce(Return(false));
+ session_.OnCanCreateNewOutgoingUnidirectionalStream();
+ // Unblock the stream. Provide one object and an EOF.
+ EXPECT_CALL(data_stream, CanWrite()).WillRepeatedly(Return(true));
+ EXPECT_CALL(*fetch, GetNextObject(_))
+ .WillOnce(Invoke([](PublishedObject& output) {
+ // DoesNotExist will be skipped.
+ output.sequence = FullSequence(0, 0, 0);
+ output.status = MoqtObjectStatus::kObjectDoesNotExist;
+ output.publisher_priority = 128;
+ output.payload = MemSliceFromString("");
+ output.fin_after_this = true;
+ return MoqtFetchTask::GetNextObjectResult::kSuccess;
+ }))
+ .WillOnce(Invoke([](PublishedObject& output) {
+ output.sequence = FullSequence(0, 0, 1);
+ output.status = MoqtObjectStatus::kEndOfGroup;
+ output.publisher_priority = 128;
+ output.payload = MemSliceFromString("");
+ output.fin_after_this = true;
+ return MoqtFetchTask::GetNextObjectResult::kSuccess;
+ }))
+ .WillOnce(Invoke([](PublishedObject& /*output*/) {
+ return MoqtFetchTask::GetNextObjectResult::kEof;
+ }));
+
+ int objects_received = 0;
+ EXPECT_CALL(data_stream, Writev(_, _))
+ .WillOnce(Invoke([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ ++objects_received;
+ quic::QuicDataReader reader(data[0]);
+ uint64_t type;
+ EXPECT_TRUE(reader.ReadVarInt62(&type));
+ EXPECT_EQ(type, static_cast<uint64_t>(
+ MoqtDataStreamType::kStreamHeaderFetch));
+ EXPECT_FALSE(options.send_fin());
+ return absl::OkStatus();
+ }))
+ .WillOnce(Invoke([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ ++objects_received;
+ EXPECT_TRUE(data.empty());
+ EXPECT_TRUE(options.send_fin());
+ return absl::OkStatus();
+ }));
+ stream_visitor->OnCanWrite();
+ EXPECT_EQ(objects_received, 2);
+}
+
// TODO: re-enable this test once this behavior is re-implemented.
#if 0
TEST_F(MoqtSessionTest, SubscribeUpdateClosesSubscription) {