Handle MoQT Objects that arrive before SUBSCRIBE_OK.
PiperOrigin-RevId: 605309658
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index f78d3f5..0c425a1 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -399,16 +399,28 @@
}
}
auto it = session_->remote_tracks_.find(message.track_alias);
+ RemoteTrack::Visitor* visitor = nullptr;
+ absl::string_view track_namespace;
+ absl::string_view track_name;
if (it == session_->remote_tracks_.end()) {
- // No SUBSCRIBE_OK received with this alias, return.
- return;
+ // SUBSCRIBE_OK has not arrived yet, but deliver it.
+ auto subscribe_it = session_->active_subscribes_.find(message.subscribe_id);
+ if (subscribe_it == session_->active_subscribes_.end()) {
+ return;
+ }
+ visitor = subscribe_it->second.visitor;
+ track_namespace = subscribe_it->second.message.track_namespace;
+ track_name = subscribe_it->second.message.track_name;
+ } else {
+ visitor = it->second.visitor();
+ track_namespace = it->second.full_track_name().track_namespace;
+ track_name = it->second.full_track_name().track_name;
}
- RemoteTrack& subscription = it->second;
- if (subscription.visitor() != nullptr) {
- subscription.visitor()->OnObjectFragment(
- subscription.full_track_name(), message.group_id, message.object_id,
- message.object_send_order, message.forwarding_preference, payload,
- end_of_message);
+ if (visitor != nullptr) {
+ visitor->OnObjectFragment(
+ FullTrackName(track_namespace, track_name), message.group_id,
+ message.object_id, message.object_send_order,
+ message.forwarding_preference, payload, end_of_message);
}
partial_object_.clear();
}
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 69021e5..1f4f9c0 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -99,6 +99,12 @@
session->remote_track_aliases_.try_emplace(name, track_alias);
}
+ static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
+ MoqtSubscribe& subscribe,
+ RemoteTrack::Visitor* visitor) {
+ session->active_subscribes_[subscribe_id] = {subscribe, visitor};
+ }
+
static void AddSubscription(MoqtSession* session, FullTrackName& name,
uint64_t subscribe_id, uint64_t track_alias,
uint64_t start_group, uint64_t start_object) {
@@ -614,6 +620,60 @@
object_stream->OnObjectMessage(object, payload, true); // complete the object
}
+TEST_F(MoqtSessionTest, ObjectBeforeSubscribeOk) {
+ MockRemoteTrackVisitor visitor_;
+ FullTrackName ftn("foo", "bar");
+ std::string payload = "deadbeef";
+ MoqtSubscribe subscribe = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*track_namespace=*/ftn.track_namespace,
+ /*track_name=*/ftn.track_name,
+ /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
+ /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
+ /*end_group=*/std::nullopt,
+ /*end_object=*/std::nullopt,
+ };
+ MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor_);
+ MoqtObject object = {
+ /*subscribe_id=*/1,
+ /*track_alias=*/2,
+ /*group_sequence=*/0,
+ /*object_sequence=*/0,
+ /*object_send_order=*/0,
+ /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
+ /*payload_length=*/8,
+ };
+ StrictMock<webtransport::test::MockStream> mock_stream;
+ std::unique_ptr<MoqtParserVisitor> object_stream =
+ MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
+
+ EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _))
+ .WillOnce([&](const FullTrackName& full_track_name,
+ uint64_t group_sequence, uint64_t object_sequence,
+ uint64_t object_send_order,
+ MoqtForwardingPreference forwarding_preference,
+ absl::string_view payload, bool end_of_message) {
+ EXPECT_EQ(full_track_name, ftn);
+ EXPECT_EQ(group_sequence, object.group_id);
+ EXPECT_EQ(object_sequence, object.object_id);
+ });
+ EXPECT_CALL(mock_stream, GetStreamId())
+ .WillRepeatedly(Return(kIncomingUniStreamId));
+ object_stream->OnObjectMessage(object, payload, true);
+
+ // SUBSCRIBE_OK arrives
+ MoqtSubscribeOk ok = {
+ /*subscribe_id=*/1,
+ /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
+ };
+ StrictMock<webtransport::test::MockStream> mock_control_stream;
+ std::unique_ptr<MoqtParserVisitor> control_stream =
+ MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
+ EXPECT_CALL(visitor_, OnReply(_, _)).Times(1);
+ control_stream->OnSubscribeOkMessage(ok);
+}
+
TEST_F(MoqtSessionTest, CreateUniStreamAndSend) {
StrictMock<webtransport::test::MockStream> mock_stream;
FullTrackName ftn("foo", "bar");