Move IncomingDataStream to moqt_uni_stream.h. Other preparatory changes for requests on bidi streams. MoqtSessionTest now has numerous redundant, verbose tests. They are retained here to show that changes are roughly a no-op, but will delete in a later CL. PiperOrigin-RevId: 922928416
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index d715290..1ce2c70 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -31,10 +31,8 @@ #include "quiche/quic/moqt/moqt_session.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/moqt_session_interface.h" -#include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h" -#include "quiche/quic/moqt/test_tools/moqt_session_peer.h" #include "quiche/quic/moqt/test_tools/moqt_simulator_harness.h" #include "quiche/quic/test_tools/quic_test_utils.h" #include "quiche/quic/test_tools/simulator/test_harness.h" @@ -676,12 +674,7 @@ // a new attempt. EXPECT_TRUE(client_->session()->Subscribe(full_track_name, &subscribe_visitor_, parameters)); - EXPECT_CALL(subscribe_visitor_, OnReply) - .WillOnce( - [](const FullTrackName&, - std::variant<SubscribeOkData, MoqtRequestErrorInfo> response) { - EXPECT_TRUE(std::holds_alternative<MoqtRequestErrorInfo>(response)); - }); // Teardown + EXPECT_CALL(subscribe_visitor_, OnPublishDone); // Test teardown } TEST_F(MoqtIntegrationTest, ObjectAcks) {
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 55b3398..ac83fe8 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -173,7 +173,8 @@ void MoqtSession::OnIncomingUnidirectionalStreamAvailable() { while (webtransport::Stream* stream = session_->AcceptIncomingUnidirectionalStream()) { - stream->SetVisitor(std::make_unique<IncomingDataStream>(this, stream)); + stream->SetVisitor( + std::make_unique<IncomingDataStream>(stream, this, callbacks_.clock)); stream->visitor()->OnCanRead(); } } @@ -484,7 +485,28 @@ SendControlMessage(framer_.SerializeSubscribe(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for " << message.full_track_name; - auto track = std::make_unique<SubscribeRemoteTrack>(message, visitor); + auto track = std::make_unique<SubscribeRemoteTrack>( + message, visitor, + [this, request_id = message.request_id, ftn = name]() { + // Deletion callback + subscribe_by_name_.erase(ftn); + upstream_by_id_.erase(request_id); + }, + [this](uint64_t alias, SubscribeRemoteTrack* track) { + // Track alias registry callback. + if (is_closing_) { + return true; + } + if (track == nullptr) { + subscribe_by_alias_.erase(alias); + return true; + } + auto [it, success] = subscribe_by_alias_.try_emplace(alias, track); + if (!success) { + Error(MoqtError::kDuplicateTrackAlias, ""); + } + return success; + }); subscribe_by_name_.emplace(message.full_track_name, track.get()); upstream_by_id_.emplace(message.request_id, std::move(track)); return true; @@ -535,7 +557,7 @@ MoqtUnsubscribe message; message.request_id = track->request_id(); SendControlMessage(framer_.SerializeUnsubscribe(message)); - DestroySubscription(track); + track->Destroy(); } bool MoqtSession::Fetch(const FullTrackName& name, @@ -565,7 +587,10 @@ SendControlMessage(framer_.SerializeFetch(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent FETCH message for " << name; auto fetch = std::make_unique<UpstreamFetch>( - message, std::get<StandaloneFetch>(message.fetch), std::move(callback)); + message, std::get<StandaloneFetch>(message.fetch), std::move(callback), + [this, id = message.request_id]() { // Deletion callback + upstream_by_id_.erase(id); + }); upstream_by_id_.emplace(message.request_id, std::move(fetch)); return true; } @@ -618,8 +643,11 @@ fetch.parameters = parameters; SendControlMessage(framer_.SerializeFetch(fetch)); QUIC_DLOG(INFO) << ENDPOINT << "Sent Joining FETCH message for " << name; - auto upstream_fetch = - std::make_unique<UpstreamFetch>(fetch, name, std::move(callback)); + auto upstream_fetch = std::make_unique<UpstreamFetch>( + fetch, name, std::move(callback), + /*Deletion callback=*/[this, id = fetch.request_id]() { + upstream_by_id_.erase(id); + }); upstream_by_id_.emplace(fetch.request_id, std::move(upstream_fetch)); return true; } @@ -661,28 +689,6 @@ published_subscriptions_.erase(it); } -void MoqtSession::MaybeDestroySubscription(SubscribeRemoteTrack* subscribe) { - if (subscribe != nullptr && subscribe->all_streams_closed()) { - DestroySubscription(subscribe); - } -} - -void MoqtSession::DestroySubscription(SubscribeRemoteTrack* subscribe) { - if (subscribe->ErrorIsAllowed()) { - subscribe->visitor()->OnReply( - subscribe->full_track_name(), - MoqtRequestErrorInfo{RequestErrorCode::kNotSupported, std::nullopt, - "Subscription closed"}); - } else { - subscribe->visitor()->OnPublishDone(subscribe->full_track_name()); - } - subscribe_by_name_.erase(subscribe->full_track_name()); - if (subscribe->track_alias().has_value()) { - subscribe_by_alias_.erase(*subscribe->track_alias()); - } - upstream_by_id_.erase(subscribe->request_id()); -} - void MoqtSession::UpdateTrackPriority( uint64_t request_id, std::optional<MoqtTrackPriority> old_priority, MoqtTrackPriority new_priority) { @@ -997,13 +1003,10 @@ SubscribeRemoteTrack* subscribe = absl::down_cast<SubscribeRemoteTrack*>(track); subscribe->OnObjectOrOk(); - auto [it, success] = - session_->subscribe_by_alias_.try_emplace(message.track_alias, subscribe); - if (!success) { - session_->Error(MoqtError::kDuplicateTrackAlias, ""); + if (!subscribe->set_track_alias(message.track_alias)) { + // A duplicate track alias could destroy the session. return absl::OkStatus(); } - subscribe->set_track_alias(message.track_alias); std::optional<SubscriptionFilter> filter = subscribe->parameters().subscription_filter; if (filter.has_value()) { @@ -1094,17 +1097,13 @@ } else { SubscribeRemoteTrack* subscribe = absl::down_cast<SubscribeRemoteTrack*>(track); - // Delete the by-name entry at this point prevents Subscribe() from - // throwing an error due to a duplicate track name. The other entries for - // this subscribe will be deleted after calling Subscribe(). - session_->subscribe_by_name_.erase(subscribe->full_track_name()); if (subscribe->visitor() != nullptr) { subscribe->visitor()->OnReply(subscribe->full_track_name(), error_info); } } if (!session_->is_closing_) { // The visitor might have closed the session. - session_->upstream_by_id_.erase(message.request_id); + track->Destroy(); } return absl::OkStatus(); } @@ -1157,11 +1156,8 @@ auto* subscribe = absl::down_cast<SubscribeRemoteTrack*>(it->second.get()); QUIC_DLOG(INFO) << ENDPOINT << "Received a PUBLISH_DONE for " << it->second->full_track_name(); - subscribe->OnPublishDone( - message.stream_count, session_->callbacks_.clock, - absl::WrapUnique(session_->alarm_factory_->CreateAlarm( - new PublishDoneDelegate(session_, subscribe)))); - session_->MaybeDestroySubscription(subscribe); + subscribe->OnPublishDone(message.stream_count, session_->callbacks_.clock, + session_->alarm_factory_.get()); return absl::OkStatus(); } @@ -1539,268 +1535,6 @@ error_reason); } -void MoqtSession::IncomingDataStream::OnObjectMessage(const MoqtObject& message, - absl::string_view payload, - bool end_of_message) { - QUICHE_DVLOG(1) << ENDPOINT << "Received OBJECT message on stream " - << stream_->GetStreamId() << " for track alias " - << message.track_alias << " with sequence " - << message.group_id << ":" << message.object_id - << " priority " << message.publisher_priority << " length " - << payload.size() << " length " << message.payload_length - << (end_of_message ? "F" : ""); - if (!session_->parameters_.deliver_partial_objects) { - if (!end_of_message) { // Buffer partial object. - if (partial_object_.empty()) { - // Avoid redundant allocations by reserving the appropriate amount of - // memory if known. - partial_object_.reserve(message.payload_length); - } - absl::StrAppend(&partial_object_, payload); - return; - } - if (!partial_object_.empty()) { // Completes the object - absl::StrAppend(&partial_object_, payload); - payload = absl::string_view(partial_object_); - } - } - if (payload.empty() && bytes_received_this_object_ > 0 && !end_of_message) { - return; // Nothing arrived. - } - if (!parser_.stream_type().has_value()) { - QUICHE_BUG(quic_bug_object_with_no_stream_type) - << "Object delivered without a stream type"; - return; - } - // Get a pointer to the upstream state. - RemoteTrack* track = track_.GetIfAvailable(); - if (track == nullptr) { - track = (parser_.stream_type()->IsFetch()) - // message.track_alias is actually a fetch ID for fetches. - ? session_->RemoteTrackById(message.track_alias) - : session_->RemoteTrackByAlias(message.track_alias); - if (track == nullptr) { - stream_->SendStopSending(kResetCodeCancelled); - // Received object for nonexistent track. - return; - } - track_ = track->weak_ptr(); - } - if (!track->CheckDataStreamType(*parser_.stream_type())) { - session_->Error(MoqtError::kProtocolViolation, - "Received object for a track with a different stream type"); - return; - } - Location location(message.group_id, message.object_id); - if (!track->InWindow(Location(message.group_id, message.object_id))) { - // This is not an error. It can be the result of a recent REQUEST_UPDATE. - return; - } - if (!track->is_fetch()) { - if (!index_.has_value()) { - if (!message.subgroup_id.has_value()) { - QUICHE_BUG(quiche_bug_moqt_subgroup_id_missing) - << "Missing subgroup ID on SUBSCRIBE stream"; - return; - } - index_ = DataStreamIndex(message.group_id, *message.subgroup_id); - } - if (no_more_objects_) { - // Already got a stream-ending object. While the lower layer won't - // deliver data after the FIN, there could have been an EndOfGroup or - // EndOfTrack signal. - session_->OnMalformedTrack(track); - return; - } - if (end_of_message) { - next_object_id_ = message.object_id + 1; - if (message.object_status == MoqtObjectStatus::kEndOfTrack || - message.object_status == MoqtObjectStatus::kEndOfGroup) { - no_more_objects_ = true; - } - } - SubscribeRemoteTrack* subscribe = - absl::down_cast<SubscribeRemoteTrack*>(track); - subscribe->OnObjectOrOk(); - if (subscribe->visitor() != nullptr) { - PublishedObjectMetadata metadata; - metadata.location = Location(message.group_id, message.object_id); - metadata.subgroup = message.subgroup_id; - metadata.extensions = message.extension_headers; - metadata.status = message.object_status; - metadata.publisher_priority = message.publisher_priority; - metadata.payload_length = message.payload_length; - metadata.arrival_time = session_->callbacks_.clock->Now(); - subscribe->visitor()->OnObjectFragment(track->full_track_name(), metadata, - payload, - bytes_received_this_object_); - } - } else { // FETCH - track->OnObjectOrOk(); - UpstreamFetch* fetch = absl::down_cast<UpstreamFetch*>(track); - if (!fetch->LocationIsValid(Location(message.group_id, message.object_id), - message.object_status, end_of_message)) { - // TODO(martinduke): in https://github.com/moq-wg/moq-transport/pull/1409 - // I make the case that this should be a protocol violation. Update if - // that proposal is accepted (at which point - // QuicSession::OnMalformedTrack can be removed, since all the - // remaining conditions are at the application layer). - session_->OnMalformedTrack(track); - return; - } - UpstreamFetch::UpstreamFetchTask* task = fetch->task(); - if (task == nullptr) { - // The application killed the FETCH. - stream_->SendStopSending(kResetCodeCancelled); - return; - } - if (!task->HasObject()) { - task->NewObject(message); - } - if (task->NeedsMorePayload() && !payload.empty()) { - task->AppendPayloadToObject(payload); - } - } - if (end_of_message) { - bytes_received_this_object_ = 0; - } else { - bytes_received_this_object_ += payload.size(); - } - partial_object_.clear(); -} - -MoqtSession::IncomingDataStream::~IncomingDataStream() { - QUICHE_DVLOG(1) << ENDPOINT << "Destroying incoming data stream " - << stream_->GetStreamId(); - if (!parser_.track_alias().has_value()) { - QUIC_DVLOG(1) << ENDPOINT - << "Destroying incoming data stream before " - "learning track alias"; - return; - } - if (!track_.IsValid()) { - return; - } - if (parser_.stream_type().has_value() && parser_.stream_type()->IsFetch()) { - session_->upstream_by_id_.erase(*parser_.track_alias()); - return; - } - if (session_->is_closing_) { - return; - } - // It's a subscribe. - SubscribeRemoteTrack* subscribe = - absl::down_cast<SubscribeRemoteTrack*>(track_.GetIfAvailable()); - if (subscribe == nullptr) { - return; - } - subscribe->OnStreamClosed(fin_received_, index_); - session_->MaybeDestroySubscription(subscribe); -} - -void MoqtSession::IncomingDataStream::MaybeReadOneObject() { - if (!parser_.track_alias().has_value() || - !parser_.stream_type().has_value() || !parser_.stream_type()->IsFetch()) { - QUICHE_BUG(quic_bug_read_one_object_parser_unexpected_state) - << "Requesting object, parser in unexpected state"; - } - RemoteTrack* track = session_->RemoteTrackById(*parser_.track_alias()); - if (track == nullptr || !track->is_fetch()) { - QUICHE_BUG(quic_bug_read_one_object_track_unexpected_state) - << "Requesting object, track in unexpected state"; - return; - } - UpstreamFetch* fetch = absl::down_cast<UpstreamFetch*>(track); - UpstreamFetch::UpstreamFetchTask* task = fetch->task(); - if (task == nullptr) { - return; - } - if (task->HasObject() && !task->NeedsMorePayload()) { - return; // The message is complete. Do not read more. - } - uint64_t start_length = task->payload_length(); - parser_.ReadAtMostOneObject(); - // If it read an object, it called OnObjectMessage and may have altered the - // task's object state. - if (task->payload_length() > start_length) { - task->NotifyNewObject(); - } -} - -void MoqtSession::IncomingDataStream::OnCanRead() { - if (!parser_.stream_type().has_value()) { - parser_.ReadStreamType(); - if (!parser_.stream_type().has_value()) { - return; - } - } - if (parser_.stream_type()->IsPadding()) { - (void)stream_->SkipBytes(stream_->ReadableBytes()); - return; - } - bool knew_track_alias = parser_.track_alias().has_value(); - if (!knew_track_alias) { - parser_.ReadTrackAlias(); - if (!parser_.track_alias().has_value()) { - return; - } - } - QUICHE_CHECK(parser_.stream_type().has_value()); - QUICHE_CHECK(parser_.track_alias().has_value()); - if (parser_.stream_type()->IsSubgroup()) { - if (!knew_track_alias) { - // This is a new stream for a subscribe. Notify the subscription. - auto it = session_->subscribe_by_alias_.find(*parser_.track_alias()); - if (it == session_->subscribe_by_alias_.end()) { - QUIC_DLOG(INFO) << ENDPOINT - << "Received object for a track with no SUBSCRIBE"; - // This is a not a session error because there might be an UNSUBSCRIBE - // or SUBSCRIBE_OK (containing the track alias) in flight. - stream_->SendStopSending(kResetCodeCancelled); - return; - } - it->second->OnStreamOpened(); - parser_.set_default_publisher_priority( - it->second->default_publisher_priority()); - } - parser_.ReadAllData(); - return; - } - auto it = session_->upstream_by_id_.find(*parser_.track_alias()); - if (it == session_->upstream_by_id_.end()) { - QUIC_DLOG(INFO) << ENDPOINT << "Received object for a track with no FETCH"; - // This is a not a session error because there might be an UNSUBSCRIBE in - // flight. - stream_->SendStopSending(kResetCodeCancelled); - return; - } - if (it->second == nullptr) { - QUICHE_BUG(quiche_bug_moqt_fetch_pointer_is_null) - << "Fetch pointer is null"; - return; - } - UpstreamFetch* fetch = absl::down_cast<UpstreamFetch*>(it->second.get()); - if (!knew_track_alias) { - // If the task already exists (FETCH_OK has arrived), the callback will - // immediately execute to read the first object. Otherwise, it will only - // execute when the task is created or a cached object is read. - fetch->OnStreamOpened([this]() { MaybeReadOneObject(); }); - return; - } - MaybeReadOneObject(); -} - -void MoqtSession::IncomingDataStream::OnControlMessageReceived() { - session_->Error(MoqtError::kProtocolViolation, - "Received a control message on a data stream"); -} - -void MoqtSession::IncomingDataStream::OnParsingError(MoqtError error_code, - absl::string_view reason) { - session_->Error(error_code, absl::StrCat("Parse error: ", reason)); -} - - void MoqtSession::OnMalformedTrack(RemoteTrack* track) { if (!track->is_fetch()) { absl::down_cast<SubscribeRemoteTrack*>(track)->visitor()->OnMalformedTrack( @@ -1838,13 +1572,7 @@ RequestErrorCode::kUninterested, std::nullopt, "Session closed"}); } while (!upstream_by_id_.empty()) { - auto upstream = upstream_by_id_.begin(); - if (upstream->second->is_fetch()) { - upstream_by_id_.erase(upstream); - continue; - } - DestroySubscription( - absl::down_cast<SubscribeRemoteTrack*>(upstream->second.get())); + upstream_by_id_.begin()->second->Destroy(); } } @@ -1852,9 +1580,13 @@ if (is_closing_) { return; } + auto it = upstream_by_id_.find(request_id); + if (it == upstream_by_id_.end()) { + return; + } + it->second->Destroy(); // This is only called from the callback where UpstreamFetchTask has been // destroyed, so there is no need to notify the application. - upstream_by_id_.erase(request_id); ControlStream* stream = GetControlStream(); if (stream == nullptr) { return;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 664004a..fe3da70 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -59,6 +59,7 @@ class QUICHE_EXPORT MoqtSession : public MoqtSessionInterface, public SessionToPublisherInterface, + public SessionToUniStreamInterface, public webtransport::SessionVisitor { public: MoqtSession(webtransport::Session* session, MoqtSessionParameters parameters, @@ -147,6 +148,30 @@ return is_closing_ ? nullptr : session_; } + // SessionToUniStreamInterface implementation. + bool deliver_partial_objects() const { + return parameters_.deliver_partial_objects; + } + // Called when the incoming track is malformed per Section 2.5 of + // draft-ietf-moqt-moq-transport-12. Unsubscribe and notify the application so + // the error can be propagated downstream, if necessary. + void OnMalformedTrack(RemoteTrack* track); + quiche::QuicheWeakPtr<RemoteTrack> GetSubscribe(uint64_t track_alias) { + auto it = subscribe_by_alias_.find(track_alias); + if (it == subscribe_by_alias_.end()) { + return quiche::QuicheWeakPtr<RemoteTrack>(); + } + return it->second->weak_ptr(); + } + quiche::QuicheWeakPtr<RemoteTrack> GetFetch(uint64_t request_id) { + auto it = upstream_by_id_.find(request_id); + if (it == upstream_by_id_.end()) { + return quiche::QuicheWeakPtr<RemoteTrack>(); + } + return it->second->weak_ptr(); + } + // Error() defined in MoqtSessionInterface. + // Send a GOAWAY message to the peer. |new_session_uri| must be empty if // called by the client. void GoAway(absl::string_view new_session_uri); @@ -289,53 +314,6 @@ // Must be last. quiche::QuicheWeakPtrFactory<ControlStream> weak_ptr_factory_; }; - class QUICHE_EXPORT IncomingDataStream : public webtransport::StreamVisitor, - public MoqtDataParserVisitor { - public: - IncomingDataStream(MoqtSession* session, webtransport::Stream* stream) - : session_(session), stream_(stream), parser_(stream, this) {} - ~IncomingDataStream(); - - // webtransport::StreamVisitor implementation. - void OnCanRead() override; - void OnCanWrite() override {} - void OnResetStreamReceived(webtransport::StreamErrorCode) override {} - void OnStopSendingReceived( - webtransport::StreamErrorCode /*error*/) override {} - void OnWriteSideInDataRecvdState() override {} - - // MoqtParserVisitor implementation. - // TODO: Handle a stream FIN. - void OnObjectMessage(const MoqtObject& message, absl::string_view payload, - bool end_of_message) override; - void OnFin() override { fin_received_ = true; } - void OnParsingError(MoqtError error_code, - absl::string_view reason) override; - - quic::Perspective perspective() const { - return session_->parameters_.perspective; - } - - webtransport::Stream* stream() const { return stream_; } - - void MaybeReadOneObject(); - - private: - friend class test::MoqtSessionPeer; - void OnControlMessageReceived(); - - uint64_t next_object_id_ = 0; - bool no_more_objects_ = false; // EndOfGroup or EndOfTrack was received. - std::optional<DataStreamIndex> index_; // Only set for subscribe. - bool fin_received_ = false; - MoqtSession* session_; - webtransport::Stream* stream_; - // Once the subscribe ID is identified, set it here. - quiche::QuicheWeakPtr<RemoteTrack> track_; - MoqtDataParser parser_; - std::string partial_object_; - uint64_t bytes_received_this_object_ = 0; - }; class QUICHE_EXPORT PublishedFetch { public: @@ -431,21 +409,6 @@ MoqtSession* session_; }; - class PublishDoneDelegate : public quic::QuicAlarm::DelegateWithoutContext { - public: - PublishDoneDelegate(MoqtSession* session, SubscribeRemoteTrack* subscribe) - : session_(session), subscribe_(subscribe) {} - - void OnAlarm() override { session_->DestroySubscription(subscribe_); } - - private: - MoqtSession* session_; - SubscribeRemoteTrack* subscribe_; - }; - - void MaybeDestroySubscription(SubscribeRemoteTrack* subscribe); - void DestroySubscription(SubscribeRemoteTrack* subscribe); - // Returns the pointer to the control stream, or nullptr if none is present. ControlStream* GetControlStream() { return control_stream_.GetIfAvailable(); } // Sends a message on the control stream; QUICHE_DCHECKs if no control stream @@ -486,11 +449,6 @@ return parameters_.support_object_acks && peer_supports_object_ack_; } - // Called when the incoming track is malformed per Section 2.5 of - // draft-ietf-moqt-moq-transport-12. Unsubscribe and notify the application so - // the error can be propagated downstream, if necessary. - void OnMalformedTrack(RemoteTrack* track); - // When the session is closing, clean up state without waiting for the // underlying WebTransport session to be destroyed. void CleanUpState(); @@ -519,7 +477,9 @@ MoqtTraceRecorder trace_recorder_; // Upstream SUBSCRIBE state. - // Upstream SUBSCRIBEs and FETCHes, indexed by subscribe_id. + // Upstream SUBSCRIBEs and FETCHes, indexed by subscribe_id. Do not erase + // directly, call RemoteTrack::Destroy(), except in deletion callbacks passed + // to RemoteTrack. absl::flat_hash_map<uint64_t, std::unique_ptr<RemoteTrack>> upstream_by_id_; // All SUBSCRIBEs, indexed by track_alias. absl::flat_hash_map<uint64_t, SubscribeRemoteTrack*> subscribe_by_alias_;
diff --git a/quiche/quic/moqt/moqt_session_interface.h b/quiche/quic/moqt/moqt_session_interface.h index af1d6d0..7b34a1d 100644 --- a/quiche/quic/moqt/moqt_session_interface.h +++ b/quiche/quic/moqt/moqt_session_interface.h
@@ -90,9 +90,7 @@ class SubscribeVisitor { public: virtual ~SubscribeVisitor() = default; - // Called when the session receives a response to the SUBSCRIBE, unless it's - // a REQUEST_ERROR with a new track_alias. In that case, the session will - // automatically retry. + // Called when the session receives a response to the SUBSCRIBE. virtual void OnReply( const FullTrackName& full_track_name, std::variant<SubscribeOkData, MoqtRequestErrorInfo> response) = 0; @@ -104,6 +102,8 @@ virtual void OnObjectFragment(const FullTrackName& full_track_name, const PublishedObjectMetadata& metadata, absl::string_view object, uint64_t offset) = 0; + // Called when the subscription state goes away, regardless of whether or not + // there was a PUBLISH_DONE message. virtual void OnPublishDone(FullTrackName full_track_name) = 0; // Called when the track is malformed per Section 2.5 of // draft-ietf-moqt-moq-transport-12. If the application is a relay, it MUST
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 7802cff..662a6bc 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -1142,7 +1142,8 @@ }; std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType); + kDefaultSubgroupStreamType, 2, + &remote_track_visitor_); EXPECT_CALL(remote_track_visitor_, OnObjectFragment) .WillOnce([&](const FullTrackName& track_name, @@ -1180,7 +1181,8 @@ }; std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType); + kDefaultSubgroupStreamType, 2, + &remote_track_visitor_); EXPECT_CALL(remote_track_visitor_, OnObjectFragment).Times(1); EXPECT_CALL(mock_stream_, GetStreamId()) @@ -1211,7 +1213,8 @@ }; std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session, &mock_stream_, - kDefaultSubgroupStreamType); + kDefaultSubgroupStreamType, 2, + &remote_track_visitor_); EXPECT_CALL(mock_stream_, GetStreamId()) .WillRepeatedly(Return(kIncomingUniStreamId)); EXPECT_CALL(remote_track_visitor_, OnObjectFragment(ftn, _, payload, 0)); @@ -1242,7 +1245,7 @@ }; std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType); + kDefaultSubgroupStreamType, 2); EXPECT_CALL(mock_stream_, SendStopSending); object_stream->OnObjectMessage(object, payload, true); @@ -1261,26 +1264,24 @@ } TEST_F(MoqtSessionTest, SubscribeOkWithBadTrackAlias) { - MockSubscribeRemoteTrackVisitor visitor; - // Create open subscription. - MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultLocalSubscribe(), - /*track_alias=*/2, &visitor); - MoqtSubscribe subscribe2 = DefaultLocalSubscribe(); - subscribe2.request_id += 2; - subscribe2.full_track_name = FullTrackName("foo2", "bar2"); - MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe2, std::nullopt, - &visitor); - - // SUBSCRIBE_OK arrives + // Create open subscription. We cannot use CreateRemoteTrack because that + // skips the code that sets the track alias callbacks. + webtransport::test::MockStream mock_control_stream; + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = + MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); + session_.Subscribe(FullTrackName("foo", "bar"), &remote_track_visitor_, + MessageParameters()); MoqtSubscribeOk subscribe_ok = { - subscribe2.request_id, + /*request_id=*/0, /*track_alias=*/2, MessageParameters(), TrackExtensions(), }; - webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = - MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); + control_stream->ReceiveMessage(subscribe_ok); + // Second subscribe, but OK has the same track alias. + session_.Subscribe(FullTrackName("foo2", "bar2"), &remote_track_visitor_, + MessageParameters()); + subscribe_ok.request_id += 2; EXPECT_CALL( mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kDuplicateTrackAlias), "")); @@ -1485,7 +1486,8 @@ }; std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - kDefaultSubgroupStreamType); + kDefaultSubgroupStreamType, 2, + &remote_track_visitor_); EXPECT_CALL(remote_track_visitor_, OnObjectFragment).Times(0); object_stream->OnObjectMessage(object, payload, true); } @@ -2614,12 +2616,11 @@ } SubscribeRemoteTrack* track = MoqtSessionPeer::remote_track(&session_, 0); ASSERT_NE(track, nullptr); - EXPECT_FALSE(track->all_streams_closed()); stream_input->ReceiveMessage( MoqtPublishDone(0, PublishDoneCode::kTrackEnded, kNumStreams, "foo")); track = MoqtSessionPeer::remote_track(&session_, 0); + EXPECT_CALL(remote_track_visitor_, OnPublishDone).Times(0); ASSERT_NE(track, nullptr); - EXPECT_FALSE(track->all_streams_closed()); EXPECT_CALL(remote_track_visitor_, OnPublishDone(_)); for (uint64_t i = 0; i < kNumStreams; ++i) { data_streams[i].reset(); @@ -2675,7 +2676,6 @@ } SubscribeRemoteTrack* track = MoqtSessionPeer::remote_track(&session_, 0); ASSERT_NE(track, nullptr); - EXPECT_FALSE(track->all_streams_closed()); EXPECT_CALL(remote_track_visitor_, OnPublishDone(_)); stream_input->ReceiveMessage( MoqtPublishDone(0, PublishDoneCode::kTrackEnded, kNumStreams, "foo")); @@ -2730,11 +2730,11 @@ } SubscribeRemoteTrack* track = MoqtSessionPeer::remote_track(&session_, 0); ASSERT_NE(track, nullptr); - EXPECT_FALSE(track->all_streams_closed()); + EXPECT_CALL(remote_track_visitor_, OnPublishDone).Times(0); // stream_count includes a stream that was never sent. stream_input->ReceiveMessage( MoqtPublishDone(0, PublishDoneCode::kTrackEnded, kNumStreams + 1, "foo")); - EXPECT_FALSE(track->all_streams_closed()); + EXPECT_CALL(remote_track_visitor_, OnPublishDone).Times(0); auto* publish_done_alarm = absl::down_cast<quic::test::MockAlarmFactory::TestAlarm*>( MoqtSessionPeer::GetPublishDoneAlarm(track)); @@ -2755,7 +2755,8 @@ &session_, &mock_stream_, MoqtDataStreamType::Subgroup(/*subgroup_id=*/0, /*first_object_id=*/0, /*no_extension_headers=*/true, - /*has_default_priority=*/false)); + /*has_default_priority=*/false), + 2); object_stream->OnObjectMessage( MoqtObject(/*track_alias=*/2, /*group_id=*/0, /*object_id=*/0, /*publisher_priority=*/0x80, /*extension_headers=*/"", @@ -2776,7 +2777,6 @@ } TEST_F(MoqtSessionTest, SubgroupStreamObjectAfterTrackEnd) { - MockSubscribeRemoteTrackVisitor remote_track_visitor_; MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), /*track_alias=*/2, &remote_track_visitor_); webtransport::test::MockStream control_stream; @@ -2787,7 +2787,8 @@ &session_, &mock_stream_, MoqtDataStreamType::Subgroup(/*subgroup_id=*/0, /*first_object_id=*/0, /*no_extension_headers=*/true, - /*has_default_priority=*/false)); + /*has_default_priority=*/false), + /*track_alias=*/2); object_stream->OnObjectMessage( MoqtObject(/*track_alias=*/2, /*group_id=*/0, /*object_id=*/0, /*publisher_priority=*/0x80, /*extension_headers=*/"", @@ -2813,7 +2814,7 @@ MoqtSessionPeer::CreateUpstreamFetch(&session_, &stream); std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream_, - MoqtDataStreamType::Fetch()); + MoqtDataStreamType::Fetch(), 0); object_stream->OnObjectMessage( MoqtObject(/*request_id=*/0, /*group_id=*/0, /*object_id=*/1, /*publisher_priority=*/0x80, /*extension_headers=*/"",
diff --git a/quiche/quic/moqt/moqt_track.cc b/quiche/quic/moqt/moqt_track.cc index e7278a6..2d36933 100644 --- a/quiche/quic/moqt/moqt_track.cc +++ b/quiche/quic/moqt/moqt_track.cc
@@ -14,6 +14,7 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_clock.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" @@ -38,14 +39,14 @@ } // namespace -bool RemoteTrack::CheckDataStreamType(MoqtDataStreamType type) { - if (is_fetch() && !type.IsFetch()) { - return false; +SubscribeRemoteTrack::~SubscribeRemoteTrack() { + if (publish_done_alarm_ != nullptr) { + publish_done_alarm_->PermanentCancel(); } - if (!is_fetch() && !type.IsSubgroup()) { - return false; + if (register_track_alias_callback_ && track_alias_.has_value()) { + register_track_alias_callback_(*track_alias_, nullptr); } - return true; + visitor_->OnPublishDone(full_track_name()); } void SubscribeRemoteTrack::OnStreamOpened() { @@ -68,6 +69,10 @@ visitor_->OnStreamReset(full_track_name(), *index); } } + if (all_streams_closed()) { + Destroy(); + return; + } if (publish_done_alarm_ == nullptr) { return; } @@ -76,10 +81,15 @@ void SubscribeRemoteTrack::OnPublishDone( uint64_t stream_count, const quic::QuicClock* clock, - std::unique_ptr<quic::QuicAlarm> publish_done_alarm) { + quic::QuicAlarmFactory* alarm_factory) { total_streams_ = stream_count; clock_ = clock; - publish_done_alarm_ = std::move(publish_done_alarm); + if (all_streams_closed()) { + Destroy(); + return; + } + publish_done_alarm_ = std::unique_ptr<quic::QuicAlarm>( + alarm_factory->CreateAlarm(new PublishDoneDelegate(this))); MaybeSetPublishDoneAlarm(); } @@ -229,6 +239,8 @@ } UpstreamFetch::UpstreamFetchTask::~UpstreamFetchTask() { + // Set status_ so that callbacks into UpstreamFetchTask exit early. + status_ = absl::CancelledError("UpstreamFetchTask destroyed"); if (task_destroyed_callback_) { std::move(task_destroyed_callback_)(); } @@ -262,6 +274,7 @@ output.metadata.publisher_priority = next_object_->publisher_priority; output.metadata.payload_length = next_object_->payload_length; output.fin_after_this = false; + // TODO(martinduke): Make sure the whole object has been delivered. if (output.metadata.location == largest_location_) { // This is the last object. eof_ = true;
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index 0994cee..036aac2 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +// TODO(martinduke): Rename this file to moqt_subscriber.h + #ifndef QUICHE_QUIC_MOQT_MOQT_TRACK_H_ #define QUICHE_QUIC_MOQT_MOQT_TRACK_H_ @@ -13,7 +15,9 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" @@ -38,20 +42,19 @@ // State common to both SUBSCRIBE and FETCH upstream. class RemoteTrack { public: - RemoteTrack(const FullTrackName& full_track_name, uint64_t id) + RemoteTrack(const FullTrackName& full_track_name, uint64_t id, + BidiStreamDeletedCallback callback) : full_track_name_(full_track_name), request_id_(id), + delete_callback_(std::move(callback)), weak_ptr_factory_(this) {} - virtual ~RemoteTrack() = default; + virtual ~RemoteTrack() { Destroy(); } - FullTrackName full_track_name() const { return full_track_name_; } + const FullTrackName& full_track_name() const { return full_track_name_; } // If REQUEST_ERROR arrives after OK or an object, it is a protocol violation. virtual void OnObjectOrOk() { error_is_allowed_ = false; } bool ErrorIsAllowed() const { return error_is_allowed_; } - // Makes sure the data stream type is consistent with the track type. - bool CheckDataStreamType(MoqtDataStreamType type); - uint64_t request_id() const { return request_id_; } // Is the object one that was requested? @@ -61,11 +64,17 @@ return weak_ptr_factory_.Create(); } - virtual MoqtPriority subscriber_priority() const = 0; - virtual void set_subscriber_priority(MoqtPriority priority) = 0; - virtual bool is_fetch() const = 0; + void Destroy() { + if (delete_callback_ == nullptr) { + return; + } + BidiStreamDeletedCallback delete_callback = std::move(delete_callback_); + delete_callback_ = nullptr; + std::move(delete_callback)(); + } + private: const FullTrackName full_track_name_; const uint64_t request_id_; @@ -73,6 +82,7 @@ // If false, an object or OK message has been received, so any ERROR message // is a protocol violation. bool error_is_allowed_ = true; + BidiStreamDeletedCallback delete_callback_; // Must be last. quiche::QuicheWeakPtrFactory<RemoteTrack> weak_ptr_factory_; @@ -81,33 +91,41 @@ // A track on the peer to which the session has subscribed. class SubscribeRemoteTrack : public RemoteTrack { public: + // If the second argument is null, delete the registration. Returns false if + // it fails due to a duplicate track alias, destroying the session. + using RegisterTrackAliasCallback = + quiche::MultiUseCallback<bool(uint64_t, SubscribeRemoteTrack*)>; + // We're using BidiStreamDeletedCallback here because this will move to a + // bidi stream. SubscribeRemoteTrack(const MoqtSubscribe& subscribe, - SubscribeVisitor* visitor) - : RemoteTrack(subscribe.full_track_name, subscribe.request_id), + SubscribeVisitor* visitor, + BidiStreamDeletedCallback callback, + RegisterTrackAliasCallback register_track_alias_callback) + : RemoteTrack(subscribe.full_track_name, subscribe.request_id, + std::move(callback)), parameters_(subscribe.parameters), - visitor_(visitor) {} - ~SubscribeRemoteTrack() override { - if (publish_done_alarm_ != nullptr) { - publish_done_alarm_->PermanentCancel(); - } - } + visitor_(visitor), + register_track_alias_callback_( + std::move(register_track_alias_callback)) {} + ~SubscribeRemoteTrack() override; void OnObjectOrOk() override { RemoteTrack::OnObjectOrOk(); } std::optional<uint64_t> track_alias() const { return track_alias_; } - void set_track_alias(uint64_t track_alias) { + // Returns false if the callback returns false, meaning the session has been + // destroyed. + [[nodiscard]] bool set_track_alias(uint64_t track_alias) { track_alias_.emplace(track_alias); + if (register_track_alias_callback_) { + return register_track_alias_callback_(track_alias, this); + } + return true; } - SubscribeVisitor* visitor() { return visitor_; } - void OnStreamOpened(); void OnStreamClosed(bool fin_received, std::optional<DataStreamIndex> index); void OnPublishDone(uint64_t stream_count, const quic::QuicClock* clock, - std::unique_ptr<quic::QuicAlarm> publish_done_alarm); - bool all_streams_closed() const { - return total_streams_.has_value() && *total_streams_ == streams_closed_; - } + quic::QuicAlarmFactory* alarm_factory); // The application can request a Joining FETCH but also for FETCH objects to // be delivered via SubscribeRemoteTrack::Visitor::OnObjectFragment(). When @@ -115,9 +133,6 @@ // FETCH objects to pipe directly into the visitor. void OnJoiningFetchReady(std::unique_ptr<MoqtFetchTask> fetch_task); - bool forward() const { return parameters_.forward(); } - void set_forward(bool forward) { parameters_.set_forward(forward); } - bool is_fetch() const override { return false; } MessageParameters& parameters() { return parameters_; } @@ -127,12 +142,6 @@ (!parameters_.subscription_filter.has_value() || parameters_.subscription_filter->InWindow(location)); } - MoqtPriority subscriber_priority() const override { - return parameters_.subscriber_priority.value_or(kDefaultSubscriberPriority); - } - void set_subscriber_priority(MoqtPriority priority) override { - parameters_.subscriber_priority = priority; - } MoqtPriority default_publisher_priority() const { return default_publisher_priority_; @@ -141,7 +150,6 @@ default_publisher_priority_ = priority; } - bool dynamic_groups() { return dynamic_groups_; } void set_dynamic_groups(bool dynamic_groups) { dynamic_groups_ = dynamic_groups; } @@ -154,11 +162,27 @@ publisher_delivery_timeout_ = publisher_delivery_timeout; } + SubscribeVisitor* visitor() const { return visitor_; } + private: friend class test::MoqtSessionPeer; friend class test::SubscribeRemoteTrackPeer; + class PublishDoneDelegate : public quic::QuicAlarm::DelegateWithoutContext { + public: + PublishDoneDelegate(SubscribeRemoteTrack* subscribe) + : subscribe_(subscribe) {} + + void OnAlarm() override { subscribe_->Destroy(); } + + private: + SubscribeRemoteTrack* subscribe_; + }; + void MaybeSetPublishDoneAlarm(); + bool all_streams_closed() const { + return total_streams_.has_value() && *total_streams_ == streams_closed_; + } MessageParameters parameters_; quic::QuicTimeDelta publisher_delivery_timeout_ = kDefaultDeliveryTimeout; @@ -174,6 +198,7 @@ int currently_open_streams_ = 0; // Every stream that has received FIN or RESET_STREAM. uint64_t streams_closed_ = 0; + RegisterTrackAliasCallback register_track_alias_callback_; // Value assigned on PUBLISH_DONE. Can destroy subscription state if // streams_closed_ == total_streams_. std::optional<uint64_t> total_streams_; @@ -196,8 +221,10 @@ public: // Standalone Fetch constructor UpstreamFetch(const MoqtFetch& fetch, const StandaloneFetch standalone, - FetchResponseCallback callback) - : RemoteTrack(standalone.full_track_name, fetch.request_id), + FetchResponseCallback callback, + BidiStreamDeletedCallback delete_callback) + : RemoteTrack(standalone.full_track_name, fetch.request_id, + std::move(delete_callback)), group_order_(fetch.parameters.group_order.value_or( MoqtDeliveryOrder::kAscending)), start_(standalone.start_location), @@ -207,8 +234,10 @@ ok_callback_(std::move(callback)) {} // Relative Joining Fetch constructor UpstreamFetch(const MoqtFetch& fetch, FullTrackName full_track_name, - FetchResponseCallback callback) - : RemoteTrack(full_track_name, fetch.request_id), + FetchResponseCallback callback, + BidiStreamDeletedCallback delete_callback) + : RemoteTrack(full_track_name, fetch.request_id, + std::move(delete_callback)), group_order_(fetch.parameters.group_order.value_or( MoqtDeliveryOrder::kAscending)), relative_groups_( @@ -219,8 +248,10 @@ // Absolute Joining Fetch constructor UpstreamFetch(const MoqtFetch& fetch, FullTrackName full_track_name, JoiningFetchAbsolute absolute_joining, - FetchResponseCallback callback) - : RemoteTrack(full_track_name, fetch.request_id), + FetchResponseCallback callback, + BidiStreamDeletedCallback delete_callback) + : RemoteTrack(full_track_name, fetch.request_id, + std::move(delete_callback)), group_order_(fetch.parameters.group_order.value_or( MoqtDeliveryOrder::kAscending)), start_(Location(absolute_joining.joining_start, 0)), @@ -234,12 +265,8 @@ return (location >= start_ && location <= end_); } - MoqtPriority subscriber_priority() const override { - return subscriber_priority_; - } - void set_subscriber_priority(MoqtPriority priority) override { - subscriber_priority_ = priority; - } + // Called when the data stream is destroyed. + void OnStreamClosed() { Destroy(); } class UpstreamFetchTask : public MoqtFetchTask { public:
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc index 152c95b..60677e7 100644 --- a/quiche/quic/moqt/moqt_track_test.cc +++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/moqt/moqt_track.h" +#include <cstdint> #include <memory> #include <optional> #include <utility> @@ -48,12 +49,23 @@ class SubscribeRemoteTrackTest : public quic::test::QuicTest { public: - SubscribeRemoteTrackTest() : track_(subscribe_, &visitor_) {} + SubscribeRemoteTrackTest() + : track_( + subscribe_, &visitor_, [this]() { deleted_ = true; }, + [this](uint64_t, SubscribeRemoteTrack* track) { + alias_registered_ = (track != nullptr); + if (alias_registered_) { + EXPECT_EQ(track, &track_); + } + return true; + }) {} MockSubscribeRemoteTrackVisitor visitor_; MoqtSubscribe subscribe_ = {/*request_id=*/1, FullTrackName("foo", "bar"), MessageParameters(Location(2, 0))}; SubscribeRemoteTrack track_; + bool alias_registered_ = false; + bool deleted_ = false; }; TEST_F(SubscribeRemoteTrackTest, Queries) { @@ -62,16 +74,10 @@ EXPECT_FALSE(track_.track_alias().has_value()); EXPECT_EQ(track_.visitor(), &visitor_); EXPECT_FALSE(track_.is_fetch()); - track_.set_track_alias(1); + EXPECT_TRUE(track_.set_track_alias(1)); EXPECT_EQ(track_.track_alias(), 1); } -TEST_F(SubscribeRemoteTrackTest, UpdateDataStreamType) { - EXPECT_TRUE(track_.CheckDataStreamType( - MoqtDataStreamType::Subgroup(1, 1, true, false))); - EXPECT_FALSE(track_.CheckDataStreamType(MoqtDataStreamType::Fetch())); -} - TEST_F(SubscribeRemoteTrackTest, AllowError) { EXPECT_TRUE(track_.ErrorIsAllowed()); track_.OnObjectOrOk(); @@ -187,10 +193,12 @@ class UpstreamFetchTest : public quic::test::QuicTest { protected: UpstreamFetchTest() - : fetch_(fetch_message_, std::get<StandaloneFetch>(fetch_message_.fetch), - [&](std::unique_ptr<MoqtFetchTask> task) { - fetch_task_ = std::move(task); - }) {} + : fetch_( + fetch_message_, std::get<StandaloneFetch>(fetch_message_.fetch), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task_ = std::move(task); + }, + [&]() { deleted_ = true; }) {} MoqtFetch fetch_message_ = { /*request_id=*/1, @@ -201,14 +209,12 @@ // The pointer held by the application. UpstreamFetch fetch_; std::unique_ptr<MoqtFetchTask> fetch_task_; + bool deleted_ = false; }; TEST_F(UpstreamFetchTest, Queries) { EXPECT_EQ(fetch_.request_id(), 1); EXPECT_EQ(fetch_.full_track_name(), FullTrackName("foo", "bar")); - EXPECT_FALSE(fetch_.CheckDataStreamType( - MoqtDataStreamType::Subgroup(1, 2, true, false))); - EXPECT_TRUE(fetch_.CheckDataStreamType(MoqtDataStreamType::Fetch())); EXPECT_TRUE(fetch_.is_fetch()); EXPECT_FALSE(fetch_.InWindow(Location{1, 0})); EXPECT_TRUE(fetch_.InWindow(Location{1, 1})); @@ -389,11 +395,12 @@ TEST_F(UpstreamFetchTest, LocationIsValidOkGroupAscendingIncorrectly) { fetch_message_.parameters.group_order = MoqtDeliveryOrder::kDescending; - UpstreamFetch fetch(fetch_message_, - std::get<StandaloneFetch>(fetch_message_.fetch), - [&](std::unique_ptr<MoqtFetchTask> task) { - fetch_task_ = std::move(task); - }); + UpstreamFetch fetch( + fetch_message_, std::get<StandaloneFetch>(fetch_message_.fetch), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task_ = std::move(task); + }, + []() {}); fetch.OnFetchResult(Location(3, 50), absl::OkStatus(), nullptr); EXPECT_TRUE( fetch.LocationIsValid(Location(2, 1), MoqtObjectStatus::kNormal, true)); @@ -447,11 +454,12 @@ JoiningFetchRelative(1, 2), MessageParameters(), }; - UpstreamFetch relative_fetch(relative_fetch_message, - FullTrackName("foo", "bar"), - [&](std::unique_ptr<MoqtFetchTask> task) { - fetch_task_ = std::move(task); - }); + UpstreamFetch relative_fetch( + relative_fetch_message, FullTrackName("foo", "bar"), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task_ = std::move(task); + }, + []() {}); relative_fetch.OnFetchResult(Location(10, 50), absl::OkStatus(), nullptr); EXPECT_FALSE(relative_fetch.InWindow(Location(7, 35))); EXPECT_TRUE(relative_fetch.InWindow(Location(8, 0))); @@ -463,11 +471,12 @@ JoiningFetchRelative(1, 10), MessageParameters(), }; - UpstreamFetch relative_fetch(relative_fetch_message, - FullTrackName("foo", "bar"), - [&](std::unique_ptr<MoqtFetchTask> task) { - fetch_task_ = std::move(task); - }); + UpstreamFetch relative_fetch( + relative_fetch_message, FullTrackName("foo", "bar"), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task_ = std::move(task); + }, + []() {}); relative_fetch.OnFetchResult(Location(1, 50), absl::OkStatus(), nullptr); EXPECT_TRUE(relative_fetch.InWindow(Location(0, 0))); EXPECT_TRUE(relative_fetch.InWindow(Location(1, 50)));
diff --git a/quiche/quic/moqt/moqt_uni_stream.cc b/quiche/quic/moqt/moqt_uni_stream.cc index f074643..ce67b69 100644 --- a/quiche/quic/moqt/moqt_uni_stream.cc +++ b/quiche/quic/moqt/moqt_uni_stream.cc
@@ -10,9 +10,12 @@ #include <utility> #include <vector> +#include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/core/quic_utils.h" @@ -24,6 +27,7 @@ #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_trace_recorder.h" +#include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_mem_slice.h" @@ -331,4 +335,244 @@ stream().ResetWithUserCode(error_code); } +IncomingDataStream::~IncomingDataStream() { + QUICHE_DVLOG(1) << "Destroying incoming data stream " + << stream_->GetStreamId(); + if (!parser_.track_alias().has_value()) { + QUIC_DVLOG(1) << "Destroying incoming data stream before " + "learning track alias"; + return; + } + if (!track_.IsValid()) { + return; + } + if (IsFetch()) { + auto fetch = absl::down_cast<UpstreamFetch*>(track_.GetIfAvailable()); + if (fetch != nullptr) { + fetch->OnStreamClosed(); + } + return; + } + // It's a subscribe. + auto subscribe = + absl::down_cast<SubscribeRemoteTrack*>(track_.GetIfAvailable()); + if (subscribe == nullptr) { + return; + } + subscribe->OnStreamClosed(fin_received_, index_); +} + +void IncomingDataStream::OnObjectMessage(const MoqtObject& message, + absl::string_view payload, + bool end_of_message) { + QUICHE_DVLOG(1) << "Received OBJECT message on stream " + << stream_->GetStreamId() << " for track alias " + << message.track_alias << " with sequence " + << message.group_id << ":" << message.object_id + << " priority " << message.publisher_priority << " length " + << payload.size() << " length " << message.payload_length + << (end_of_message ? "F" : ""); + if (!session_->deliver_partial_objects()) { + if (!end_of_message) { // Buffer partial object. + if (partial_object_.empty()) { + // Avoid redundant allocations by reserving the appropriate amount of + // memory if known. + partial_object_.reserve(message.payload_length); + } + absl::StrAppend(&partial_object_, payload); + return; + } + if (!partial_object_.empty()) { // Completes the object + absl::StrAppend(&partial_object_, payload); + payload = absl::string_view(partial_object_); + } + } + if (payload.empty() && bytes_received_this_object_ > 0 && !end_of_message) { + return; // Nothing arrived. + } + if (!parser_.track_alias().has_value()) { + QUICHE_BUG(quic_bug_object_with_no_stream_type) + << "Object delivered without preliminaries"; + return; + } + // Get a pointer to the upstream state. + if (!track_.IsValid()) { + track_ = IsFetch() ? session_->GetFetch(message.track_alias) + : session_->GetSubscribe(message.track_alias); + } + if (!track_.IsValid()) { + // The request has gone away. + stream_->SendStopSending(kResetCodeCancelled); + return; + } + Location location(message.group_id, message.object_id); + RemoteTrack* track = track_.GetIfAvailable(); + if (track == nullptr || + !track->InWindow(Location(message.group_id, message.object_id))) { + // This is not an error. It can be the result of a recent REQUEST_UPDATE or + // UNSUBSCRIBE. + return; + } + if (!IsFetch()) { + if (!index_.has_value()) { + if (!message.subgroup_id.has_value()) { + QUICHE_BUG(quiche_bug_moqt_subgroup_id_missing) + << "Missing subgroup ID on SUBSCRIBE stream"; + return; + } + index_ = DataStreamIndex(message.group_id, *message.subgroup_id); + } + if (no_more_objects_) { + // Already got a stream-ending object. While the lower layer won't + // deliver data after the FIN, there could have been an EndOfGroup or + // EndOfTrack signal. + session_->OnMalformedTrack(track); + return; + } + if (end_of_message) { + next_object_id_ = message.object_id + 1; + if (message.object_status == MoqtObjectStatus::kEndOfTrack || + message.object_status == MoqtObjectStatus::kEndOfGroup) { + no_more_objects_ = true; + } + } + SubscribeRemoteTrack* subscribe = + absl::down_cast<SubscribeRemoteTrack*>(track); + subscribe->OnObjectOrOk(); + if (visitor_ != nullptr) { + PublishedObjectMetadata metadata; + metadata.location = Location(message.group_id, message.object_id); + metadata.subgroup = message.subgroup_id; + metadata.extensions = message.extension_headers; + metadata.status = message.object_status; + metadata.publisher_priority = message.publisher_priority; + metadata.payload_length = message.payload_length; + metadata.arrival_time = clock_->Now(); + visitor_->OnObjectFragment(track->full_track_name(), metadata, payload, + bytes_received_this_object_); + } + } else { // FETCH + track->OnObjectOrOk(); + UpstreamFetch* fetch = absl::down_cast<UpstreamFetch*>(track); + if (!fetch->LocationIsValid(Location(message.group_id, message.object_id), + message.object_status, end_of_message)) { + // TODO(martinduke): in https://github.com/moq-wg/moq-transport/pull/1409 + // I make the case that this should be a protocol violation. Update if + // that proposal is accepted (at which point + // QuicSession::OnMalformedTrack can be removed, since all the + // remaining conditions are at the application layer). + session_->OnMalformedTrack(track); + return; + } + UpstreamFetch::UpstreamFetchTask* task = fetch->task(); + if (task == nullptr) { + // The application killed the FETCH. + stream_->SendStopSending(kResetCodeCancelled); + return; + } + if (!task->HasObject()) { + task->NewObject(message); + } + if (task->NeedsMorePayload() && !payload.empty()) { + task->AppendPayloadToObject(payload); + } + } + if (end_of_message) { + bytes_received_this_object_ = 0; + } else { + bytes_received_this_object_ += payload.size(); + } + partial_object_.clear(); +} + +void IncomingDataStream::MaybeReadOneObject() { + if (!parser_.track_alias().has_value() || + !parser_.stream_type().has_value() || !parser_.stream_type()->IsFetch()) { + QUICHE_BUG(quic_bug_read_one_object_parser_unexpected_state) + << "Requesting object, parser in unexpected state"; + } + if (!track_.IsValid()) { + return; + } + UpstreamFetch* fetch = + absl::down_cast<UpstreamFetch*>(track_.GetIfAvailable()); + UpstreamFetch::UpstreamFetchTask* task = fetch->task(); + if (task == nullptr) { + return; + } + if (task->HasObject() && !task->NeedsMorePayload()) { + return; // The message is complete. Do not read more. + } + uint64_t start_length = task->payload_length(); + parser_.ReadAtMostOneObject(); + // If it read an object, it called OnObjectMessage and may have altered the + // task's object state. + if (task->payload_length() > start_length) { + task->NotifyNewObject(); + } +} + +void IncomingDataStream::OnCanRead() { + if (!parser_.stream_type().has_value()) { + parser_.ReadStreamType(); + if (!parser_.stream_type().has_value()) { + return; + } + } + if (parser_.stream_type()->IsPadding()) { + (void)stream_->SkipBytes(stream_->ReadableBytes()); + return; + } + bool knew_track_alias = parser_.track_alias().has_value(); + if (!knew_track_alias) { + parser_.ReadTrackAlias(); + if (!parser_.track_alias().has_value()) { + return; + } + } + QUICHE_CHECK(parser_.stream_type().has_value()); + QUICHE_CHECK(parser_.track_alias().has_value()); + if (parser_.stream_type()->IsSubgroup()) { + if (!knew_track_alias) { + track_ = session_->GetSubscribe(*parser_.track_alias()); + // This is a new stream for a subscribe. Notify the subscription. + SubscribeRemoteTrack* subscribe = + absl::down_cast<SubscribeRemoteTrack*>(track_.GetIfAvailable()); + if (subscribe == nullptr) { + stream_->SendStopSending(kResetCodeCancelled); + return; + } + subscribe->OnStreamOpened(); + parser_.set_default_publisher_priority( + subscribe->default_publisher_priority()); + visitor_ = subscribe->visitor(); + } + parser_.ReadAllData(); + return; + } + // FETCH + if (!knew_track_alias) { + track_ = session_->GetFetch(*parser_.track_alias()); + } + if (!track_.IsValid()) { + stream_->SendStopSending(kResetCodeCancelled); + return; + } + UpstreamFetch* fetch = + absl::down_cast<UpstreamFetch*>(track_.GetIfAvailable()); + if (!knew_track_alias) { + // If the task already exists (FETCH_OK has arrived), the callback will + // immediately execute to read the first object. Otherwise, it will only + // execute when the task is created or a cached object is read. + fetch->OnStreamOpened([this]() { MaybeReadOneObject(); }); + return; + } + MaybeReadOneObject(); +} + +void IncomingDataStream::OnParsingError(MoqtError error_code, + absl::string_view reason) { + session_->Error(error_code, absl::StrCat("Parse error: ", reason)); +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_uni_stream.h b/quiche/quic/moqt/moqt_uni_stream.h index fd0d7d8..f828e21 100644 --- a/quiche/quic/moqt/moqt_uni_stream.h +++ b/quiche/quic/moqt/moqt_uni_stream.h
@@ -8,19 +8,25 @@ #include <cstdint> #include <memory> #include <optional> +#include <string> #include <utility> #include "absl/base/nullability.h" +#include "absl/strings/string_view.h" #include "quiche/quic/core/quic_alarm.h" #include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_object.h" +#include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" +#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/moqt_trace_recorder.h" +#include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_callbacks.h" @@ -31,6 +37,7 @@ namespace test { class OutgoingSubgroupStreamPeer; +class MoqtSessionPeer; } // A base class for locally initiated unidirectional streams, which can serve @@ -188,6 +195,70 @@ FetchStreamCloseCallback close_callback_; }; +class SessionToUniStreamInterface { + public: + virtual ~SessionToUniStreamInterface() = default; + virtual bool deliver_partial_objects() const = 0; + virtual void OnMalformedTrack(RemoteTrack* name) = 0; + virtual quiche::QuicheWeakPtr<RemoteTrack> GetSubscribe( + uint64_t track_alias) = 0; + virtual quiche::QuicheWeakPtr<RemoteTrack> GetFetch(uint64_t request_id) = 0; + virtual void Error(MoqtError error_code, absl::string_view reason) = 0; +}; + +class QUICHE_EXPORT IncomingDataStream : public webtransport::StreamVisitor, + public MoqtDataParserVisitor { + public: + IncomingDataStream(webtransport::Stream* absl_nonnull stream, + SessionToUniStreamInterface* absl_nonnull session, + const quic::QuicClock* absl_nonnull clock) + : stream_(stream), + parser_(stream, this), + session_(session), + clock_(clock) {} + ~IncomingDataStream(); + + // webtransport::StreamVisitor implementation. + void OnCanRead() override; + void OnCanWrite() override {} + void OnResetStreamReceived(webtransport::StreamErrorCode) override {} + void OnStopSendingReceived(webtransport::StreamErrorCode /*error*/) override { + } + void OnWriteSideInDataRecvdState() override {} + + // MoqtParserVisitor implementation. + // TODO: Handle a stream FIN. + void OnObjectMessage(const MoqtObject& message, absl::string_view payload, + bool end_of_message) override; + void OnFin() override { fin_received_ = true; } + void OnParsingError(MoqtError error_code, absl::string_view reason) override; + + webtransport::Stream* stream() const { return stream_; } + + void MaybeReadOneObject(); + + private: + friend class test::MoqtSessionPeer; + bool IsFetch() const { + return parser_.stream_type().has_value() && + parser_.stream_type()->IsFetch(); + } + + uint64_t next_object_id_ = 0; + bool no_more_objects_ = false; // EndOfGroup or EndOfTrack was received. + std::optional<DataStreamIndex> index_; // Only set for subscribe. + bool fin_received_ = false; + webtransport::Stream* stream_; + SubscribeVisitor* visitor_ = nullptr; + // Once the subscribe ID is identified, set it here. + quiche::QuicheWeakPtr<RemoteTrack> track_; + MoqtDataParser parser_; + std::string partial_object_; + uint64_t bytes_received_this_object_ = 0; + SessionToUniStreamInterface* session_; + const quic::QuicClock* absl_nonnull clock_; +}; + } // namespace moqt #endif // QUICHE_QUIC_MOQT_MOQT_UNI_STREAM_H_
diff --git a/quiche/quic/moqt/moqt_uni_stream_test.cc b/quiche/quic/moqt/moqt_uni_stream_test.cc index c021460..8459c98 100644 --- a/quiche/quic/moqt/moqt_uni_stream_test.cc +++ b/quiche/quic/moqt/moqt_uni_stream_test.cc
@@ -19,9 +19,11 @@ #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_object.h" #include "quiche/quic/moqt/moqt_trace_recorder.h" +#include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" #include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h" #include "quiche/quic/moqt/test_tools/moqt_session_peer.h" @@ -31,6 +33,7 @@ #include "quiche/common/platform/api/quiche_expect_bug.h" #include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_weak_ptr.h" +#include "quiche/web_transport/test_tools/in_memory_stream.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -453,6 +456,201 @@ task_ptr_->CallObjectsAvailableCallback(); } +MoqtObject kDefaultObject = { + 2, // track_alias + 0, // group_id + 0, // object_id + 0x80, // publisher_priority + "", // extension_headers + MoqtObjectStatus::kNormal, + 0, // subgroup_id + 0, // payload_length +}; + +class MockSessionToUniStreamInterface : public SessionToUniStreamInterface { + public: + MockSessionToUniStreamInterface() = default; + ~MockSessionToUniStreamInterface() override = default; + + MOCK_METHOD(bool, deliver_partial_objects, (), (const, override)); + MOCK_METHOD(void, OnMalformedTrack, (RemoteTrack*), (override)); + MOCK_METHOD(quiche::QuicheWeakPtr<RemoteTrack>, GetSubscribe, (uint64_t), + (override)); + MOCK_METHOD(quiche::QuicheWeakPtr<RemoteTrack>, GetFetch, (uint64_t), + (override)); + MOCK_METHOD(void, Error, (MoqtError, absl::string_view), (override)); +}; + +class IncomingDataStreamTest : public quic::test::QuicTest { + public: + IncomingDataStreamTest() + : mock_stream_(14), + ftn_("foo", "bar"), + subscribe_message_(1, ftn_, MessageParameters()) { + EXPECT_CALL(session_, deliver_partial_objects()) + .WillRepeatedly(Return(false)); + track_ = std::make_unique<SubscribeRemoteTrack>( + subscribe_message_, &visitor_, []() {}, + [this](uint64_t alias, SubscribeRemoteTrack* track) -> bool { + alias_ = alias; + alias_track_ = track; + return true; + }); + EXPECT_TRUE(track_->set_track_alias(2)); + CreateStream(); + } + + void CreateStream() { + stream_ = std::make_unique<IncomingDataStream>(&mock_stream_, &session_, + &mock_clock_); + } + void ProcessStreamType(MoqtDataStreamType type) { + uint8_t type_byte = static_cast<uint8_t>(type.value()); + mock_stream_.Receive( + absl::string_view(reinterpret_cast<const char*>(&type_byte), 1), false); + stream_->OnCanRead(); + } + void ProcessAlias(uint8_t alias) { + mock_stream_.Receive( + absl::string_view(reinterpret_cast<const char*>(&alias), 1), false); + EXPECT_CALL(session_, GetSubscribe(alias)) + .WillOnce(Return(track_->weak_ptr())); + stream_->OnCanRead(); + EXPECT_EQ(alias_, alias); + EXPECT_EQ(alias_track_, track_.get()); + } + + webtransport::test::InMemoryStream mock_stream_; + testing::NiceMock<MockSessionToUniStreamInterface> session_; + quic::MockClock mock_clock_; + FullTrackName ftn_; + MoqtSubscribe subscribe_message_; + testing::NiceMock<MockSubscribeRemoteTrackVisitor> visitor_; + std::unique_ptr<SubscribeRemoteTrack> track_; + std::unique_ptr<IncomingDataStream> stream_; + uint64_t alias_ = 0; + SubscribeRemoteTrack* alias_track_ = nullptr; +}; + +TEST_F(IncomingDataStreamTest, DestructorBeforeTrackAlias) { + // The stream doesn't know the track, so there's no visitor to notify. + EXPECT_CALL(visitor_, OnStreamReset).Times(0); + stream_.reset(); +} + +TEST_F(IncomingDataStreamTest, DestructorAfterObject) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + EXPECT_CALL(visitor_, OnObjectFragment); + stream_->OnObjectMessage(kDefaultObject, "", true); + EXPECT_CALL(visitor_, OnStreamReset); + stream_.reset(); +} + +TEST_F(IncomingDataStreamTest, DestructorAfterFin) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + EXPECT_CALL(visitor_, OnObjectFragment); + stream_->OnObjectMessage(kDefaultObject, "", true); + stream_->OnFin(); + EXPECT_CALL(visitor_, OnStreamFin); + stream_.reset(); +} + +TEST_F(IncomingDataStreamTest, OnParsingError) { + EXPECT_CALL(session_, + Error(MoqtError::kProtocolViolation, "Parse error: reason")) + .Times(1); + stream_->OnParsingError(MoqtError::kProtocolViolation, "reason"); +} + +TEST_F(IncomingDataStreamTest, OnObjectMessageNoTrackAliasError) { + EXPECT_QUICHE_BUG(stream_->OnObjectMessage(kDefaultObject, "payload", true), + "Object delivered without preliminaries"); +} + +TEST_F(IncomingDataStreamTest, OnObjectMessageBufferPartialObject) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + MoqtObject object = kDefaultObject; + object.payload_length = 10; + EXPECT_CALL(visitor_, OnObjectFragment).Times(0); + stream_->OnObjectMessage(object, "foo", false); + EXPECT_CALL(visitor_, OnObjectFragment); + stream_->OnObjectMessage(object, "bar", true); +} + +TEST_F(IncomingDataStreamTest, OnObjectMessageInvalidTrack) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + uint8_t alias = 2; + mock_stream_.Receive( + absl::string_view(reinterpret_cast<const char*>(&alias), 1), false); + EXPECT_CALL(session_, GetSubscribe(2)) + .WillOnce(Return(quiche::QuicheWeakPtr<RemoteTrack>())); + stream_->OnCanRead(); + EXPECT_TRUE(mock_stream_.was_reset()); +} + +TEST_F(IncomingDataStreamTest, OnObjectMessageNotInWindow) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + track_->parameters().set_forward(false); + EXPECT_CALL(visitor_, OnObjectFragment).Times(0); + stream_->OnObjectMessage(kDefaultObject, "", true); +} + +TEST_F(IncomingDataStreamTest, OnObjectMessageMissingSubgroupId) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + MoqtObject object = kDefaultObject; + object.subgroup_id = std::nullopt; + EXPECT_QUICHE_BUG(stream_->OnObjectMessage(object, "", true), + "Missing subgroup ID on SUBSCRIBE stream"); +} + +TEST_F(IncomingDataStreamTest, OnObjectMessageMalformedTrack) { + ProcessStreamType(MoqtDataStreamType::Subgroup(0, 0, false, 0x80)); + ProcessAlias(2); + MoqtObject object = kDefaultObject; + object.object_status = MoqtObjectStatus::kEndOfTrack; + EXPECT_CALL(visitor_, OnObjectFragment); + stream_->OnObjectMessage(object, "", true); + + EXPECT_CALL(session_, OnMalformedTrack(track_.get())); + MoqtObject object2 = object; + object2.object_id = 1; + object2.object_status = MoqtObjectStatus::kNormal; + stream_->OnObjectMessage(object2, "", true); +} + +TEST_F(IncomingDataStreamTest, MaybeReadOneObjectUnexpectedState) { + EXPECT_QUICHE_BUG(stream_->MaybeReadOneObject(), + "Requesting object, parser in unexpected state"); +} + +TEST_F(IncomingDataStreamTest, OnCanReadFetchNewTrackAliasInvalidFetch) { + char fetch_bytes[] = {0x05, 0x03}; + mock_stream_.Receive(absl::string_view(fetch_bytes, 2), false); + EXPECT_CALL(session_, GetFetch(3)) + .WillOnce(Return(quiche::QuicheWeakPtr<RemoteTrack>())); + stream_->OnCanRead(); + EXPECT_TRUE(mock_stream_.was_reset()); +} + +TEST_F(IncomingDataStreamTest, OnCanReadFetchNewTrackAliasSuccess) { + MoqtFetch fetch; + fetch.request_id = 3; + StandaloneFetch standalone(ftn_, Location(0, 0), Location(0, 9)); + auto upstream_fetch = std::make_unique<UpstreamFetch>( + fetch, standalone, [](std::unique_ptr<MoqtFetchTask>) {}, []() {}); + upstream_fetch->OnFetchResult(Location(0, 0), absl::OkStatus(), []() {}); + EXPECT_CALL(session_, GetFetch(3)) + .WillOnce(Return(upstream_fetch->weak_ptr())); + char fetch_bytes[] = {0x05, 0x03}; + mock_stream_.Receive(absl::string_view(fetch_bytes, 2), false); + stream_->OnCanRead(); +} + } // namespace } // namespace moqt::test
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index a4ba325..d70c396 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -5,7 +5,6 @@ #ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_ #define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_SESSION_PEER_H_ -#include <cstddef> #include <cstdint> #include <memory> #include <optional> @@ -27,8 +26,6 @@ #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_parser.h" -#include "quiche/quic/moqt/moqt_priority.h" -#include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_session.h" #include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/moqt_subscription.h" @@ -50,6 +47,10 @@ parser->type_ = type; parser->next_input_ = MoqtDataParser::NextInput::kTrackAlias; } + static void SetTrackAlias(MoqtDataParser* parser, uint64_t track_alias) { + parser->metadata_.track_alias = track_alias; + parser->next_input_ = MoqtDataParser::NextInput::kGroupId; + } }; // Helper class to interact with MOQT bidi streams in tests. @@ -102,18 +103,24 @@ static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream( MoqtSession* session, webtransport::Stream* stream, - MoqtDataStreamType type) { - auto new_stream = - std::make_unique<MoqtSession::IncomingDataStream>(session, stream); + MoqtDataStreamType type, + std::optional<uint64_t> track_alias = std::nullopt, + SubscribeVisitor* visitor = nullptr) { + auto new_stream = std::make_unique<IncomingDataStream>( + stream, session, session->callbacks_.clock); MoqtDataParserPeer::SetType(&new_stream->parser_, type); + if (track_alias.has_value()) { + MoqtDataParserPeer::SetTrackAlias(&new_stream->parser_, *track_alias); + new_stream->visitor_ = visitor; + } return new_stream; } static std::unique_ptr<webtransport::StreamVisitor> CreateIncomingStreamVisitor(MoqtSession* session, webtransport::Stream* stream) { - auto new_stream = - std::make_unique<MoqtSession::IncomingDataStream>(session, stream); + auto new_stream = std::make_unique<IncomingDataStream>( + stream, session, session->callbacks_.clock); return new_stream; } @@ -140,10 +147,23 @@ const MoqtSubscribe& subscribe, const std::optional<uint64_t> track_alias, SubscribeVisitor* visitor) { - auto track = std::make_unique<SubscribeRemoteTrack>(subscribe, visitor); + auto track = std::make_unique<SubscribeRemoteTrack>( + subscribe, visitor, + [session = session, ftn = subscribe.full_track_name, + id = subscribe.request_id]() { + session->subscribe_by_name_.erase(ftn); + session->upstream_by_id_.erase(id); + }, + [session = session](uint64_t alias, SubscribeRemoteTrack* track) { + if (track == nullptr) { + session->subscribe_by_alias_.erase(alias); + return true; + } + session->subscribe_by_alias_[alias] = track; + return true; + }); if (track_alias.has_value()) { - track->set_track_alias(*track_alias); - session->subscribe_by_alias_.try_emplace(*track_alias, track.get()); + ASSERT_TRUE(track->set_track_alias(*track_alias)); } session->subscribe_by_name_.try_emplace(subscribe.full_track_name, track.get()); @@ -192,7 +212,8 @@ fetch_message, std::get<StandaloneFetch>(fetch_message.fetch), [&](std::unique_ptr<MoqtFetchTask> fetch_task) { task = std::move(fetch_task); - })); + }, + [session = session]() { session->upstream_by_id_.erase(0); })); QUICHE_DCHECK(success); UpstreamFetch* fetch = absl::down_cast<UpstreamFetch*>(it->second.get()); // Initialize the fetch task
diff --git a/quiche/web_transport/test_tools/in_memory_stream.h b/quiche/web_transport/test_tools/in_memory_stream.h index 5a27bde..340fe7b 100644 --- a/quiche/web_transport/test_tools/in_memory_stream.h +++ b/quiche/web_transport/test_tools/in_memory_stream.h
@@ -76,6 +76,7 @@ } bool fin_sent() const { return fin_sent_; } + bool was_reset() const { return abruptly_terminated_; } protected: virtual void OnWrite(absl::string_view data) {}