Send MoQT Fetch and handle FETCH_OK/FETCH_ERROR. The application can induce FETCH_CANCEL by destroying the MoqtFetchTask object it owns.
Does not include data stream handling or object processing.
Also:
- Make upstream_by_id_ the owner of upstream data structures, instead of subscribe_by_alias_.
- Avoid track alias collisions by changing the next proposed alias to be larger than any publisher-provided alias.
Not in production.
PiperOrigin-RevId: 707099356
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc
index e395190..3e0a226 100644
--- a/quiche/quic/moqt/moqt_messages.cc
+++ b/quiche/quic/moqt/moqt_messages.cc
@@ -9,13 +9,14 @@
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/status/status.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "quiche/quic/platform/api/quic_bug_tracker.h"
-#include "quiche/common/platform/api/quiche_bug_tracker.h"
+#include "quiche/web_transport/web_transport.h"
namespace moqt {
@@ -200,4 +201,16 @@
FullTrackName::FullTrackName(absl::Span<const absl::string_view> elements)
: tuple_(elements.begin(), elements.end()) {}
+absl::Status MoqtStreamErrorToStatus(webtransport::StreamErrorCode error_code,
+ absl::string_view reason_phrase) {
+ switch (error_code) {
+ case kResetCodeSubscriptionGone:
+ return absl::NotFoundError(reason_phrase);
+ case kResetCodeTimedOut:
+ return absl::DeadlineExceededError(reason_phrase);
+ default:
+ return absl::UnknownError(reason_phrase);
+ }
+}
+
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index 4b189f9..508d103 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -16,6 +16,7 @@
#include <vector>
#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
@@ -24,6 +25,7 @@
#include "quiche/quic/core/quic_versions.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/common/platform/api/quiche_export.h"
+#include "quiche/web_transport/web_transport.h"
namespace moqt {
@@ -121,9 +123,10 @@
// Error codes used by MoQT to reset streams.
// TODO: update with spec-defined error codes once those are available, see
// <https://github.com/moq-wg/moq-transport/issues/481>.
-inline constexpr uint64_t kResetCodeUnknown = 0x00;
-inline constexpr uint64_t kResetCodeSubscriptionGone = 0x01;
-inline constexpr uint64_t kResetCodeTimedOut = 0x02;
+inline constexpr webtransport::StreamErrorCode kResetCodeUnknown = 0x00;
+inline constexpr webtransport::StreamErrorCode kResetCodeSubscriptionGone =
+ 0x01;
+inline constexpr webtransport::StreamErrorCode kResetCodeTimedOut = 0x02;
enum class QUICHE_EXPORT MoqtRole : uint64_t {
kPublisher = 0x1,
@@ -583,6 +586,9 @@
MoqtDataStreamType GetMessageTypeForForwardingPreference(
MoqtForwardingPreference preference);
+absl::Status MoqtStreamErrorToStatus(webtransport::StreamErrorCode error_code,
+ absl::string_view reason_phrase);
+
} // namespace moqt
#endif // QUICHE_QUIC_MOQT_MOQT_MESSAGES_H_
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index c181578..8496188 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -417,7 +417,7 @@
}
void MoqtSession::Unsubscribe(const FullTrackName& name) {
- RemoteTrack* track = RemoteTrackByName(name);
+ SubscribeRemoteTrack* track = RemoteTrackByName(name);
if (track == nullptr) {
return;
}
@@ -425,10 +425,45 @@
message.subscribe_id = track->subscribe_id();
SendControlMessage(framer_.SerializeUnsubscribe(message));
// Destroy state.
- upstream_by_name_.erase(name);
+ subscribe_by_name_.erase(name);
+ subscribe_by_alias_.erase(track->track_alias());
upstream_by_id_.erase(track->subscribe_id());
- subscribe_by_alias_.erase(
- static_cast<SubscribeRemoteTrack*>(track)->track_alias());
+}
+
+bool MoqtSession::Fetch(const FullTrackName& name,
+ FetchResponseCallback callback, FullSequence start,
+ uint64_t end_group, std::optional<uint64_t> end_object,
+ MoqtPriority priority,
+ std::optional<MoqtDeliveryOrder> delivery_order,
+ MoqtSubscribeParameters parameters) {
+ if (peer_role_ == MoqtRole::kSubscriber) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Tried to send FETCH to subscriber peer";
+ return false;
+ }
+ // TODO(martinduke): support authorization info
+ if (next_subscribe_id_ >= peer_max_subscribe_id_) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Tried to send FETCH with ID "
+ << next_subscribe_id_
+ << " which is greater than the maximum ID "
+ << peer_max_subscribe_id_;
+ return false;
+ }
+ MoqtFetch message;
+ message.full_track_name = name;
+ message.subscribe_id = next_subscribe_id_++;
+ message.start_object = start;
+ message.end_group = end_group;
+ message.end_object = end_object;
+ message.subscriber_priority = priority;
+ message.group_order = delivery_order;
+ message.parameters = parameters;
+ message.parameters.object_ack_window = std::nullopt;
+ SendControlMessage(framer_.SerializeFetch(message));
+ QUIC_DLOG(INFO) << ENDPOINT << "Sent FETCH message for "
+ << message.full_track_name;
+ auto fetch = std::make_unique<UpstreamFetch>(message, std::move(callback));
+ upstream_by_id_.emplace(message.subscribe_id, std::move(fetch));
+ return true;
}
void MoqtSession::PublishedFetch::FetchStreamVisitor::OnCanWrite() {
@@ -517,7 +552,7 @@
<< peer_max_subscribe_id_;
return false;
}
- if (upstream_by_name_.contains(message.full_track_name)) {
+ if (subscribe_by_name_.contains(message.full_track_name)) {
QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE for track "
<< message.full_track_name
<< " which is already subscribed";
@@ -529,6 +564,13 @@
return false;
}
message.subscribe_id = next_subscribe_id_++;
+ if (provided_track_alias.has_value()) {
+ message.track_alias = *provided_track_alias;
+ next_remote_track_alias_ =
+ std::max(next_remote_track_alias_, *provided_track_alias) + 1;
+ } else {
+ message.track_alias = next_remote_track_alias_++;
+ }
message.track_alias =
provided_track_alias.value_or(next_remote_track_alias_++);
if (SupportsObjectAck() && visitor != nullptr) {
@@ -546,9 +588,9 @@
QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
<< message.full_track_name;
auto track = std::make_unique<SubscribeRemoteTrack>(message, visitor);
- upstream_by_name_.emplace(message.full_track_name, track.get());
- upstream_by_id_.emplace(message.subscribe_id, track.get());
- subscribe_by_alias_.emplace(message.track_alias, std::move(track));
+ subscribe_by_name_.emplace(message.full_track_name, track.get());
+ subscribe_by_alias_.emplace(message.track_alias, track.get());
+ upstream_by_id_.emplace(message.subscribe_id, std::move(track));
return true;
}
@@ -608,7 +650,7 @@
if (it == subscribe_by_alias_.end()) {
return nullptr;
}
- return it->second.get();
+ return it->second;
}
RemoteTrack* MoqtSession::RemoteTrackById(uint64_t subscribe_id) {
@@ -616,12 +658,13 @@
if (it == upstream_by_id_.end()) {
return nullptr;
}
- return it->second;
+ return it->second.get();
}
-RemoteTrack* MoqtSession::RemoteTrackByName(const FullTrackName& name) {
- auto it = upstream_by_name_.find(name);
- if (it == upstream_by_name_.end()) {
+SubscribeRemoteTrack* MoqtSession::RemoteTrackByName(
+ const FullTrackName& name) {
+ auto it = subscribe_by_name_.find(name);
+ if (it == subscribe_by_name_.end()) {
return nullptr;
}
return it->second;
@@ -945,12 +988,10 @@
<< ", error = " << static_cast<int>(message.error_code)
<< " (" << message.reason_phrase << ")";
SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track);
- // Delete secondary references to the track. Preserve the owner
- // (subscribe_by_alias_) to get the original subscribe, if needed. Erasing the
- // other references now prevents an error due to a duplicate subscription in
- // Subscribe().
- session_->upstream_by_id_.erase(subscribe->subscribe_id());
- session_->upstream_by_name_.erase(subscribe->full_track_name());
+ // Delete the by-name entry at this point prevents Subscribe() from throwing
+ // an error due to a duplicate track name. The other entries for this
+ // subscribe will be deleted after calling Subscribe().
+ session_->subscribe_by_name_.erase(subscribe->full_track_name());
if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) {
// Automatically resubscribe with new alias.
MoqtSubscribe& subscribe_message = subscribe->GetSubscribe();
@@ -961,6 +1002,7 @@
message.reason_phrase);
}
session_->subscribe_by_alias_.erase(subscribe->track_alias());
+ session_->upstream_by_id_.erase(subscribe->subscribe_id());
}
void MoqtSession::ControlStream::OnUnsubscribeMessage(
@@ -1196,6 +1238,80 @@
}
}
+void MoqtSession::ControlStream::OnFetchOkMessage(const MoqtFetchOk& message) {
+ RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id);
+ if (track == nullptr) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_OK for "
+ << "subscribe_id = " << message.subscribe_id
+ << " but no track exists";
+ // Subscription state might have been destroyed for internal reasons.
+ return;
+ }
+ if (!track->is_fetch()) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Received FETCH_OK for a SUBSCRIBE");
+ return;
+ }
+ QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_OK for subscribe_id = "
+ << message.subscribe_id << " " << track->full_track_name();
+ UpstreamFetch* fetch = static_cast<UpstreamFetch*>(track);
+ fetch->OnFetchResult(message.largest_id, absl::OkStatus(),
+ [=, session = session_]() {
+ session->CancelFetch(message.subscribe_id);
+ });
+}
+
+void MoqtSession::ControlStream::OnFetchErrorMessage(
+ const MoqtFetchError& message) {
+ RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id);
+ if (track == nullptr) {
+ QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_ERROR for "
+ << "subscribe_id = " << message.subscribe_id
+ << " but no track exists";
+ // Subscription state might have been destroyed for internal reasons.
+ return;
+ }
+ if (!track->is_fetch()) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Received FETCH_ERROR for a SUBSCRIBE");
+ return;
+ }
+ if (!track->ErrorIsAllowed()) {
+ session_->Error(MoqtError::kProtocolViolation,
+ "Received FETCH_ERROR after FETCH_OK or objects");
+ return;
+ }
+ QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_ERROR for "
+ << "subscribe_id = " << message.subscribe_id << " ("
+ << track->full_track_name() << ")"
+ << ", error = " << static_cast<int>(message.error_code)
+ << " (" << message.reason_phrase << ")";
+ UpstreamFetch* fetch = static_cast<UpstreamFetch*>(track);
+ absl::Status status;
+ switch (message.error_code) {
+ case moqt::SubscribeErrorCode::kInternalError:
+ status = absl::InternalError(message.reason_phrase);
+ break;
+ case SubscribeErrorCode::kInvalidRange:
+ status = absl::OutOfRangeError(message.reason_phrase);
+ break;
+ case SubscribeErrorCode::kTrackDoesNotExist:
+ status = absl::NotFoundError(message.reason_phrase);
+ break;
+ case SubscribeErrorCode::kUnauthorized:
+ status = absl::UnauthenticatedError(message.reason_phrase);
+ break;
+ case SubscribeErrorCode::kTimeout:
+ status = absl::DeadlineExceededError(message.reason_phrase);
+ break;
+ default:
+ status = absl::UnknownError(message.reason_phrase);
+ break;
+ }
+ fetch->OnFetchResult(FullSequence(0, 0), status, nullptr);
+ session_->upstream_by_id_.erase(message.subscribe_id);
+}
+
void MoqtSession::ControlStream::OnParsingError(MoqtError error_code,
absl::string_view reason) {
session_->Error(error_code, absl::StrCat("Parse error: ", reason));
@@ -1688,6 +1804,22 @@
return true;
}
+void MoqtSession::CancelFetch(uint64_t subscribe_id) {
+ // This is only called from the callback where UpstreamFetchTask has been
+ // destroyed, so there is no need to notify the application.
+ upstream_by_id_.erase(subscribe_id);
+ ControlStream* stream = GetControlStream();
+ if (stream == nullptr) {
+ return;
+ }
+ MoqtFetchCancel message;
+ message.subscribe_id = subscribe_id;
+ stream->SendOrBufferMessage(framer_.SerializeFetchCancel(message));
+ // The FETCH_CANCEL will cause a RESET_STREAM to return, which would be the
+ // same as a STOP_SENDING. However, a FETCH_CANCEL works even if the stream
+ // hasn't opened yet.
+}
+
void MoqtSession::PublishedSubscription::SendDatagram(FullSequence sequence) {
std::optional<PublishedObject> object =
track_publisher_->GetCachedObject(sequence);
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index ca8e80c..adf4b7e 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -182,6 +182,15 @@
// Returns false if the subscription is not found. The session immediately
// destroys all subscription state.
void Unsubscribe(const FullTrackName& name);
+ // |callback| will be called when FETCH_OK or FETCH_ERROR is received, and
+ // delivers a pointer to MoqtFetchTask for application use. The callback
+ // transfers ownership of MoqtFetchTask to the application.
+ // To cancel a FETCH, simply destroy the FetchTask.
+ bool Fetch(const FullTrackName& name, FetchResponseCallback callback,
+ FullSequence start, uint64_t end_group,
+ std::optional<uint64_t> end_object, MoqtPriority priority,
+ std::optional<MoqtDeliveryOrder> delivery_order,
+ MoqtSubscribeParameters parameters = MoqtSubscribeParameters());
webtransport::Session* session() { return session_; }
MoqtSessionCallbacks& callbacks() { return callbacks_; }
@@ -264,8 +273,8 @@
void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override;
void OnFetchMessage(const MoqtFetch& message) override;
void OnFetchCancelMessage(const MoqtFetchCancel& message) override {}
- void OnFetchOkMessage(const MoqtFetchOk& message) override {}
- void OnFetchErrorMessage(const MoqtFetchError& message) override {}
+ void OnFetchOkMessage(const MoqtFetchOk& message) override;
+ void OnFetchErrorMessage(const MoqtFetchError& message) override;
void OnObjectAckMessage(const MoqtObjectAck& message) override {
auto subscription_it =
session_->published_subscriptions_.find(message.subscribe_id);
@@ -334,7 +343,6 @@
webtransport::Stream* stream_;
// Once the subscribe ID is identified, set it here.
quiche::QuicheWeakPtr<RemoteTrack> track_;
- // std::optional<uint64_t> subscribe_id_ = std::nullopt;
MoqtDataParser parser_;
std::string partial_object_;
};
@@ -562,7 +570,7 @@
SubscribeRemoteTrack* RemoteTrackByAlias(uint64_t track_alias);
RemoteTrack* RemoteTrackById(uint64_t subscribe_id);
- RemoteTrack* RemoteTrackByName(const FullTrackName& name);
+ SubscribeRemoteTrack* RemoteTrackByName(const FullTrackName& name);
// Checks that a subscribe ID from a SUBSCRIBE or FETCH is valid, and throws
// a session error if is not.
@@ -576,6 +584,8 @@
MoqtDataStreamType type, bool is_first_on_stream,
bool fin);
+ void CancelFetch(uint64_t subscribe_id);
+
// Sends an OBJECT_ACK message for a specific subscribe ID.
void SendObjectAck(uint64_t subscribe_id, uint64_t group_id,
uint64_t object_id,
@@ -606,16 +616,13 @@
std::string error_;
// Upstream SUBSCRIBE state.
- // All the tracks the session is subscribed to, indexed by track_alias.
- absl::flat_hash_map<uint64_t, std::unique_ptr<SubscribeRemoteTrack>>
- subscribe_by_alias_;
- // Upstream SUBSCRIBEs indexed by subscribe_id.
- // TODO(martinduke): Add fetches to this.
- absl::flat_hash_map<uint64_t, RemoteTrack*> upstream_by_id_;
- // The application only has track names, so this allows MoqtSession to
- // quickly find what it's looking for. Also allows a quick check for duplicate
- // subscriptions.
- absl::flat_hash_map<FullTrackName, RemoteTrack*> upstream_by_name_;
+ // Upstream SUBSCRIBEs and FETCHes, indexed by subscribe_id.
+ absl::flat_hash_map<uint64_t, std::unique_ptr<RemoteTrack>> upstream_by_id_;
+ // All SUBSCRIBEs, indexed by track_alias.
+ absl::flat_hash_map<uint64_t, SubscribeRemoteTrack*> subscribe_by_alias_;
+ // All SUBSCRIBEs, indexed by track name.
+ absl::flat_hash_map<FullTrackName, SubscribeRemoteTrack*> subscribe_by_name_;
+ // The next track alias to guess on a SUBSCRIBE.
uint64_t next_remote_track_alias_ = 0;
// The next subscribe ID that the local endpoint can send.
uint64_t next_subscribe_id_ = 0;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 1c845b6..172a152 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -2202,6 +2202,60 @@
stream_input->OnSubscribeAnnouncesMessage(announces);
}
+TEST_F(MoqtSessionTest, FetchThenOkThenCancel) {
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ std::unique_ptr<MoqtFetchTask> fetch_task;
+ session_.Fetch(
+ FullTrackName("foo", "bar"),
+ [&](std::unique_ptr<MoqtFetchTask> task) {
+ fetch_task = std::move(task);
+ },
+ FullSequence(0, 0), 4, std::nullopt, 128, std::nullopt,
+ MoqtSubscribeParameters());
+ MoqtFetchOk ok = {
+ /*subscribe_id=*/0,
+ /*group_order=*/MoqtDeliveryOrder::kAscending,
+ /*largest_id=*/FullSequence(3, 25),
+ MoqtSubscribeParameters(),
+ };
+ stream_input->OnFetchOkMessage(ok);
+ ASSERT_NE(fetch_task, nullptr);
+ EXPECT_EQ(fetch_task->GetLargestId(), FullSequence(3, 25));
+ EXPECT_TRUE(fetch_task->GetStatus().ok());
+ PublishedObject object;
+ EXPECT_EQ(fetch_task->GetNextObject(object),
+ MoqtFetchTask::GetNextObjectResult::kPending);
+ // Cancel the fetch.
+ EXPECT_CALL(mock_stream,
+ Writev(ControlMessageOfType(MoqtMessageType::kFetchCancel), _));
+ fetch_task.reset();
+}
+
+TEST_F(MoqtSessionTest, FetchThenError) {
+ webtransport::test::MockStream mock_stream;
+ std::unique_ptr<MoqtControlParserVisitor> stream_input =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+ std::unique_ptr<MoqtFetchTask> fetch_task;
+ session_.Fetch(
+ FullTrackName("foo", "bar"),
+ [&](std::unique_ptr<MoqtFetchTask> task) {
+ fetch_task = std::move(task);
+ },
+ FullSequence(0, 0), 4, std::nullopt, 128, std::nullopt,
+ MoqtSubscribeParameters());
+ MoqtFetchError error = {
+ /*subscribe_id=*/0,
+ /*error_code=*/SubscribeErrorCode::kUnauthorized,
+ /*reason_phrase=*/"No username provided",
+ };
+ stream_input->OnFetchErrorMessage(error);
+ ASSERT_NE(fetch_task, nullptr);
+ EXPECT_TRUE(absl::IsUnauthenticated(fetch_task->GetStatus()));
+ EXPECT_EQ(fetch_task->GetStatus().message(), "No username provided");
+}
+
// TODO: re-enable this test once this behavior is re-implemented.
#if 0
TEST_F(MoqtSessionTest, SubscribeUpdateClosesSubscription) {
diff --git a/quiche/quic/moqt/moqt_track.cc b/quiche/quic/moqt/moqt_track.cc
index 07f44d7..31980d1 100644
--- a/quiche/quic/moqt/moqt_track.cc
+++ b/quiche/quic/moqt/moqt_track.cc
@@ -4,7 +4,16 @@
#include "quiche/quic/moqt/moqt_track.h"
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_publisher.h"
+#include "quiche/common/platform/api/quiche_logging.h"
+#include "quiche/web_transport/web_transport.h"
namespace moqt {
@@ -16,4 +25,69 @@
return true;
}
+UpstreamFetch::~UpstreamFetch() {
+ if (task_.IsValid()) {
+ // Notify the task (which the application owns) that nothing more is coming.
+ // If this has already been called, UpstreamFetchTask will ignore it.
+ task_.GetIfAvailable()->OnStreamAndFetchClosed(kResetCodeUnknown, "");
+ }
+}
+
+void UpstreamFetch::OnFetchResult(FullSequence largest_id, absl::Status status,
+ TaskDestroyedCallback callback) {
+ auto task = std::make_unique<UpstreamFetchTask>(largest_id, status,
+ std::move(callback));
+ task_ = task->weak_ptr();
+ std::move(ok_callback_)(std::move(task));
+}
+
+UpstreamFetch::UpstreamFetchTask::~UpstreamFetchTask() {
+ if (task_destroyed_callback_) {
+ std::move(task_destroyed_callback_)();
+ }
+}
+
+MoqtFetchTask::GetNextObjectResult
+UpstreamFetch::UpstreamFetchTask::GetNextObject(PublishedObject& output) {
+ if (!next_object_.has_value()) {
+ if (!status_.ok()) {
+ return kError;
+ }
+ if (eof_) {
+ return kEof;
+ }
+ return kPending;
+ }
+ output = *std::move(next_object_);
+ next_object_.reset();
+ return kSuccess;
+}
+
+void UpstreamFetch::UpstreamFetchTask::NewObject(PublishedObject& object) {
+ QUICHE_DCHECK(!next_object_.has_value());
+ next_object_ = std::move(object);
+ if (object_available_callback_) {
+ object_available_callback_();
+ }
+}
+
+void UpstreamFetch::UpstreamFetchTask::OnStreamAndFetchClosed(
+ std::optional<webtransport::StreamErrorCode> error,
+ absl::string_view reason_phrase) {
+ if (eof_ || error.has_value()) {
+ return;
+ }
+ // Delete callbacks, because IncomingDataStream and UpstreamFetch are gone.
+ can_read_callback_ = nullptr;
+ task_destroyed_callback_ = nullptr;
+ if (!error.has_value()) { // This was a FIN.
+ eof_ = true;
+ } else {
+ status_ = MoqtStreamErrorToStatus(*error, reason_phrase);
+ }
+ if (object_available_callback_) {
+ object_available_callback_();
+ }
+}
+
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h
index e47464a..fe158e1 100644
--- a/quiche/quic/moqt/moqt_track.h
+++ b/quiche/quic/moqt/moqt_track.h
@@ -8,14 +8,18 @@
#include <cstdint>
#include <memory>
#include <optional>
+#include <utility>
+#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/core/quic_time.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_priority.h"
+#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/moqt_subscribe_windows.h"
#include "quiche/common/quiche_callbacks.h"
#include "quiche/common/quiche_weak_ptr.h"
+#include "quiche/web_transport/web_transport.h"
namespace moqt {
@@ -133,6 +137,106 @@
std::unique_ptr<MoqtSubscribe> subscribe_;
};
+// MoqtSession calls this when a FETCH_OK or FETCH_ERROR is received. The
+// destination of the callback owns |fetch_task| and MoqtSession will react
+// safely if the owner destroys it.
+using FetchResponseCallback =
+ quiche::SingleUseCallback<void(std::unique_ptr<MoqtFetchTask> fetch_task)>;
+
+// This is a callback to MoqtSession::IncomingDataStream. Called when the
+// FetchTask has its object cache empty, on creation, and whenever the
+// application reads it.
+using CanReadCallback = quiche::MultiUseCallback<void()>;
+
+// If the application destroys the FetchTask, this is a signal to MoqtSession to
+// cancel the FETCH and STOP_SENDING the stream.
+using TaskDestroyedCallback = quiche::SingleUseCallback<void()>;
+
+// Class for upstream FETCH. It will notify the application using |callback|
+// when a FETCH_OK or FETCH_ERROR is received.
+class UpstreamFetch : public RemoteTrack {
+ public:
+ UpstreamFetch(const MoqtFetch& fetch, FetchResponseCallback callback)
+ : RemoteTrack(fetch.full_track_name, fetch.subscribe_id,
+ SubscribeWindow(
+ fetch.start_object,
+ FullSequence(fetch.end_group,
+ fetch.end_object.value_or(UINT64_MAX)))),
+ ok_callback_(std::move(callback)) {
+ // Immediately set the data stream type.
+ CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch);
+ }
+ UpstreamFetch(const UpstreamFetch&) = delete;
+ ~UpstreamFetch();
+
+ class UpstreamFetchTask : public MoqtFetchTask {
+ public:
+ // If the UpstreamFetch is destroyed, it will call OnStreamAndFetchClosed
+ // which sets the TaskDestroyedCallback to nullptr. Thus, |callback| can
+ // assume that UpstreamFetch is valid.
+ UpstreamFetchTask(FullSequence largest_id, absl::Status status,
+ TaskDestroyedCallback callback)
+ : largest_id_(largest_id),
+ status_(status),
+ task_destroyed_callback_(std::move(callback)),
+ weak_ptr_factory_(this) {}
+ ~UpstreamFetchTask() override;
+
+ // Implementation of MoqtFetchTask.
+ GetNextObjectResult GetNextObject(PublishedObject& output) override;
+ void SetObjectAvailableCallback(
+ ObjectsAvailableCallback callback) override {
+ object_available_callback_ = std::move(callback);
+ };
+ absl::Status GetStatus() override { return status_; };
+ FullSequence GetLargestId() const override { return largest_id_; }
+
+ quiche::QuicheWeakPtr<UpstreamFetchTask> weak_ptr() {
+ return weak_ptr_factory_.Create();
+ }
+
+ // Manage the relationship with the data stream.
+ void OnStreamOpened(CanReadCallback callback) {
+ can_read_callback_ = std::move(callback);
+ }
+
+ // Called when the data stream receives a new object.
+ void NewObject(PublishedObject& object);
+
+ // Deletes callbacks to session or stream, updates the status. If |error|
+ // has no value, will append an EOF to the object stream.
+ void OnStreamAndFetchClosed(
+ std::optional<webtransport::StreamErrorCode> error,
+ absl::string_view reason_phrase);
+
+ private:
+ FullSequence largest_id_;
+ absl::Status status_;
+ TaskDestroyedCallback task_destroyed_callback_;
+
+ // Object delivery state.
+ std::optional<PublishedObject> next_object_;
+ bool eof_ = false; // The next object is EOF.
+ ObjectsAvailableCallback object_available_callback_;
+ CanReadCallback can_read_callback_;
+
+ // Must be last.
+ quiche::QuicheWeakPtrFactory<UpstreamFetchTask> weak_ptr_factory_;
+ };
+
+ // Arrival of FETCH_OK/FETCH_ERROR.
+ void OnFetchResult(FullSequence largest_id, absl::Status status,
+ TaskDestroyedCallback callback);
+
+ UpstreamFetchTask* task() { return task_.GetIfAvailable(); }
+
+ private:
+ quiche::QuicheWeakPtr<UpstreamFetchTask> task_;
+
+ // Initial values from Fetch() call.
+ FetchResponseCallback ok_callback_; // Will be destroyed on FETCH_OK.
+};
+
} // namespace moqt
#endif // QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc
index a5717ce..8072535 100644
--- a/quiche/quic/moqt/moqt_track_test.cc
+++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -4,11 +4,16 @@
#include "quiche/quic/moqt/moqt_track.h"
+#include <memory>
#include <optional>
+#include <utility>
+#include "absl/status/status.h"
#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_publisher.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
#include "quiche/quic/platform/api/quic_test.h"
+#include "quiche/common/platform/api/quiche_mem_slice.h"
namespace moqt {
@@ -64,7 +69,96 @@
EXPECT_FALSE(track_.InWindow(FullSequence(2, 0)));
}
-// TODO: Write test for GetStreamForSequence.
+class UpstreamFetchTest : public quic::test::QuicTest {
+ protected:
+ UpstreamFetchTest()
+ : fetch_(fetch_message_, [&](std::unique_ptr<MoqtFetchTask> task) {
+ fetch_task_ = std::move(task);
+ }) {}
+
+ MoqtFetch fetch_message_ = {
+ /*fetch_id=*/1,
+ /*full_track_name=*/FullTrackName("foo", "bar"),
+ /*subscriber_priority=*/128,
+ /*group_order=*/std::nullopt,
+ /*start_object=*/FullSequence(1, 1),
+ /*end_group=*/3,
+ /*end_object=*/100,
+ /*parameters=*/MoqtSubscribeParameters(),
+ };
+ // The pointer held by the application.
+ UpstreamFetch fetch_;
+ std::unique_ptr<MoqtFetchTask> fetch_task_;
+};
+
+TEST_F(UpstreamFetchTest, Queries) {
+ EXPECT_EQ(fetch_.subscribe_id(), 1);
+ EXPECT_EQ(fetch_.full_track_name(), FullTrackName("foo", "bar"));
+ EXPECT_FALSE(
+ fetch_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderSubgroup));
+ EXPECT_TRUE(
+ fetch_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch));
+ EXPECT_TRUE(fetch_.is_fetch());
+ EXPECT_FALSE(fetch_.InWindow(FullSequence{1, 0}));
+ EXPECT_TRUE(fetch_.InWindow(FullSequence{1, 1}));
+ EXPECT_TRUE(fetch_.InWindow(FullSequence{3, 100}));
+ EXPECT_FALSE(fetch_.InWindow(FullSequence{3, 101}));
+}
+
+TEST_F(UpstreamFetchTest, AllowError) {
+ EXPECT_TRUE(fetch_.ErrorIsAllowed());
+ fetch_.OnObjectOrOk();
+ EXPECT_FALSE(fetch_.ErrorIsAllowed());
+}
+
+TEST_F(UpstreamFetchTest, FetchResponse) {
+ EXPECT_EQ(fetch_task_, nullptr);
+ fetch_.OnFetchResult(FullSequence(3, 50), absl::OkStatus(), nullptr);
+ EXPECT_NE(fetch_task_, nullptr);
+ EXPECT_NE(fetch_.task(), nullptr);
+ EXPECT_TRUE(fetch_task_->GetStatus().ok());
+ EXPECT_EQ(fetch_task_->GetLargestId(), FullSequence(3, 50));
+}
+
+TEST_F(UpstreamFetchTest, FetchClosedByMoqt) {
+ bool terminated = false;
+ fetch_.OnFetchResult(FullSequence(3, 50), absl::OkStatus(),
+ [&]() { terminated = true; });
+ bool got_eof = false;
+ fetch_task_->SetObjectAvailableCallback([&]() {
+ PublishedObject object;
+ EXPECT_EQ(fetch_task_->GetNextObject(object),
+ MoqtFetchTask::GetNextObjectResult::kEof);
+ got_eof = true;
+ });
+ fetch_.task()->OnStreamAndFetchClosed(std::nullopt, "");
+ EXPECT_FALSE(terminated);
+ EXPECT_TRUE(got_eof);
+}
+
+TEST_F(UpstreamFetchTest, FetchClosedByApplication) {
+ bool terminated = false;
+ fetch_.OnFetchResult(FullSequence(3, 50), absl::Status(),
+ [&]() { terminated = true; });
+ fetch_task_.reset();
+ EXPECT_TRUE(terminated);
+}
+
+TEST_F(UpstreamFetchTest, ObjectRetrieval) {
+ fetch_.OnFetchResult(FullSequence(3, 50), absl::OkStatus(), nullptr);
+ PublishedObject object;
+ EXPECT_EQ(fetch_task_->GetNextObject(object),
+ MoqtFetchTask::GetNextObjectResult::kPending);
+ PublishedObject new_object(FullSequence(3, 0),
+ MoqtObjectStatus::kGroupDoesNotExist, 128,
+ quiche::QuicheMemSlice(), false);
+ fetch_task_->SetObjectAvailableCallback([&]() {
+ EXPECT_EQ(fetch_task_->GetNextObject(object),
+ MoqtFetchTask::GetNextObjectResult::kSuccess);
+ EXPECT_EQ(object.sequence, new_object.sequence);
+ });
+ fetch_.task()->NewObject(new_object);
+}
} // namespace test
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h
index 5426ef6..e867376 100644
--- a/quiche/quic/moqt/test_tools/moqt_session_peer.h
+++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -17,7 +17,6 @@
#include "quiche/quic/moqt/moqt_session.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
-#include "quiche/quic/platform/api/quic_test.h"
#include "quiche/web_transport/test_tools/mock_web_transport.h"
#include "quiche/web_transport/web_transport.h"
@@ -75,11 +74,12 @@
const MoqtSubscribe& subscribe,
SubscribeRemoteTrack::Visitor* visitor) {
auto track = std::make_unique<SubscribeRemoteTrack>(subscribe, visitor);
- session->upstream_by_id_.try_emplace(subscribe.subscribe_id, track.get());
- session->upstream_by_name_.try_emplace(subscribe.full_track_name,
- track.get());
session->subscribe_by_alias_.try_emplace(subscribe.track_alias,
- std::move(track));
+ track.get());
+ session->subscribe_by_name_.try_emplace(subscribe.full_track_name,
+ track.get());
+ session->upstream_by_id_.try_emplace(subscribe.subscribe_id,
+ std::move(track));
}
static MoqtObjectListener* AddSubscription(