Refactor MoQT Upstream SUBSCRIBE processing and data structures. Streamline and simplify in preparation for also sending FETCH. - Move subscribe-specific properties to SubscribeRemoteTrack, derived from RemoteTrack. - A map of subscribes by track_alias_. Also, a separate map of pointers to those records indexed by subscribe_id_. pointers to FetchTask will also be stored in the latter. - Eliminate "active subscribes" and just put the state into SubscribeRemoteTrack, since we are now indexing all subscribe by ID. - Allow access to the stream type in the parser to verify it's correct for the subscription. - Based on @vasilvv suggestion, deleted a check in quiche_weak_ptr.h that was misfiring. - Added support for Upstream UNSUBSCRIBE to avoid a problem with ChatClient not clearing subscription state. Not in production. PiperOrigin-RevId: 700756286
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index c7630a5..dc1820b 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -149,7 +149,7 @@ EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call(FullTrackName{"foo"})) .WillOnce(Return(std::nullopt)); - MockRemoteTrackVisitor server_visitor; + MockSubscribeRemoteTrackVisitor server_visitor; testing::MockFunction<void( FullTrackName track_namespace, std::optional<MoqtAnnounceErrorReason> error_message)> @@ -166,7 +166,7 @@ EXPECT_FALSE(error.has_value()); server_->session()->SubscribeCurrentGroup(track_name, &server_visitor); }); - EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() { + EXPECT_CALL(server_visitor, OnReply(_, _, _)).WillOnce([&]() { matches = true; }); bool success = @@ -179,7 +179,7 @@ // Set up the server to subscribe to "data" track for the namespace announce // it receives. - MockRemoteTrackVisitor server_visitor; + MockSubscribeRemoteTrackVisitor server_visitor; EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call(_)) .WillOnce([&](FullTrackName track_namespace) { FullTrackName track_name = track_namespace; @@ -196,7 +196,7 @@ client_->session()->set_publisher(&known_track_publisher); queue->AddObject(MemSliceFromString("object data"), /*key=*/true); bool received_subscribe_ok = false; - EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() { + EXPECT_CALL(server_visitor, OnReply(_, _, _)).WillOnce([&]() { received_subscribe_ok = true; }); client_->session()->Announce( @@ -232,7 +232,7 @@ {MoqtForwardingPreference::kSubgroup, MoqtForwardingPreference::kDatagram}) { SCOPED_TRACE(MoqtForwardingPreferenceToString(forwarding_preference)); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; std::string name = absl::StrCat("pref_", static_cast<int>(forwarding_preference)); auto queue = std::make_shared<MoqtOutgoingQueue>( @@ -295,7 +295,7 @@ {MoqtForwardingPreference::kSubgroup, MoqtForwardingPreference::kDatagram}) { SCOPED_TRACE(MoqtForwardingPreferenceToString(forwarding_preference)); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; std::string name = absl::StrCat("pref_", static_cast<int>(forwarding_preference)); auto queue = std::make_shared<MoqtOutgoingQueue>( @@ -379,10 +379,10 @@ auto track_publisher = std::make_shared<MockTrackPublisher>(full_track_name); publisher.Add(track_publisher); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; std::optional<absl::string_view> expected_reason = std::nullopt; bool received_ok = false; - EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason)) + EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason)) .WillOnce([&]() { received_ok = true; }); client_->session()->SubscribeAbsolute(full_track_name, 0, 0, &client_visitor); bool success = @@ -399,10 +399,10 @@ auto track_publisher = std::make_shared<MockTrackPublisher>(full_track_name); publisher.Add(track_publisher); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; std::optional<absl::string_view> expected_reason = std::nullopt; bool received_ok = false; - EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason)) + EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason)) .WillOnce([&]() { received_ok = true; }); client_->session()->SubscribeCurrentObject(full_track_name, &client_visitor); bool success = @@ -419,10 +419,10 @@ auto track_publisher = std::make_shared<MockTrackPublisher>(full_track_name); publisher.Add(track_publisher); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; std::optional<absl::string_view> expected_reason = std::nullopt; bool received_ok = false; - EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason)) + EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason)) .WillOnce([&]() { received_ok = true; }); client_->session()->SubscribeCurrentGroup(full_track_name, &client_visitor); bool success = @@ -433,10 +433,10 @@ TEST_F(MoqtIntegrationTest, SubscribeError) { EstablishSession(); FullTrackName full_track_name("foo", "bar"); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; std::optional<absl::string_view> expected_reason = "No tracks published"; bool received_ok = false; - EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason)) + EXPECT_CALL(client_visitor, OnReply(full_track_name, _, expected_reason)) .WillOnce([&]() { received_ok = true; }); client_->session()->SubscribeCurrentObject(full_track_name, &client_visitor); bool success = @@ -452,7 +452,7 @@ ConnectEndpoints(); FullTrackName full_track_name("foo", "bar"); - MockRemoteTrackVisitor client_visitor; + MockSubscribeRemoteTrackVisitor client_visitor; MoqtKnownTrackPublisher publisher; server_->session()->set_publisher(&publisher); @@ -468,8 +468,9 @@ .WillOnce([&](MoqtObjectAckFunction new_ack_function) { ack_function = std::move(new_ack_function); }); - EXPECT_CALL(client_visitor, OnReply(_, _)) - .WillOnce([&](const FullTrackName&, std::optional<absl::string_view>) { + EXPECT_CALL(client_visitor, OnReply(_, _, _)) + .WillOnce([&](const FullTrackName&, std::optional<FullSequence>, + std::optional<absl::string_view>) { ack_function(10, 20, quic::QuicTimeDelta::FromMicroseconds(-123)); ack_function(100, 200, quic::QuicTimeDelta::FromMicroseconds(456)); });
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index 58795c4..34953e4 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -20,6 +20,10 @@ namespace moqt { +namespace test { +class MoqtDataParserPeer; +} + class QUICHE_EXPORT MoqtControlParserVisitor { public: virtual ~MoqtControlParserVisitor() = default; @@ -195,7 +199,11 @@ // be used for testing. void set_chunk_size(size_t size) { chunk_size_ = size; } + std::optional<MoqtDataStreamType> stream_type() const { return type_; } + private: + friend class test::MoqtDataParserPeer; + // If there is buffered data from the previous attempt at parsing it, new data // will be added in `chunk_size_`-sized chunks. constexpr static size_t kDefaultChunkSize = 64;
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 0bf32bf..10b5939 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -198,11 +198,27 @@ << message.object_id << " priority " << message.publisher_priority << " length " << payload.size(); - auto [full_track_name, visitor] = - TrackPropertiesFromAlias(message, MoqtForwardingPreference::kDatagram); + SubscribeRemoteTrack* track = RemoteTrackByAlias(message.track_alias); + if (track == nullptr) { + return; + } + if (!track->CheckDataStreamType(MoqtDataStreamType::kObjectDatagram)) { + Error(MoqtError::kProtocolViolation, + "Received DATAGRAM for non-datagram track"); + return; + } + if (!track->InWindow(FullSequence(message.group_id, message.object_id))) { + // TODO(martinduke): a recent SUBSCRIBE_UPDATE could put us here, and it's + // not an error. + return; + } + QUICHE_CHECK(!track->is_fetch()); + track->OnObjectOrOk(); + SubscribeRemoteTrack::Visitor* visitor = track->visitor(); if (visitor != nullptr) { visitor->OnObjectFragment( - full_track_name, FullSequence{message.group_id, 0, message.object_id}, + track->full_track_name(), + FullSequence{message.group_id, 0, message.object_id}, message.publisher_priority, message.object_status, payload, true); } } @@ -248,7 +264,7 @@ bool MoqtSession::SubscribeAbsolute(const FullTrackName& name, uint64_t start_group, uint64_t start_object, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters) { MoqtSubscribe message; message.full_track_name = name; @@ -265,7 +281,7 @@ bool MoqtSession::SubscribeAbsolute(const FullTrackName& name, uint64_t start_group, uint64_t start_object, uint64_t end_group, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters) { if (end_group < start_group) { QUIC_DLOG(ERROR) << "Subscription end is before beginning"; @@ -286,7 +302,7 @@ bool MoqtSession::SubscribeAbsolute(const FullTrackName& name, uint64_t start_group, uint64_t start_object, uint64_t end_group, uint64_t end_object, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters) { if (end_group < start_group) { QUIC_DLOG(ERROR) << "Subscription end is before beginning"; @@ -309,7 +325,7 @@ } bool MoqtSession::SubscribeCurrentObject(const FullTrackName& name, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters) { MoqtSubscribe message; message.full_track_name = name; @@ -324,7 +340,7 @@ } bool MoqtSession::SubscribeCurrentGroup(const FullTrackName& name, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters) { MoqtSubscribe message; message.full_track_name = name; @@ -339,6 +355,21 @@ return Subscribe(message, visitor); } +void MoqtSession::Unsubscribe(const FullTrackName& name) { + RemoteTrack* track = RemoteTrackByName(name); + if (track == nullptr) { + return; + } + MoqtUnsubscribe message; + message.subscribe_id = track->subscribe_id(); + SendControlMessage(framer_.SerializeUnsubscribe(message)); + // Destroy state. + upstream_by_name_.erase(name); + upstream_by_id_.erase(track->subscribe_id()); + subscribe_by_alias_.erase( + static_cast<SubscribeRemoteTrack*>(track)->track_alias()); +} + void MoqtSession::PublishedFetch::FetchStreamVisitor::OnCanWrite() { std::shared_ptr<PublishedFetch> fetch = fetch_.lock(); if (fetch == nullptr) { @@ -411,7 +442,8 @@ } bool MoqtSession::Subscribe(MoqtSubscribe& message, - RemoteTrack::Visitor* visitor) { + SubscribeRemoteTrack::Visitor* visitor, + std::optional<uint64_t> provided_track_alias) { if (peer_role_ == MoqtRole::kSubscriber) { QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE to subscriber peer"; return false; @@ -424,16 +456,20 @@ << peer_max_subscribe_id_; return false; } - message.subscribe_id = next_subscribe_id_++; - auto it = remote_track_aliases_.find(message.full_track_name); - if (it != remote_track_aliases_.end()) { - message.track_alias = it->second; - if (message.track_alias >= next_remote_track_alias_) { - next_remote_track_alias_ = message.track_alias + 1; - } - } else { - message.track_alias = next_remote_track_alias_++; + if (upstream_by_name_.contains(message.full_track_name)) { + QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE for track " + << message.full_track_name + << " which is already subscribed"; + return false; } + if (provided_track_alias.has_value() && + subscribe_by_alias_.contains(*provided_track_alias)) { + Error(MoqtError::kProtocolViolation, "Provided track alias already in use"); + return false; + } + message.subscribe_id = next_subscribe_id_++; + message.track_alias = + provided_track_alias.value_or(next_remote_track_alias_++); if (SupportsObjectAck() && visitor != nullptr) { // Since we do not expose subscribe IDs directly in the API, instead wrap // the session and subscribe ID in a callback. @@ -448,7 +484,10 @@ SendControlMessage(framer_.SerializeSubscribe(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for " << message.full_track_name; - active_subscribes_.try_emplace(message.subscribe_id, message, visitor); + auto track = std::make_unique<SubscribeRemoteTrack>(message, visitor); + upstream_by_name_.emplace(message.full_track_name, track.get()); + upstream_by_id_.emplace(message.subscribe_id, track.get()); + subscribe_by_alias_.emplace(message.track_alias, std::move(track)); return true; } @@ -503,6 +542,30 @@ return true; } +SubscribeRemoteTrack* MoqtSession::RemoteTrackByAlias(uint64_t track_alias) { + auto it = subscribe_by_alias_.find(track_alias); + if (it == subscribe_by_alias_.end()) { + return nullptr; + } + return it->second.get(); +} + +RemoteTrack* MoqtSession::RemoteTrackById(uint64_t subscribe_id) { + auto it = upstream_by_id_.find(subscribe_id); + if (it == upstream_by_id_.end()) { + return nullptr; + } + return it->second; +} + +RemoteTrack* MoqtSession::RemoteTrackByName(const FullTrackName& name) { + auto it = upstream_by_name_.find(name); + if (it == upstream_by_name_.end()) { + return nullptr; + } + return it->second; +} + void MoqtSession::OnCanCreateNewOutgoingUnidirectionalStream() { while (!subscribes_with_queued_outgoing_data_streams_.empty() && session_->CanOpenNextOutgoingUnidirectionalStream()) { @@ -555,58 +618,6 @@ SendControlMessage(framer_.SerializeMaxSubscribeId(message)); } -std::pair<FullTrackName, RemoteTrack::Visitor*> -MoqtSession::TrackPropertiesFromAlias( - const MoqtObject& message, - std::optional<MoqtForwardingPreference> forwarding_preference) { - auto it = remote_tracks_.find(message.track_alias); - if (it == remote_tracks_.end()) { - ActiveSubscribe* subscribe = nullptr; - // SUBSCRIBE_OK has not arrived yet, but deliver the object. Indexing - // active_subscribes_ by track alias would make this faster if the - // subscriber has tons of incomplete subscribes. - for (auto& open_subscribe : active_subscribes_) { - if (open_subscribe.second.message.track_alias == message.track_alias) { - subscribe = &open_subscribe.second; - break; - } - } - if (subscribe == nullptr) { - return std::pair<FullTrackName, RemoteTrack::Visitor*>( - {FullTrackName{}, nullptr}); - } - subscribe->received_object = true; - if (forwarding_preference.has_value()) { - if (subscribe->forwarding_preference.has_value()) { - if (*forwarding_preference != *subscribe->forwarding_preference) { - Error(MoqtError::kProtocolViolation, - "Forwarding preference changes mid-track"); - return std::pair<FullTrackName, RemoteTrack::Visitor*>( - {FullTrackName{}, nullptr}); - } - } else { - subscribe->forwarding_preference = *forwarding_preference; - } - } else { - QUICHE_BUG(quic_subscribe_no_forwarding preference) - << "Objects from a subscribe should know the forwarding preference"; - } - return std::make_pair(subscribe->message.full_track_name, - subscribe->visitor); - } - RemoteTrack& track = it->second; - // Update the forwarding preference if it is present. - if (forwarding_preference.has_value() && - !track.CheckForwardingPreference(*forwarding_preference)) { - // Incorrect forwarding preference. - Error(MoqtError::kProtocolViolation, - "Forwarding preference changes mid-track"); - return std::pair<FullTrackName, RemoteTrack::Visitor*>( - {FullTrackName{}, nullptr}); - } - return std::make_pair(track.full_track_name(), track.visitor()); -} - bool MoqtSession::ValidateSubscribeId(uint64_t subscribe_id) { if (peer_role_ == MoqtRole::kPublisher) { QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE"; @@ -820,66 +831,75 @@ void MoqtSession::ControlStream::OnSubscribeOkMessage( const MoqtSubscribeOk& message) { - auto it = session_->active_subscribes_.find(message.subscribe_id); - if (it == session_->active_subscribes_.end()) { - session_->Error(MoqtError::kProtocolViolation, - "Received SUBSCRIBE_OK for nonexistent subscribe"); + RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id); + if (track == nullptr) { + QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for " + << "subscribe_id = " << message.subscribe_id + << " but no track exists"; + // Subscription state might have been destroyed for internal reasons. return; } - MoqtSubscribe& subscribe = it->second.message; + if (track->is_fetch()) { + session_->Error(MoqtError::kProtocolViolation, + "Received SUBSCRIBE_OK for a FETCH"); + return; + } QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for " << "subscribe_id = " << message.subscribe_id << " " - << subscribe.full_track_name; - // Copy the Remote Track from session_->active_subscribes_ to - // session_->remote_tracks_. - RemoteTrack::Visitor* visitor = it->second.visitor; - auto [track_iter, new_entry] = session_->remote_tracks_.try_emplace( - subscribe.track_alias, subscribe.full_track_name, subscribe.track_alias, - visitor); - if (it->second.forwarding_preference.has_value()) { - if (!track_iter->second.CheckForwardingPreference( - *it->second.forwarding_preference)) { - session_->Error(MoqtError::kProtocolViolation, - "Forwarding preference different in early objects"); - return; - } + << track->full_track_name(); + SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track); + subscribe->OnObjectOrOk(); + // TODO(martinduke): Handle expires field. + // TODO(martinduke): Resize the window based on largest_id. + if (subscribe->visitor() != nullptr) { + subscribe->visitor()->OnReply(track->full_track_name(), message.largest_id, + std::nullopt); } - // TODO: handle expires. - if (visitor != nullptr) { - visitor->OnReply(subscribe.full_track_name, std::nullopt); - } - session_->active_subscribes_.erase(it); + subscribe->OnObjectOrOk(); } void MoqtSession::ControlStream::OnSubscribeErrorMessage( const MoqtSubscribeError& message) { - auto it = session_->active_subscribes_.find(message.subscribe_id); - if (it == session_->active_subscribes_.end()) { - session_->Error(MoqtError::kProtocolViolation, - "Received SUBSCRIBE_ERROR for nonexistent subscribe"); + RemoteTrack* track = session_->RemoteTrackById(message.subscribe_id); + if (track == nullptr) { + QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for " + << "subscribe_id = " << message.subscribe_id + << " but no track exists"; + // Subscription state might have been destroyed for internal reasons. return; } - if (it->second.received_object) { + if (track->is_fetch()) { session_->Error(MoqtError::kProtocolViolation, - "Received SUBSCRIBE_ERROR after object"); + "Received SUBSCRIBE_ERROR for a FETCH"); return; } - MoqtSubscribe& subscribe = it->second.message; + if (!track->ErrorIsAllowed()) { + session_->Error(MoqtError::kProtocolViolation, + "Received SUBSCRIBE_ERROR after SUBSCRIBE_OK or objects"); + return; + } QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for " << "subscribe_id = " << message.subscribe_id << " (" - << subscribe.full_track_name << ")" + << track->full_track_name() << ")" << ", error = " << static_cast<int>(message.error_code) << " (" << message.reason_phrase << ")"; - RemoteTrack::Visitor* visitor = it->second.visitor; + SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track); + // Delete secondary references to the track. Preserve the owner + // (subscribe_by_alias_) to get the original subscribe, if needed. Erasing the + // other references now prevents an error due to a duplicate subscription in + // Subscribe(). + session_->upstream_by_id_.erase(subscribe->subscribe_id()); + session_->upstream_by_name_.erase(subscribe->full_track_name()); if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) { // Automatically resubscribe with new alias. - session_->remote_track_aliases_[subscribe.full_track_name] = - message.track_alias; - session_->Subscribe(subscribe, visitor); - } else if (visitor != nullptr) { - visitor->OnReply(subscribe.full_track_name, message.reason_phrase); + MoqtSubscribe& subscribe_message = subscribe->GetSubscribe(); + session_->Subscribe(subscribe_message, subscribe->visitor(), + message.track_alias); + } else if (subscribe->visitor() != nullptr) { + subscribe->visitor()->OnReply(subscribe->full_track_name(), std::nullopt, + message.reason_phrase); } - session_->active_subscribes_.erase(it); + session_->subscribe_by_alias_.erase(subscribe->track_alias()); } void MoqtSession::ControlStream::OnUnsubscribeMessage( @@ -1084,11 +1104,37 @@ payload = absl::string_view(partial_object_); } } - auto [full_track_name, visitor] = session_->TrackPropertiesFromAlias( - message, MoqtForwardingPreference::kSubgroup); - if (visitor != nullptr) { - visitor->OnObjectFragment( - full_track_name, + QUICHE_BUG_IF(quic_bug_object_with_no_stream_type, + !parser_.stream_type().has_value()) + << "Object delivered without a stream type"; + // Get a pointer to the upstream state. + RemoteTrack* track = track_.GetIfAvaliable(); + if (track == nullptr) { + track = (*parser_.stream_type() == MoqtDataStreamType::kStreamHeaderFetch) + // message.track_alias is actually a fetch ID for fetches. + ? session_->RemoteTrackById(message.track_alias) + : session_->RemoteTrackByAlias(message.track_alias); + if (track == nullptr) { + stream_->SendStopSending(kResetCodeSubscriptionGone); + // Received object for nonexistent track. + return; + } + track_ = track->weak_ptr(); + } + if (!track->CheckDataStreamType(*parser_.stream_type())) { + session_->Error(MoqtError::kProtocolViolation, + "Received object for a track with a different stream type"); + return; + } + if (!track->InWindow(FullSequence(message.group_id, message.object_id))) { + // This is not an error. It can be the result of a recent SUBSCRIBE_UPDATE. + return; + } + track->OnObjectOrOk(); + SubscribeRemoteTrack* subscribe = static_cast<SubscribeRemoteTrack*>(track); + if (subscribe->visitor() != nullptr) { + subscribe->visitor()->OnObjectFragment( + track->full_track_name(), FullSequence{message.group_id, message.subgroup_id.value_or(0), message.object_id}, message.publisher_priority, message.object_status, payload,
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index c4ae794..628f188 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -29,6 +29,7 @@ #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_callbacks.h" +#include "quiche/common/quiche_weak_ptr.h" #include "quiche/web_transport/web_transport.h" namespace moqt { @@ -117,24 +118,28 @@ // Subscribe from (start_group, start_object) to the end of the track. bool SubscribeAbsolute( const FullTrackName& name, uint64_t start_group, uint64_t start_object, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); // Subscribe from (start_group, start_object) to the end of end_group. bool SubscribeAbsolute( const FullTrackName& name, uint64_t start_group, uint64_t start_object, - uint64_t end_group, RemoteTrack::Visitor* visitor, + uint64_t end_group, SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); // Subscribe from (start_group, start_object) to (end_group, end_object). bool SubscribeAbsolute( const FullTrackName& name, uint64_t start_group, uint64_t start_object, - uint64_t end_group, uint64_t end_object, RemoteTrack::Visitor* visitor, + uint64_t end_group, uint64_t end_object, + SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); bool SubscribeCurrentObject( - const FullTrackName& name, RemoteTrack::Visitor* visitor, + const FullTrackName& name, SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); bool SubscribeCurrentGroup( - const FullTrackName& name, RemoteTrack::Visitor* visitor, + const FullTrackName& name, SubscribeRemoteTrack::Visitor* visitor, MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); + // Returns false if the subscription is not found. The session immediately + // destroys all subscription state. + void Unsubscribe(const FullTrackName& name); webtransport::Session* session() { return session_; } MoqtSessionCallbacks& callbacks() { return callbacks_; } @@ -285,6 +290,9 @@ MoqtSession* session_; webtransport::Stream* stream_; + // Once the subscribe ID is identified, set it here. + quiche::QuicheWeakPtr<RemoteTrack> track_; + // std::optional<uint64_t> subscribe_id_ = std::nullopt; MoqtDataParser parser_; std::string partial_object_; }; @@ -494,8 +502,10 @@ // is present. void SendControlMessage(quiche::QuicheBuffer message); - // Returns false if the SUBSCRIBE isn't sent. - bool Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor); + // Returns false if the SUBSCRIBE isn't sent. |provided_track_alias| has a + // value only if this call is due to a SUBSCRIBE_ERROR. + bool Subscribe(MoqtSubscribe& message, SubscribeRemoteTrack::Visitor* visitor, + std::optional<uint64_t> provided_track_alias = std::nullopt); // Opens a new data stream, or queues it if the session is flow control // blocked. @@ -508,13 +518,9 @@ // Returns false if creation failed. [[nodiscard]] bool OpenDataStream(std::shared_ptr<PublishedFetch> fetch); - // Get FullTrackName and visitor for a subscribe_id and track_alias. Returns - // an empty FullTrackName tuple and nullptr if not present. If the caller has - // information about the track's forwarding preference, it can be passed via - // |forwarding_preference| so that it can be stored in RemoteTrack. - std::pair<FullTrackName, RemoteTrack::Visitor*> TrackPropertiesFromAlias( - const MoqtObject& message, - std::optional<MoqtForwardingPreference> forwarding_preference); + SubscribeRemoteTrack* RemoteTrackByAlias(uint64_t track_alias); + RemoteTrack* RemoteTrackById(uint64_t subscribe_id); + RemoteTrack* RemoteTrackByName(const FullTrackName& name); // Checks that a subscribe ID from a SUBSCRIBE or FETCH is valid, and throws // a session error if is not. @@ -557,11 +563,22 @@ bool peer_supports_object_ack_ = false; std::string error_; + // Upstream SUBSCRIBE state. // All the tracks the session is subscribed to, indexed by track_alias. - absl::flat_hash_map<uint64_t, RemoteTrack> remote_tracks_; - // Look up aliases for remote tracks by name - absl::flat_hash_map<FullTrackName, uint64_t> remote_track_aliases_; + absl::flat_hash_map<uint64_t, std::unique_ptr<SubscribeRemoteTrack>> + subscribe_by_alias_; + // Upstream SUBSCRIBEs indexed by subscribe_id. + // TODO(martinduke): Add fetches to this. + absl::flat_hash_map<uint64_t, RemoteTrack*> upstream_by_id_; + // The application only has track names, so this allows MoqtSession to + // quickly find what it's looking for. Also allows a quick check for duplicate + // subscriptions. + absl::flat_hash_map<FullTrackName, RemoteTrack*> upstream_by_name_; uint64_t next_remote_track_alias_ = 0; + // The next subscribe ID that the local endpoint can send. + uint64_t next_subscribe_id_ = 0; + // The maximum subscribe ID that the local endpoint can send. + uint64_t peer_max_subscribe_id_ = 0; // All open incoming subscriptions, indexed by track name, used to check for // duplicates. @@ -585,21 +602,6 @@ absl::flat_hash_map<uint64_t, std::shared_ptr<PublishedFetch>> incoming_fetches_; - // Indexed by subscribe_id. - struct ActiveSubscribe { - MoqtSubscribe message; - RemoteTrack::Visitor* visitor; - // The forwarding preference of the first received object, which all - // subsequent objects must match. - std::optional<MoqtForwardingPreference> forwarding_preference; - // If true, an object has arrived for the subscription before SUBSCRIBE_OK - // arrived. - bool received_object = false; - }; - // Outgoing SUBSCRIBEs that have not received SUBSCRIBE_OK or SUBSCRIBE_ERROR. - absl::flat_hash_map<uint64_t, ActiveSubscribe> active_subscribes_; - uint64_t next_subscribe_id_ = 0; - // Monitoring interfaces for expected incoming subscriptions. absl::flat_hash_map<FullTrackName, MoqtPublishingMonitorInterface*> monitoring_interfaces_for_published_tracks_; @@ -613,12 +615,10 @@ // parameter, and other checks have changed/been disabled. MoqtRole peer_role_ = MoqtRole::kPubSub; - // The maximum subscribe ID that the local endpoint can send. - uint64_t peer_max_subscribe_id_ = 0; - // The maximum subscribe ID sent to the peer. - uint64_t local_max_subscribe_id_ = 0; // The minimum subscribe ID the peer can use that is monotonically increasing. uint64_t next_incoming_subscribe_id_ = 0; + // The maximum subscribe ID sent to the peer. + uint64_t local_max_subscribe_id_ = 0; // Must be last. Token used to make sure that the streams do not call into // the session when the session has already been destroyed.
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 2e7aed0..3298cfc 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -24,7 +24,6 @@ #include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" -#include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" #include "quiche/quic/moqt/test_tools/moqt_session_peer.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" @@ -50,6 +49,21 @@ constexpr webtransport::StreamId kIncomingUniStreamId = 15; constexpr webtransport::StreamId kOutgoingUniStreamId = 14; +MoqtSubscribe DefaultSubscribe() { + MoqtSubscribe subscribe = { + /*subscribe_id=*/1, + /*track_alias=*/2, + /*full_track_name=*/FullTrackName("foo", "bar"), + /*subscriber_priority=*/0x80, + /*group_order=*/std::nullopt, + /*start_group=*/0, + /*start_object=*/0, + /*end_group=*/std::nullopt, + /*end_object=*/std::nullopt, + }; + return subscribe; +} + static std::shared_ptr<MockTrackPublisher> SetupPublisher( FullTrackName track_name, MoqtForwardingPreference forwarding_preference, FullSequence largest_sequence) { @@ -195,18 +209,7 @@ } TEST_F(MoqtSessionTest, AddLocalTrack) { - MoqtSubscribe request = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/FullTrackName({"foo", "bar"}), - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - /*parameters=*/MoqtSubscribeParameters(), - }; + MoqtSubscribe request = DefaultSubscribe(); webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); @@ -298,26 +301,13 @@ .WillRepeatedly(Return(FullSequence(10, 20))); publisher_.Add(track); - // Peer subscribes to (0, 0) - MoqtSubscribe request = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/FullTrackName({"foo", "bar"}), - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - /*parameters=*/MoqtSubscribeParameters(), - }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL( mock_stream, Writev(ControlMessageOfType(MoqtMessageType::kSubscribeError), _)); - stream_input->OnSubscribeMessage(request); + stream_input->OnSubscribeMessage(DefaultSubscribe()); } TEST_F(MoqtSessionTest, TwoSubscribesForTrack) { @@ -439,19 +429,7 @@ } TEST_F(MoqtSessionTest, SubscribeIdNotIncreasing) { - // Peer subscribes to (0, 0) - MoqtSubscribe request = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/FullTrackName({"foo", "bar"}), - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - /*parameters=*/MoqtSubscribeParameters(), - }; + MoqtSubscribe request = DefaultSubscribe(); webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); @@ -475,7 +453,7 @@ TEST_F(MoqtSessionTest, TooManySubscribes) { MoqtSessionPeer::set_next_subscribe_id(&session_, kDefaultInitialMaxSubscribeId); - MockRemoteTrackVisitor remote_track_visitor; + MockSubscribeRemoteTrackVisitor remote_track_visitor; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); @@ -488,11 +466,26 @@ &remote_track_visitor)); } +TEST_F(MoqtSessionTest, SubscribeDuplicateTrackName) { + MockSubscribeRemoteTrackVisitor remote_track_visitor; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + EXPECT_CALL(mock_session_, GetStreamById(_)) + .WillRepeatedly(Return(&mock_stream)); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); + EXPECT_TRUE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), + &remote_track_visitor)); + EXPECT_FALSE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), + &remote_track_visitor)); +} + TEST_F(MoqtSessionTest, SubscribeWithOk) { webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); - MockRemoteTrackVisitor remote_track_visitor; + MockSubscribeRemoteTrackVisitor remote_track_visitor; EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); EXPECT_CALL(mock_stream, Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); @@ -503,8 +496,9 @@ /*subscribe_id=*/0, /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0), }; - EXPECT_CALL(remote_track_visitor, OnReply(_, _)) + EXPECT_CALL(remote_track_visitor, OnReply(_, _, _)) .WillOnce([&](const FullTrackName& ftn, + std::optional<FullSequence> /*largest_id*/, std::optional<absl::string_view> error_message) { EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_FALSE(error_message.has_value()); @@ -515,7 +509,7 @@ TEST_F(MoqtSessionTest, MaxSubscribeIdChangesResponse) { MoqtSessionPeer::set_next_subscribe_id(&session_, kDefaultInitialMaxSubscribeId + 1); - MockRemoteTrackVisitor remote_track_visitor; + MockSubscribeRemoteTrackVisitor remote_track_visitor; EXPECT_FALSE(session_.SubscribeCurrentGroup(FullTrackName("foo", "bar"), &remote_track_visitor)); MoqtMaxSubscribeId max_subscribe_id = { @@ -590,7 +584,7 @@ webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); - MockRemoteTrackVisitor remote_track_visitor; + MockSubscribeRemoteTrackVisitor remote_track_visitor; EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); EXPECT_CALL(mock_stream, Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); @@ -603,8 +597,9 @@ /*reason_phrase=*/"deadbeef", /*track_alias=*/2, }; - EXPECT_CALL(remote_track_visitor, OnReply(_, _)) + EXPECT_CALL(remote_track_visitor, OnReply(_, _, _)) .WillOnce([&](const FullTrackName& ftn, + std::optional<FullSequence> /*largest_id*/, std::optional<absl::string_view> error_message) { EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_EQ(*error_message, "deadbeef"); @@ -612,6 +607,21 @@ stream_input->OnSubscribeErrorMessage(error); } +TEST_F(MoqtSessionTest, Unsubscribe) { + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + MockSubscribeRemoteTrackVisitor remote_track_visitor; + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), + &remote_track_visitor); + EXPECT_CALL(mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribe), _)); + EXPECT_NE(MoqtSessionPeer::remote_track(&session_, 2), nullptr); + session_.Unsubscribe(FullTrackName("foo", "bar")); + // State is destroyed. + EXPECT_EQ(MoqtSessionPeer::remote_track(&session_, 2), nullptr); +} + TEST_F(MoqtSessionTest, ReplyToAnnounce) { webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = @@ -630,10 +640,10 @@ } TEST_F(MoqtSessionTest, IncomingObject) { - MockRemoteTrackVisitor visitor_; + MockSubscribeRemoteTrackVisitor visitor_; FullTrackName ftn("foo", "bar"); std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2); + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -645,7 +655,8 @@ }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); + MoqtSessionPeer::CreateIncomingDataStream( + &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(1); EXPECT_CALL(mock_stream, GetStreamId()) @@ -654,10 +665,10 @@ } TEST_F(MoqtSessionTest, IncomingPartialObject) { - MockRemoteTrackVisitor visitor_; + MockSubscribeRemoteTrackVisitor visitor_; FullTrackName ftn("foo", "bar"); std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2); + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -669,7 +680,8 @@ }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); + MoqtSessionPeer::CreateIncomingDataStream( + &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(1); EXPECT_CALL(mock_stream, GetStreamId()) @@ -683,10 +695,10 @@ parameters.deliver_partial_objects = true; MoqtSession session(&mock_session_, parameters, session_callbacks_.AsSessionCallbacks()); - MockRemoteTrackVisitor visitor_; + MockSubscribeRemoteTrackVisitor visitor_; FullTrackName ftn("foo", "bar"); std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session, ftn, &visitor_, 2); + MoqtSessionPeer::CreateRemoteTrack(&session, DefaultSubscribe(), &visitor_); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -698,7 +710,8 @@ }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session, &mock_stream); + MoqtSessionPeer::CreateIncomingDataStream( + &session, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(2); EXPECT_CALL(mock_stream, GetStreamId()) @@ -708,21 +721,10 @@ } TEST_F(MoqtSessionTest, ObjectBeforeSubscribeOk) { - MockRemoteTrackVisitor visitor_; + MockSubscribeRemoteTrackVisitor visitor_; FullTrackName ftn("foo", "bar"); std::string payload = "deadbeef"; - MoqtSubscribe subscribe = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/ftn, - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - }; - MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor_); + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -734,7 +736,8 @@ }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); + MoqtSessionPeer::CreateIncomingDataStream( + &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)) .WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence, @@ -758,26 +761,15 @@ webtransport::test::MockStream mock_control_stream; std::unique_ptr<MoqtControlParserVisitor> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); - EXPECT_CALL(visitor_, OnReply(_, _)).Times(1); + EXPECT_CALL(visitor_, OnReply(_, _, _)).Times(1); control_stream->OnSubscribeOkMessage(ok); } TEST_F(MoqtSessionTest, ObjectBeforeSubscribeError) { - MockRemoteTrackVisitor visitor; + MockSubscribeRemoteTrackVisitor visitor; FullTrackName ftn("foo", "bar"); std::string payload = "deadbeef"; - MoqtSubscribe subscribe = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/ftn, - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - }; - MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor); + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -789,7 +781,8 @@ }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); + MoqtSessionPeer::CreateIncomingDataStream( + &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _)) .WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence, @@ -813,123 +806,53 @@ webtransport::test::MockStream mock_control_stream; std::unique_ptr<MoqtControlParserVisitor> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); - EXPECT_CALL(mock_session_, - CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), - "Received SUBSCRIBE_ERROR after object")) + EXPECT_CALL( + mock_session_, + CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), + "Received SUBSCRIBE_ERROR after SUBSCRIBE_OK or objects")) .Times(1); control_stream->OnSubscribeErrorMessage(subscribe_error); } -TEST_F(MoqtSessionTest, TwoEarlyObjectsDifferentForwarding) { - MockRemoteTrackVisitor visitor; - FullTrackName ftn("foo", "bar"); - std::string payload = "deadbeef"; - MoqtSubscribe subscribe = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/ftn, - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - }; - MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/8, - }; - webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); +TEST_F(MoqtSessionTest, SubscribeErrorWithTrackAlias) { + MockSubscribeRemoteTrackVisitor visitor; + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor); - EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _)) - .WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence, - MoqtPriority publisher_priority, MoqtObjectStatus status, - absl::string_view payload, bool end_of_message) { - EXPECT_EQ(full_track_name, ftn); - EXPECT_EQ(sequence.group, object.group_id); - EXPECT_EQ(sequence.object, object.object_id); - }); - EXPECT_CALL(mock_stream, GetStreamId()) - .WillRepeatedly(Return(kIncomingUniStreamId)); - object_stream->OnObjectMessage(object, payload, true); - char datagram[] = {0x01, 0x02, 0x00, 0x00, 0x00, 0x08, 0x64, - 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66}; - EXPECT_CALL(mock_session_, - CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), - "Forwarding preference changes mid-track")) + // SUBSCRIBE_ERROR arrives + MoqtSubscribeError subscribe_error = { + /*subscribe_id=*/1, + /*error_code=*/SubscribeErrorCode::kRetryTrackAlias, + /*reason_phrase=*/"foo", + /*track_alias =*/3, + }; + webtransport::test::MockStream mock_control_stream; + std::unique_ptr<MoqtControlParserVisitor> control_stream = + MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); + EXPECT_CALL(mock_control_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)) .Times(1); - session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram))); + control_stream->OnSubscribeErrorMessage(subscribe_error); } -TEST_F(MoqtSessionTest, EarlyObjectForwardingDoesNotMatchTrack) { - MockRemoteTrackVisitor visitor; - FullTrackName ftn("foo", "bar"); - std::string payload = "deadbeef"; - MoqtSubscribe subscribe = { +TEST_F(MoqtSessionTest, SubscribeErrorWithBadTrackAlias) { + MockSubscribeRemoteTrackVisitor visitor; + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor); + + // SUBSCRIBE_ERROR arrives + MoqtSubscribeError subscribe_error = { /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/ftn, - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - }; - MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor); - MoqtObject object = { - /*track_alias=*/2, - /*group_sequence=*/0, - /*object_sequence=*/0, - /*publisher_priority=*/0, - /*object_status=*/MoqtObjectStatus::kNormal, - /*subgroup_id=*/0, - /*payload_length=*/8, - }; - webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); - - EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _)) - .WillOnce([&](const FullTrackName& full_track_name, FullSequence sequence, - MoqtPriority publisher_priority, MoqtObjectStatus status, - absl::string_view payload, bool end_of_message) { - EXPECT_EQ(full_track_name, ftn); - EXPECT_EQ(sequence.group, object.group_id); - EXPECT_EQ(sequence.object, object.object_id); - }); - EXPECT_CALL(mock_stream, GetStreamId()) - .WillRepeatedly(Return(kIncomingUniStreamId)); - object_stream->OnObjectMessage(object, payload, true); - - MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor, 2); - // The track already exists, and has a different forwarding preference. - MoqtSessionPeer::remote_track(&session_, 2) - .CheckForwardingPreference(MoqtForwardingPreference::kDatagram); - - // SUBSCRIBE_OK arrives - MoqtSubscribeOk ok = { - /*subscribe_id=*/1, - /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0), - /*group_order=*/MoqtDeliveryOrder::kAscending, - /*largest_id=*/std::nullopt, + /*error_code=*/SubscribeErrorCode::kRetryTrackAlias, + /*reason_phrase=*/"foo", + /*track_alias =*/2, }; webtransport::test::MockStream mock_control_stream; std::unique_ptr<MoqtControlParserVisitor> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), - "Forwarding preference different in early objects")) + "Provided track alias already in use")) .Times(1); - control_stream->OnSubscribeOkMessage(ok); + control_stream->OnSubscribeErrorMessage(subscribe_error); } TEST_F(MoqtSessionTest, CreateOutgoingDataStreamAndSend) { @@ -1428,10 +1351,10 @@ } TEST_F(MoqtSessionTest, ReceiveDatagram) { - MockRemoteTrackVisitor visitor_; + MockSubscribeRemoteTrackVisitor visitor_; FullTrackName ftn("foo", "bar"); std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2); + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -1452,11 +1375,10 @@ session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram))); } -TEST_F(MoqtSessionTest, ForwardingPreferenceMismatch) { - MockRemoteTrackVisitor visitor_; - FullTrackName ftn("foo", "bar"); +TEST_F(MoqtSessionTest, DataStreamTypeMismatch) { + MockSubscribeRemoteTrackVisitor visitor_; std::string payload = "deadbeef"; - MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2); + MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), &visitor_); MoqtObject object = { /*track_alias=*/2, /*group_sequence=*/0, @@ -1468,7 +1390,8 @@ }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtDataParserVisitor> object_stream = - MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); + MoqtSessionPeer::CreateIncomingDataStream( + &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(1); EXPECT_CALL(mock_stream, GetStreamId()) @@ -1478,11 +1401,46 @@ 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66}; EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), - "Forwarding preference changes mid-track")) + "Received DATAGRAM for non-datagram track")) .Times(1); session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram))); } +TEST_F(MoqtSessionTest, StreamObjectOutOfWindow) { + MockSubscribeRemoteTrackVisitor visitor_; + std::string payload = "deadbeef"; + MoqtSubscribe subscribe = DefaultSubscribe(); + subscribe.start_group = 1; + MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe, &visitor_); + MoqtObject object = { + /*track_alias=*/2, + /*group_sequence=*/0, + /*object_sequence=*/0, + /*publisher_priority=*/0, + /*object_status=*/MoqtObjectStatus::kNormal, + /*subgroup_id=*/0, + /*payload_length=*/8, + }; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtDataParserVisitor> object_stream = + MoqtSessionPeer::CreateIncomingDataStream( + &session_, &mock_stream, MoqtDataStreamType::kStreamHeaderSubgroup); + EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(0); + object_stream->OnObjectMessage(object, payload, true); +} + +TEST_F(MoqtSessionTest, DatagramOutOfWindow) { + MockSubscribeRemoteTrackVisitor visitor_; + std::string payload = "deadbeef"; + MoqtSubscribe subscribe = DefaultSubscribe(); + subscribe.start_group = 1; + MoqtSessionPeer::CreateRemoteTrack(&session_, subscribe, &visitor_); + char datagram[] = {0x01, 0x02, 0x00, 0x00, 0x80, 0x08, 0x64, + 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66}; + EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _)).Times(0); + session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram))); +} + TEST_F(MoqtSessionTest, AnnounceToPublisher) { MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher); testing::MockFunction<void( @@ -1496,18 +1454,6 @@ TEST_F(MoqtSessionTest, SubscribeFromPublisher) { MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher); - MoqtSubscribe request = { - /*subscribe_id=*/1, - /*track_alias=*/2, - /*full_track_name=*/FullTrackName({"foo", "bar"}), - /*subscriber_priority=*/0x80, - /*group_order=*/std::nullopt, - /*start_group=*/0, - /*start_object=*/0, - /*end_group=*/std::nullopt, - /*end_object=*/std::nullopt, - /*parameters=*/MoqtSubscribeParameters(), - }; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); @@ -1517,7 +1463,7 @@ "Received SUBSCRIBE from publisher")) .Times(1); EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)).Times(1); - stream_input->OnSubscribeMessage(request); + stream_input->OnSubscribeMessage(DefaultSubscribe()); } TEST_F(MoqtSessionTest, AnnounceFromSubscriber) {
diff --git a/quiche/quic/moqt/moqt_track.cc b/quiche/quic/moqt/moqt_track.cc index ee85664..07f44d7 100644 --- a/quiche/quic/moqt/moqt_track.cc +++ b/quiche/quic/moqt/moqt_track.cc
@@ -4,18 +4,15 @@ #include "quiche/quic/moqt/moqt_track.h" -#include <cstdint> - #include "quiche/quic/moqt/moqt_messages.h" namespace moqt { -bool RemoteTrack::CheckForwardingPreference( - MoqtForwardingPreference preference) { - if (forwarding_preference_.has_value()) { - return forwarding_preference_.value() == preference; +bool RemoteTrack::CheckDataStreamType(MoqtDataStreamType type) { + if (data_stream_type_.has_value()) { + return data_stream_type_.value() == type; } - forwarding_preference_ = preference; + data_stream_type_ = type; return true; }
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index 0da047a..e47464a 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -6,13 +6,16 @@ #define QUICHE_QUIC_MOQT_MOQT_SUBSCRIPTION_H_ #include <cstdint> +#include <memory> #include <optional> #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_subscribe_windows.h" #include "quiche/common/quiche_callbacks.h" +#include "quiche/common/quiche_weak_ptr.h" namespace moqt { @@ -20,9 +23,65 @@ quiche::MultiUseCallback<void(uint64_t group_id, uint64_t object_id, quic::QuicTimeDelta delta_from_deadline)>; -// A track on the peer to which the session has subscribed. +// State common to both SUBSCRIBE and FETCH upstream. class RemoteTrack { public: + RemoteTrack(const FullTrackName& full_track_name, uint64_t id, + SubscribeWindow window) + : full_track_name_(full_track_name), + subscribe_id_(id), + window_(window), + weak_ptr_factory_(this) {} + virtual ~RemoteTrack() = default; + + FullTrackName full_track_name() const { return full_track_name_; } + // If FETCH_ERROR or SUBSCRIBE_ERROR arrives after OK or an object, it is a + // protocol violation. + virtual void OnObjectOrOk() { error_is_allowed_ = false; } + bool ErrorIsAllowed() const { return error_is_allowed_; } + + // When called while processing the first object in the track, sets the + // data stream type to the value indicated by the incoming encoding. + // Otherwise, returns true if the incoming object does not violate the rule + // that the type is consistent. + bool CheckDataStreamType(MoqtDataStreamType type); + + bool is_fetch() const { + return data_stream_type_.has_value() && + *data_stream_type_ == MoqtDataStreamType::kStreamHeaderFetch; + } + + uint64_t subscribe_id() const { return subscribe_id_; } + + // Is the object one that was requested? + bool InWindow(FullSequence sequence) const { + return window_.InWindow(sequence); + } + + void ChangeWindow(SubscribeWindow& window) { window_ = window; } + + quiche::QuicheWeakPtr<RemoteTrack> weak_ptr() { + return weak_ptr_factory_.Create(); + } + + private: + const FullTrackName full_track_name_; + const uint64_t subscribe_id_; + SubscribeWindow window_; + std::optional<MoqtDataStreamType> data_stream_type_; + // If false, an object or OK message has been received, so any ERROR message + // is a protocol violation. + bool error_is_allowed_ = true; + + // Must be last. + quiche::QuicheWeakPtrFactory<RemoteTrack> weak_ptr_factory_; +}; + +// A track on the peer to which the session has subscribed. +class SubscribeRemoteTrack : public RemoteTrack { + public: + // TODO: Separate this out (as it's used by the application) and give it a + // name like MoqtTrackSubscriber, class Visitor { public: virtual ~Visitor() = default; @@ -31,6 +90,7 @@ // automatically retry. virtual void OnReply( const FullTrackName& full_track_name, + std::optional<FullSequence> largest_id, std::optional<absl::string_view> error_reason_phrase) = 0; // Called when the subscription process is far enough that it is possible to // send OBJECT_ACK messages; provides a callback to do so. The callback is @@ -43,35 +103,34 @@ absl::string_view object, bool end_of_message) = 0; // TODO(martinduke): Add final sequence numbers }; - RemoteTrack(const FullTrackName& full_track_name, uint64_t track_alias, - Visitor* visitor) - : full_track_name_(full_track_name), - track_alias_(track_alias), - visitor_(visitor) {} + SubscribeRemoteTrack(const MoqtSubscribe& subscribe, Visitor* visitor) + : RemoteTrack(subscribe.full_track_name, subscribe.subscribe_id, + SubscribeWindow(subscribe.start_group.value_or(0), + subscribe.start_object.value_or(0), + subscribe.end_group.value_or(UINT64_MAX), + subscribe.end_object.value_or(UINT64_MAX))), + track_alias_(subscribe.track_alias), + visitor_(visitor), + subscribe_(std::make_unique<MoqtSubscribe>(subscribe)) {} - const FullTrackName& full_track_name() { return full_track_name_; } - + void OnObjectOrOk() override { + subscribe_.reset(); // No SUBSCRIBE_ERROR, no need to store this anymore. + RemoteTrack::OnObjectOrOk(); + } uint64_t track_alias() const { return track_alias_; } - Visitor* visitor() { return visitor_; } - - // When called while processing the first object in the track, sets the - // forwarding preference to the value indicated by the incoming encoding. - // Otherwise, returns true if the incoming object does not violate the rule - // that the preference is consistent. - bool CheckForwardingPreference(MoqtForwardingPreference preference); - - std::optional<MoqtForwardingPreference> forwarding_preference() const { - return forwarding_preference_; + MoqtSubscribe& GetSubscribe() { + return *subscribe_; + // This class will soon be destroyed, so there's no need to null the + // unique_ptr; } private: - // TODO: There is no accounting for the number of outstanding subscribes, - // because we can't match track names to individual subscribes. - const FullTrackName full_track_name_; const uint64_t track_alias_; Visitor* visitor_; - std::optional<MoqtForwardingPreference> forwarding_preference_; + // For convenience, store the subscribe message if it has to be re-sent with + // a new track alias. + std::unique_ptr<MoqtSubscribe> subscribe_; }; } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc index cd62dd8..a5717ce 100644 --- a/quiche/quic/moqt/moqt_track_test.cc +++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -4,6 +4,9 @@ #include "quiche/quic/moqt/moqt_track.h" +#include <optional> + +#include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" #include "quiche/quic/platform/api/quic_test.h" @@ -11,27 +14,54 @@ namespace test { -class RemoteTrackTest : public quic::test::QuicTest { +class SubscribeRemoteTrackTest : public quic::test::QuicTest { public: - RemoteTrackTest() - : track_(FullTrackName("foo", "bar"), /*track_alias=*/5, &visitor_) {} - RemoteTrack track_; - MockRemoteTrackVisitor visitor_; + SubscribeRemoteTrackTest() : track_(subscribe_, &visitor_) {} + + MockSubscribeRemoteTrackVisitor visitor_; + MoqtSubscribe subscribe_ = { + /*subscribe_id=*/1, + /*track_alias=*/2, + /*full_track_name=*/FullTrackName("foo", "bar"), + /*subscriber_priority=*/128, + /*group_order=*/std::nullopt, + /*ranges=*/2, + 0, + std::nullopt, + std::nullopt, + MoqtSubscribeParameters(), + }; + SubscribeRemoteTrack track_; }; -TEST_F(RemoteTrackTest, Queries) { +TEST_F(SubscribeRemoteTrackTest, Queries) { EXPECT_EQ(track_.full_track_name(), FullTrackName("foo", "bar")); - EXPECT_EQ(track_.track_alias(), 5); + EXPECT_EQ(track_.subscribe_id(), 1); + EXPECT_EQ(track_.track_alias(), 2); EXPECT_EQ(track_.visitor(), &visitor_); + EXPECT_FALSE(track_.is_fetch()); } -TEST_F(RemoteTrackTest, UpdateForwardingPreference) { +TEST_F(SubscribeRemoteTrackTest, UpdateDataStreamType) { EXPECT_TRUE( - track_.CheckForwardingPreference(MoqtForwardingPreference::kSubgroup)); + track_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderSubgroup)); EXPECT_TRUE( - track_.CheckForwardingPreference(MoqtForwardingPreference::kSubgroup)); - EXPECT_FALSE( - track_.CheckForwardingPreference(MoqtForwardingPreference::kDatagram)); + track_.CheckDataStreamType(MoqtDataStreamType::kStreamHeaderSubgroup)); + EXPECT_FALSE(track_.CheckDataStreamType(MoqtDataStreamType::kObjectDatagram)); +} + +TEST_F(SubscribeRemoteTrackTest, AllowError) { + EXPECT_TRUE(track_.ErrorIsAllowed()); + EXPECT_EQ(track_.GetSubscribe().subscribe_id, subscribe_.subscribe_id); + track_.OnObjectOrOk(); + EXPECT_FALSE(track_.ErrorIsAllowed()); +} + +TEST_F(SubscribeRemoteTrackTest, Windows) { + EXPECT_TRUE(track_.InWindow(FullSequence(2, 0))); + SubscribeWindow new_window(2, 1); + track_.ChangeWindow(new_window); + EXPECT_FALSE(track_.InWindow(FullSequence(2, 0))); } // TODO: Write test for GetStreamForSequence.
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index d41edd6..5426ef6 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -23,6 +23,13 @@ namespace moqt::test { +class MoqtDataParserPeer { + public: + static void SetType(MoqtDataParser* parser, MoqtDataStreamType type) { + parser->type_ = type; + } +}; + class MoqtSessionPeer { public: static constexpr webtransport::StreamId kControlStreamId = 4; @@ -43,9 +50,11 @@ } static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream( - MoqtSession* session, webtransport::Stream* stream) { + MoqtSession* session, webtransport::Stream* stream, + MoqtDataStreamType type) { auto new_stream = std::make_unique<MoqtSession::IncomingDataStream>(session, stream); + MoqtDataParserPeer::SetType(&new_stream->parser_, type); return new_stream; } @@ -62,18 +71,15 @@ return static_cast<MoqtSession::ControlStream*>(visitor); } - static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name, - RemoteTrack::Visitor* visitor, - uint64_t track_alias) { - session->remote_tracks_.try_emplace(track_alias, name, track_alias, - visitor); - session->remote_track_aliases_.try_emplace(name, track_alias); - } - - static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id, - MoqtSubscribe& subscribe, - RemoteTrack::Visitor* visitor) { - session->active_subscribes_[subscribe_id] = {subscribe, visitor}; + static void CreateRemoteTrack(MoqtSession* session, + const MoqtSubscribe& subscribe, + SubscribeRemoteTrack::Visitor* visitor) { + auto track = std::make_unique<SubscribeRemoteTrack>(subscribe, visitor); + session->upstream_by_id_.try_emplace(subscribe.subscribe_id, track.get()); + session->upstream_by_name_.try_emplace(subscribe.full_track_name, + track.get()); + session->subscribe_by_alias_.try_emplace(subscribe.track_alias, + std::move(track)); } static MoqtObjectListener* AddSubscription( @@ -109,8 +115,9 @@ session->peer_role_ = role; } - static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) { - return session->remote_tracks_.find(track_alias)->second; + static SubscribeRemoteTrack* remote_track(MoqtSession* session, + uint64_t track_alias) { + return session->RemoteTrackByAlias(track_alias); } static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
diff --git a/quiche/quic/moqt/tools/chat_client.cc b/quiche/quic/moqt/tools/chat_client.cc index 28c93c1..e98cf75 100644 --- a/quiche/quic/moqt/tools/chat_client.cc +++ b/quiche/quic/moqt/tools/chat_client.cc
@@ -109,6 +109,7 @@ void ChatClient::RemoteTrackVisitor::OnReply( const FullTrackName& full_track_name, + std::optional<FullSequence> /*largest_id*/, std::optional<absl::string_view> reason_phrase) { client_->subscribes_to_make_--; if (full_track_name == client_->chat_strings_->GetCatalogName()) { @@ -200,7 +201,7 @@ } void ChatClient::ProcessCatalog(absl::string_view object, - RemoteTrack::Visitor* visitor, + SubscribeRemoteTrack::Visitor* visitor, uint64_t group_sequence, uint64_t object_sequence) { std::string message(object); @@ -245,7 +246,7 @@ continue; } if (!add) { - // TODO: Unsubscribe from the user that's leaving + session_->Unsubscribe(chat_strings_->GetFullTrackNameFromUsername(user)); std::cout << user << "left the chat\n"; other_users_.erase(user); continue;
diff --git a/quiche/quic/moqt/tools/chat_client.h b/quiche/quic/moqt/tools/chat_client.h index 024b78d..e2c341f 100644 --- a/quiche/quic/moqt/tools/chat_client.h +++ b/quiche/quic/moqt/tools/chat_client.h
@@ -88,11 +88,13 @@ quic::QuicEventLoop* event_loop() { return event_loop_; } - class QUICHE_EXPORT RemoteTrackVisitor : public moqt::RemoteTrack::Visitor { + class QUICHE_EXPORT RemoteTrackVisitor + : public moqt::SubscribeRemoteTrack::Visitor { public: RemoteTrackVisitor(ChatClient* client) : client_(client) {} void OnReply(const moqt::FullTrackName& full_track_name, + std::optional<FullSequence> largest_id, std::optional<absl::string_view> reason_phrase) override; void OnCanAckObjects(MoqtObjectAckFunction) override {} @@ -127,7 +129,7 @@ // Objects from the same catalog group arrive on the same stream, and in // object sequence order. void ProcessCatalog(absl::string_view object, - moqt::RemoteTrack::Visitor* visitor, + moqt::SubscribeRemoteTrack::Visitor* visitor, uint64_t group_sequence, uint64_t object_sequence); struct ChatUser {
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc index c4f0327..3f31a47 100644 --- a/quiche/quic/moqt/tools/chat_server.cc +++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -72,6 +72,7 @@ void ChatServer::RemoteTrackVisitor::OnReply( const moqt::FullTrackName& full_track_name, + std::optional<FullSequence> /*largest_id*/, std::optional<absl::string_view> reason_phrase) { std::cout << "Subscription to user " << server_->strings().GetUsernameFromFullTrackName(full_track_name)
diff --git a/quiche/quic/moqt/tools/chat_server.h b/quiche/quic/moqt/tools/chat_server.h index 1af6e5c..0c883ed 100644 --- a/quiche/quic/moqt/tools/chat_server.h +++ b/quiche/quic/moqt/tools/chat_server.h
@@ -35,10 +35,11 @@ absl::string_view chat_id, absl::string_view output_file); ~ChatServer(); - class RemoteTrackVisitor : public RemoteTrack::Visitor { + class RemoteTrackVisitor : public SubscribeRemoteTrack::Visitor { public: explicit RemoteTrackVisitor(ChatServer* server); void OnReply(const moqt::FullTrackName& full_track_name, + std::optional<FullSequence> largest_id, std::optional<absl::string_view> reason_phrase) override; void OnCanAckObjects(MoqtObjectAckFunction) override {} void OnObjectFragment(
diff --git a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc index f31fff4..4b7aa2d 100644 --- a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc +++ b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
@@ -160,13 +160,14 @@ } private: - class NamespaceHandler : public RemoteTrack::Visitor { + class NamespaceHandler : public SubscribeRemoteTrack::Visitor { public: explicit NamespaceHandler(absl::string_view directory) : directory_(directory) {} void OnReply( const FullTrackName& full_track_name, + std::optional<FullSequence> /*largest_id*/, std::optional<absl::string_view> error_reason_phrase) override { if (error_reason_phrase.has_value()) { QUICHE_LOG(ERROR) << "Failed to subscribe to the peer track "
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h index 972fc0c..7031696 100644 --- a/quiche/quic/moqt/tools/moqt_mock_visitor.h +++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -77,10 +77,11 @@ FullTrackName track_name_; }; -class MockRemoteTrackVisitor : public RemoteTrack::Visitor { +class MockSubscribeRemoteTrackVisitor : public SubscribeRemoteTrack::Visitor { public: MOCK_METHOD(void, OnReply, (const FullTrackName& full_track_name, + std::optional<FullSequence> largest_id, std::optional<absl::string_view> error_reason_phrase), (override)); MOCK_METHOD(void, OnCanAckObjects, (MoqtObjectAckFunction ack_function),
diff --git a/quiche/quic/moqt/tools/moqt_simulator_bin.cc b/quiche/quic/moqt/tools/moqt_simulator_bin.cc index 39cf2c8..b2a30e6 100644 --- a/quiche/quic/moqt/tools/moqt_simulator_bin.cc +++ b/quiche/quic/moqt/tools/moqt_simulator_bin.cc
@@ -186,12 +186,13 @@ std::vector<QuicBandwidth> bitrate_history_; }; -class ObjectReceiver : public RemoteTrack::Visitor { +class ObjectReceiver : public SubscribeRemoteTrack::Visitor { public: explicit ObjectReceiver(const QuicClock* clock, QuicTimeDelta deadline) : clock_(clock), deadline_(deadline) {} void OnReply(const FullTrackName& full_track_name, + std::optional<FullSequence> /*largest_id*/, std::optional<absl::string_view> error_reason_phrase) override { QUICHE_CHECK(full_track_name == TrackName()); QUICHE_CHECK(!error_reason_phrase.has_value()) << *error_reason_phrase;