FIN MoQT streams when directed by the TrackPublisher.
Update MoqtLiveRelayQueue to send these signals.
Reset streams when MoqtLiveRelayQueue purges a Group from cache.
PiperOrigin-RevId: 697269570
diff --git a/quiche/quic/moqt/moqt_cached_object.cc b/quiche/quic/moqt/moqt_cached_object.cc
index 332edb4..ef78a61 100644
--- a/quiche/quic/moqt/moqt_cached_object.cc
+++ b/quiche/quic/moqt/moqt_cached_object.cc
@@ -20,6 +20,7 @@
object.payload->data(), object.payload->length(),
[retained_pointer = object.payload](const char*) {});
}
+ result.fin_after_this = object.fin_after_this;
return result;
}
diff --git a/quiche/quic/moqt/moqt_cached_object.h b/quiche/quic/moqt/moqt_cached_object.h
index be556e2..1fe524a 100644
--- a/quiche/quic/moqt/moqt_cached_object.h
+++ b/quiche/quic/moqt/moqt_cached_object.h
@@ -21,6 +21,7 @@
MoqtObjectStatus status;
MoqtPriority publisher_priority;
std::shared_ptr<quiche::QuicheMemSlice> payload;
+ bool fin_after_this; // This is the last object before FIN.
};
// Transforms a CachedObject into a PublishedObject.
diff --git a/quiche/quic/moqt/moqt_live_relay_queue.cc b/quiche/quic/moqt/moqt_live_relay_queue.cc
index 3ca2532..a81139e 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue.cc
+++ b/quiche/quic/moqt/moqt_live_relay_queue.cc
@@ -8,6 +8,7 @@
#include <optional>
#include <vector>
+#include "absl/base/attributes.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "quiche/quic/moqt/moqt_cached_object.h"
@@ -22,11 +23,49 @@
namespace moqt {
+bool MoqtLiveRelayQueue::AddFin(FullSequence sequence) {
+ switch (forwarding_preference_) {
+ case MoqtForwardingPreference::kDatagram:
+ return false;
+ case MoqtForwardingPreference::kTrack:
+ // TODO(martinduke): Support if it doesn't go away.
+ return false;
+ case MoqtForwardingPreference::kSubgroup:
+ break;
+ }
+ auto group_it = queue_.find(sequence.group);
+ if (group_it == queue_.end()) {
+ // Group does not exist.
+ return false;
+ }
+ Group& group = group_it->second;
+ auto subgroup_it = group.subgroups.find(
+ SubgroupPriority{publisher_priority_, sequence.subgroup});
+ if (subgroup_it == group.subgroups.end()) {
+ // Subgroup does not exist.
+ return false;
+ }
+ if (subgroup_it->second.empty()) {
+ // Cannot FIN an empty subgroup.
+ return false;
+ }
+ if (subgroup_it->second.rbegin()->first != sequence.object) {
+ // The queue does not yet have the last object.
+ return false;
+ }
+ subgroup_it->second.rbegin()->second.fin_after_this = true;
+ for (MoqtObjectListener* listener : listeners_) {
+ listener->OnNewFinAvailable(sequence);
+ }
+ return true;
+}
+
// TODO(martinduke): Unless Track Forwarding preference goes away, support it.
bool MoqtLiveRelayQueue::AddRawObject(FullSequence sequence,
MoqtObjectStatus status,
MoqtPriority priority,
- absl::string_view payload) {
+ absl::string_view payload, bool fin) {
+ bool last_object_in_stream = fin;
if (queue_.size() == kMaxQueuedGroups) {
if (queue_.begin()->first > sequence.group) {
QUICHE_DLOG(INFO) << "Skipping object from group " << sequence.group
@@ -35,6 +74,9 @@
}
if (queue_.find(sequence.group) == queue_.end()) {
// Erase the oldest group.
+ for (MoqtObjectListener* listener : listeners_) {
+ listener->OnGroupAbandoned(queue_.begin()->first);
+ }
queue_.erase(queue_.begin());
}
}
@@ -74,17 +116,22 @@
}
auto subgroup_it = group.subgroups.try_emplace(
SubgroupPriority{priority, sequence.subgroup});
- auto& object_queue = subgroup_it.first->second;
- if (!object_queue.empty()) { // Check if the new object is valid
- auto last_object = object_queue.rbegin();
- if (last_object->first >= sequence.object) {
- QUICHE_DLOG(INFO) << "Skipping object because it does not increase the "
- << "object ID monotonically in the subgroup.";
+ auto& subgroup = subgroup_it.first->second;
+ if (!subgroup.empty()) { // Check if the new object is valid
+ CachedObject& last_object = subgroup.rbegin()->second;
+ if (last_object.fin_after_this) {
+ QUICHE_DLOG(INFO) << "Skipping object because it is after the end of the "
+ << "subgroup";
return false;
}
- if (last_object->second.status == MoqtObjectStatus::kEndOfSubgroup) {
- QUICHE_DLOG(INFO) << "Skipping object because it is after the end of the "
- << "subgroup.";
+ // If last_object has stream-ending status, it should have been caught by
+ // the fin_after_this check above.
+ QUICHE_DCHECK(last_object.status != MoqtObjectStatus::kEndOfSubgroup &&
+ last_object.status != MoqtObjectStatus::kEndOfGroup &&
+ last_object.status != MoqtObjectStatus::kEndOfTrack);
+ if (last_object.sequence.object >= sequence.object) {
+ QUICHE_DLOG(INFO) << "Skipping object because it does not increase the "
+ << "object ID monotonically in the subgroup.";
return false;
}
}
@@ -95,13 +142,17 @@
if (sequence.object >= group.next_object) {
group.next_object = sequence.object + 1;
}
+ // Anticipate stream FIN with most non-normal objects.
switch (status) {
case MoqtObjectStatus::kEndOfTrack:
end_of_track_ = sequence;
- break;
- case MoqtObjectStatus::kEndOfGroup:
+ ABSL_FALLTHROUGH_INTENDED;
case MoqtObjectStatus::kGroupDoesNotExist:
+ case MoqtObjectStatus::kEndOfGroup:
group.complete = true;
+ ABSL_FALLTHROUGH_INTENDED;
+ case MoqtObjectStatus::kEndOfSubgroup:
+ last_object_in_stream = true;
break;
default:
break;
@@ -111,8 +162,8 @@
? nullptr
: std::make_shared<quiche::QuicheMemSlice>(quiche::QuicheBuffer::Copy(
quiche::SimpleBufferAllocator::Get(), payload));
- object_queue.emplace(sequence.object,
- CachedObject{sequence, status, priority, slice});
+ subgroup.emplace(sequence.object, CachedObject{sequence, status, priority,
+ slice, last_object_in_stream});
for (MoqtObjectListener* listener : listeners_) {
listener->OnNewObjectAvailable(sequence);
}
@@ -139,7 +190,7 @@
}
// Find an object with ID of at least sequence.object.
auto object_it = subgroup.lower_bound(sequence.object);
- if (object_it == subgroup_it->second.end()) {
+ if (object_it == subgroup.end()) {
// No object after the last one received.
return std::nullopt;
}
diff --git a/quiche/quic/moqt/moqt_live_relay_queue.h b/quiche/quic/moqt/moqt_live_relay_queue.h
index 6b23e7a..97287da 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue.h
+++ b/quiche/quic/moqt/moqt_live_relay_queue.h
@@ -53,13 +53,21 @@
// occur. A false return value might result in a session error on the
// inbound session, but this queue is the only place that retains enough state
// to check.
- bool AddObject(FullSequence sequence, MoqtObjectStatus status) {
- return AddRawObject(sequence, status, publisher_priority_, "");
+ bool AddObject(FullSequence sequence, MoqtObjectStatus status,
+ bool fin = false) {
+ return AddRawObject(sequence, status, publisher_priority_, "", fin);
}
- bool AddObject(FullSequence sequence, absl::string_view object) {
+ bool AddObject(FullSequence sequence, absl::string_view object,
+ bool fin = false) {
return AddRawObject(sequence, MoqtObjectStatus::kNormal,
- publisher_priority_, object);
+ publisher_priority_, object, fin);
}
+ // Record a received FIN that did not come with the last object.
+ // If the forwarding preference is kDatagram or kTrack, |sequence| is ignored.
+ // Otherwise, |sequence| is used to determine which stream is being FINed. If
+ // the object ID does not match the last object ID in the stream, no action
+ // is taken.
+ bool AddFin(FullSequence sequence);
// MoqtTrackPublisher implementation.
const FullTrackName& GetTrackName() const override { return track_; }
@@ -112,12 +120,12 @@
struct Group {
uint64_t next_object = 0;
- bool complete = false;
+ bool complete = false; // If true, kEndOfGroup has been received.
absl::btree_map<SubgroupPriority, Subgroup> subgroups;
};
bool AddRawObject(FullSequence sequence, MoqtObjectStatus status,
- MoqtPriority priority, absl::string_view payload);
+ MoqtPriority priority, absl::string_view payload, bool fin);
FullTrackName track_;
MoqtForwardingPreference forwarding_preference_;
diff --git a/quiche/quic/moqt/moqt_live_relay_queue_test.cc b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
index 4e0bcbf..165b8fc 100644
--- a/quiche/quic/moqt/moqt_live_relay_queue_test.cc
+++ b/quiche/quic/moqt/moqt_live_relay_queue_test.cc
@@ -67,6 +67,8 @@
}
}
+ MOCK_METHOD(void, OnNewFinAvailable, (FullSequence sequence));
+ MOCK_METHOD(void, OnGroupAbandoned, (uint64_t group_id));
MOCK_METHOD(void, CloseStreamForGroup, (uint64_t group_id), ());
MOCK_METHOD(void, CloseStreamForSubgroup,
(uint64_t group_id, uint64_t subgroup_id), ());
@@ -198,9 +200,11 @@
EXPECT_CALL(queue, PublishObject(2, 0, "e"));
EXPECT_CALL(queue, PublishObject(2, 1, "f"));
EXPECT_CALL(queue, CloseStreamForGroup(2));
+ EXPECT_CALL(queue, OnGroupAbandoned(0));
EXPECT_CALL(queue, PublishObject(3, 0, "g"));
EXPECT_CALL(queue, PublishObject(3, 1, "h"));
EXPECT_CALL(queue, CloseStreamForGroup(3));
+ EXPECT_CALL(queue, OnGroupAbandoned(1));
EXPECT_CALL(queue, PublishObject(4, 0, "i"));
EXPECT_CALL(queue, PublishObject(4, 1, "j"));
}
@@ -237,9 +241,11 @@
EXPECT_CALL(queue, PublishObject(2, 0, "e"));
EXPECT_CALL(queue, PublishObject(2, 1, "f"));
EXPECT_CALL(queue, CloseStreamForGroup(2));
+ EXPECT_CALL(queue, OnGroupAbandoned(0));
EXPECT_CALL(queue, PublishObject(3, 0, "g"));
EXPECT_CALL(queue, PublishObject(3, 1, "h"));
EXPECT_CALL(queue, CloseStreamForGroup(3));
+ EXPECT_CALL(queue, OnGroupAbandoned(1));
EXPECT_CALL(queue, PublishObject(4, 0, "i"));
EXPECT_CALL(queue, PublishObject(4, 1, "j"));
@@ -286,9 +292,11 @@
EXPECT_CALL(queue, PublishObject(2, 0, "e"));
EXPECT_CALL(queue, PublishObject(2, 1, "f"));
EXPECT_CALL(queue, CloseStreamForGroup(2));
+ EXPECT_CALL(queue, OnGroupAbandoned(0));
EXPECT_CALL(queue, PublishObject(3, 0, "g"));
EXPECT_CALL(queue, PublishObject(3, 1, "h"));
EXPECT_CALL(queue, CloseStreamForGroup(3));
+ EXPECT_CALL(queue, OnGroupAbandoned(1));
EXPECT_CALL(queue, PublishObject(4, 0, "i"));
EXPECT_CALL(queue, PublishObject(4, 1, "j"));
}
@@ -428,6 +436,36 @@
EXPECT_FALSE(queue.AddObject(FullSequence{0, 0, 2}, "b"));
}
+TEST(MoqtLiveRelayQueue, AddObjectWithFin) {
+ TestMoqtLiveRelayQueue queue;
+ {
+ testing::InSequence seq;
+ EXPECT_CALL(queue, PublishObject(0, 0, "a"));
+ }
+ EXPECT_TRUE(queue.AddObject(FullSequence{0, 0, 0}, "a", true));
+ std::optional<PublishedObject> object =
+ queue.GetCachedObject(FullSequence{0, 0});
+ ASSERT_TRUE(object.has_value());
+ EXPECT_EQ(object->status, MoqtObjectStatus::kNormal);
+ EXPECT_TRUE(object->fin_after_this);
+}
+
+TEST(MoqtLiveRelayQueue, LateFin) {
+ TestMoqtLiveRelayQueue queue;
+ {
+ testing::InSequence seq;
+ EXPECT_CALL(queue, PublishObject(0, 0, "a"));
+ }
+ EXPECT_TRUE(queue.AddObject(FullSequence{0, 0, 0}, "a", false));
+ EXPECT_CALL(queue, OnNewFinAvailable(FullSequence{0, 0}));
+ EXPECT_TRUE(queue.AddFin(FullSequence{0, 0}));
+ std::optional<PublishedObject> object =
+ queue.GetCachedObject(FullSequence{0, 0});
+ ASSERT_TRUE(object.has_value());
+ EXPECT_EQ(object->status, MoqtObjectStatus::kNormal);
+ EXPECT_TRUE(object->fin_after_this);
+}
+
} // namespace
} // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_outgoing_queue.cc b/quiche/quic/moqt/moqt_outgoing_queue.cc
index 51b2478..ef53465 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue.cc
+++ b/quiche/quic/moqt/moqt_outgoing_queue.cc
@@ -41,6 +41,9 @@
if (queue_.size() == kMaxQueuedGroups) {
queue_.erase(queue_.begin());
+ for (MoqtObjectListener* listener : listeners_) {
+ listener->OnGroupAbandoned(current_group_id_ - kMaxQueuedGroups + 1);
+ }
}
queue_.emplace_back();
++current_group_id_;
@@ -52,9 +55,11 @@
void MoqtOutgoingQueue::AddRawObject(MoqtObjectStatus status,
quiche::QuicheMemSlice payload) {
FullSequence sequence{current_group_id_, queue_.back().size()};
+ bool fin = forwarding_preference_ == MoqtForwardingPreference::kSubgroup &&
+ status == MoqtObjectStatus::kEndOfGroup;
queue_.back().push_back(CachedObject{
sequence, status, publisher_priority_,
- std::make_shared<quiche::QuicheMemSlice>(std::move(payload))});
+ std::make_shared<quiche::QuicheMemSlice>(std::move(payload)), fin});
for (MoqtObjectListener* listener : listeners_) {
listener->OnNewObjectAvailable(sequence);
}
@@ -73,12 +78,7 @@
const std::vector<CachedObject>& group =
queue_[sequence.group - first_group_in_queue()];
if (sequence.object >= group.size()) {
- if (sequence.group == current_group_id_) {
- return std::nullopt;
- }
- return PublishedObject{FullSequence{sequence.group, sequence.object},
- MoqtObjectStatus::kObjectDoesNotExist,
- publisher_priority_, quiche::QuicheMemSlice()};
+ return std::nullopt;
}
QUICHE_DCHECK(sequence == group[sequence.object].sequence);
return CachedObjectToPublishedObject(group[sequence.object]);
diff --git a/quiche/quic/moqt/moqt_outgoing_queue_test.cc b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
index 059f549..deb8b31 100644
--- a/quiche/quic/moqt/moqt_outgoing_queue_test.cc
+++ b/quiche/quic/moqt/moqt_outgoing_queue_test.cc
@@ -66,6 +66,8 @@
}
}
+ MOCK_METHOD(void, OnNewFinAvailable, (FullSequence sequence));
+ MOCK_METHOD(void, OnGroupAbandoned, (uint64_t group_id));
MOCK_METHOD(void, CloseStreamForGroup, (uint64_t group_id), ());
MOCK_METHOD(void, PublishObject,
(uint64_t group_id, uint64_t object_id,
diff --git a/quiche/quic/moqt/moqt_publisher.h b/quiche/quic/moqt/moqt_publisher.h
index be9c4f9..a0897f4 100644
--- a/quiche/quic/moqt/moqt_publisher.h
+++ b/quiche/quic/moqt/moqt_publisher.h
@@ -26,6 +26,7 @@
MoqtObjectStatus status;
MoqtPriority publisher_priority;
quiche::QuicheMemSlice payload;
+ bool fin_after_this = false;
};
// MoqtObjectListener is an interface for any entity that is listening for
@@ -38,6 +39,12 @@
// available. The object payload itself may be retrieved via GetCachedObject
// method of the associated track publisher.
virtual void OnNewObjectAvailable(FullSequence sequence) = 0;
+ // Notifies that a pure FIN has arrived following |sequence|.
+ virtual void OnNewFinAvailable(FullSequence sequence) = 0;
+
+ // No further object will be published for the given group, usually due to a
+ // timeout. The owner of the Listener may want to reset the relevant streams.
+ virtual void OnGroupAbandoned(uint64_t group_id) = 0;
// Notifies that the Publisher is being destroyed, so no more objects are
// coming.
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 7c52c68..0f8fad7 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -1061,6 +1061,39 @@
"Publisher is gone");
}
+void MoqtSession::PublishedSubscription::OnNewFinAvailable(
+ FullSequence sequence) {
+ if (!window_.InWindow(sequence)) {
+ return;
+ }
+ std::optional<webtransport::StreamId> stream_id =
+ stream_map().GetStreamForSequence(sequence);
+ if (!stream_id.has_value()) {
+ return;
+ }
+ webtransport::Stream* raw_stream =
+ session_->session_->GetStreamById(*stream_id);
+ if (raw_stream == nullptr) {
+ return;
+ }
+ OutgoingDataStream* stream =
+ static_cast<OutgoingDataStream*>(raw_stream->visitor());
+ stream->Fin(sequence);
+}
+
+void MoqtSession::PublishedSubscription::OnGroupAbandoned(uint64_t group_id) {
+ std::vector<webtransport::StreamId> streams =
+ stream_map().GetStreamsForGroup(group_id);
+ for (webtransport::StreamId stream_id : streams) {
+ webtransport::Stream* raw_stream =
+ session_->session_->GetStreamById(stream_id);
+ if (raw_stream == nullptr) {
+ continue;
+ }
+ raw_stream->ResetWithUserCode(kResetCodeTimedOut);
+ }
+}
+
void MoqtSession::PublishedSubscription::Backfill() {
const FullSequence start = window_.start();
const FullSequence end = track_publisher_->GetLargestSequence();
@@ -1261,6 +1294,17 @@
}
}
+void MoqtSession::OutgoingDataStream::Fin(FullSequence last_object) {
+ if (next_object_ <= last_object) {
+ // There is still data to send, do nothing.
+ return;
+ }
+ // All data has already been sent; send a pure FIN.
+ bool success = stream_->SendFin();
+ QUICHE_BUG_IF(OutgoingDataStream_fin_failed, !success)
+ << "Writing pure FIN failed.";
+}
+
void MoqtSession::OutgoingDataStream::SendNextObject(
PublishedSubscription& subscription, PublishedObject object) {
QUICHE_DCHECK(next_object_ <= object.sequence);
@@ -1291,7 +1335,6 @@
session_->framer_.SerializeObjectHeader(
header, GetMessageTypeForForwardingPreference(forwarding_preference),
!stream_header_written_);
- bool fin = false;
switch (forwarding_preference) {
case MoqtForwardingPreference::kTrack:
if (object.status == MoqtObjectStatus::kEndOfGroup ||
@@ -1301,20 +1344,10 @@
} else {
next_object_.object = header.object_id + 1;
}
- fin = object.status == MoqtObjectStatus::kEndOfTrack ||
- !subscription.InWindow(next_object_);
break;
case MoqtForwardingPreference::kSubgroup:
- // TODO(martinduke): EndOfGroup and EndOfTrack implies the ability to
- // close other streams/subgroups. PublishedObject should contain a boolean
- // if the stream is safe to close.
next_object_.object = header.object_id + 1;
- fin = object.status == MoqtObjectStatus::kEndOfTrack ||
- object.status == MoqtObjectStatus::kEndOfGroup ||
- object.status == MoqtObjectStatus::kEndOfSubgroup ||
- object.status == MoqtObjectStatus::kGroupDoesNotExist ||
- !subscription.InWindow(next_object_);
break;
case MoqtForwardingPreference::kDatagram:
@@ -1327,7 +1360,7 @@
std::array<absl::string_view, 2> write_vector = {
serialized_header.AsStringView(), object.payload.AsStringView()};
quiche::StreamWriteOptions options;
- options.set_send_fin(fin);
+ options.set_send_fin(object.fin_after_this);
absl::Status write_status = stream_->Writev(write_vector, options);
if (!write_status.ok()) {
QUICHE_BUG(MoqtSession_SendNextObject_write_failed)
@@ -1339,7 +1372,7 @@
}
QUIC_DVLOG(1) << "Stream " << stream_->GetStreamId() << " successfully wrote "
- << object.sequence << ", fin = " << fin
+ << object.sequence << ", fin = " << object.fin_after_this
<< ", next: " << next_object_;
stream_header_written_ = true;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index e24268b..04972b1 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -313,6 +313,8 @@
// This is only called for objects that have just arrived.
void OnNewObjectAvailable(FullSequence sequence) override;
void OnTrackPublisherGone() override;
+ void OnNewFinAvailable(FullSequence sequence) override;
+ void OnGroupAbandoned(uint64_t group_id) override;
void ProcessObjectAck(const MoqtObjectAck& message) {
if (monitoring_interface_ == nullptr) {
return;
@@ -402,6 +404,10 @@
// stream becomes write-blocked or closed.
void SendObjects(PublishedSubscription& subscription);
+ // Sends a pure FIN on the stream, if the last object sent matches
+ // |last_object|. Otherwise, does nothing.
+ void Fin(FullSequence last_object);
+
// Recomputes the send order and updates it for the associated stream.
void UpdateSendOrder(PublishedSubscription& subscription);
@@ -412,8 +418,8 @@
// the stream and returns nullptr.
PublishedSubscription* GetSubscriptionIfValid();
- // Actually sends an object on the stream; the object MUST be
- // `next_object_`.
+ // Actually sends an object on the stream; the object MUST have object ID
+ // of at least `next_object_`.
void SendNextObject(PublishedSubscription& subscription,
PublishedObject object);
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 643364c..bc31674 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -1187,13 +1187,252 @@
});
EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] {
return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127,
- MemSliceFromString("deadbeef")};
+ MemSliceFromString("deadbeef"), false};
});
EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] {
return std::optional<PublishedObject>();
});
subscription->OnNewObjectAvailable(FullSequence(5, 0));
EXPECT_TRUE(correct_message);
+ EXPECT_FALSE(fin);
+}
+
+TEST_F(MoqtSessionTest, FinDataStreamFromCache) {
+ FullTrackName ftn("foo", "bar");
+ auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup,
+ FullSequence(4, 2));
+ MoqtObjectListener* subscription =
+ MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0);
+
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(true));
+ bool fin = false;
+ webtransport::test::MockStream mock_stream;
+ EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; });
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&mock_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(mock_stream, SetVisitor(_))
+ .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ });
+ EXPECT_CALL(mock_stream, visitor()).WillOnce([&] {
+ return stream_visitor.get();
+ });
+ EXPECT_CALL(mock_stream, GetStreamId())
+ .WillRepeatedly(Return(kOutgoingUniStreamId));
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillRepeatedly(Return(&mock_stream));
+
+ // Verify first six message fields are sent correctly
+ bool correct_message = false;
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = absl::StartsWith(data[0], kExpectedMessage);
+ fin = options.send_fin();
+ return absl::OkStatus();
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] {
+ return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127,
+ MemSliceFromString("deadbeef"), true};
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ subscription->OnNewObjectAvailable(FullSequence(5, 0));
+ EXPECT_TRUE(correct_message);
+ EXPECT_TRUE(fin);
+}
+
+TEST_F(MoqtSessionTest, GroupAbandoned) {
+ FullTrackName ftn("foo", "bar");
+ auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup,
+ FullSequence(4, 2));
+ MoqtObjectListener* subscription =
+ MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0);
+
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(true));
+ bool fin = false;
+ webtransport::test::MockStream mock_stream;
+ EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; });
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&mock_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(mock_stream, SetVisitor(_))
+ .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ });
+ EXPECT_CALL(mock_stream, visitor()).WillOnce([&] {
+ return stream_visitor.get();
+ });
+ EXPECT_CALL(mock_stream, GetStreamId())
+ .WillRepeatedly(Return(kOutgoingUniStreamId));
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillRepeatedly(Return(&mock_stream));
+
+ // Verify first six message fields are sent correctly
+ bool correct_message = false;
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = absl::StartsWith(data[0], kExpectedMessage);
+ fin |= options.send_fin();
+ return absl::OkStatus();
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] {
+ return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127,
+ MemSliceFromString("deadbeef"), true};
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ subscription->OnNewObjectAvailable(FullSequence(5, 0));
+ EXPECT_TRUE(correct_message);
+ EXPECT_TRUE(fin);
+
+ EXPECT_CALL(mock_stream, ResetWithUserCode(kResetCodeTimedOut));
+ subscription->OnGroupAbandoned(5);
+}
+
+TEST_F(MoqtSessionTest, LateFinDataStream) {
+ FullTrackName ftn("foo", "bar");
+ auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup,
+ FullSequence(4, 2));
+ MoqtObjectListener* subscription =
+ MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0);
+
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(true));
+ bool fin = false;
+ webtransport::test::MockStream mock_stream;
+ EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; });
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&mock_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(mock_stream, SetVisitor(_))
+ .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ });
+ EXPECT_CALL(mock_stream, visitor()).WillRepeatedly([&] {
+ return stream_visitor.get();
+ });
+ EXPECT_CALL(mock_stream, GetStreamId())
+ .WillRepeatedly(Return(kOutgoingUniStreamId));
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillRepeatedly(Return(&mock_stream));
+
+ // Verify first six message fields are sent correctly
+ bool correct_message = false;
+ const std::string kExpectedMessage = {0x04, 0x02, 0x05, 0x00, 0x00};
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = absl::StartsWith(data[0], kExpectedMessage);
+ fin = options.send_fin();
+ return absl::OkStatus();
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] {
+ return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127,
+ MemSliceFromString("deadbeef"), false};
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ subscription->OnNewObjectAvailable(FullSequence(5, 0));
+ EXPECT_TRUE(correct_message);
+ EXPECT_FALSE(fin);
+ fin = false;
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ EXPECT_TRUE(data.empty());
+ fin = options.send_fin();
+ return absl::OkStatus();
+ });
+ subscription->OnNewFinAvailable(FullSequence(5, 0));
+}
+
+TEST_F(MoqtSessionTest, SeparateFinForFutureObject) {
+ FullTrackName ftn("foo", "bar");
+ auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup,
+ FullSequence(4, 2));
+ MoqtObjectListener* subscription =
+ MoqtSessionPeer::AddSubscription(&session_, track, 0, 2, 5, 0);
+
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(true));
+ bool fin = false;
+ webtransport::test::MockStream mock_stream;
+ EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return !fin; });
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&mock_stream));
+ std::unique_ptr<webtransport::StreamVisitor> stream_visitor;
+ EXPECT_CALL(mock_stream, SetVisitor(_))
+ .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> visitor) {
+ stream_visitor = std::move(visitor);
+ });
+ EXPECT_CALL(mock_stream, visitor()).WillRepeatedly([&] {
+ return stream_visitor.get();
+ });
+ EXPECT_CALL(mock_stream, GetStreamId())
+ .WillRepeatedly(Return(kOutgoingUniStreamId));
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillRepeatedly(Return(&mock_stream));
+
+ // Verify first six message fields are sent correctly
+ bool correct_message = false;
+ const std::string kExpectedMessage = {0x04, 0x00, 0x02, 0x05, 0x00, 0x00};
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = absl::StartsWith(data[0], kExpectedMessage);
+ fin = options.send_fin();
+ return absl::OkStatus();
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 0))).WillRepeatedly([] {
+ return PublishedObject{FullSequence(5, 0), MoqtObjectStatus::kNormal, 127,
+ MemSliceFromString("deadbeef"), false};
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ subscription->OnNewObjectAvailable(FullSequence(5, 0));
+ EXPECT_FALSE(fin);
+ // Try to deliver (5,1), but fail.
+ EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return false; });
+ EXPECT_CALL(*track, GetCachedObject(_)).Times(0);
+ EXPECT_CALL(mock_stream, Writev(_, _)).Times(0);
+ subscription->OnNewObjectAvailable(FullSequence(5, 1));
+ // Notify that FIN arrived, but do nothing with it because (5, 1) isn't sent.
+ EXPECT_CALL(mock_stream, Writev(_, _)).Times(0);
+ subscription->OnNewFinAvailable(FullSequence(5, 1));
+
+ // Reopen the window.
+ correct_message = false;
+ // object id, payload length, status.
+ const std::string kExpectedMessage2 = {0x01, 0x00, 0x03};
+ EXPECT_CALL(mock_stream, CanWrite()).WillRepeatedly([&] { return true; });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 1))).WillRepeatedly([] {
+ return PublishedObject{FullSequence(5, 1), MoqtObjectStatus::kEndOfGroup,
+ 127, MemSliceFromString(""), true};
+ });
+ EXPECT_CALL(*track, GetCachedObject(FullSequence(5, 2))).WillRepeatedly([] {
+ return std::optional<PublishedObject>();
+ });
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = absl::StartsWith(data[0], kExpectedMessage2);
+ fin = options.send_fin();
+ return absl::OkStatus();
+ });
+ stream_visitor->OnCanWrite();
+ EXPECT_TRUE(correct_message);
+ EXPECT_TRUE(fin);
}
// TODO: Test operation with multiple streams.
diff --git a/quiche/quic/moqt/moqt_subscribe_windows.cc b/quiche/quic/moqt/moqt_subscribe_windows.cc
index 4637bd1..6e9c8e9 100644
--- a/quiche/quic/moqt/moqt_subscribe_windows.cc
+++ b/quiche/quic/moqt/moqt_subscribe_windows.cc
@@ -4,6 +4,7 @@
#include "quiche/quic/moqt/moqt_subscribe_windows.h"
+#include <cstdint>
#include <optional>
#include <vector>
@@ -21,35 +22,61 @@
return (!end_.has_value() || seq <= *end_);
}
+ReducedSequenceIndex::ReducedSequenceIndex(
+ FullSequence sequence, MoqtForwardingPreference preference) {
+ switch (preference) {
+ case MoqtForwardingPreference::kTrack:
+ sequence_ = FullSequence(0, 0, 0);
+ break;
+ case MoqtForwardingPreference::kSubgroup:
+ sequence_ = FullSequence(sequence.group, sequence.subgroup, 0);
+ break;
+ case MoqtForwardingPreference::kDatagram:
+ sequence_ = FullSequence(sequence.group, sequence.object, 0);
+ return;
+ }
+}
+
std::optional<webtransport::StreamId> SendStreamMap::GetStreamForSequence(
FullSequence sequence) const {
- ReducedSequenceIndex index(sequence, forwarding_preference_);
- auto stream_it = send_streams_.find(index);
- if (stream_it == send_streams_.end()) {
+ FullSequence index =
+ ReducedSequenceIndex(sequence, forwarding_preference_).sequence();
+ auto group_it = send_streams_.find(index.group);
+ if (group_it == send_streams_.end()) {
return std::nullopt;
}
- return stream_it->second;
+ auto subgroup_it = group_it->second.find(index.subgroup);
+ if (subgroup_it == group_it->second.end()) {
+ return std::nullopt;
+ }
+ return subgroup_it->second;
}
void SendStreamMap::AddStream(FullSequence sequence,
webtransport::StreamId stream_id) {
- ReducedSequenceIndex index(sequence, forwarding_preference_);
- if (forwarding_preference_ == MoqtForwardingPreference::kDatagram) {
- QUIC_BUG(quic_bug_moqt_draft_03_01) << "Adding a stream for datagram";
- return;
- }
- auto [stream_it, success] = send_streams_.emplace(index, stream_id);
+ FullSequence index =
+ ReducedSequenceIndex(sequence, forwarding_preference_).sequence();
+ auto [it, result] = send_streams_.insert({index.group, Group()});
+ auto [sg, success] = it->second.try_emplace(index.subgroup, stream_id);
QUIC_BUG_IF(quic_bug_moqt_draft_03_02, !success) << "Stream already added";
}
void SendStreamMap::RemoveStream(FullSequence sequence,
webtransport::StreamId stream_id) {
- ReducedSequenceIndex index(sequence, forwarding_preference_);
- QUICHE_DCHECK(send_streams_.contains(index) &&
- send_streams_.find(index)->second == stream_id)
- << "Requested to remove a stream ID that does not match the one in the "
- "map";
- send_streams_.erase(index);
+ FullSequence index =
+ ReducedSequenceIndex(sequence, forwarding_preference_).sequence();
+ auto group_it = send_streams_.find(index.group);
+ if (group_it == send_streams_.end()) {
+ QUICHE_NOTREACHED();
+ return;
+ }
+ auto subgroup_it = group_it->second.find(index.subgroup);
+ if (subgroup_it == group_it->second.end() ||
+ subgroup_it->second != stream_id) {
+ QUICHE_NOTREACHED();
+ return;
+ }
+ group_it->second.erase(subgroup_it);
}
bool SubscribeWindow::UpdateStartEnd(FullSequence start,
@@ -66,25 +93,25 @@
return true;
}
-ReducedSequenceIndex::ReducedSequenceIndex(
- FullSequence sequence, MoqtForwardingPreference preference) {
- switch (preference) {
- case MoqtForwardingPreference::kTrack:
- sequence_ = FullSequence(0, 0);
- break;
- case MoqtForwardingPreference::kSubgroup:
- sequence_ = FullSequence(sequence.group, 0);
- break;
- case MoqtForwardingPreference::kDatagram:
- sequence_ = sequence;
- return;
- }
-}
-
std::vector<webtransport::StreamId> SendStreamMap::GetAllStreams() const {
std::vector<webtransport::StreamId> ids;
- for (const auto& [index, id] : send_streams_) {
- ids.push_back(id);
+ for (const auto& [group, subgroup_map] : send_streams_) {
+ for (const auto& [subgroup, stream_id] : subgroup_map) {
+ ids.push_back(stream_id);
+ }
+ }
+ return ids;
+}
+
+std::vector<webtransport::StreamId> SendStreamMap::GetStreamsForGroup(
+ uint64_t group_id) const {
+ std::vector<webtransport::StreamId> ids;
+ auto it = send_streams_.find(group_id);
+ if (it == send_streams_.end()) {
+ return ids;
+ }
+ for (const auto& [subgroup, stream_id] : it->second) {
+ ids.push_back(stream_id);
}
return ids;
}
diff --git a/quiche/quic/moqt/moqt_subscribe_windows.h b/quiche/quic/moqt/moqt_subscribe_windows.h
index b850f3e..d6f8ea4 100644
--- a/quiche/quic/moqt/moqt_subscribe_windows.h
+++ b/quiche/quic/moqt/moqt_subscribe_windows.h
@@ -9,7 +9,7 @@
#include <optional>
#include <vector>
-#include "absl/container/flat_hash_map.h"
+#include "absl/container/btree_map.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/common/platform/api/quiche_export.h"
#include "quiche/web_transport/web_transport.h"
@@ -36,7 +36,7 @@
: start_(start), end_(end) {}
bool InWindow(const FullSequence& seq) const;
- bool HasEnd() const { return end_.has_value(); }
+ const std::optional<FullSequence>& end() const { return end_; }
FullSequence start() const { return start_; }
// Updates the subscription window. Returns true if the update is valid (in
@@ -44,6 +44,7 @@
bool UpdateStartEnd(FullSequence start, std::optional<FullSequence> end);
private:
+ // The subgroups in these sequences have no meaning.
FullSequence start_;
std::optional<FullSequence> end_;
};
@@ -61,6 +62,7 @@
bool operator!=(const ReducedSequenceIndex& other) const {
return sequence_ != other.sequence_;
}
+ FullSequence sequence() { return sequence_; }
template <typename H>
friend H AbslHashValue(H h, const ReducedSequenceIndex& m) {
@@ -82,10 +84,12 @@
void AddStream(FullSequence sequence, webtransport::StreamId stream_id);
void RemoveStream(FullSequence sequence, webtransport::StreamId stream_id);
std::vector<webtransport::StreamId> GetAllStreams() const;
+ std::vector<webtransport::StreamId> GetStreamsForGroup(
+ uint64_t group_id) const;
private:
- absl::flat_hash_map<ReducedSequenceIndex, webtransport::StreamId>
- send_streams_;
+ using Group = absl::btree_map<uint64_t, webtransport::StreamId>;
+ absl::btree_map<uint64_t, Group> send_streams_;
MoqtForwardingPreference forwarding_preference_;
};
diff --git a/quiche/quic/moqt/moqt_subscribe_windows_test.cc b/quiche/quic/moqt/moqt_subscribe_windows_test.cc
index ba8687e..1fcc43d 100644
--- a/quiche/quic/moqt/moqt_subscribe_windows_test.cc
+++ b/quiche/quic/moqt/moqt_subscribe_windows_test.cc
@@ -59,8 +59,13 @@
TEST_F(SubscribeWindowTest, AddQueryRemoveStreamIdDatagram) {
SendStreamMap stream_map(MoqtForwardingPreference::kDatagram);
- EXPECT_QUIC_BUG(stream_map.AddStream(FullSequence{4, 0}, 2),
- "Adding a stream for datagram");
+ stream_map.AddStream(FullSequence{4, 0}, 2),
+ stream_map.AddStream(FullSequence{4, 1}, 6);
+ EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 0)), 2);
+ EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 1)), 6);
+ EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 2)), std::nullopt);
+ stream_map.RemoveStream(FullSequence{4, 1}, 6);
+ EXPECT_EQ(stream_map.GetStreamForSequence(FullSequence(4, 1)), std::nullopt);
}
TEST_F(SubscribeWindowTest, UpdateStartEnd) {