Send MoQT Fetch and handle FETCH_OK/FETCH_ERROR. The application can induce FETCH_CANCEL by destroying the MoqtFetchTask object it owns. Does not include data stream handling or object processing. Also: - Make upstream_by_id_ the owner of upstream data structures, instead of subscribe_by_alias_. - Avoid track alias collisions by changing the next proposed alias to be larger than any publisher-provided alias. Not in production. PiperOrigin-RevId: 707099356
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc index e395190..3e0a226 100644 --- a/quiche/quic/moqt/moqt_messages.cc +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -9,13 +9,14 @@ #include <vector> #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" -#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/web_transport/web_transport.h" namespace moqt { @@ -200,4 +201,16 @@ FullTrackName::FullTrackName(absl::Span<const absl::string_view> elements) : tuple_(elements.begin(), elements.end()) {} +absl::Status MoqtStreamErrorToStatus(webtransport::StreamErrorCode error_code, + absl::string_view reason_phrase) { + switch (error_code) { + case kResetCodeSubscriptionGone: + return absl::NotFoundError(reason_phrase); + case kResetCodeTimedOut: + return absl::DeadlineExceededError(reason_phrase); + default: + return absl::UnknownError(reason_phrase); + } +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index 4b189f9..508d103 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -16,6 +16,7 @@ #include <vector> #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -24,6 +25,7 @@ #include "quiche/quic/core/quic_versions.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/common/platform/api/quiche_export.h" +#include "quiche/web_transport/web_transport.h" namespace moqt { @@ -121,9 +123,10 @@ // Error codes used by MoQT to reset streams. // TODO: update with spec-defined error codes once those are available, see // <https://github.com/moq-wg/moq-transport/issues/481>. -inline constexpr uint64_t kResetCodeUnknown = 0x00; -inline constexpr uint64_t kResetCodeSubscriptionGone = 0x01; -inline constexpr uint64_t kResetCodeTimedOut = 0x02; +inline constexpr webtransport::StreamErrorCode kResetCodeUnknown = 0x00; +inline constexpr webtransport::StreamErrorCode kResetCodeSubscriptionGone = + 0x01; +inline constexpr webtransport::StreamErrorCode kResetCodeTimedOut = 0x02; enum class QUICHE_EXPORT MoqtRole : uint64_t { kPublisher = 0x1, @@ -583,6 +586,9 @@ MoqtDataStreamType GetMessageTypeForForwardingPreference( MoqtForwardingPreference preference); +absl::Status MoqtStreamErrorToStatus(webtransport::StreamErrorCode error_code, + absl::string_view reason_phrase); + } // namespace moqt #endif // QUICHE_QUIC_MOQT_MOQT_MESSAGES_H_
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index c181578..8496188 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -417,7 +417,7 @@ } void MoqtSession::Unsubscribe(const FullTrackName& name) { - RemoteTrack* track = RemoteTrackByName(name); + SubscribeRemoteTrack* track = RemoteTrackByName(name); if (track == nullptr) { return; } @@ -425,10 +425,45 @@ message.subscribe_id = track->subscribe_id(); SendControlMessage(framer_.SerializeUnsubscribe(message)); // Destroy state. - upstream_by_name_.erase(name); + subscribe_by_name_.erase(name); + subscribe_by_alias_.erase(track->track_alias()); upstream_by_id_.erase(track->subscribe_id()); - subscribe_by_alias_.erase( - static_cast<SubscribeRemoteTrack*>(track)->track_alias()); +} + +bool MoqtSession::Fetch(const FullTrackName& name, + FetchResponseCallback callback, FullSequence start, + uint64_t end_group, std::optional<uint64_t> end_object, + MoqtPriority priority, + std::optional<MoqtDeliveryOrder> delivery_order, + MoqtSubscribeParameters parameters) { + if (peer_role_ == MoqtRole::kSubscriber) { + QUIC_DLOG(INFO) << ENDPOINT << "Tried to send FETCH to subscriber peer"; + return false; + } + // TODO(martinduke): support authorization info + if (next_subscribe_id_ >= peer_max_subscribe_id_) { + QUIC_DLOG(INFO) << ENDPOINT << "Tried to send FETCH with ID " + << next_subscribe_id_ + << " which is greater than the maximum ID " + << peer_max_subscribe_id_; + return false; + } + MoqtFetch message; + message.full_track_name = name; + message.subscribe_id = next_subscribe_id_++; + message.start_object = start; + message.end_group = end_group; + message.end_object = end_object; + message.subscriber_priority = priority; + message.group_order = delivery_order; + message.parameters = parameters; + message.parameters.object_ack_window = std::nullopt; + SendControlMessage(framer_.SerializeFetch(message)); + QUIC_DLOG(INFO) << ENDPOINT << "Sent FETCH message for " + << message.full_track_name; + auto fetch = std::make_unique<UpstreamFetch>(message, std::move(callback)); + upstream_by_id_.emplace(message.subscribe_id, std::move(fetch)); + return true; } void MoqtSession::PublishedFetch::FetchStreamVisitor::OnCanWrite() { @@ -517,7 +552,7 @@ << peer_max_subscribe_id_; return false; } - if (upstream_by_name_.contains(message.full_track_name)) { + if (subscribe_by_name_.contains(message.full_track_name)) { QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE for track " << message.full_track_name << " which is already subscribed"; @@ -529,6 +564,13 @@ return false; } message.subscribe_id = next_subscribe_id_++; + if (provided_track_alias.has_value()) { + message.track_alias = *provided_track_alias; + next_remote_track_alias_ = + std::max(next_remote_track_alias_, *provided_track_alias) + 1; + } else { + message.track_alias = next_remote_track_alias_++; + } message.track_alias = provided_track_alias.value_or(next_remote_track_alias_++); if (SupportsObjectAck() && visitor != nullptr) { @@ -546,9 +588,9 @@ QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for " << message.full_track_name; auto track = std::make_unique<SubscribeRemoteTrack>(message, visitor); - upstream_by_name_.emplace(message.full_track_name, track.get()); - upstream_by_id_.emplace(message.subscribe_id, track.get()); - subscribe_by_alias_.emplace(message.track_alias, std::move(track)); + subscribe_by_name_.emplace(message.full_track_name, track.get()); + subscribe_by_alias_.emplace(message.track_alias, track.get()); + upstream_by_id_.emplace(message.subscribe_id, std::move(track)); return true; } @@ -608,7 +650,7 @@ if (it == subscribe_by_alias_.end()) { return nullptr; } - return it->second.get(); + return it->second; } RemoteTrack* MoqtSession::RemoteTrackById(uint64_t subscribe_id) { @@ -616,12 +658,13 @@ if (it == upstream_by_id_.end()) { return nullptr; } - return it->second; + return it->second.get(); } -RemoteTrack* MoqtSession::RemoteTrackByName(const FullTrackName& name) { - auto it = upstream_by_name_.find(name); - if (it == upstream_by_name_.end()) { +SubscribeRemoteTrack* MoqtSession::RemoteTrackByName( + const FullTrackName& name) { + auto it = subscribe_by_name_.find(name); + if (it == subscribe_by_name_.end()) { return nullptr; } return it->second; @@ -945,12 +988,10 @@ << ", error = " << static_cast<int>(message.error_code) << " (" << message.reason_phrase << ")"; SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track); - // Delete secondary references to the track. Preserve the owner - // (subscribe_by_alias_) to get the original subscribe, if needed. Erasing the - // other references now prevents an error due to a duplicate subscription in - // Subscribe(). - session_->upstream_by_id_.erase(subscribe->subscribe_id()); - session_->upstream_by_name_.erase(subscribe->full_track_name()); + // Delete the by-name entry at this point prevents Subscribe() from throwing + // an error due to a duplicate track name. The other entries for this + // subscribe will be deleted after calling Subscribe(). + session_->subscribe_by_name_.erase(subscribe->full_track_name()); if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) { // Automatically resubscribe with new alias. MoqtSubscribe& subscribe_message = subscribe->GetSubscribe(); @@ -961,6 +1002,7 @@ message.reason_phrase); } session_->subscribe_by_alias_.erase(subscribe->track_alias()); + session_->upstream_by_id_.erase(subscribe->subscribe_id()); } void MoqtSession::ControlStream::OnUnsubscribeMessage( @@ -1196,6 +1238,80 @@ } } +void MoqtSession::ControlStream::OnFetchOkMessage(const MoqtFetchOk& message) { + RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id); + if (track == nullptr) { + QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_OK for " + << "subscribe_id = " << message.subscribe_id + << " but no track exists"; + // Subscription state might have been destroyed for internal reasons. + return; + } + if (!track->is_fetch()) { + session_->Error(MoqtError::kProtocolViolation, + "Received FETCH_OK for a SUBSCRIBE"); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_OK for subscribe_id = " + << message.subscribe_id << " " << track->full_track_name(); + UpstreamFetch* fetch = static_cast<UpstreamFetch*>(track); + fetch->OnFetchResult(message.largest_id, absl::OkStatus(), + [=, session = session_]() { + session->CancelFetch(message.subscribe_id); + }); +} + +void MoqtSession::ControlStream::OnFetchErrorMessage( + const MoqtFetchError& message) { + RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id); + if (track == nullptr) { + QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_ERROR for " + << "subscribe_id = " << message.subscribe_id + << " but no track exists"; + // Subscription state might have been destroyed for internal reasons. + return; + } + if (!track->is_fetch()) { + session_->Error(MoqtError::kProtocolViolation, + "Received FETCH_ERROR for a SUBSCRIBE"); + return; + } + if (!track->ErrorIsAllowed()) { + session_->Error(MoqtError::kProtocolViolation, + "Received FETCH_ERROR after FETCH_OK or objects"); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_ERROR for " + << "subscribe_id = " << message.subscribe_id << " (" + << track->full_track_name() << ")" + << ", error = " << static_cast<int>(message.error_code) + << " (" << message.reason_phrase << ")"; + UpstreamFetch* fetch = static_cast<UpstreamFetch*>(track); + absl::Status status; + switch (message.error_code) { + case moqt::SubscribeErrorCode::kInternalError: + status = absl::InternalError(message.reason_phrase); + break; + case SubscribeErrorCode::kInvalidRange: + status = absl::OutOfRangeError(message.reason_phrase); + break; + case SubscribeErrorCode::kTrackDoesNotExist: + status = absl::NotFoundError(message.reason_phrase); + break; + case SubscribeErrorCode::kUnauthorized: + status = absl::UnauthenticatedError(message.reason_phrase); + break; + case SubscribeErrorCode::kTimeout: + status = absl::DeadlineExceededError(message.reason_phrase); + break; + default: + status = absl::UnknownError(message.reason_phrase); + break; + } + fetch->OnFetchResult(FullSequence(0, 0), status, nullptr); + session_->upstream_by_id_.erase(message.subscribe_id); +} + void MoqtSession::ControlStream::OnParsingError(MoqtError error_code, absl::string_view reason) { session_->Error(error_code, absl::StrCat("Parse error: ", reason)); @@ -1688,6 +1804,22 @@ return true; } +void MoqtSession::CancelFetch(uint64_t subscribe_id) { + // This is only called from the callback where UpstreamFetchTask has been + // destroyed, so there is no need to notify the application. + upstream_by_id_.erase(subscribe_id); + ControlStream* stream = GetControlStream(); + if (stream == nullptr) { + return; + } + MoqtFetchCancel message; + message.subscribe_id = subscribe_id; + stream->SendOrBufferMessage(framer_.SerializeFetchCancel(message)); + // The FETCH_CANCEL will cause a RESET_STREAM to return, which would be the + // same as a STOP_SENDING. However, a FETCH_CANCEL works even if the stream + // hasn't opened yet. +} + void MoqtSession::PublishedSubscription::SendDatagram(FullSequence sequence) { std::optional<PublishedObject> object = track_publisher_->GetCachedObject(sequence);
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index ca8e80c..adf4b7e 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -182,6 +182,15 @@ // Returns false if the subscription is not found. The session immediately // destroys all subscription state. void Unsubscribe(const FullTrackName& name); + // |callback| will be called when FETCH_OK or FETCH_ERROR is received, and + // delivers a pointer to MoqtFetchTask for application use. The callback + // transfers ownership of MoqtFetchTask to the application. + // To cancel a FETCH, simply destroy the FetchTask. + bool Fetch(const FullTrackName& name, FetchResponseCallback callback, + FullSequence start, uint64_t end_group, + std::optional<uint64_t> end_object, MoqtPriority priority, + std::optional<MoqtDeliveryOrder> delivery_order, + MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); webtransport::Session* session() { return session_; } MoqtSessionCallbacks& callbacks() { return callbacks_; } @@ -264,8 +273,8 @@ void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override; void OnFetchMessage(const MoqtFetch& message) override; void OnFetchCancelMessage(const MoqtFetchCancel& message) override {} - void OnFetchOkMessage(const MoqtFetchOk& message) override {} - void OnFetchErrorMessage(const MoqtFetchError& message) override {} + void OnFetchOkMessage(const MoqtFetchOk& message) override; + void OnFetchErrorMessage(const MoqtFetchError& message) override; void OnObjectAckMessage(const MoqtObjectAck& message) override { auto subscription_it = session_->published_subscriptions_.find(message.subscribe_id); @@ -334,7 +343,6 @@ webtransport::Stream* stream_; // Once the subscribe ID is identified, set it here. quiche::QuicheWeakPtr<RemoteTrack> track_; - // std::optional<uint64_t> subscribe_id_ = std::nullopt; MoqtDataParser parser_; std::string partial_object_; }; @@ -562,7 +570,7 @@ SubscribeRemoteTrack* RemoteTrackByAlias(uint64_t track_alias); RemoteTrack* RemoteTrackById(uint64_t subscribe_id); - RemoteTrack* RemoteTrackByName(const FullTrackName& name); + SubscribeRemoteTrack* RemoteTrackByName(const FullTrackName& name); // Checks that a subscribe ID from a SUBSCRIBE or FETCH is valid, and throws // a session error if is not. @@ -576,6 +584,8 @@ MoqtDataStreamType type, bool is_first_on_stream, bool fin); + void CancelFetch(uint64_t subscribe_id); + // Sends an OBJECT_ACK message for a specific subscribe ID. void SendObjectAck(uint64_t subscribe_id, uint64_t group_id, uint64_t object_id, @@ -606,16 +616,13 @@ std::string error_; // Upstream SUBSCRIBE state. - // All the tracks the session is subscribed to, indexed by track_alias. - absl::flat_hash_map<uint64_t, std::unique_ptr<SubscribeRemoteTrack>> - subscribe_by_alias_; - // Upstream SUBSCRIBEs indexed by subscribe_id. - // TODO(martinduke): Add fetches to this. - absl::flat_hash_map<uint64_t, RemoteTrack*> upstream_by_id_; - // The application only has track names, so this allows MoqtSession to - // quickly find what it's looking for. Also allows a quick check for duplicate - // subscriptions. - absl::flat_hash_map<FullTrackName, RemoteTrack*> upstream_by_name_; + // Upstream SUBSCRIBEs and FETCHes, indexed by subscribe_id. + absl::flat_hash_map<uint64_t, std::unique_ptr<RemoteTrack>> upstream_by_id_; + // All SUBSCRIBEs, indexed by track_alias. + absl::flat_hash_map<uint64_t, SubscribeRemoteTrack*> subscribe_by_alias_; + // All SUBSCRIBEs, indexed by track name. + absl::flat_hash_map<FullTrackName, SubscribeRemoteTrack*> subscribe_by_name_; + // The next track alias to guess on a SUBSCRIBE. uint64_t next_remote_track_alias_ = 0; // The next subscribe ID that the local endpoint can send. uint64_t next_subscribe_id_ = 0;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 1c845b6..172a152 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -2202,6 +2202,60 @@ stream_input->OnSubscribeAnnouncesMessage(announces); } +TEST_F(MoqtSessionTest, FetchThenOkThenCancel) { + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + std::unique_ptr<MoqtFetchTask> fetch_task; + session_.Fetch( + FullTrackName("foo", "bar"), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task = std::move(task); + }, + FullSequence(0, 0), 4, std::nullopt, 128, std::nullopt, + MoqtSubscribeParameters()); + MoqtFetchOk ok = { + /*subscribe_id=*/0, + /*group_order=*/MoqtDeliveryOrder::kAscending, + /*largest_id=*/FullSequence(3, 25), + MoqtSubscribeParameters(), + }; + stream_input->OnFetchOkMessage(ok); + ASSERT_NE(fetch_task, nullptr); + EXPECT_EQ(fetch_task->GetLargestId(), FullSequence(3, 25)); + EXPECT_TRUE(fetch_task->GetStatus().ok()); + PublishedObject object; + EXPECT_EQ(fetch_task->GetNextObject(object), + MoqtFetchTask::GetNextObjectResult::kPending); + // Cancel the fetch. + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kFetchCancel), _)); + fetch_task.reset(); +} + +TEST_F(MoqtSessionTest, FetchThenError) { + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + std::unique_ptr<MoqtFetchTask> fetch_task; + session_.Fetch( + FullTrackName("foo", "bar"), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task = std::move(task); + }, + FullSequence(0, 0), 4, std::nullopt, 128, std::nullopt, + MoqtSubscribeParameters()); + MoqtFetchError error = { + /*subscribe_id=*/0, + /*error_code=*/SubscribeErrorCode::kUnauthorized, + /*reason_phrase=*/"No username provided", + }; + stream_input->OnFetchErrorMessage(error); + ASSERT_NE(fetch_task, nullptr); + EXPECT_TRUE(absl::IsUnauthenticated(fetch_task->GetStatus())); + EXPECT_EQ(fetch_task->GetStatus().message(), "No username provided"); +} + // TODO: re-enable this test once this behavior is re-implemented. #if 0 TEST_F(MoqtSessionTest, SubscribeUpdateClosesSubscription) {
diff --git a/quiche/quic/moqt/moqt_track.cc b/quiche/quic/moqt/moqt_track.cc index 07f44d7..31980d1 100644 --- a/quiche/quic/moqt/moqt_track.cc +++ b/quiche/quic/moqt/moqt_track.cc
@@ -4,7 +4,16 @@ #include "quiche/quic/moqt/moqt_track.h" +#include <memory> +#include <optional> +#include <utility> + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_publisher.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/web_transport/web_transport.h" namespace moqt { @@ -16,4 +25,69 @@ return true; } +UpstreamFetch::~UpstreamFetch() { + if (task_.IsValid()) { + // Notify the task (which the application owns) that nothing more is coming. + // If this has already been called, UpstreamFetchTask will ignore it. + task_.GetIfAvailable()->OnStreamAndFetchClosed(kResetCodeUnknown, ""); + } +} + +void UpstreamFetch::OnFetchResult(FullSequence largest_id, absl::Status status, + TaskDestroyedCallback callback) { + auto task = std::make_unique<UpstreamFetchTask>(largest_id, status, + std::move(callback)); + task_ = task->weak_ptr(); + std::move(ok_callback_)(std::move(task)); +} + +UpstreamFetch::UpstreamFetchTask::~UpstreamFetchTask() { + if (task_destroyed_callback_) { + std::move(task_destroyed_callback_)(); + } +} + +MoqtFetchTask::GetNextObjectResult +UpstreamFetch::UpstreamFetchTask::GetNextObject(PublishedObject& output) { + if (!next_object_.has_value()) { + if (!status_.ok()) { + return kError; + } + if (eof_) { + return kEof; + } + return kPending; + } + output = *std::move(next_object_); + next_object_.reset(); + return kSuccess; +} + +void UpstreamFetch::UpstreamFetchTask::NewObject(PublishedObject& object) { + QUICHE_DCHECK(!next_object_.has_value()); + next_object_ = std::move(object); + if (object_available_callback_) { + object_available_callback_(); + } +} + +void UpstreamFetch::UpstreamFetchTask::OnStreamAndFetchClosed( + std::optional<webtransport::StreamErrorCode> error, + absl::string_view reason_phrase) { + if (eof_ || error.has_value()) { + return; + } + // Delete callbacks, because IncomingDataStream and UpstreamFetch are gone. + can_read_callback_ = nullptr; + task_destroyed_callback_ = nullptr; + if (!error.has_value()) { // This was a FIN. + eof_ = true; + } else { + status_ = MoqtStreamErrorToStatus(*error, reason_phrase); + } + if (object_available_callback_) { + object_available_callback_(); + } +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index e47464a..fe158e1 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -8,14 +8,18 @@ #include <cstdint> #include <memory> #include <optional> +#include <utility> +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_subscribe_windows.h" #include "quiche/common/quiche_callbacks.h" #include "quiche/common/quiche_weak_ptr.h" +#include "quiche/web_transport/web_transport.h" namespace moqt { @@ -133,6 +137,106 @@ std::unique_ptr<MoqtSubscribe> subscribe_; }; +// MoqtSession calls this when a FETCH_OK or FETCH_ERROR is received. The +// destination of the callback owns |fetch_task| and MoqtSession will react +// safely if the owner destroys it. +using FetchResponseCallback = + quiche::SingleUseCallback<void(std::unique_ptr<MoqtFetchTask> fetch_task)>; + +// This is a callback to MoqtSession::IncomingDataStream. Called when the +// FetchTask has its object cache empty, on creation, and whenever the +// application reads it. +using CanReadCallback = quiche::MultiUseCallback<void()>; + +// If the application destroys the FetchTask, this is a signal to MoqtSession to +// cancel the FETCH and STOP_SENDING the stream. +using TaskDestroyedCallback = quiche::SingleUseCallback<void()>; + +// Class for upstream FETCH. It will notify the application using |callback| +// when a FETCH_OK or FETCH_ERROR is received. +class UpstreamFetch : public RemoteTrack { + public: + UpstreamFetch(const MoqtFetch& fetch, FetchResponseCallback callback) + : RemoteTrack(fetch.full_track_name, fetch.subscribe_id, + SubscribeWindow( + fetch.start_object, + FullSequence(fetch.end_group, + fetch.end_object.value_or(UINT64_MAX)))), + ok_callback_(std::move(callback)) { + // Immediately set the data stream type. + CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch); + } + UpstreamFetch(const UpstreamFetch&) = delete; + ~UpstreamFetch(); + + class UpstreamFetchTask : public MoqtFetchTask { + public: + // If the UpstreamFetch is destroyed, it will call OnStreamAndFetchClosed + // which sets the TaskDestroyedCallback to nullptr. Thus, |callback| can + // assume that UpstreamFetch is valid. + UpstreamFetchTask(FullSequence largest_id, absl::Status status, + TaskDestroyedCallback callback) + : largest_id_(largest_id), + status_(status), + task_destroyed_callback_(std::move(callback)), + weak_ptr_factory_(this) {} + ~UpstreamFetchTask() override; + + // Implementation of MoqtFetchTask. + GetNextObjectResult GetNextObject(PublishedObject& output) override; + void SetObjectAvailableCallback( + ObjectsAvailableCallback callback) override { + object_available_callback_ = std::move(callback); + }; + absl::Status GetStatus() override { return status_; }; + FullSequence GetLargestId() const override { return largest_id_; } + + quiche::QuicheWeakPtr<UpstreamFetchTask> weak_ptr() { + return weak_ptr_factory_.Create(); + } + + // Manage the relationship with the data stream. + void OnStreamOpened(CanReadCallback callback) { + can_read_callback_ = std::move(callback); + } + + // Called when the data stream receives a new object. + void NewObject(PublishedObject& object); + + // Deletes callbacks to session or stream, updates the status. If |error| + // has no value, will append an EOF to the object stream. + void OnStreamAndFetchClosed( + std::optional<webtransport::StreamErrorCode> error, + absl::string_view reason_phrase); + + private: + FullSequence largest_id_; + absl::Status status_; + TaskDestroyedCallback task_destroyed_callback_; + + // Object delivery state. + std::optional<PublishedObject> next_object_; + bool eof_ = false; // The next object is EOF. + ObjectsAvailableCallback object_available_callback_; + CanReadCallback can_read_callback_; + + // Must be last. + quiche::QuicheWeakPtrFactory<UpstreamFetchTask> weak_ptr_factory_; + }; + + // Arrival of FETCH_OK/FETCH_ERROR. + void OnFetchResult(FullSequence largest_id, absl::Status status, + TaskDestroyedCallback callback); + + UpstreamFetchTask* task() { return task_.GetIfAvailable(); } + + private: + quiche::QuicheWeakPtr<UpstreamFetchTask> task_; + + // Initial values from Fetch() call. + FetchResponseCallback ok_callback_; // Will be destroyed on FETCH_OK. +}; + } // namespace moqt #endif // QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc index a5717ce..8072535 100644 --- a/quiche/quic/moqt/moqt_track_test.cc +++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -4,11 +4,16 @@ #include "quiche/quic/moqt/moqt_track.h" +#include <memory> #include <optional> +#include <utility> +#include "absl/status/status.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" #include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" namespace moqt { @@ -64,7 +69,96 @@ EXPECT_FALSE(track_.InWindow(FullSequence(2, 0))); } -// TODO: Write test for GetStreamForSequence. +class UpstreamFetchTest : public quic::test::QuicTest { + protected: + UpstreamFetchTest() + : fetch_(fetch_message_, [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task_ = std::move(task); + }) {} + + MoqtFetch fetch_message_ = { + /*fetch_id=*/1, + /*full_track_name=*/FullTrackName("foo", "bar"), + /*subscriber_priority=*/128, + /*group_order=*/std::nullopt, + /*start_object=*/FullSequence(1, 1), + /*end_group=*/3, + /*end_object=*/100, + /*parameters=*/MoqtSubscribeParameters(), + }; + // The pointer held by the application. + UpstreamFetch fetch_; + std::unique_ptr<MoqtFetchTask> fetch_task_; +}; + +TEST_F(UpstreamFetchTest, Queries) { + EXPECT_EQ(fetch_.subscribe_id(), 1); + EXPECT_EQ(fetch_.full_track_name(), FullTrackName("foo", "bar")); + EXPECT_FALSE( + fetch_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderSubgroup)); + EXPECT_TRUE( + fetch_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch)); + EXPECT_TRUE(fetch_.is_fetch()); + EXPECT_FALSE(fetch_.InWindow(FullSequence{1, 0})); + EXPECT_TRUE(fetch_.InWindow(FullSequence{1, 1})); + EXPECT_TRUE(fetch_.InWindow(FullSequence{3, 100})); + EXPECT_FALSE(fetch_.InWindow(FullSequence{3, 101})); +} + +TEST_F(UpstreamFetchTest, AllowError) { + EXPECT_TRUE(fetch_.ErrorIsAllowed()); + fetch_.OnObjectOrOk(); + EXPECT_FALSE(fetch_.ErrorIsAllowed()); +} + +TEST_F(UpstreamFetchTest, FetchResponse) { + EXPECT_EQ(fetch_task_, nullptr); + fetch_.OnFetchResult(FullSequence(3, 50), absl::OkStatus(), nullptr); + EXPECT_NE(fetch_task_, nullptr); + EXPECT_NE(fetch_.task(), nullptr); + EXPECT_TRUE(fetch_task_->GetStatus().ok()); + EXPECT_EQ(fetch_task_->GetLargestId(), FullSequence(3, 50)); +} + +TEST_F(UpstreamFetchTest, FetchClosedByMoqt) { + bool terminated = false; + fetch_.OnFetchResult(FullSequence(3, 50), absl::OkStatus(), + [&]() { terminated = true; }); + bool got_eof = false; + fetch_task_->SetObjectAvailableCallback([&]() { + PublishedObject object; + EXPECT_EQ(fetch_task_->GetNextObject(object), + MoqtFetchTask::GetNextObjectResult::kEof); + got_eof = true; + }); + fetch_.task()->OnStreamAndFetchClosed(std::nullopt, ""); + EXPECT_FALSE(terminated); + EXPECT_TRUE(got_eof); +} + +TEST_F(UpstreamFetchTest, FetchClosedByApplication) { + bool terminated = false; + fetch_.OnFetchResult(FullSequence(3, 50), absl::Status(), + [&]() { terminated = true; }); + fetch_task_.reset(); + EXPECT_TRUE(terminated); +} + +TEST_F(UpstreamFetchTest, ObjectRetrieval) { + fetch_.OnFetchResult(FullSequence(3, 50), absl::OkStatus(), nullptr); + PublishedObject object; + EXPECT_EQ(fetch_task_->GetNextObject(object), + MoqtFetchTask::GetNextObjectResult::kPending); + PublishedObject new_object(FullSequence(3, 0), + MoqtObjectStatus::kGroupDoesNotExist, 128, + quiche::QuicheMemSlice(), false); + fetch_task_->SetObjectAvailableCallback([&]() { + EXPECT_EQ(fetch_task_->GetNextObject(object), + MoqtFetchTask::GetNextObjectResult::kSuccess); + EXPECT_EQ(object.sequence, new_object.sequence); + }); + fetch_.task()->NewObject(new_object); +} } // namespace test
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index 5426ef6..e867376 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -17,7 +17,6 @@ #include "quiche/quic/moqt/moqt_session.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" -#include "quiche/quic/platform/api/quic_test.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -75,11 +74,12 @@ const MoqtSubscribe& subscribe, SubscribeRemoteTrack::Visitor* visitor) { auto track = std::make_unique<SubscribeRemoteTrack>(subscribe, visitor); - session->upstream_by_id_.try_emplace(subscribe.subscribe_id, track.get()); - session->upstream_by_name_.try_emplace(subscribe.full_track_name, - track.get()); session->subscribe_by_alias_.try_emplace(subscribe.track_alias, - std::move(track)); + track.get()); + session->subscribe_by_name_.try_emplace(subscribe.full_track_name, + track.get()); + session->upstream_by_id_.try_emplace(subscribe.subscribe_id, + std::move(track)); } static MoqtObjectListener* AddSubscription(