Create OutgoingFetchStream and factor out OutgoingUniStream as a parent of both data stream types. PiperOrigin-RevId: 916033776
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 37103af..b5e44ab 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -650,66 +650,6 @@ kDefaultGoAwayTimeout); } -void MoqtSession::PublishedFetch::FetchStreamVisitor::OnCanWrite() { - std::shared_ptr<PublishedFetch> fetch = fetch_.lock(); - if (fetch == nullptr) { - return; - } - PublishedObject object; - while (stream_->CanWrite()) { - MoqtFetchTask::GetNextObjectResult result = - fetch->fetch_task()->GetNextObject(object); - switch (result) { - case MoqtFetchTask::GetNextObjectResult::kSuccess: - // Skip ObjectDoesNotExist in FETCH. - if (object.metadata.status != MoqtObjectStatus::kNormal) { - QUIC_BUG(quic_bug_got_doesnotexist_in_fetch) - << "Got Non-normal object in FETCH"; - continue; - } - if (last_object_.has_value() && - object.metadata.location == last_object_->location) { - // This is the continuation of the previous object. - webtransport::StreamWriteOptions options; - absl::Status write_status = - stream_->Writev(absl::MakeSpan(object.payload), options); - if (!write_status.ok()) { - QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed) - << "Writing into MoQT stream failed despite CanWrite() being " - "true before; status: " - << write_status; - fetch->session_->Error(MoqtError::kInternalError, - "Data stream write error"); - return; - } - break; - } - if (fetch->session_->WriteObjectToStream( - stream_, fetch->request_id(), object.metadata, - std::move(object.payload), MoqtDataStreamType::Fetch(), - // last Object ID doesn't matter for FETCH, just use zero. - last_object_, /*fin=*/false)) { - last_object_ = object.metadata; - } - break; - case MoqtFetchTask::GetNextObjectResult::kPending: - return; - case MoqtFetchTask::GetNextObjectResult::kEof: - // TODO(martinduke): Either prefetch the next object, or alter the API - // so that we're not sending FIN in a separate frame. - if (!webtransport::SendFinOnStream(*stream_).ok()) { - QUIC_DVLOG(1) << "Sending FIN onStream " << stream_->GetStreamId() - << " failed"; - } - return; - case MoqtFetchTask::GetNextObjectResult::kError: - stream_->ResetWithUserCode(static_cast<webtransport::StreamErrorCode>( - fetch->fetch_task()->GetStatus().code())); - return; - } - } -} - void MoqtSession::GoAwayTimeoutDelegate::OnAlarm() { session_->Error(MoqtError::kGoawayTimeout, "Peer did not close session after GOAWAY"); @@ -808,7 +748,7 @@ return new_stream; } -bool MoqtSession::OpenDataStream(std::shared_ptr<PublishedFetch> fetch, +bool MoqtSession::OpenDataStream(PublishedFetch* fetch, webtransport::SendOrder send_order) { webtransport::Stream* new_stream = session_->OpenOutgoingUnidirectionalStream(); @@ -818,14 +758,24 @@ return false; } fetch->SetStreamId(new_stream->GetStreamId()); - new_stream->SetPriority(webtransport::StreamPriority{ - /*send_group_id=*/kMoqtSendGroupId, send_order}); // The line below will lead to updating ObjectsAvailableCallback in the // FetchTask to call OnCanWrite() on the stream. If there is an object // available, the callback will be invoked synchronously (i.e. before // SetVisitor() returns). - new_stream->SetVisitor( - std::make_unique<PublishedFetch::FetchStreamVisitor>(fetch, new_stream)); + new_stream->SetVisitor(std::make_unique<OutgoingFetchStream>( + framer_, new_stream, fetch->request_id(), + webtransport::StreamPriority{/*send_group_id=*/kMoqtSendGroupId, + send_order}, + fetch->release_fetch_task(), + // use weakptr to avoid use-after-free for this. + [weakptr = GetWeakPtr(), request_id = fetch->request_id()]() { + if (weakptr.IsValid()) { + auto session = + absl::down_cast<MoqtSession*>(weakptr.GetIfAvailable()); + session->incoming_fetches_.erase(request_id); + } + }, + &trace_recorder_)); return true; } @@ -864,7 +814,7 @@ auto fetch = incoming_fetches_.find(next->subscription_id); // Create the stream if the fetch still exists. if (fetch != incoming_fetches_.end() && - !OpenDataStream(fetch->second, next->send_order)) { + !OpenDataStream(fetch->second.get(), next->send_order)) { return; // A QUIC_BUG has fired because this shouldn't happen. } // FETCH needs only one stream, and can be deleted from the queue. Or, @@ -1567,8 +1517,8 @@ return SendRequestError(message.request_id, RequestErrorCode::kInvalidRange, std::nullopt, fetch->GetStatus().message()); } - auto published_fetch = std::make_unique<PublishedFetch>( - message.request_id, session_, std::move(fetch)); + auto published_fetch = + std::make_unique<PublishedFetch>(message.request_id, std::move(fetch)); auto result = session_->incoming_fetches_.emplace(message.request_id, std::move(published_fetch)); if (!result.second) { // Emplace failed. @@ -1578,7 +1528,7 @@ RequestErrorCode::kInternalError, std::nullopt, "Could not initialize FETCH state"); } - MoqtFetchTask* fetch_task = result.first->second->fetch_task(); + MoqtFetchTask* fetch_task = result.first->second->fetch_task_ptr(); fetch_task->SetFetchResponseCallback( [this, request_id = message.request_id]( std::variant<MoqtFetchOk, MoqtRequestError> message) { @@ -1610,7 +1560,7 @@ return; } if (!session_->session()->CanOpenNextOutgoingUnidirectionalStream() || - !session_->OpenDataStream(it->second, send_order)) { + !session_->OpenDataStream(it->second.get(), send_order)) { if (!session_->subscribes_with_queued_outgoing_data_streams_.contains( SubscriptionWithQueuedStream(request_id, send_order))) { // Put the FETCH in the queue for a new stream unless it has already @@ -2316,50 +2266,6 @@ // TODO: send PUBLISH_DONE if the subscription is done. } -bool MoqtSession::WriteObjectToStream( - webtransport::Stream* stream, uint64_t id, - const PublishedObjectMetadata& metadata, - std::vector<quiche::QuicheMemSlice> payload, MoqtDataStreamType type, - std::optional<PublishedObjectMetadata> last_object, bool fin) { - QUICHE_DCHECK(stream->CanWrite()); - MoqtObject header; - header.track_alias = id; - header.group_id = metadata.location.group; - header.subgroup_id = metadata.subgroup; - header.object_id = metadata.location.object; - header.publisher_priority = metadata.publisher_priority; - header.extension_headers = metadata.extensions; - header.object_status = metadata.status; - header.payload_length = metadata.payload_length; - - quiche::QuicheBuffer serialized_header = - framer_.SerializeObjectHeader(header, type, last_object); - // TODO(vasilvv): add a version of WebTransport write API that accepts - // memslices so that we can avoid a copy here. - std::vector<quiche::QuicheMemSlice> write_vector; - write_vector.reserve(payload.size() + 1); - write_vector.push_back(quiche::QuicheMemSlice(std::move(serialized_header))); - for (auto& slice : payload) { - write_vector.push_back(std::move(slice)); - } - webtransport::StreamWriteOptions options; - options.set_send_fin(fin); - absl::Status write_status = - stream->Writev(absl::MakeSpan(write_vector), options); - if (!write_status.ok()) { - QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed) - << "Writing into MoQT stream failed despite CanWrite() being true " - "before; status: " - << write_status; - Error(MoqtError::kInternalError, "Data stream write error"); - return false; - } - - QUIC_DVLOG(1) << "Stream " << stream->GetStreamId() << " successfully wrote " - << metadata.location << ", fin = " << fin; - return true; -} - void MoqtSession::OnMalformedTrack(RemoteTrack* track) { if (!track->is_fetch()) { absl::down_cast<SubscribeRemoteTrack*>(track)->visitor()->OnMalformedTrack( @@ -2450,15 +2356,6 @@ OnObjectSent(object->metadata.location); } -MoqtSession::PublishedFetch::FetchStreamVisitor::FetchStreamVisitor( - std::shared_ptr<PublishedFetch> fetch, webtransport::Stream* stream) - : fetch_(fetch), stream_(stream) { - fetch->fetch_task()->SetObjectAvailableCallback( - [this]() { this->OnCanWrite(); }); - fetch->session()->trace_recorder_.RecordFetchStreamCreated( - stream->GetStreamId()); -} - void MoqtSession::PublishedSubscription::ProcessObjectAck( const MoqtObjectAck& message) { session_->trace_recorder_.RecordObjectAck(
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index d66df53..5c1998a 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -14,6 +14,7 @@ #include <vector> #include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/btree_map.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" @@ -547,49 +548,24 @@ class QUICHE_EXPORT PublishedFetch { public: - PublishedFetch(uint64_t request_id, MoqtSession* session, - std::unique_ptr<MoqtFetchTask> fetch) - : session_(session), - fetch_(std::move(fetch)), - request_id_(request_id) {} + PublishedFetch(uint64_t request_id, std::unique_ptr<MoqtFetchTask> fetch) + : request_id_(request_id), fetch_(std::move(fetch)) {} - class FetchStreamVisitor : public webtransport::StreamVisitor { - public: - FetchStreamVisitor(std::shared_ptr<PublishedFetch> fetch, - webtransport::Stream* stream); - ~FetchStreamVisitor() { - std::shared_ptr<PublishedFetch> fetch = fetch_.lock(); - if (fetch != nullptr) { - fetch->session()->incoming_fetches_.erase(fetch->request_id_); - } - } - // webtransport::StreamVisitor implementation. - void OnCanRead() override {} // Write-only stream. - void OnCanWrite() override; - void OnResetStreamReceived( - webtransport::StreamErrorCode /*error*/) override { - } // Write-only stream - void OnStopSendingReceived( - webtransport::StreamErrorCode /*error*/) override {} - void OnWriteSideInDataRecvdState() override {} - - private: - std::weak_ptr<PublishedFetch> fetch_; - std::optional<PublishedObjectMetadata> last_object_; - webtransport::Stream* stream_; - }; - - MoqtFetchTask* fetch_task() { return fetch_.get(); } - MoqtSession* session() { return session_; } + MoqtFetchTask* fetch_task_ptr() { return fetch_.get(); } + // Can only be called once. + std::unique_ptr<MoqtFetchTask> release_fetch_task() { + auto on_return = absl::MakeCleanup([this] { fetch_ = nullptr; }); + return std::move(fetch_); + } uint64_t request_id() const { return request_id_; } void SetStreamId(webtransport::StreamId id) { stream_id_ = id; } private: - MoqtSession* session_; - std::unique_ptr<MoqtFetchTask> fetch_; uint64_t request_id_; // Store the stream ID in case a FETCH_CANCEL requires a reset. std::optional<webtransport::StreamId> stream_id_; + // Temporary storage until the stream is created. + std::unique_ptr<MoqtFetchTask> fetch_; }; class QUICHE_EXPORT DownstreamTrackStatus : public MoqtObjectListener { @@ -698,7 +674,7 @@ webtransport::Stream* OpenDataStream(PublishedSubscription& subscription, const NewStreamParameters& parameters); // Returns false if creation failed. - [[nodiscard]] bool OpenDataStream(std::shared_ptr<PublishedFetch> fetch, + [[nodiscard]] bool OpenDataStream(PublishedFetch* fetch, webtransport::SendOrder send_order); SubscribeRemoteTrack* RemoteTrackByAlias(uint64_t track_alias); @@ -709,17 +685,6 @@ // a session error if is not. bool ValidateRequestId(uint64_t request_id); - // TODO(martinduke): Delete once Fetch uses OutgoingSubgroupStream. - // Actually sends an object on |stream| with track alias or fetch ID |id| - // and metadata in |object|. Not for use with datagrams. Returns |true| if - // the write was successful. - bool WriteObjectToStream(webtransport::Stream* stream, uint64_t id, - const PublishedObjectMetadata& metadata, - std::vector<quiche::QuicheMemSlice> payload, - MoqtDataStreamType type, - std::optional<PublishedObjectMetadata> last_object, - bool fin); - void CancelFetch(uint64_t request_id); // Sends an OBJECT_ACK message for a specific subscribe ID. @@ -813,10 +778,8 @@ absl::flat_hash_set<uint64_t> used_track_aliases_; uint64_t next_local_track_alias_ = 0; - // Incoming FETCHes, indexed by fetch ID. There will be other pointers to - // PublishedFetch, so storing a shared_ptr in the map provides pointer - // stability for the value. - absl::flat_hash_map<uint64_t, std::shared_ptr<PublishedFetch>> + // Incoming FETCHes, indexed by fetch ID. + absl::flat_hash_map<uint64_t, std::unique_ptr<PublishedFetch>> incoming_fetches_; absl::flat_hash_map<uint64_t, std::unique_ptr<DownstreamTrackStatus>>
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 08d0a61..ca4d6bd 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -61,7 +61,6 @@ namespace { -using ::quic::test::MemSliceFromString; using ::testing::_; using ::testing::Optional; using ::testing::Return;
diff --git a/quiche/quic/moqt/moqt_uni_stream.cc b/quiche/quic/moqt/moqt_uni_stream.cc index e0b9c37..f074643 100644 --- a/quiche/quic/moqt/moqt_uni_stream.cc +++ b/quiche/quic/moqt/moqt_uni_stream.cc
@@ -17,9 +17,11 @@ #include "quiche/quic/core/quic_time.h" #include "quiche/quic/core/quic_utils.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_object.h" +#include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_trace_recorder.h" #include "quiche/quic/moqt/moqt_types.h" @@ -31,6 +33,49 @@ namespace moqt { +void OutgoingUniStream::UpdatePriority(MoqtPriority subscriber_priority) { + priority_.send_order = UpdateSendOrderForSubscriberPriority( + priority_.send_order, subscriber_priority); + stream_.SetPriority(priority_); +} + +bool OutgoingUniStream::WriteObjectToStream(PublishedObject& object, + MoqtDataStreamType type) { + MoqtObject header; + header.track_alias = track_identifier_; + header.group_id = object.metadata.location.group; + header.subgroup_id = object.metadata.subgroup; + header.object_id = object.metadata.location.object; + header.publisher_priority = object.metadata.publisher_priority; + header.extension_headers = object.metadata.extensions; + header.object_status = object.metadata.status; + header.payload_length = object.metadata.payload_length; + + quiche::QuicheBuffer serialized_header = + framer_.SerializeObjectHeader(header, type, last_object_); + std::vector<quiche::QuicheMemSlice> write_vector; + write_vector.reserve(object.payload.size() + 1); + write_vector.push_back(quiche::QuicheMemSlice(std::move(serialized_header))); + for (auto& slice : object.payload) { + write_vector.push_back(std::move(slice)); + } + webtransport::StreamWriteOptions options; + options.set_send_fin(!type.IsFetch() && object.fin_after_this); + absl::Status write_status = + stream_.Writev(absl::MakeSpan(write_vector), options); + if (!write_status.ok()) { + QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed) + << "Writing into MoQT stream failed despite CanWrite being true " + "before; status: " + << write_status; + return false; + } + QUIC_DVLOG(1) << "Stream " << stream_.GetStreamId() << " successfully wrote " + << object.metadata.location + << ", fin = " << object.fin_after_this; + return true; +} + OutgoingSubgroupStream::OutgoingSubgroupStream( MoqtFramer framer, webtransport::Stream* absl_nonnull stream, DataStreamIndex index, uint64_t first_object, @@ -38,15 +83,13 @@ std::shared_ptr<MoqtTrackPublisher> absl_nonnull track_publisher, webtransport::StreamPriority priority, uint64_t track_alias, MoqtTraceRecorder* absl_nonnull trace_recorder) - : stream_(*stream), + : OutgoingUniStream(framer, stream, priority, track_alias), index_(index), visitor_(std::move(visitor)), - framer_(framer), + track_alias_(track_alias), publisher_(track_publisher), - next_object_(first_object), - priority_(priority) { - stream_.SetPriority(priority_); + next_object_(first_object) { trace_recorder->RecordSubgroupStreamCreated(stream->GetStreamId(), track_alias_, index); } @@ -83,7 +126,7 @@ if (visitor != nullptr) { visitor->OnStreamTimeout(stream_->index_); } - stream_->stream_.ResetWithUserCode(kResetCodeDeliveryTimeout); + stream_->stream().ResetWithUserCode(kResetCodeDeliveryTimeout); } void OutgoingSubgroupStream::SendObjects() { @@ -91,7 +134,7 @@ if (visitor == nullptr) { return; } - while (stream_.CanWrite()) { + while (stream().CanWrite()) { std::optional<PublishedObject> object = publisher_->GetCachedObject( index_.group, index_.subgroup, next_object_, already_delivered_); if (!object.has_value()) { @@ -107,7 +150,7 @@ if (!visitor->InWindow(object->metadata.location)) { // It is possible that the next object became irrelevant due to a // REQUEST_UPDATE. Close the stream if so. - absl::Status status = webtransport::SendFinOnStream(stream_); + absl::Status status = webtransport::SendFinOnStream(stream()); QUICHE_BUG_IF(OutgoingSubgroupStream_fin_due_to_update, !status.ok()) << "Writing FIN failed despite CanWrite() being true."; return; @@ -118,10 +161,18 @@ visitor->clock()->ApproximateNow() - object->metadata.arrival_time > delivery_timeout) { visitor->OnStreamTimeout(index_); - stream_.ResetWithUserCode(kResetCodeDeliveryTimeout); + stream().ResetWithUserCode(kResetCodeDeliveryTimeout); // No class access below this line. return; } + // Always include extension header length, because it's difficult to know + // a priori if they're going to appear on a stream. + if (!last_object().has_value()) { + type_ = MoqtDataStreamType::Subgroup( + index_.subgroup, next_object_, false, + object->metadata.publisher_priority == + publisher_->extensions().default_publisher_priority()); + } uint64_t start_offset = already_delivered_; already_delivered_ += quic::MemSliceSpanTotalSize(absl::MakeSpan(object->payload)); @@ -138,26 +189,27 @@ webtransport::StreamWriteOptions options; options.set_send_fin(object->fin_after_this); absl::Status write_status = - stream_.Writev(absl::MakeSpan(object->payload), options); + stream().Writev(absl::MakeSpan(object->payload), options); if (!write_status.ok()) { QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed) << "Writing into MoQT stream failed despite CanWrite() being true " "before; status: " << write_status; - stream_.ResetWithUserCode(kResetCodeInternalError); + stream().ResetWithUserCode(kResetCodeInternalError); return; } } else { - if (!WriteObjectToStream(*object)) { - stream_.ResetWithUserCode(kResetCodeInternalError); + if (!WriteObjectToStream(*object, type_)) { + stream().ResetWithUserCode(kResetCodeInternalError); // No class access below this line. return; } - last_object_ = object->metadata; - next_object_ = last_object_->location.object; + set_last_object(object->metadata); + next_object_ = object->metadata.location.object; visitor->OnObjectSent(object->metadata.location); } - if (already_delivered_ != last_object_->payload_length) { + QUICHE_DCHECK(last_object().has_value()); + if (already_delivered_ != last_object()->payload_length) { return; } ++next_object_; @@ -176,7 +228,7 @@ return; } // All data has already been sent; send a pure FIN. - absl::Status status = webtransport::SendFinOnStream(stream_); + absl::Status status = webtransport::SendFinOnStream(stream()); QUICHE_BUG_IF(OutgoingSubgroupStream_fin_failed, !status.ok()) << "Writing pure FIN failed."; SubscriptionPublisherInterface* visitor = visitor_.GetIfAvailable(); @@ -202,48 +254,81 @@ delivery_timeout_alarm_->Set(deadline); } -bool OutgoingSubgroupStream::WriteObjectToStream(PublishedObject& object) { - MoqtObject header; - header.track_alias = track_alias_; - header.group_id = object.metadata.location.group; - header.subgroup_id = object.metadata.subgroup; - header.object_id = object.metadata.location.object; - header.publisher_priority = object.metadata.publisher_priority; - header.extension_headers = object.metadata.extensions; - header.object_status = object.metadata.status; - header.payload_length = object.metadata.payload_length; +OutgoingFetchStream::OutgoingFetchStream( + MoqtFramer framer, webtransport::Stream* absl_nonnull stream, + uint64_t request_id, webtransport::StreamPriority priority, + std::unique_ptr<MoqtFetchTask> incoming_objects, + FetchStreamCloseCallback close_callback, + MoqtTraceRecorder* absl_nonnull trace_recorder) + : OutgoingUniStream(framer, stream, priority, request_id), + incoming_objects_(std::move(incoming_objects)), + close_callback_(std::move(close_callback)) { + incoming_objects_->SetObjectAvailableCallback( + [this]() { this->OnCanWrite(); }); + trace_recorder->RecordFetchStreamCreated(stream->GetStreamId()); +} - // Always include extension header length, because it's difficult to know - // a priori if they're going to appear on a stream. - if (!last_object_.has_value()) { - type_ = MoqtDataStreamType::Subgroup( - index_.subgroup, next_object_, false, - object.metadata.publisher_priority == - publisher_->extensions().default_publisher_priority()); +OutgoingFetchStream::~OutgoingFetchStream() { + if (close_callback_ != nullptr) { + std::move(close_callback_)(); } - quiche::QuicheBuffer serialized_header = - framer_.SerializeObjectHeader(header, type_, last_object_); - std::vector<quiche::QuicheMemSlice> write_vector; - write_vector.reserve(object.payload.size() + 1); - write_vector.push_back(quiche::QuicheMemSlice(std::move(serialized_header))); - for (auto& slice : object.payload) { - write_vector.push_back(std::move(slice)); + close_callback_ = nullptr; +} + +void OutgoingFetchStream::OnCanWrite() { + PublishedObject object; + while (stream().CanWrite()) { + MoqtFetchTask::GetNextObjectResult result = + incoming_objects_->GetNextObject(object); + switch (result) { + case MoqtFetchTask::GetNextObjectResult::kSuccess: + // Skip ObjectDoesNotExist in FETCH. + if (object.metadata.status != MoqtObjectStatus::kNormal) { + QUICHE_BUG(quiche_bug_got_doesnotexist_in_fetch) + << "Got Non-normal object in FETCH"; + continue; + } + if (last_object().has_value() && + object.metadata.location == last_object()->location) { + // This is the continuation of the previous object. + webtransport::StreamWriteOptions options; + absl::Status write_status = + stream().Writev(absl::MakeSpan(object.payload), options); + if (!write_status.ok()) { + QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed) + << "Writing into MoQT stream failed despite CanWrite() being " + "true before; status: " + << write_status; + stream().ResetWithUserCode(kResetCodeInternalError); + return; + } + break; + } + if (WriteObjectToStream(object, MoqtDataStreamType::Fetch())) { + set_last_object(object.metadata); + } + break; + case MoqtFetchTask::GetNextObjectResult::kPending: + return; + case MoqtFetchTask::GetNextObjectResult::kEof: + // TODO(martinduke): Either prefetch the next object, or alter the API + // so that we're not sending FIN in a separate frame. + if (!webtransport::SendFinOnStream(stream()).ok()) { + QUICHE_DVLOG(1) << "Sending FIN onStream " << stream().GetStreamId() + << " failed"; + } + return; + case MoqtFetchTask::GetNextObjectResult::kError: + stream().ResetWithUserCode(static_cast<webtransport::StreamErrorCode>( + incoming_objects_->GetStatus().code())); + return; + } } - webtransport::StreamWriteOptions options; - options.set_send_fin(object.fin_after_this); - absl::Status write_status = - stream_.Writev(absl::MakeSpan(write_vector), options); - if (!write_status.ok()) { - QUICHE_BUG(MoqtSession_WriteObjectToStream_write_failed) - << "Writing into MoQT stream failed despite CanWrite being true " - "before; status: " - << write_status; - return false; - } - QUICHE_DVLOG(1) << "Stream " << stream_.GetStreamId() - << " successfully wrote " << object.metadata.location - << ", fin = " << object.fin_after_this; - return true; +} + +void OutgoingFetchStream::OnStopSendingReceived( + webtransport::StreamErrorCode error_code) { + stream().ResetWithUserCode(error_code); } } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_uni_stream.h b/quiche/quic/moqt/moqt_uni_stream.h index 9884b26..c2921c3 100644 --- a/quiche/quic/moqt/moqt_uni_stream.h +++ b/quiche/quic/moqt/moqt_uni_stream.h
@@ -8,11 +8,13 @@ #include <cstdint> #include <memory> #include <optional> +#include <utility> #include "absl/base/nullability.h" #include "quiche/quic/core/quic_alarm.h" #include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_object.h" @@ -21,6 +23,7 @@ #include "quiche/quic/moqt/moqt_trace_recorder.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_callbacks.h" #include "quiche/common/quiche_weak_ptr.h" #include "quiche/web_transport/web_transport.h" @@ -30,6 +33,55 @@ class MoqtSessionPeer; } +// A base class for locally initiated unidirectional streams, which can serve +// either a Subgroup or a FETCH response. It contains most of the machinery for +// managing the WebTransport stream. +class OutgoingUniStream : public webtransport::StreamVisitor { + public: + OutgoingUniStream(MoqtFramer framer, + webtransport::Stream* absl_nonnull stream, + webtransport::StreamPriority priority, + uint64_t track_identifier) + : stream_(*stream), + priority_(priority), + track_identifier_(track_identifier), + framer_(framer) { + stream_.SetPriority(priority_); + } + virtual ~OutgoingUniStream() = default; + + // webtransport::StreamVisitor implementation. + void OnCanRead() override {} // Write-only. + // OnCanWrite() deferred to children. + virtual void OnResetStreamReceived(webtransport::StreamErrorCode) override {} + // OnStopSendingReceived() deferred to children. + void OnWriteSideInDataRecvdState() override {} + + // Recomputes the send order and updates it for the associated stream. + void UpdatePriority(MoqtPriority subscriber_priority); + + protected: + webtransport::Stream& stream() { return stream_; } + std::optional<PublishedObjectMetadata>& last_object() { return last_object_; } + void set_last_object(PublishedObjectMetadata metadata) { + last_object_ = std::move(metadata); + } + + // Writes an object to the stream. Returns false if the write failed. The + // caller should reset the stream if that happens. + bool WriteObjectToStream(PublishedObject& object, MoqtDataStreamType type); + + private: + webtransport::Stream& stream_; // Always valid because it owns this object. + webtransport::StreamPriority priority_; + uint64_t track_identifier_; // track alias or fetch request ID. + + MoqtFramer framer_; + // Used to compute the object ID diff and pass metadata for partial objects. + // If nullopt, the stream header has not been written yet. + std::optional<PublishedObjectMetadata> last_object_; +}; + // This interface provides information about the subscription. class SubscriptionPublisherInterface { public: @@ -48,8 +100,7 @@ }; // This is for subscriptions only. FETCH uses its own construct. -class QUICHE_EXPORT OutgoingSubgroupStream - : public webtransport::StreamVisitor { +class QUICHE_EXPORT OutgoingSubgroupStream : public OutgoingUniStream { public: // |visitor| is owned by the subscription, so the WeakPtr also serves as a // liveness token. @@ -62,12 +113,9 @@ MoqtTraceRecorder* absl_nonnull trace_recorder); ~OutgoingSubgroupStream(); - // webtransport::StreamVisitor implementation. - void OnCanRead() override {} + // webtransport::StreamVisitor overrides. void OnCanWrite() override; - void OnResetStreamReceived(webtransport::StreamErrorCode) override {} void OnStopSendingReceived(webtransport::StreamErrorCode error_code) override; - void OnWriteSideInDataRecvdState() override {} class DeliveryTimeoutDelegate : public quic::QuicAlarm::DelegateWithoutContext { @@ -86,13 +134,6 @@ // Reset can be called directly on the stream, with no need to involve the // visitor. - // Recomputes the send order and updates it for the associated stream. - void UpdatePriority(MoqtPriority subscriber_priority) { - priority_.send_order = UpdateSendOrderForSubscriberPriority( - priority_.send_order, subscriber_priority); - stream_.SetPriority(priority_); - } - // Creates and sets an alarm for the given deadline. Does nothing if the // alarm is already created. void CreateAndSetAlarm(quic::QuicTime deadline); @@ -106,14 +147,9 @@ // the class, on a write error. void SendObjects(); - // Writes an object to the stream. Returns false if the write failed. The - // caller should reset the stream if that happens. - bool WriteObjectToStream(PublishedObject& object); - - webtransport::Stream& stream_; // Always valid because it owns this object. DataStreamIndex index_; quiche::QuicheWeakPtr<SubscriptionPublisherInterface> visitor_; - MoqtFramer framer_; + MoqtDataStreamType type_; uint64_t track_alias_; std::shared_ptr<MoqtTrackPublisher> publisher_; @@ -124,16 +160,34 @@ // Number of payload bytes from next_object_ that has already been written // to the stream. uint64_t already_delivered_ = 0; - // Used in subgroup streams to compute the object ID diff and pass metadata - // for partial objects. If nullopt, the stream header has not been written - // yet. - std::optional<PublishedObjectMetadata> last_object_; - webtransport::StreamPriority priority_; + // If this data stream is for SUBSCRIBE, reset it if an object has been // excessively delayed per Section 7.1.1.2. std::unique_ptr<quic::QuicAlarm> delivery_timeout_alarm_; }; +using FetchStreamCloseCallback = quiche::SingleUseCallback<void()>; + +class QUICHE_EXPORT OutgoingFetchStream : public OutgoingUniStream { + public: + OutgoingFetchStream(MoqtFramer framer, + webtransport::Stream* absl_nonnull stream, + uint64_t request_id, + webtransport::StreamPriority priority, + std::unique_ptr<MoqtFetchTask> incoming_objects, + FetchStreamCloseCallback close_callback, + MoqtTraceRecorder* absl_nonnull trace_recorder); + ~OutgoingFetchStream(); + + // webtransport::StreamVisitor implementation. + void OnCanWrite() override; + void OnStopSendingReceived(webtransport::StreamErrorCode error_code) override; + + private: + std::unique_ptr<MoqtFetchTask> incoming_objects_; + FetchStreamCloseCallback close_callback_; +}; + } // namespace moqt #endif // QUICHE_QUIC_MOQT_MOQT_UNI_STREAM_H_
diff --git a/quiche/quic/moqt/moqt_uni_stream_test.cc b/quiche/quic/moqt/moqt_uni_stream_test.cc index 45f82b4..4ebf4c5 100644 --- a/quiche/quic/moqt/moqt_uni_stream_test.cc +++ b/quiche/quic/moqt/moqt_uni_stream_test.cc
@@ -16,6 +16,7 @@ #include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_names.h" @@ -304,6 +305,132 @@ stream_->OnCanWrite(); } +class OutgoingFetchStreamTest : public quic::test::QuicTest { + public: + OutgoingFetchStreamTest() + : task_(std::make_unique<StrictMock<MockFetchTask>>()), + task_ptr_(task_.get()), + trace_recorder_(nullptr) { + EXPECT_CALL(mock_stream_, GetStreamId()).WillRepeatedly(Return(14)); + EXPECT_CALL(mock_stream_, SetPriority); + stream_ = std::make_unique<OutgoingFetchStream>( + framer_, &mock_stream_, 10, webtransport::StreamPriority(), + std::move(task_), [this]() { close_callback_called_ = true; }, + &trace_recorder_); + } + ~OutgoingFetchStreamTest() override { + stream_.reset(); + EXPECT_TRUE(close_callback_called_); + } + + protected: + MoqtFramer framer_{true}; + StrictMock<webtransport::test::MockStream> mock_stream_; + std::unique_ptr<StrictMock<MockFetchTask>> task_; + MockFetchTask* task_ptr_; + MoqtTraceRecorder trace_recorder_; + bool close_callback_called_ = false; + std::unique_ptr<OutgoingFetchStream> stream_; +}; + +TEST_F(OutgoingFetchStreamTest, OnCanWritePending) { + EXPECT_CALL(mock_stream_, CanWrite()).WillOnce(Return(true)); + EXPECT_CALL(*task_ptr_, GetNextObject) + .WillOnce(Return(MoqtFetchTask::kPending)); + stream_->OnCanWrite(); +} + +TEST_F(OutgoingFetchStreamTest, OnCanWriteSuccess) { + PublishedObject obj = DefaultObject(); + EXPECT_CALL(mock_stream_, CanWrite()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(*task_ptr_, GetNextObject).WillOnce([&](PublishedObject& out) { + out = std::move(obj); + return MoqtFetchTask::kSuccess; + }); + EXPECT_CALL(mock_stream_, Writev).WillOnce(Return(absl::OkStatus())); + stream_->OnCanWrite(); +} + +TEST_F(OutgoingFetchStreamTest, OnCanWriteNonNormalStatus) { + PublishedObject obj = DefaultObject(); + obj.metadata.status = MoqtObjectStatus::kObjectDoesNotExist; + EXPECT_CALL(mock_stream_, CanWrite()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(*task_ptr_, GetNextObject).WillOnce([&](PublishedObject& out) { + out = std::move(obj); + return MoqtFetchTask::kSuccess; + }); + EXPECT_QUICHE_BUG(stream_->OnCanWrite(), "Got Non-normal object in FETCH"); +} + +TEST_F(OutgoingFetchStreamTest, OnCanWriteEof) { + EXPECT_CALL(mock_stream_, CanWrite()).WillOnce(Return(true)); + EXPECT_CALL(*task_ptr_, GetNextObject).WillOnce(Return(MoqtFetchTask::kEof)); + EXPECT_CALL(mock_stream_, Writev) + .WillOnce([](absl::Span<quiche::QuicheMemSlice> data, + const webtransport::StreamWriteOptions& options) { + EXPECT_TRUE(data.empty()); + EXPECT_TRUE(options.send_fin()); + return absl::OkStatus(); + }); + stream_->OnCanWrite(); +} + +TEST_F(OutgoingFetchStreamTest, OnCanWriteEofFail) { + EXPECT_CALL(mock_stream_, CanWrite()).WillOnce(Return(true)); + EXPECT_CALL(*task_ptr_, GetNextObject).WillOnce(Return(MoqtFetchTask::kEof)); + EXPECT_CALL(mock_stream_, Writev) + .WillOnce(Return(absl::InternalError("error"))); + stream_->OnCanWrite(); +} + +TEST_F(OutgoingFetchStreamTest, OnCanWriteWriteError) { + PublishedObject obj = DefaultObject(); + EXPECT_CALL(mock_stream_, CanWrite()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(*task_ptr_, GetNextObject).WillOnce([&](PublishedObject& out) { + out = std::move(obj); + return MoqtFetchTask::kSuccess; + }); + EXPECT_CALL(mock_stream_, Writev) + .WillOnce(Return(absl::InternalError("error"))); + EXPECT_QUICHE_BUG(stream_->OnCanWrite(), + "Writing into MoQT stream failed despite CanWrite being " + "true before; status: INTERNAL: error"); +} + +TEST_F(OutgoingFetchStreamTest, OnCanWriteError) { + EXPECT_CALL(mock_stream_, CanWrite()).WillOnce(Return(true)); + EXPECT_CALL(*task_ptr_, GetNextObject) + .WillOnce(Return(MoqtFetchTask::kError)); + EXPECT_CALL(*task_ptr_, GetStatus()) + .WillOnce(Return(absl::InternalError("error"))); + EXPECT_CALL( + mock_stream_, + ResetWithUserCode(static_cast<uint64_t>(absl::StatusCode::kInternal))); + stream_->OnCanWrite(); +} + +TEST_F(OutgoingFetchStreamTest, OnStopSendingReceived) { + EXPECT_CALL(mock_stream_, ResetWithUserCode(17)); + stream_->OnStopSendingReceived(17); +} + +TEST_F(OutgoingFetchStreamTest, UpdatePriority) { + EXPECT_CALL(mock_stream_, SetPriority(webtransport::StreamPriority{ + 0, 0x3fc0000000000000ULL})); + stream_->UpdatePriority(0); +} + +TEST_F(OutgoingFetchStreamTest, ObjectAvailableCallback) { + EXPECT_CALL(mock_stream_, CanWrite()).WillOnce(Return(false)); + task_ptr_->CallObjectsAvailableCallback(); +} + } // namespace } // namespace moqt::test