FIN MoQT streams when directed by the TrackPublisher. Update MoqtLiveRelayQueue to send these signals. Reset streams when MoqtLiveRelayQueue purges a Group from cache. PiperOrigin-RevId: 697269570
diff --git a/quiche/quic/moqt/moqt_cached_object.cc b/quiche/quic/moqt/moqt_cached_object.cc index 332edb4..ef78a61 100644 --- a/quiche/quic/moqt/moqt_cached_object.cc +++ b/quiche/quic/moqt/moqt_cached_object.cc
@@ -20,6 +20,7 @@ object.payload->data(), object.payload->length(), [retained_pointer = object.payload](const char*) {}); } + result.fin_after_this = object.fin_after_this; return result; }
diff --git a/quiche/quic/moqt/moqt_cached_object.h b/quiche/quic/moqt/moqt_cached_object.h index be556e2..1fe524a 100644 --- a/quiche/quic/moqt/moqt_cached_object.h +++ b/quiche/quic/moqt/moqt_cached_object.h
@@ -21,6 +21,7 @@ MoqtObjectStatus status; MoqtPriority publisher_priority; std::shared_ptr<quiche::QuicheMemSlice> payload; + bool fin_after_this; // This is the last object before FIN. }; // Transforms a CachedObject into a PublishedObject.
diff --git a/quiche/quic/moqt/moqt_live_relay_queue.cc b/quiche/quic/moqt/moqt_live_relay_queue.cc index 3ca2532..a81139e 100644 --- a/quiche/quic/moqt/moqt_live_relay_queue.cc +++ b/quiche/quic/moqt/moqt_live_relay_queue.cc
@@ -8,6 +8,7 @@ #include <optional> #include <vector> +#include "absl/base/attributes.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_cached_object.h" @@ -22,11 +23,49 @@ namespace moqt { +bool MoqtLiveRelayQueue::AddFin(FullSequence sequence) { + switch (forwarding_preference_) { + case MoqtForwardingPreference::kDatagram: + return false; + case MoqtForwardingPreference::kTrack: + // TODO(martinduke): Support if it doesn't go away. + return false; + case MoqtForwardingPreference::kSubgroup: + break; + } + auto group_it = queue_.find(sequence.group); + if (group_it == queue_.end()) { + // Group does not exist. + return false; + } + Group& group = group_it->second; + auto subgroup_it = group.subgroups.find( + SubgroupPriority{publisher_priority_, sequence.subgroup}); + if (subgroup_it == group.subgroups.end()) { + // Subgroup does not exist. + return false; + } + if (subgroup_it->second.empty()) { + // Cannot FIN an empty subgroup. + return false; + } + if (subgroup_it->second.rbegin()->first != sequence.object) { + // The queue does not yet have the last object. + return false; + } + subgroup_it->second.rbegin()->second.fin_after_this = true; + for (MoqtObjectListener* listener : listeners_) { + listener->OnNewFinAvailable(sequence); + } + return true; +} + // TODO(martinduke): Unless Track Forwarding preference goes away, support it. bool MoqtLiveRelayQueue::AddRawObject(FullSequence sequence, MoqtObjectStatus status, MoqtPriority priority, - absl::string_view payload) { + absl::string_view payload, bool fin) { + bool last_object_in_stream = fin; if (queue_.size() == kMaxQueuedGroups) { if (queue_.begin()->first > sequence.group) { QUICHE_DLOG(INFO) << "Skipping object from group " << sequence.group @@ -35,6 +74,9 @@ } if (queue_.find(sequence.group) == queue_.end()) { // Erase the oldest group. + for (MoqtObjectListener* listener : listeners_) { + listener->OnGroupAbandoned(queue_.begin()->first); + } queue_.erase(queue_.begin()); } } @@ -74,17 +116,22 @@ } auto subgroup_it = group.subgroups.try_emplace( SubgroupPriority{priority, sequence.subgroup}); - auto& object_queue = subgroup_it.first->second; - if (!object_queue.empty()) { // Check if the new object is valid - auto last_object = object_queue.rbegin(); - if (last_object->first >= sequence.object) { - QUICHE_DLOG(INFO) << "Skipping object because it does not increase the " - << "object ID monotonically in the subgroup."; + auto& subgroup = subgroup_it.first->second; + if (!subgroup.empty()) { // Check if the new object is valid + CachedObject& last_object = subgroup.rbegin()->second; + if (last_object.fin_after_this) { + QUICHE_DLOG(INFO) << "Skipping object because it is after the end of the " + << "subgroup"; return false; } - if (last_object->second.status == MoqtObjectStatus::kEndOfSubgroup) { - QUICHE_DLOG(INFO) << "Skipping object because it is after the end of the " - << "subgroup."; + // If last_object has stream-ending status, it should have been caught by + // the fin_after_this check above. + QUICHE_DCHECK(last_object.status != MoqtObjectStatus::kEndOfSubgroup && + last_object.status != MoqtObjectStatus::kEndOfGroup && + last_object.status != MoqtObjectStatus::kEndOfTrack); + if (last_object.sequence.object >= sequence.object) { + QUICHE_DLOG(INFO) << "Skipping object because it does not increase the " + << "object ID monotonically in the subgroup."; return false; } } @@ -95,13 +142,17 @@ if (sequence.object >= group.next_object) { group.next_object = sequence.object + 1; } + // Anticipate stream FIN with most non-normal objects. switch (status) { case MoqtObjectStatus::kEndOfTrack: end_of_track_ = sequence; - break; - case MoqtObjectStatus::kEndOfGroup: + ABSL_FALLTHROUGH_INTENDED; case MoqtObjectStatus::kGroupDoesNotExist: + case MoqtObjectStatus::kEndOfGroup: group.complete = true; + ABSL_FALLTHROUGH_INTENDED; + case MoqtObjectStatus::kEndOfSubgroup: + last_object_in_stream = true; break; default: break; @@ -111,8 +162,8 @@ ? nullptr : std::make_shared<quiche::QuicheMemSlice>(quiche::QuicheBuffer::Copy( quiche::SimpleBufferAllocator::Get(), payload)); - object_queue.emplace(sequence.object, - CachedObject{sequence, status, priority, slice}); + subgroup.emplace(sequence.object, CachedObject{sequence, status, priority, + slice, last_object_in_stream}); for (MoqtObjectListener* listener : listeners_) { listener->OnNewObjectAvailable(sequence); } @@ -139,7 +190,7 @@ } // Find an object with ID of at least sequence.object. auto object_it = subgroup.lower_bound(sequence.object); - if (object_it == subgroup_it->second.end()) { + if (object_it == subgroup.end()) { // No object after the last one received. return std::nullopt; }
diff --git a/quiche/quic/moqt/moqt_live_relay_queue.h b/quiche/quic/moqt/moqt_live_relay_queue.h index 6b23e7a..97287da 100644 --- a/quiche/quic/moqt/moqt_live_relay_queue.h +++ b/quiche/quic/moqt/moqt_live_relay_queue.h
@@ -53,13 +53,21 @@ // occur. A false return value might result in a session error on the // inbound session, but this queue is the only place that retains enough state // to check. - bool AddObject(FullSequence sequence, MoqtObjectStatus status) { - return AddRawObject(sequence, status, publisher_priority_, ""); + bool AddObject(FullSequence sequence, MoqtObjectStatus status, + bool fin = false) { + return AddRawObject(sequence, status, publisher_priority_, "", fin); } - bool AddObject(FullSequence sequence, absl::string_view object) { + bool AddObject(FullSequence sequence, absl::string_view object, + bool fin = false) { return AddRawObject(sequence, MoqtObjectStatus::kNormal, - publisher_priority_, object); + publisher_priority_, object, fin); } + // Record a received FIN that did not come with the last object. + // If the forwarding preference is kDatagram or kTrack, |sequence| is ignored. + // Otherwise, |sequence| is used to determine which stream is being FINed. If + // the object ID does not match the last object ID in the stream, no action + // is taken. + bool AddFin(FullSequence sequence); // MoqtTrackPublisher implementation. const FullTrackName& GetTrackName() const override { return track_; } @@ -112,12 +120,12 @@ struct Group { uint64_t next_object = 0; - bool complete = false; + bool complete = false; // If true, kEndOfGroup has been received. absl::btree_map<SubgroupPriority, Subgroup> subgroups; }; bool AddRawObject(FullSequence sequence, MoqtObjectStatus status, - MoqtPriority priority, absl::string_view payload); + MoqtPriority priority, absl::string_view payload, bool fin); FullTrackName track_; MoqtForwardingPreference forwarding_preference_;
diff --git a/quiche/quic/moqt/moqt_live_relay_queue_test.cc b/quiche/quic/moqt/moqt_live_relay_queue_test.cc index 4e0bcbf..165b8fc 100644 --- a/quiche/quic/moqt/moqt_live_relay_queue_test.cc +++ b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
@@ -67,6 +67,8 @@ } } + MOCK_METHOD(void, OnNewFinAvailable, (FullSequence sequence)); + MOCK_METHOD(void, OnGroupAbandoned, (uint64_t group_id)); MOCK_METHOD(void, CloseStreamForGroup, (uint64_t group_id), ()); MOCK_METHOD(void, CloseStreamForSubgroup, (uint64_t group_id, uint64_t subgroup_id), ()); @@ -198,9 +200,11 @@ EXPECT_CALL(queue, PublishObject(2, 0, "e")); EXPECT_CALL(queue, PublishObject(2, 1, "f")); EXPECT_CALL(queue, CloseStreamForGroup(2)); + EXPECT_CALL(queue, OnGroupAbandoned(0)); EXPECT_CALL(queue, PublishObject(3, 0, "g")); EXPECT_CALL(queue, PublishObject(3, 1, "h")); EXPECT_CALL(queue, CloseStreamForGroup(3)); + EXPECT_CALL(queue, OnGroupAbandoned(1)); EXPECT_CALL(queue, PublishObject(4, 0, "i")); EXPECT_CALL(queue, PublishObject(4, 1, "j")); } @@ -237,9 +241,11 @@ EXPECT_CALL(queue, PublishObject(2, 0, "e")); EXPECT_CALL(queue, PublishObject(2, 1, "f")); EXPECT_CALL(queue, CloseStreamForGroup(2)); + EXPECT_CALL(queue, OnGroupAbandoned(0)); EXPECT_CALL(queue, PublishObject(3, 0, "g")); EXPECT_CALL(queue, PublishObject(3, 1, "h")); EXPECT_CALL(queue, CloseStreamForGroup(3)); + EXPECT_CALL(queue, OnGroupAbandoned(1)); EXPECT_CALL(queue, PublishObject(4, 0, "i")); EXPECT_CALL(queue, PublishObject(4, 1, "j")); @@ -286,9 +292,11 @@ EXPECT_CALL(queue, PublishObject(2, 0, "e")); EXPECT_CALL(queue, PublishObject(2, 1, "f")); EXPECT_CALL(queue, CloseStreamForGroup(2)); + EXPECT_CALL(queue, OnGroupAbandoned(0)); EXPECT_CALL(queue, PublishObject(3, 0, "g")); EXPECT_CALL(queue, PublishObject(3, 1, "h")); EXPECT_CALL(queue, CloseStreamForGroup(3)); + EXPECT_CALL(queue, OnGroupAbandoned(1)); EXPECT_CALL(queue, PublishObject(4, 0, "i")); EXPECT_CALL(queue, PublishObject(4, 1, "j")); } @@ -428,6 +436,36 @@ EXPECT_FALSE(queue.AddObject(FullSequence{0, 0, 2}, "b")); } +TEST(MoqtLiveRelayQueue, AddObjectWithFin) { + TestMoqtLiveRelayQueue queue; + { + testing::InSequence seq; + EXPECT_CALL(queue, PublishObject(0, 0, "a")); + } + EXPECT_TRUE(queue.AddObject(FullSequence{0, 0, 0}, "a", true)); + std::optional<PublishedObject> object = + queue.GetCachedObject(FullSequence{0, 0}); + ASSERT_TRUE(object.has_value()); + EXPECT_EQ(object->status, MoqtObjectStatus::kNormal); + EXPECT_TRUE(object->fin_after_this); +} + +TEST(MoqtLiveRelayQueue, LateFin) { + TestMoqtLiveRelayQueue queue; + { + testing::InSequence seq; + EXPECT_CALL(queue, PublishObject(0, 0, "a")); + } + EXPECT_TRUE(queue.AddObject(FullSequence{0, 0, 0}, "a", false)); + EXPECT_CALL(queue, OnNewFinAvailable(FullSequence{0, 0})); + EXPECT_TRUE(queue.AddFin(FullSequence{0, 0})); + std::optional<PublishedObject> object = + queue.GetCachedObject(FullSequence{0, 0}); + ASSERT_TRUE(object.has_value()); + EXPECT_EQ(object->status, MoqtObjectStatus::kNormal); + EXPECT_TRUE(object->fin_after_this); +} + } // namespace } // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_outgoing_queue.cc b/quiche/quic/moqt/moqt_outgoing_queue.cc index 51b2478..ef53465 100644 --- a/quiche/quic/moqt/moqt_outgoing_queue.cc +++ b/quiche/quic/moqt/moqt_outgoing_queue.cc
@@ -41,6 +41,9 @@ if (queue_.size() == kMaxQueuedGroups) { queue_.erase(queue_.begin()); + for (MoqtObjectListener* listener : listeners_) { + listener->OnGroupAbandoned(current_group_id_ - kMaxQueuedGroups + 1); + } } queue_.emplace_back(); ++current_group_id_; @@ -52,9 +55,11 @@ void MoqtOutgoingQueue::AddRawObject(MoqtObjectStatus status, quiche::QuicheMemSlice payload) { FullSequence sequence{current_group_id_, queue_.back().size()}; + bool fin = forwarding_preference_ == MoqtForwardingPreference::kSubgroup && + status == MoqtObjectStatus::kEndOfGroup; queue_.back().push_back(CachedObject{ sequence, status, publisher_priority_, - std::make_shared<quiche::QuicheMemSlice>(std::move(payload))}); + std::make_shared<quiche::QuicheMemSlice>(std::move(payload)), fin}); for (MoqtObjectListener* listener : listeners_) { listener->OnNewObjectAvailable(sequence); } @@ -73,12 +78,7 @@ const std::vector<CachedObject>& group = queue_[sequence.group - first_group_in_queue()]; if (sequence.object >= group.size()) { - if (sequence.group == current_group_id_) { - return std::nullopt; - } - return PublishedObject{FullSequence{sequence.group, sequence.object}, - MoqtObjectStatus::kObjectDoesNotExist, - publisher_priority_, quiche::QuicheMemSlice()}; + return std::nullopt; } QUICHE_DCHECK(sequence == group[sequence.object].sequence); return CachedObjectToPublishedObject(group[sequence.object]);
diff --git a/quiche/quic/moqt/moqt_outgoing_queue_test.cc b/quiche/quic/moqt/moqt_outgoing_queue_test.cc index 059f549..deb8b31 100644 --- a/quiche/quic/moqt/moqt_outgoing_queue_test.cc +++ b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
@@ -66,6 +66,8 @@ } } + MOCK_METHOD(void, OnNewFinAvailable, (FullSequence sequence)); + MOCK_METHOD(void, OnGroupAbandoned, (uint64_t group_id)); MOCK_METHOD(void, CloseStreamForGroup, (uint64_t group_id), ()); MOCK_METHOD(void, PublishObject, (uint64_t group_id, uint64_t object_id,
diff --git a/quiche/quic/moqt/moqt_publisher.h b/quiche/quic/moqt/moqt_publisher.h index be9c4f9..a0897f4 100644 --- a/quiche/quic/moqt/moqt_publisher.h +++ b/quiche/quic/moqt/moqt_publisher.h
@@ -26,6 +26,7 @@ MoqtObjectStatus status; MoqtPriority publisher_priority; quiche::QuicheMemSlice payload; + bool fin_after_this = false; }; // MoqtObjectListener is an interface for any entity that is listening for @@ -38,6 +39,12 @@ // available. The object payload itself may be retrieved via GetCachedObject // method of the associated track publisher. virtual void OnNewObjectAvailable(FullSequence sequence) = 0; + // Notifies that a pure FIN has arrived following |sequence|. + virtual void OnNewFinAvailable(FullSequence sequence) = 0; + + // No further object will be published for the given group, usually due to a + // timeout. The owner of the Listener may want to reset the relevant streams. + virtual void OnGroupAbandoned(uint64_t group_id) = 0; // Notifies that the Publisher is being destroyed, so no more objects are // coming.
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 7c52c68..0f8fad7 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -1061,6 +1061,39 @@ "Publisher is gone"); } +void MoqtSession::PublishedSubscription::OnNewFinAvailable( + FullSequence sequence) { + if (!window_.InWindow(sequence)) { + return; + } + std::optional<webtransport::StreamId> stream_id = + stream_map().GetStreamForSequence(sequence); + if (!stream_id.has_value()) { + return; + } + webtransport::Stream* raw_stream = + session_->session_->GetStreamById(*stream_id); + if (raw_stream == nullptr) { + return; + } + OutgoingDataStream* stream = + static_cast<OutgoingDataStream*>(raw_stream->visitor()); + stream->Fin(sequence); +} + +void MoqtSession::PublishedSubscription::OnGroupAbandoned(uint64_t group_id) { + std::vector<webtransport::StreamId> streams = + stream_map().GetStreamsForGroup(group_id); + for (webtransport::StreamId stream_id : streams) { + webtransport::Stream* raw_stream = + session_->session_->GetStreamById(stream_id); + if (raw_stream == nullptr) { + continue; + } + raw_stream->ResetWithUserCode(kResetCodeTimedOut); + } +} + void MoqtSession::PublishedSubscription::Backfill() { const FullSequence start = window_.start(); const FullSequence end = track_publisher_->GetLargestSequence(); @@ -1261,6 +1294,17 @@ } } +void MoqtSession::OutgoingDataStream::Fin(FullSequence last_object) { + if (next_object_ <= last_object) { + // There is still data to send, do nothing. + return; + } + // All data has already been sent; send a pure FIN. + bool success = stream_->SendFin(); + QUICHE_BUG_IF(OutgoingDataStream_fin_failed, !success) + << "Writing pure FIN failed."; +} + void MoqtSession::OutgoingDataStream::SendNextObject( PublishedSubscription& subscription, PublishedObject object) { QUICHE_DCHECK(next_object_ <= object.sequence); @@ -1291,7 +1335,6 @@ session_->framer_.SerializeObjectHeader( header, GetMessageTypeForForwardingPreference(forwarding_preference), !stream_header_written_); - bool fin = false; switch (forwarding_preference) { case MoqtForwardingPreference::kTrack: if (object.status == MoqtObjectStatus::kEndOfGroup || @@ -1301,20 +1344,10 @@ } else { next_object_.object = header.object_id + 1; } - fin = object.status == MoqtObjectStatus::kEndOfTrack || - !subscription.InWindow(next_object_); break; case MoqtForwardingPreference::kSubgroup: - // TODO(martinduke): EndOfGroup and EndOfTrack implies the ability to - // close other streams/subgroups. PublishedObject should contain a boolean - // if the stream is safe to close. next_object_.object = header.object_id + 1; - fin = object.status == MoqtObjectStatus::kEndOfTrack || - object.status == MoqtObjectStatus::kEndOfGroup || - object.status == MoqtObjectStatus::kEndOfSubgroup || - object.status == MoqtObjectStatus::kGroupDoesNotExist || - !subscription.InWindow(next_object_); break; case MoqtForwardingPreference::kDatagram: @@ -1327,7 +1360,7 @@ std::array<absl::string_view, 2> write_vector = { serialized_header.AsStringView(), object.payload.AsStringView()}; quiche::StreamWriteOptions options; - options.set_send_fin(fin); + options.set_send_fin(object.fin_after_this); absl::Status write_status = stream_->Writev(write_vector, options); if (!write_status.ok()) { QUICHE_BUG(MoqtSession_SendNextObject_write_failed) @@ -1339,7 +1372,7 @@ } QUIC_DVLOG(1) << "Stream " << stream_->GetStreamId() << " successfully wrote " - << object.sequence << ", fin = " << fin + << object.sequence << ", fin = " << object.fin_after_this << ", next: " << next_object_; stream_header_written_ = true;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index e24268b..04972b1 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -313,6 +313,8 @@ // This is only called for objects that have just arrived. void OnNewObjectAvailable(FullSequence sequence) override; void OnTrackPublisherGone() override; + void OnNewFinAvailable(FullSequence sequence) override; + void OnGroupAbandoned(uint64_t group_id) override; void ProcessObjectAck(const MoqtObjectAck& message) { if (monitoring_interface_ == nullptr) { return; @@ -402,6 +404,10 @@ // stream becomes write-blocked or closed. void SendObjects(PublishedSubscription& subscription); + // Sends a pure FIN on the stream, if the last object sent matches + // |last_object|. Otherwise, does nothing. + void Fin(FullSequence last_object); + // Recomputes the send order and updates it for the associated stream. void UpdateSendOrder(PublishedSubscription& subscription); @@ -412,8 +418,8 @@ // the stream and returns nullptr. PublishedSubscription* GetSubscriptionIfValid(); - // Actually sends an object on the stream; the object MUST be - // `next_object_`. + // Actually sends an object on the stream; the object MUST have object ID + // of at least `next_object_`. void SendNextObject(PublishedSubscription& subscription, PublishedObject object);
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 643364c..bc31674 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -1187,13 +1187,252 @@ }); EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] { return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127, - MemSliceFromString("deadbeef")}; + MemSliceFromString("deadbeef"), false}; }); EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] { return std::optional<PublishedObject>(); }); subscription->OnNewObjectAvailable(FullSequence(5, 0)); EXPECT_TRUE(correct_message); + EXPECT_FALSE(fin); +} + +TEST_F(MoqtSessionTest, FinDataStreamFromCache) { + FullTrackName ftn("foo", "bar"); + auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, + FullSequence(4, 2)); + MoqtObjectListener* subscription = + MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0); + + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + bool fin = false; + webtransport::test::MockStream mock_stream; + EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; }); + EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) + .WillOnce(Return(&mock_stream)); + std::unique_ptr<webtransport::StreamVisitor> stream_visitor; + EXPECT_CALL(mock_stream, SetVisitor(_)) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_visitor = std::move(visitor); + }); + EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { + return stream_visitor.get(); + }); + EXPECT_CALL(mock_stream, GetStreamId()) + .WillRepeatedly(Return(kOutgoingUniStreamId)); + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillRepeatedly(Return(&mock_stream)); + + // Verify first six message fields are sent correctly + bool correct_message = false; + const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00}; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = absl::StartsWith(data[0], kExpectedMessage); + fin = options.send_fin(); + return absl::OkStatus(); + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] { + return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127, + MemSliceFromString("deadbeef"), true}; + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] { + return std::optional<PublishedObject>(); + }); + subscription->OnNewObjectAvailable(FullSequence(5, 0)); + EXPECT_TRUE(correct_message); + EXPECT_TRUE(fin); +} + +TEST_F(MoqtSessionTest, GroupAbandoned) { + FullTrackName ftn("foo", "bar"); + auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, + FullSequence(4, 2)); + MoqtObjectListener* subscription = + MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0); + + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + bool fin = false; + webtransport::test::MockStream mock_stream; + EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; }); + EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) + .WillOnce(Return(&mock_stream)); + std::unique_ptr<webtransport::StreamVisitor> stream_visitor; + EXPECT_CALL(mock_stream, SetVisitor(_)) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_visitor = std::move(visitor); + }); + EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { + return stream_visitor.get(); + }); + EXPECT_CALL(mock_stream, GetStreamId()) + .WillRepeatedly(Return(kOutgoingUniStreamId)); + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillRepeatedly(Return(&mock_stream)); + + // Verify first six message fields are sent correctly + bool correct_message = false; + const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00}; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = absl::StartsWith(data[0], kExpectedMessage); + fin |= options.send_fin(); + return absl::OkStatus(); + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] { + return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127, + MemSliceFromString("deadbeef"), true}; + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] { + return std::optional<PublishedObject>(); + }); + subscription->OnNewObjectAvailable(FullSequence(5, 0)); + EXPECT_TRUE(correct_message); + EXPECT_TRUE(fin); + + EXPECT_CALL(mock_stream, ResetWithUserCode(kResetCodeTimedOut)); + subscription->OnGroupAbandoned(5); +} + +TEST_F(MoqtSessionTest, LateFinDataStream) { + FullTrackName ftn("foo", "bar"); + auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, + FullSequence(4, 2)); + MoqtObjectListener* subscription = + MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0); + + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + bool fin = false; + webtransport::test::MockStream mock_stream; + EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; }); + EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) + .WillOnce(Return(&mock_stream)); + std::unique_ptr<webtransport::StreamVisitor> stream_visitor; + EXPECT_CALL(mock_stream, SetVisitor(_)) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_visitor = std::move(visitor); + }); + EXPECT_CALL(mock_stream, visitor()).WillRepeatedly([&] { + return stream_visitor.get(); + }); + EXPECT_CALL(mock_stream, GetStreamId()) + .WillRepeatedly(Return(kOutgoingUniStreamId)); + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillRepeatedly(Return(&mock_stream)); + + // Verify first six message fields are sent correctly + bool correct_message = false; + const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00}; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = absl::StartsWith(data[0], kExpectedMessage); + fin = options.send_fin(); + return absl::OkStatus(); + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] { + return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127, + MemSliceFromString("deadbeef"), false}; + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] { + return std::optional<PublishedObject>(); + }); + subscription->OnNewObjectAvailable(FullSequence(5, 0)); + EXPECT_TRUE(correct_message); + EXPECT_FALSE(fin); + fin = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + EXPECT_TRUE(data.empty()); + fin = options.send_fin(); + return absl::OkStatus(); + }); + subscription->OnNewFinAvailable(FullSequence(5, 0)); +} + +TEST_F(MoqtSessionTest, SeparateFinForFutureObject) { + FullTrackName ftn("foo", "bar"); + auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, + FullSequence(4, 2)); + MoqtObjectListener* subscription = + MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0); + + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + bool fin = false; + webtransport::test::MockStream mock_stream; + EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; }); + EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) + .WillOnce(Return(&mock_stream)); + std::unique_ptr<webtransport::StreamVisitor> stream_visitor; + EXPECT_CALL(mock_stream, SetVisitor(_)) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) { + stream_visitor = std::move(visitor); + }); + EXPECT_CALL(mock_stream, visitor()).WillRepeatedly([&] { + return stream_visitor.get(); + }); + EXPECT_CALL(mock_stream, GetStreamId()) + .WillRepeatedly(Return(kOutgoingUniStreamId)); + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillRepeatedly(Return(&mock_stream)); + + // Verify first six message fields are sent correctly + bool correct_message = false; + const std::string kExpectedMessage = {0x04, 0x00, 0x02, 0x05, 0x00, 0x00}; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = absl::StartsWith(data[0], kExpectedMessage); + fin = options.send_fin(); + return absl::OkStatus(); + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] { + return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127, + MemSliceFromString("deadbeef"), false}; + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] { + return std::optional<PublishedObject>(); + }); + subscription->OnNewObjectAvailable(FullSequence(5, 0)); + EXPECT_FALSE(fin); + // Try to deliver (5,1), but fail. + EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return false; }); + EXPECT_CALL(*track, GetCachedObject(_)).Times(0); + EXPECT_CALL(mock_stream, Writev(_, _)).Times(0); + subscription->OnNewObjectAvailable(FullSequence(5, 1)); + // Notify that FIN arrived, but do nothing with it because (5, 1) isn't sent. + EXPECT_CALL(mock_stream, Writev(_, _)).Times(0); + subscription->OnNewFinAvailable(FullSequence(5, 1)); + + // Reopen the window. + correct_message = false; + // object id, payload length, status. + const std::string kExpectedMessage2 = {0x01, 0x00, 0x03}; + EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return true; }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] { + return PublishedObject{FullSequence(5, 1), MoqtObjectStatus::kEndOfGroup, + 127, MemSliceFromString(""), true}; + }); + EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 2))).WillRepeatedly([] { + return std::optional<PublishedObject>(); + }); + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = absl::StartsWith(data[0], kExpectedMessage2); + fin = options.send_fin(); + return absl::OkStatus(); + }); + stream_visitor->OnCanWrite(); + EXPECT_TRUE(correct_message); + EXPECT_TRUE(fin); } // TODO: Test operation with multiple streams.
diff --git a/quiche/quic/moqt/moqt_subscribe_windows.cc b/quiche/quic/moqt/moqt_subscribe_windows.cc index 4637bd1..6e9c8e9 100644 --- a/quiche/quic/moqt/moqt_subscribe_windows.cc +++ b/quiche/quic/moqt/moqt_subscribe_windows.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/moqt/moqt_subscribe_windows.h" +#include <cstdint> #include <optional> #include <vector> @@ -21,35 +22,61 @@ return (!end_.has_value() || seq <= *end_); } +ReducedSequenceIndex::ReducedSequenceIndex( + FullSequence sequence, MoqtForwardingPreference preference) { + switch (preference) { + case MoqtForwardingPreference::kTrack: + sequence_ = FullSequence(0, 0, 0); + break; + case MoqtForwardingPreference::kSubgroup: + sequence_ = FullSequence(sequence.group, sequence.subgroup, 0); + break; + case MoqtForwardingPreference::kDatagram: + sequence_ = FullSequence(sequence.group, sequence.object, 0); + return; + } +} + std::optional<webtransport::StreamId> SendStreamMap::GetStreamForSequence( FullSequence sequence) const { - ReducedSequenceIndex index(sequence, forwarding_preference_); - auto stream_it = send_streams_.find(index); - if (stream_it == send_streams_.end()) { + FullSequence index = + ReducedSequenceIndex(sequence, forwarding_preference_).sequence(); + auto group_it = send_streams_.find(index.group); + if (group_it == send_streams_.end()) { return std::nullopt; } - return stream_it->second; + auto subgroup_it = group_it->second.find(index.subgroup); + if (subgroup_it == group_it->second.end()) { + return std::nullopt; + } + return subgroup_it->second; } void SendStreamMap::AddStream(FullSequence sequence, webtransport::StreamId stream_id) { - ReducedSequenceIndex index(sequence, forwarding_preference_); - if (forwarding_preference_ == MoqtForwardingPreference::kDatagram) { - QUIC_BUG(quic_bug_moqt_draft_03_01) << "Adding a stream for datagram"; - return; - } - auto [stream_it, success] = send_streams_.emplace(index, stream_id); + FullSequence index = + ReducedSequenceIndex(sequence, forwarding_preference_).sequence(); + auto [it, result] = send_streams_.insert({index.group, Group()}); + auto [sg, success] = it->second.try_emplace(index.subgroup, stream_id); QUIC_BUG_IF(quic_bug_moqt_draft_03_02, !success) << "Stream already added"; } void SendStreamMap::RemoveStream(FullSequence sequence, webtransport::StreamId stream_id) { - ReducedSequenceIndex index(sequence, forwarding_preference_); - QUICHE_DCHECK(send_streams_.contains(index) && - send_streams_.find(index)->second == stream_id) - << "Requested to remove a stream ID that does not match the one in the " - "map"; - send_streams_.erase(index); + FullSequence index = + ReducedSequenceIndex(sequence, forwarding_preference_).sequence(); + auto group_it = send_streams_.find(index.group); + if (group_it == send_streams_.end()) { + QUICHE_NOTREACHED(); + return; + } + auto subgroup_it = group_it->second.find(index.subgroup); + if (subgroup_it == group_it->second.end() || + subgroup_it->second != stream_id) { + QUICHE_NOTREACHED(); + return; + } + group_it->second.erase(subgroup_it); } bool SubscribeWindow::UpdateStartEnd(FullSequence start, @@ -66,25 +93,25 @@ return true; } -ReducedSequenceIndex::ReducedSequenceIndex( - FullSequence sequence, MoqtForwardingPreference preference) { - switch (preference) { - case MoqtForwardingPreference::kTrack: - sequence_ = FullSequence(0, 0); - break; - case MoqtForwardingPreference::kSubgroup: - sequence_ = FullSequence(sequence.group, 0); - break; - case MoqtForwardingPreference::kDatagram: - sequence_ = sequence; - return; - } -} - std::vector<webtransport::StreamId> SendStreamMap::GetAllStreams() const { std::vector<webtransport::StreamId> ids; - for (const auto& [index, id] : send_streams_) { - ids.push_back(id); + for (const auto& [group, subgroup_map] : send_streams_) { + for (const auto& [subgroup, stream_id] : subgroup_map) { + ids.push_back(stream_id); + } + } + return ids; +} + +std::vector<webtransport::StreamId> SendStreamMap::GetStreamsForGroup( + uint64_t group_id) const { + std::vector<webtransport::StreamId> ids; + auto it = send_streams_.find(group_id); + if (it == send_streams_.end()) { + return ids; + } + for (const auto& [subgroup, stream_id] : it->second) { + ids.push_back(stream_id); } return ids; }
diff --git a/quiche/quic/moqt/moqt_subscribe_windows.h b/quiche/quic/moqt/moqt_subscribe_windows.h index b850f3e..d6f8ea4 100644 --- a/quiche/quic/moqt/moqt_subscribe_windows.h +++ b/quiche/quic/moqt/moqt_subscribe_windows.h
@@ -9,7 +9,7 @@ #include <optional> #include <vector> -#include "absl/container/flat_hash_map.h" +#include "absl/container/btree_map.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/web_transport/web_transport.h" @@ -36,7 +36,7 @@ : start_(start), end_(end) {} bool InWindow(const FullSequence& seq) const; - bool HasEnd() const { return end_.has_value(); } + const std::optional<FullSequence>& end() const { return end_; } FullSequence start() const { return start_; } // Updates the subscription window. Returns true if the update is valid (in @@ -44,6 +44,7 @@ bool UpdateStartEnd(FullSequence start, std::optional<FullSequence> end); private: + // The subgroups in these sequences have no meaning. FullSequence start_; std::optional<FullSequence> end_; }; @@ -61,6 +62,7 @@ bool operator!=(const ReducedSequenceIndex& other) const { return sequence_ != other.sequence_; } + FullSequence sequence() { return sequence_; } template <typename H> friend H AbslHashValue(H h, const ReducedSequenceIndex& m) { @@ -82,10 +84,12 @@ void AddStream(FullSequence sequence, webtransport::StreamId stream_id); void RemoveStream(FullSequence sequence, webtransport::StreamId stream_id); std::vector<webtransport::StreamId> GetAllStreams() const; + std::vector<webtransport::StreamId> GetStreamsForGroup( + uint64_t group_id) const; private: - absl::flat_hash_map<ReducedSequenceIndex, webtransport::StreamId> - send_streams_; + using Group = absl::btree_map<uint64_t, webtransport::StreamId>; + absl::btree_map<uint64_t, Group> send_streams_; MoqtForwardingPreference forwarding_preference_; };
diff --git a/quiche/quic/moqt/moqt_subscribe_windows_test.cc b/quiche/quic/moqt/moqt_subscribe_windows_test.cc index ba8687e..1fcc43d 100644 --- a/quiche/quic/moqt/moqt_subscribe_windows_test.cc +++ b/quiche/quic/moqt/moqt_subscribe_windows_test.cc
@@ -59,8 +59,13 @@ TEST_F(SubscribeWindowTest, AddQueryRemoveStreamIdDatagram) { SendStreamMap stream_map(MoqtForwardingPreference::kDatagram); - EXPECT_QUIC_BUG(stream_map.AddStream(FullSequence{4, 0}, 2), - "Adding a stream for datagram"); + stream_map.AddStream(FullSequence{4, 0}, 2), + stream_map.AddStream(FullSequence{4, 1}, 6); + EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 0)), 2); + EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 1)), 6); + EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 2)), std::nullopt); + stream_map.RemoveStream(FullSequence{4, 1}, 6); + EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 1)), std::nullopt); } TEST_F(SubscribeWindowTest, UpdateStartEnd) {