MoQT tracking of largest object ID in each group. Fixes a problem with generating SUBSCRIBE_DONE. PiperOrigin-RevId: 643339799
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index 9690e4d..fbc1474 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -162,7 +162,7 @@ EXPECT_TRUE(success); } -TEST_F(MoqtIntegrationTest, AnnounceSuccessSendDatainResponse) { +TEST_F(MoqtIntegrationTest, AnnounceSuccessSendDataInResponse) { EstablishSession(); // Set up the server to subscribe to "data" track for the namespace announce @@ -180,6 +180,10 @@ client_->session()->AddLocalTrack(FullTrackName{"test", "data"}, MoqtForwardingPreference::kGroup, &queue); queue.AddObject(MemSliceFromString("object data"), /*key=*/true); + bool received_subscribe_ok = false; + EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() { + received_subscribe_ok = true; + }); client_->session()->Announce( "test", [](absl::string_view, std::optional<MoqtAnnounceErrorReason>) {}); @@ -201,6 +205,7 @@ }); bool success = test_harness_.RunUntilWithDefaultTimeout( [&]() { return received_object; }); + EXPECT_TRUE(received_subscribe_ok); EXPECT_TRUE(success); }
diff --git a/quiche/quic/moqt/moqt_outgoing_queue.cc b/quiche/quic/moqt/moqt_outgoing_queue.cc index 84c2a85..f12732b 100644 --- a/quiche/quic/moqt/moqt_outgoing_queue.cc +++ b/quiche/quic/moqt/moqt_outgoing_queue.cc
@@ -52,16 +52,13 @@ MoqtOutgoingQueue_requires_kGroup, window.forwarding_preference() != MoqtForwardingPreference::kGroup) << "MoqtOutgoingQueue currently only supports kGroup."; - if (window.HasEnd()) { - // TODO: support this (this would require changing the logic for closing the - // stream below). - return absl::UnimplementedError("SUBSCRIBEs with an end are not supported"); - } return [this, &window]() { for (size_t i = 0; i < queue_.size(); ++i) { const uint64_t group_id = first_group_in_queue() + i; const Group& group = queue_[i]; - const bool is_last_group = (i == queue_.size() - 1); + const bool is_last_group = + ((i == queue_.size() - 1) || + !window.InWindow(FullSequence{group_id + 1, 0})); for (size_t j = 0; j < group.size(); ++j) { const FullSequence sequence{group_id, j}; if (!window.InWindow(sequence)) {
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 3c41df1..b3934cc 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/moqt/moqt_session.h" +#include <algorithm> #include <array> #include <cstdint> #include <memory> @@ -724,7 +725,6 @@ session_->used_track_aliases_.insert(message.track_alias); } FullSequence start; - std::optional<FullSequence> end; if (message.start_group.has_value()) { // The filter is AbsoluteStart or AbsoluteRange. QUIC_BUG_IF(quic_bug_invalid_subscribe, !message.start_object.has_value()) @@ -742,16 +742,17 @@ --start.object; } } - if (message.end_group.has_value()) { - end = FullSequence(*message.end_group, message.end_object.has_value() - ? *message.end_object - : UINT64_MAX); - } LocalTrack::Visitor::PublishPastObjectsCallback publish_past_objects; - SubscribeWindow window = - SubscribeWindow(message.subscribe_id, track.forwarding_preference(), - track.next_sequence(), start, end); if (start < track.next_sequence() && track.visitor() != nullptr) { + // Pull a copy of objects that have already been published. + FullSequence end_of_past_subscription{ + message.end_group.has_value() ? *message.end_group : UINT64_MAX, + message.end_object.has_value() ? *message.end_object : UINT64_MAX}; + end_of_past_subscription = + std::min(end_of_past_subscription, track.next_sequence()); + SubscribeWindow window = + SubscribeWindow(message.subscribe_id, track.forwarding_preference(), + track.next_sequence(), start, end_of_past_subscription); absl::StatusOr<LocalTrack::Visitor::PublishPastObjectsCallback> past_objects_available = track.visitor()->OnSubscribeForPast(window); if (!past_objects_available.ok()) { @@ -767,11 +768,14 @@ SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok)); QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for " << message.track_namespace << ":" << message.track_name; - if (end.has_value()) { - track.AddWindow(message.subscribe_id, start.group, start.object, end->group, - end->object); - } else { + if (!message.end_group.has_value()) { track.AddWindow(message.subscribe_id, start.group, start.object); + } else if (message.end_object.has_value()) { + track.AddWindow(message.subscribe_id, start.group, start.object, + *message.end_group, *message.end_object); + } else { + track.AddWindow(message.subscribe_id, start.group, start.object, + *message.end_group); } session_->local_track_by_subscribe_id_.emplace(message.subscribe_id, track.full_track_name());
diff --git a/quiche/quic/moqt/moqt_subscribe_windows.cc b/quiche/quic/moqt/moqt_subscribe_windows.cc index 3e07a0c..d8dcf60 100644 --- a/quiche/quic/moqt/moqt_subscribe_windows.cc +++ b/quiche/quic/moqt/moqt_subscribe_windows.cc
@@ -68,9 +68,6 @@ next_to_backfill_ = std::nullopt; } } - // TODO(martinduke): If the subscription ends in a full group with undefined - // object sequence, the only way to know to send SUBSCRIBE_DONE is by getting - // an upstream SUBSCRIBE_DONE. return (!next_to_backfill_.has_value() && end_.has_value() && *end_ <= sequence); }
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index eaa20c6..8306f89 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -9,6 +9,7 @@ #include <optional> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_messages.h" @@ -72,6 +73,24 @@ } void AddWindow(uint64_t subscribe_id, uint64_t start_group, + uint64_t start_object, uint64_t end_group) { + // The end object might be unknown. + auto it = max_object_ids_.find(end_group); + if (end_group >= next_sequence_.group) { + // Group is not fully published yet, so end object is unknown. + windows_.AddWindow(subscribe_id, next_sequence_, start_group, + start_object, end_group, UINT64_MAX); + return; + } + while (it == max_object_ids_.end()) { + // Find the latest group ID that actually had an object. + it = max_object_ids_.find(--end_group); + } + windows_.AddWindow(subscribe_id, next_sequence_, start_group, start_object, + end_group, it->second); + } + + void AddWindow(uint64_t subscribe_id, uint64_t start_group, uint64_t start_object, uint64_t end_group, uint64_t end_object) { windows_.AddWindow(subscribe_id, next_sequence_, start_group, start_object, @@ -88,8 +107,18 @@ // Updates next_sequence_ if |sequence| is larger. void SentSequence(FullSequence sequence) { + if (sequence.group > next_sequence_.group) { + max_object_ids_[next_sequence_.group] = next_sequence_.object - 1; + } + if (sequence.group < next_sequence_.group) { + // Late object from previous group. + auto it = max_object_ids_.find(sequence.group); + if (it == max_object_ids_.end() || it->second < sequence.object) { + max_object_ids_[sequence.group] = sequence.object; + } + } if (next_sequence_ <= sequence) { - next_sequence_ = {sequence.group, sequence.object + 1}; + next_sequence_ = sequence.next(); } } @@ -116,6 +145,9 @@ // By recording the highest observed sequence number, MoQT can interpret // relative sequence numbers in SUBSCRIBEs. FullSequence next_sequence_ = {0, 0}; + // The highest object ID observed for each group ID. An entry only exists for + // the group after a later group has been observed. + absl::flat_hash_map<uint64_t, uint64_t> max_object_ids_; Visitor* visitor_; };
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc index 36fab35..40a52bb 100644 --- a/quiche/quic/moqt/moqt_track_test.cc +++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -52,6 +52,38 @@ EXPECT_EQ(track_.GetWindow(0), nullptr); } +TEST_F(LocalTrackTest, GroupSubscriptionUsesMaxObjectId) { + // Populate max_object_ids_ + track_.SentSequence(FullSequence(0, 0)); + track_.SentSequence(FullSequence(1, 0)); + track_.SentSequence(FullSequence(1, 1)); + // Skip Group 2 + track_.SentSequence(FullSequence(3, 0)); + track_.SentSequence(FullSequence(3, 1)); + track_.SentSequence(FullSequence(3, 2)); + track_.SentSequence(FullSequence(3, 3)); + track_.SentSequence(FullSequence(4, 0)); + track_.SentSequence(FullSequence(4, 1)); + track_.SentSequence(FullSequence(4, 2)); + track_.SentSequence(FullSequence(4, 3)); + track_.SentSequence(FullSequence(4, 4)); + EXPECT_EQ(track_.next_sequence(), FullSequence(4, 5)); + track_.AddWindow(0, 1, 1, 3); + SubscribeWindow* window = track_.GetWindow(0); + EXPECT_TRUE(window->InWindow(FullSequence(3, 3))); + EXPECT_FALSE(window->InWindow(FullSequence(3, 4))); + // End on an empty group. + track_.AddWindow(1, 1, 1, 2); + window = track_.GetWindow(1); + EXPECT_TRUE(window->InWindow(FullSequence(1, 1))); + EXPECT_FALSE(window->InWindow(FullSequence(1, 2))); + // End on an group in progress. + track_.AddWindow(2, 1, 1, 4); + window = track_.GetWindow(2); + EXPECT_TRUE(window->InWindow(FullSequence(4, 9))); + EXPECT_FALSE(window->InWindow(FullSequence(5, 0))); +} + TEST_F(LocalTrackTest, ShouldSend) { track_.AddWindow(0, 4, 1); EXPECT_TRUE(track_.HasSubscriber());