Update MoQT FETCH to draft-11. This is mainly adding Absolute Joining FETCH. This does not include an API to send an Absolute Joining FETCH. PiperOrigin-RevId: 776739732
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index 66fb3af..842a356 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -10,6 +10,7 @@ #include <optional> #include <string> #include <utility> +#include <variant> #include <vector> #include "absl/status/status.h" @@ -676,13 +677,16 @@ } quiche::QuicheBuffer MoqtFramer::SerializeFetch(const MoqtFetch& message) { - if (!message.joining_fetch.has_value() && - (message.end_group < message.start_object.group || - (message.end_group == message.start_object.group && - message.end_object.has_value() && - *message.end_object < message.start_object.object))) { - QUICHE_BUG(MoqtFramer_invalid_fetch) << "Invalid FETCH object range"; - return quiche::QuicheBuffer(); + if (std::holds_alternative<StandaloneFetch>(message.fetch)) { + const StandaloneFetch& standalone_fetch = + std::get<StandaloneFetch>(message.fetch); + if (standalone_fetch.end_group < standalone_fetch.start_object.group || + (standalone_fetch.end_group == standalone_fetch.start_object.group && + standalone_fetch.end_object.has_value() && + *standalone_fetch.end_object < standalone_fetch.start_object.object)) { + QUICHE_BUG(MoqtFramer_invalid_fetch) << "Invalid FETCH object range"; + return quiche::QuicheBuffer(); + } } KeyValuePairList parameters; VersionSpecificParametersToKeyValuePairList(message.parameters, parameters); @@ -691,28 +695,42 @@ << "Serializing invalid MoQT parameters"; return quiche::QuicheBuffer(); } - if (message.joining_fetch.has_value()) { + if (std::holds_alternative<StandaloneFetch>(message.fetch)) { + const StandaloneFetch& standalone_fetch = + std::get<StandaloneFetch>(message.fetch); return SerializeControlMessage( MoqtMessageType::kFetch, WireVarInt62(message.fetch_id), WireUint8(message.subscriber_priority), WireDeliveryOrder(message.group_order), - WireVarInt62(FetchType::kJoining), - WireVarInt62(message.joining_fetch->joining_subscribe_id), - WireVarInt62(message.joining_fetch->preceding_group_offset), + WireVarInt62(FetchType::kStandalone), + WireFullTrackName(standalone_fetch.full_track_name), + WireVarInt62(standalone_fetch.start_object.group), + WireVarInt62(standalone_fetch.start_object.object), + WireVarInt62(standalone_fetch.end_group), + WireVarInt62(standalone_fetch.end_object.has_value() + ? *standalone_fetch.end_object + 1 + : 0), WireKeyValuePairList(parameters)); } + uint64_t subscribe_id; + uint64_t joining_start; + if (std::holds_alternative<JoiningFetchRelative>(message.fetch)) { + const JoiningFetchRelative& joining_fetch = + std::get<JoiningFetchRelative>(message.fetch); + subscribe_id = joining_fetch.joining_subscribe_id; + joining_start = joining_fetch.joining_start; + } else { + const JoiningFetchAbsolute& joining_fetch = + std::get<JoiningFetchAbsolute>(message.fetch); + subscribe_id = joining_fetch.joining_subscribe_id; + joining_start = joining_fetch.joining_start; + } return SerializeControlMessage( MoqtMessageType::kFetch, WireVarInt62(message.fetch_id), WireUint8(message.subscriber_priority), WireDeliveryOrder(message.group_order), - WireVarInt62(FetchType::kStandalone), - WireFullTrackName(message.full_track_name), - WireVarInt62(message.start_object.group), - WireVarInt62(message.start_object.object), - WireVarInt62(message.end_group), - WireVarInt62(message.end_object.has_value() ? *message.end_object + 1 - : 0), - WireKeyValuePairList(parameters)); + WireVarInt62(message.fetch.index() + 1), WireVarInt62(subscribe_id), + WireVarInt62(joining_start), WireKeyValuePairList(parameters)); } quiche::QuicheBuffer MoqtFramer::SerializeFetchCancel(
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index ab64df2..f507d80 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -447,11 +447,13 @@ /*subscribe_id =*/1, /*subscriber_priority=*/2, /*group_order=*/MoqtDeliveryOrder::kAscending, - /*joining_fetch=*/std::nullopt, - /*full_track_name=*/FullTrackName{"foo", "bar"}, - /*start_object=*/Location{1, 2}, - /*end_group=*/1, - /*end_object=*/1, + /*fetch=*/ + StandaloneFetch{ + FullTrackName("foo", "bar"), + /*start_object=*/Location{1, 2}, + /*end_group=*/1, + /*end_object=*/1, + }, /*parameters=*/ VersionSpecificParameters(AuthTokenType::kOutOfBand, "baz"), }; @@ -459,8 +461,8 @@ EXPECT_QUIC_BUG(buffer = framer_.SerializeFetch(fetch), "Invalid FETCH object range"); EXPECT_EQ(buffer.size(), 0); - fetch.end_group = 0; - fetch.end_object = std::nullopt; + std::get<StandaloneFetch>(fetch.fetch).end_group = 0; + std::get<StandaloneFetch>(fetch.fetch).end_object = std::nullopt; EXPECT_QUIC_BUG(buffer = framer_.SerializeFetch(fetch), "Invalid FETCH object range"); EXPECT_EQ(buffer.size(), 0); @@ -498,8 +500,16 @@ EXPECT_EQ(*end_group, 5); } -TEST_F(MoqtFramerSimpleTest, JoiningFetch) { - JoiningFetchMessage message; +TEST_F(MoqtFramerSimpleTest, RelativeJoiningFetch) { + RelativeJoiningFetchMessage message; + quiche::QuicheBuffer buffer = + framer_.SerializeFetch(std::get<MoqtFetch>(message.structured_data())); + EXPECT_EQ(buffer.size(), message.total_message_size()); + EXPECT_EQ(buffer.AsStringView(), message.PacketSample()); +} + +TEST_F(MoqtFramerSimpleTest, AbsoluteJoiningFetch) { + AbsoluteJoiningFetchMessage message; quiche::QuicheBuffer buffer = framer_.SerializeFetch(std::get<MoqtFetch>(message.structured_data())); EXPECT_EQ(buffer.size(), message.total_message_size());
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index f5ec8c2..43d4ba6 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -14,10 +14,10 @@ #include <string> #include <tuple> #include <utility> +#include <variant> #include <vector> #include "absl/container/btree_map.h" -#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -709,28 +709,68 @@ enum class QUICHE_EXPORT FetchType : uint64_t { kStandalone = 0x1, - kJoining = 0x2, + kRelativeJoining = 0x2, + kAbsoluteJoining = 0x3, }; -struct JoiningFetch { - JoiningFetch(uint64_t joining_subscribe_id, uint64_t preceding_group_offset) +struct StandaloneFetch { + StandaloneFetch() = default; + StandaloneFetch(FullTrackName full_track_name, Location start_object, + uint64_t end_group, std::optional<uint64_t> end_object) + : full_track_name(full_track_name), + start_object(start_object), + end_group(end_group), + end_object(end_object) {} + FullTrackName full_track_name; + Location start_object; // subgroup is ignored + uint64_t end_group; + std::optional<uint64_t> end_object; + bool operator==(const StandaloneFetch& other) const { + return full_track_name == other.full_track_name && + start_object == other.start_object && end_group == other.end_group && + end_object == other.end_object; + } + bool operator!=(const StandaloneFetch& other) const { + return !(*this == other); + } +}; + +struct JoiningFetchRelative { + JoiningFetchRelative(uint64_t joining_subscribe_id, uint64_t joining_start) : joining_subscribe_id(joining_subscribe_id), - preceding_group_offset(preceding_group_offset) {} + joining_start(joining_start) {} uint64_t joining_subscribe_id; - uint64_t preceding_group_offset; + uint64_t joining_start; + bool operator==(const JoiningFetchRelative& other) const { + return joining_subscribe_id == other.joining_subscribe_id && + joining_start == other.joining_start; + } + bool operator!=(const JoiningFetchRelative& other) const { + return !(*this == other); + } +}; + +struct JoiningFetchAbsolute { + JoiningFetchAbsolute(uint64_t joining_subscribe_id, uint64_t joining_start) + : joining_subscribe_id(joining_subscribe_id), + joining_start(joining_start) {} + uint64_t joining_subscribe_id; + uint64_t joining_start; + bool operator==(const JoiningFetchAbsolute& other) const { + return joining_subscribe_id == other.joining_subscribe_id && + joining_start == other.joining_start; + } + bool operator!=(const JoiningFetchAbsolute& other) const { + return !(*this == other); + } }; struct QUICHE_EXPORT MoqtFetch { uint64_t fetch_id; MoqtPriority subscriber_priority; std::optional<MoqtDeliveryOrder> group_order; - // If joining_fetch has a value, then the parser will not populate the name - // and ranges. The session will populate them instead. - std::optional<JoiningFetch> joining_fetch; - FullTrackName full_track_name; - Location start_object; - uint64_t end_group; - std::optional<uint64_t> end_object; + std::variant<StandaloneFetch, JoiningFetchRelative, JoiningFetchAbsolute> + fetch; VersionSpecificParameters parameters; };
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 710848c..d29b5a2 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -787,7 +787,6 @@ size_t MoqtControlParser::ProcessFetch(quic::QuicDataReader& reader) { MoqtFetch fetch; uint8_t group_order; - uint64_t end_object; uint64_t type; if (!reader.ReadVarInt62(&fetch.fetch_id) || !reader.ReadUInt8(&fetch.subscriber_priority) || @@ -799,32 +798,45 @@ return 0; } switch (static_cast<FetchType>(type)) { - case FetchType::kJoining: { + case FetchType::kAbsoluteJoining: { uint64_t joining_subscribe_id; - uint64_t preceding_group_offset; + uint64_t joining_start; if (!reader.ReadVarInt62(&joining_subscribe_id) || - !reader.ReadVarInt62(&preceding_group_offset)) { + !reader.ReadVarInt62(&joining_start)) { return 0; } - fetch.joining_fetch = - JoiningFetch{joining_subscribe_id, preceding_group_offset}; + fetch.fetch = JoiningFetchAbsolute{joining_subscribe_id, joining_start}; + break; + } + case FetchType::kRelativeJoining: { + uint64_t joining_subscribe_id; + uint64_t joining_start; + if (!reader.ReadVarInt62(&joining_subscribe_id) || + !reader.ReadVarInt62(&joining_start)) { + return 0; + } + fetch.fetch = JoiningFetchRelative{joining_subscribe_id, joining_start}; break; } case FetchType::kStandalone: { - fetch.joining_fetch = std::nullopt; - if (!ReadFullTrackName(reader, fetch.full_track_name) || - !reader.ReadVarInt62(&fetch.start_object.group) || - !reader.ReadVarInt62(&fetch.start_object.object) || - !reader.ReadVarInt62(&fetch.end_group) || + fetch.fetch = StandaloneFetch(); + StandaloneFetch& standalone_fetch = + std::get<StandaloneFetch>(fetch.fetch); + uint64_t end_object; + if (!ReadFullTrackName(reader, standalone_fetch.full_track_name) || + !reader.ReadVarInt62(&standalone_fetch.start_object.group) || + !reader.ReadVarInt62(&standalone_fetch.start_object.object) || + !reader.ReadVarInt62(&standalone_fetch.end_group) || !reader.ReadVarInt62(&end_object)) { return 0; } - fetch.end_object = + standalone_fetch.end_object = end_object == 0 ? std::optional<uint64_t>() : (end_object - 1); - if (fetch.end_group < fetch.start_object.group || - (fetch.end_group == fetch.start_object.group && - fetch.end_object.has_value() && - *fetch.end_object < fetch.start_object.object)) { + if (standalone_fetch.end_group < standalone_fetch.start_object.group || + (standalone_fetch.end_group == standalone_fetch.start_object.group && + standalone_fetch.end_object.has_value() && + *standalone_fetch.end_object < + standalone_fetch.start_object.object)) { ParseError("End object comes before start object in FETCH"); return 0; }
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 9b8bd29..b527d59 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -1313,10 +1313,22 @@ EXPECT_EQ(visitor_.parsing_error_, "Invalid number of namespace elements"); } -TEST_F(MoqtMessageSpecificTest, JoiningFetch) { +TEST_F(MoqtMessageSpecificTest, RelativeJoiningFetch) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); MoqtControlParser parser(kRawQuic, &stream, visitor_); - JoiningFetchMessage message; + RelativeJoiningFetchMessage message; + stream.Receive(message.PacketSample(), false); + parser.ReadAndDispatchMessages(); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.parsing_error_, std::nullopt); + EXPECT_TRUE(visitor_.last_message_.has_value() && + message.EqualFieldValues(*visitor_.last_message_)); +} + +TEST_F(MoqtMessageSpecificTest, AbsoluteJoiningFetch) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtControlParser parser(kRawQuic, &stream, visitor_); + AbsoluteJoiningFetchMessage message; stream.Receive(message.PacketSample(), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 1);
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 2167463..8f9f585 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -6,6 +6,7 @@ #include <algorithm> #include <array> +#include <cstddef> #include <cstdint> #include <memory> #include <optional> @@ -483,29 +484,26 @@ return false; } MoqtFetch message; - message.full_track_name = name; + message.fetch = StandaloneFetch(name, start, end_group, end_object); message.fetch_id = next_request_id_; next_request_id_ += 2; - message.start_object = start; - message.end_group = end_group; - message.end_object = end_object; message.subscriber_priority = priority; message.group_order = delivery_order; message.parameters = parameters; SendControlMessage(framer_.SerializeFetch(message)); - QUIC_DLOG(INFO) << ENDPOINT << "Sent FETCH message for " - << message.full_track_name; - auto fetch = std::make_unique<UpstreamFetch>(message, std::move(callback)); + QUIC_DLOG(INFO) << ENDPOINT << "Sent FETCH message for " << name; + auto fetch = std::make_unique<UpstreamFetch>( + message, std::get<StandaloneFetch>(message.fetch), std::move(callback)); upstream_by_id_.emplace(message.fetch_id, std::move(fetch)); return true; } -bool MoqtSession::JoiningFetch(const FullTrackName& name, - SubscribeRemoteTrack::Visitor* visitor, - uint64_t num_previous_groups, - VersionSpecificParameters parameters) { +bool MoqtSession::RelativeJoiningFetch(const FullTrackName& name, + SubscribeRemoteTrack::Visitor* visitor, + uint64_t num_previous_groups, + VersionSpecificParameters parameters) { QUICHE_DCHECK(name.IsValid()); - return JoiningFetch( + return RelativeJoiningFetch( name, visitor, [this, id = next_request_id_](std::unique_ptr<MoqtFetchTask> fetch_task) { // Move the fetch_task to the subscribe to plumb into its visitor. @@ -522,13 +520,11 @@ parameters); } -bool MoqtSession::JoiningFetch(const FullTrackName& name, - SubscribeRemoteTrack::Visitor* visitor, - FetchResponseCallback callback, - uint64_t num_previous_groups, - MoqtPriority priority, - std::optional<MoqtDeliveryOrder> delivery_order, - VersionSpecificParameters parameters) { +bool MoqtSession::RelativeJoiningFetch( + const FullTrackName& name, SubscribeRemoteTrack::Visitor* visitor, + FetchResponseCallback callback, uint64_t num_previous_groups, + MoqtPriority priority, std::optional<MoqtDeliveryOrder> delivery_order, + VersionSpecificParameters parameters) { QUICHE_DCHECK(name.IsValid()); if ((next_request_id_ + 2) >= peer_max_request_id_) { QUIC_DLOG(INFO) << ENDPOINT << "Tried to send JOINING_FETCH with ID " @@ -554,12 +550,12 @@ next_request_id_ += 2; fetch.subscriber_priority = priority; fetch.group_order = delivery_order; - fetch.joining_fetch = {subscribe.request_id, num_previous_groups}; + fetch.fetch = JoiningFetchRelative{subscribe.request_id, num_previous_groups}; fetch.parameters = parameters; SendControlMessage(framer_.SerializeFetch(fetch)); QUIC_DLOG(INFO) << ENDPOINT << "Sent Joining FETCH message for " << name; auto upstream_fetch = - std::make_unique<UpstreamFetch>(fetch, std::move(callback)); + std::make_unique<UpstreamFetch>(fetch, name, std::move(callback)); upstream_by_id_.emplace(fetch.fetch_id, std::move(upstream_fetch)); return true; } @@ -1359,8 +1355,20 @@ Location start_object; uint64_t end_group; std::optional<uint64_t> end_object; - if (message.joining_fetch.has_value()) { - uint64_t joining_subscribe_id = message.joining_fetch->joining_subscribe_id; + if (std::holds_alternative<StandaloneFetch>(message.fetch)) { + const StandaloneFetch& standalone_fetch = + std::get<StandaloneFetch>(message.fetch); + track_name = standalone_fetch.full_track_name; + start_object = standalone_fetch.start_object; + end_group = standalone_fetch.end_group; + end_object = standalone_fetch.end_object; + } else { + uint64_t joining_subscribe_id = + std::holds_alternative<JoiningFetchRelative>(message.fetch) + ? std::get<struct JoiningFetchRelative>(message.fetch) + .joining_subscribe_id + : std::get<JoiningFetchAbsolute>(message.fetch) + .joining_subscribe_id; auto it = session_->published_subscriptions_.find(joining_subscribe_id); if (it == session_->published_subscriptions_.end()) { QUIC_DLOG(INFO) << ENDPOINT << "Received a JOINING_FETCH for " @@ -1383,20 +1391,26 @@ } track_name = it->second->publisher().GetTrackName(); Location fetch_end = it->second->GetWindowStart(); - if (message.joining_fetch->preceding_group_offset > fetch_end.group) { - start_object = Location(0, 0); + if (std::holds_alternative<JoiningFetchRelative>(message.fetch)) { + const JoiningFetchRelative& relative_fetch = + std::get<JoiningFetchRelative>(message.fetch); + if (relative_fetch.joining_start > fetch_end.group) { + start_object = Location(0, 0); + } else { + start_object = + Location(fetch_end.group - relative_fetch.joining_start, 0); + } } else { - start_object = Location( - fetch_end.group - message.joining_fetch->preceding_group_offset, 0); + const JoiningFetchAbsolute& absolute_fetch = + std::get<JoiningFetchAbsolute>(message.fetch); + start_object = + Location(fetch_end.group - absolute_fetch.joining_start, 0); } end_group = fetch_end.group; end_object = fetch_end.object - 1; - } else { - track_name = message.full_track_name; - start_object = message.start_object; - end_group = message.end_group; - end_object = message.end_object; } + // The check for end_object < start_object is done in + // MoqtTrackPublisher::Fetch(). QUIC_DLOG(INFO) << ENDPOINT << "Received a FETCH for " << track_name; absl::StatusOr<std::shared_ptr<MoqtTrackPublisher>> track_publisher = session_->publisher_->GetTrack(track_name);
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 67336e2..63cbab8 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -156,19 +156,19 @@ // in the FETCH will not be filled by with ObjectDoesNotExist. If the FETCH // fails for any reason, the application will not receive a notification; it // will just appear to be missing objects. - bool JoiningFetch(const FullTrackName& name, - SubscribeRemoteTrack::Visitor* visitor, - uint64_t num_previous_groups, - VersionSpecificParameters parameters) override; + bool RelativeJoiningFetch(const FullTrackName& name, + SubscribeRemoteTrack::Visitor* visitor, + uint64_t num_previous_groups, + VersionSpecificParameters parameters) override; // Sends both a SUBSCRIBE and a joining FETCH, beginning |num_previous_groups| // groups before the current group. The application provides |callback| to // fully control acceptance of Fetched objects. - bool JoiningFetch(const FullTrackName& name, - SubscribeRemoteTrack::Visitor* visitor, - FetchResponseCallback callback, - uint64_t num_previous_groups, MoqtPriority priority, - std::optional<MoqtDeliveryOrder> delivery_order, - VersionSpecificParameters parameters) override; + bool RelativeJoiningFetch(const FullTrackName& name, + SubscribeRemoteTrack::Visitor* visitor, + FetchResponseCallback callback, + uint64_t num_previous_groups, MoqtPriority priority, + std::optional<MoqtDeliveryOrder> delivery_order, + VersionSpecificParameters parameters) override; // Send a GOAWAY message to the peer. |new_session_uri| must be empty if // called by the client.
diff --git a/quiche/quic/moqt/moqt_session_interface.h b/quiche/quic/moqt/moqt_session_interface.h index 8e822f9..6846c96 100644 --- a/quiche/quic/moqt/moqt_session_interface.h +++ b/quiche/quic/moqt/moqt_session_interface.h
@@ -94,20 +94,21 @@ // in the FETCH will not be filled by with ObjectDoesNotExist. If the FETCH // fails for any reason, the application will not receive a notification; it // will just appear to be missing objects. - virtual bool JoiningFetch(const FullTrackName& name, - SubscribeRemoteTrack::Visitor* visitor, - uint64_t num_previous_groups, - VersionSpecificParameters parameters) = 0; + virtual bool RelativeJoiningFetch(const FullTrackName& name, + SubscribeRemoteTrack::Visitor* visitor, + uint64_t num_previous_groups, + VersionSpecificParameters parameters) = 0; // Sends both a SUBSCRIBE and a joining FETCH, beginning `num_previous_groups` // groups before the current group. `callback` acts the same way as the // callback for the regular Fetch() call. - virtual bool JoiningFetch(const FullTrackName& name, - SubscribeRemoteTrack::Visitor* visitor, - FetchResponseCallback callback, - uint64_t num_previous_groups, MoqtPriority priority, - std::optional<MoqtDeliveryOrder> delivery_order, - VersionSpecificParameters parameters) = 0; + virtual bool RelativeJoiningFetch( + const FullTrackName& name, SubscribeRemoteTrack::Visitor* visitor, + FetchResponseCallback callback, uint64_t num_previous_groups, + MoqtPriority priority, std::optional<MoqtDeliveryOrder> delivery_order, + VersionSpecificParameters parameters) = 0; + + // TODO(martinduke): Add an API for absolute joining fetch. // TODO: Add SubscribeAnnounces, UnsubscribeAnnounces method. // TODO: Add Announce, Unannounce method.
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 9b25531..8131aa2 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -79,11 +79,8 @@ /*fetch_id=*/1, /*subscriber_priority=*/0x80, /*group_order=*/std::nullopt, - /*joining_fetch=*/std::nullopt, - kDefaultTrackName(), - /*start=*/Location(0, 0), - /*end_group=*/1, - /*end_object=*/std::nullopt, + /*fetch=*/ + StandaloneFetch(kDefaultTrackName(), Location(0, 0), 1, std::nullopt), /*parameters=*/VersionSpecificParameters(), }; return fetch; @@ -2442,7 +2439,7 @@ // Joining FETCH arrives. The resulting Fetch should begin at (2, 0). MoqtFetch fetch = DefaultFetch(); fetch.fetch_id = 3; - fetch.joining_fetch = {1, 2}; + fetch.fetch = JoiningFetchRelative(1, 2); EXPECT_CALL(*track, Fetch(Location(2, 0), 4, std::optional<uint64_t>(10), _)) .WillOnce(Return(std::make_unique<MockFetchTask>())); stream_input->OnFetchMessage(fetch); @@ -2452,7 +2449,7 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); - fetch.joining_fetch = {1, 2}; + fetch.fetch = JoiningFetchRelative(1, 2); MoqtFetchError expected_error = { /*request_id=*/1, /*error_code=*/RequestErrorCode::kTrackDoesNotExist, @@ -2473,7 +2470,7 @@ MoqtFetch fetch = DefaultFetch(); fetch.fetch_id = 3; - fetch.joining_fetch = {1, 2}; + fetch.fetch = JoiningFetchRelative(1, 2); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), "Joining Fetch for non-LatestObject subscribe")) @@ -2503,13 +2500,13 @@ /*fetch_id=*/2, /*subscriber_priority=*/0x80, /*group_order=*/MoqtDeliveryOrder::kAscending, - /*joining_fetch=*/JoiningFetch(0, 1), + /*fetch=*/JoiningFetchRelative(0, 1), }; EXPECT_CALL(mock_stream_, Writev(SerializedControlMessage(expected_subscribe), _)); EXPECT_CALL(mock_stream_, Writev(SerializedControlMessage(expected_fetch), _)); - EXPECT_TRUE(session_.JoiningFetch( + EXPECT_TRUE(session_.RelativeJoiningFetch( expected_subscribe.full_track_name, &remote_track_visitor, nullptr, 1, 0x80, MoqtDeliveryOrder::kAscending, VersionSpecificParameters())); } @@ -2524,9 +2521,9 @@ Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kFetch), _)); - EXPECT_TRUE(session_.JoiningFetch(FullTrackName("foo", "bar"), - &remote_track_visitor, 0, - VersionSpecificParameters())); + EXPECT_TRUE(session_.RelativeJoiningFetch(FullTrackName("foo", "bar"), + &remote_track_visitor, 0, + VersionSpecificParameters())); EXPECT_CALL(remote_track_visitor, OnReply).Times(1); stream_input->OnSubscribeOkMessage(
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index 74761f0..6f5b96c 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -232,13 +232,35 @@ // when a FETCH_OK or FETCH_ERROR is received. class UpstreamFetch : public RemoteTrack { public: - UpstreamFetch(const MoqtFetch& fetch, FetchResponseCallback callback) - : RemoteTrack(fetch.full_track_name, fetch.fetch_id, - fetch.joining_fetch.has_value() - ? SubscribeWindow(Location(0, 0)) - : SubscribeWindow(fetch.start_object, fetch.end_group, - fetch.end_object), - fetch.subscriber_priority), + // Standalone Fetch constructor + UpstreamFetch(const MoqtFetch& fetch, const StandaloneFetch standalone, + FetchResponseCallback callback) + : RemoteTrack( + standalone.full_track_name, fetch.fetch_id, + SubscribeWindow(standalone.start_object, standalone.end_group, + standalone.end_object), + fetch.subscriber_priority), + ok_callback_(std::move(callback)) { + // Immediately set the data stream type. + CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch); + } + // Relative Joining Fetch constructor + UpstreamFetch(const MoqtFetch& fetch, FullTrackName full_track_name, + FetchResponseCallback callback) + : RemoteTrack(full_track_name, fetch.fetch_id, + SubscribeWindow(Location(0, 0)), fetch.subscriber_priority), + ok_callback_(std::move(callback)) { + // Immediately set the data stream type. + CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch); + } + // Absolute Joining Fetch constructor + UpstreamFetch(const MoqtFetch& fetch, FullTrackName full_track_name, + JoiningFetchAbsolute absolute_joining, + FetchResponseCallback callback) + : RemoteTrack( + full_track_name, fetch.fetch_id, + SubscribeWindow(Location(absolute_joining.joining_start, 0)), + fetch.subscriber_priority), ok_callback_(std::move(callback)) { // Immediately set the data stream type. CheckDataStreamType(MoqtDataStreamType::kStreamHeaderFetch);
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc index 5090f3d..44a2f1a 100644 --- a/quiche/quic/moqt/moqt_track_test.cc +++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -96,19 +96,17 @@ class UpstreamFetchTest : public quic::test::QuicTest { protected: UpstreamFetchTest() - : fetch_(fetch_message_, [&](std::unique_ptr<MoqtFetchTask> task) { - fetch_task_ = std::move(task); - }) {} + : fetch_(fetch_message_, std::get<StandaloneFetch>(fetch_message_.fetch), + [&](std::unique_ptr<MoqtFetchTask> task) { + fetch_task_ = std::move(task); + }) {} MoqtFetch fetch_message_ = { /*request_id=*/1, /*subscriber_priority=*/128, /*group_order=*/std::nullopt, - /*joining_fetch=*/std::nullopt, - /*full_track_name=*/FullTrackName("foo", "bar"), - /*start_object=*/Location(1, 1), - /*end_group=*/3, - /*end_object=*/100, + /*fetch=*/ + StandaloneFetch(FullTrackName("foo", "bar"), Location(1, 1), 3, 100), VersionSpecificParameters(), }; // The pointer held by the application.
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index 634fe12..467c2fb 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -183,17 +183,15 @@ 0, 128, std::nullopt, - std::nullopt, - FullTrackName{"foo", "bar"}, - Location{0, 0}, - 4, - std::nullopt, + StandaloneFetch(FullTrackName{"foo", "bar"}, Location{0, 0}, 4, + std::nullopt), VersionSpecificParameters(), }; std::unique_ptr<MoqtFetchTask> task; auto [it, success] = session->upstream_by_id_.try_emplace( 0, std::make_unique<UpstreamFetch>( - fetch_message, [&](std::unique_ptr<MoqtFetchTask> fetch_task) { + fetch_message, std::get<StandaloneFetch>(fetch_message.fetch), + [&](std::unique_ptr<MoqtFetchTask> fetch_task) { task = std::move(fetch_task); })); QUICHE_DCHECK(success);
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index de211c6..1ad79b2 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -1321,39 +1321,10 @@ QUIC_LOG(INFO) << "FETCH group_order mismatch"; return false; } - if (cast.joining_fetch.has_value() != fetch_.joining_fetch.has_value()) { - QUIC_LOG(INFO) << "FETCH type mismatch"; + if (cast.fetch != fetch_.fetch) { + QUIC_LOG(INFO) << "FETCH mismatch"; return false; } - if (cast.joining_fetch.has_value()) { - if (cast.joining_fetch->joining_subscribe_id != - fetch_.joining_fetch->joining_subscribe_id) { - QUIC_LOG(INFO) << "FETCH joining_subscribe_id mismatch"; - return false; - } - if (cast.joining_fetch->preceding_group_offset != - fetch_.joining_fetch->preceding_group_offset) { - QUIC_LOG(INFO) << "FETCH preceding_group_offset mismatch"; - return false; - } - } else { - if (cast.full_track_name != fetch_.full_track_name) { - QUIC_LOG(INFO) << "FETCH full_track_name mismatch"; - return false; - } - if (cast.start_object != fetch_.start_object) { - QUIC_LOG(INFO) << "FETCH start_object mismatch"; - return false; - } - if (cast.end_group != fetch_.end_group) { - QUIC_LOG(INFO) << "FETCH end_group mismatch"; - return false; - } - if (cast.end_object != fetch_.end_object) { - QUIC_LOG(INFO) << "FETCH end_object mismatch"; - return false; - } - } if (cast.parameters != fetch_.parameters) { QUIC_LOG(INFO) << "FETCH parameters mismatch"; return false; @@ -1373,8 +1344,8 @@ // Avoid varint nonsense. QUICHE_CHECK(group < 64); QUICHE_CHECK(!object.has_value() || *object < 64); - fetch_.end_group = group; - fetch_.end_object = object; + std::get<StandaloneFetch>(fetch_.fetch).end_group = group; + std::get<StandaloneFetch>(fetch_.fetch).end_object = object; raw_packet_[18] = group; raw_packet_[19] = object.has_value() ? (*object + 1) : 0; SetWireImage(raw_packet_, sizeof(raw_packet_)); @@ -1403,20 +1374,22 @@ /*fetch_id =*/1, /*subscriber_priority=*/2, /*group_order=*/MoqtDeliveryOrder::kAscending, - /*joining_fetch=*/std::optional<JoiningFetch>(), - FullTrackName("foo", "bar"), - /*start_object=*/Location{1, 2}, - /*end_group=*/5, - /*end_object=*/6, + /*fetch =*/ + StandaloneFetch{ + FullTrackName("foo", "bar"), + /*start_object=*/Location{1, 2}, + /*end_group=*/5, + /*end_object=*/6, + }, VersionSpecificParameters(AuthTokenType::kOutOfBand, "baz"), }; }; // This is not used in the parameterized Parser and Framer tests, because it // does not have its own MoqtMessageType. -class QUICHE_NO_EXPORT JoiningFetchMessage : public TestMessageBase { +class QUICHE_NO_EXPORT RelativeJoiningFetchMessage : public TestMessageBase { public: - JoiningFetchMessage() : TestMessageBase() { + RelativeJoiningFetchMessage() : TestMessageBase() { SetWireImage(raw_packet_, sizeof(raw_packet_)); } bool EqualFieldValues(MessageStructuredData& values) const override { @@ -1433,39 +1406,10 @@ QUIC_LOG(INFO) << "FETCH group_order mismatch"; return false; } - if (cast.joining_fetch.has_value() != fetch_.joining_fetch.has_value()) { - QUIC_LOG(INFO) << "FETCH type mismatch"; + if (cast.fetch != fetch_.fetch) { + QUIC_LOG(INFO) << "FETCH mismatch"; return false; } - if (cast.joining_fetch.has_value()) { - if (cast.joining_fetch->joining_subscribe_id != - fetch_.joining_fetch->joining_subscribe_id) { - QUIC_LOG(INFO) << "FETCH joining_subscribe_id mismatch"; - return false; - } - if (cast.joining_fetch->preceding_group_offset != - fetch_.joining_fetch->preceding_group_offset) { - QUIC_LOG(INFO) << "FETCH preceding_group_offset mismatch"; - return false; - } - } else { - if (cast.full_track_name != fetch_.full_track_name) { - QUIC_LOG(INFO) << "FETCH full_track_name mismatch"; - return false; - } - if (cast.start_object != fetch_.start_object) { - QUIC_LOG(INFO) << "FETCH start_object mismatch"; - return false; - } - if (cast.end_group != fetch_.end_group) { - QUIC_LOG(INFO) << "FETCH end_group mismatch"; - return false; - } - if (cast.end_object != fetch_.end_object) { - QUIC_LOG(INFO) << "FETCH end_object mismatch"; - return false; - } - } if (cast.parameters != fetch_.parameters) { QUIC_LOG(INFO) << "FETCH parameters mismatch"; return false; @@ -1492,7 +1436,7 @@ 0x01, // fetch_id = 1 0x02, // priority = kHigh 0x01, // group_order = kAscending - 0x02, // type = kJoining + 0x02, // type = kRelativeJoining 0x02, 0x02, // joining_subscribe_id = 2, 2 groups 0x01, 0x01, 0x05, 0x03, 0x00, 0x62, 0x61, 0x7a, // parameters = "baz" }; @@ -1501,12 +1445,72 @@ /*fetch_id =*/1, /*subscriber_priority=*/2, /*group_order=*/MoqtDeliveryOrder::kAscending, - /*joining_fetch=*/JoiningFetch{2, 2}, - /* the next four are ignored for joining fetches*/ - FullTrackName("foo", "bar"), - /*start_object=*/Location{1, 2}, - /*end_group=*/5, - /*end_object=*/6, + /*fetch=*/JoiningFetchRelative{2, 2}, + VersionSpecificParameters(AuthTokenType::kOutOfBand, "baz"), + }; +}; + +// This is not used in the parameterized Parser and Framer tests, because it +// does not have its own MoqtMessageType. +class QUICHE_NO_EXPORT AbsoluteJoiningFetchMessage : public TestMessageBase { + public: + AbsoluteJoiningFetchMessage() : TestMessageBase() { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtFetch>(values); + if (cast.fetch_id != fetch_.fetch_id) { + QUIC_LOG(INFO) << "FETCH fetch_id mismatch"; + return false; + } + if (cast.subscriber_priority != fetch_.subscriber_priority) { + QUIC_LOG(INFO) << "FETCH subscriber_priority mismatch"; + return false; + } + if (cast.group_order != fetch_.group_order) { + QUIC_LOG(INFO) << "FETCH group_order mismatch"; + return false; + } + if (cast.fetch != fetch_.fetch) { + QUIC_LOG(INFO) << "FETCH mismatch"; + return false; + } + if (cast.parameters != fetch_.parameters) { + QUIC_LOG(INFO) << "FETCH parameters mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { + ExpandVarintsImpl("v--vvv---v---vvvvvv-----"); + } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(fetch_); + } + + void SetGroupOrder(uint8_t group_order) { + raw_packet_[5] = static_cast<uint8_t>(group_order); + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + private: + uint8_t raw_packet_[17] = { + 0x16, 0x00, 0x0e, + 0x01, // fetch_id = 1 + 0x02, // priority = kHigh + 0x01, // group_order = kAscending + 0x03, // type = kAbsoluteJoining + 0x02, 0x02, // joining_subscribe_id = 2, group_id = 2 + 0x01, 0x01, 0x05, 0x03, 0x00, 0x62, 0x61, 0x7a, // parameters = "baz" + }; + + MoqtFetch fetch_ = { + /*fetch_id =*/1, + /*subscriber_priority=*/2, + /*group_order=*/MoqtDeliveryOrder::kAscending, + /*fetch=*/JoiningFetchAbsolute{2, 2}, VersionSpecificParameters(AuthTokenType::kOutOfBand, "baz"), }; };
diff --git a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc index 7e70c29..7bf3b31 100644 --- a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc +++ b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
@@ -155,8 +155,8 @@ absl::StrSplit(track_list, ',', absl::AllowEmpty()); for (absl::string_view track : tracks_to_subscribe) { FullTrackName full_track_name(track_namespace, track); - session_->JoiningFetch(full_track_name, &it->second, 0, - VersionSpecificParameters()); + session_->RelativeJoiningFetch(full_track_name, &it->second, 0, + VersionSpecificParameters()); } return std::nullopt;
diff --git a/quiche/quic/moqt/tools/moqt_simulator_bin.cc b/quiche/quic/moqt/tools/moqt_simulator_bin.cc index 3bf12ee..5e51814 100644 --- a/quiche/quic/moqt/tools/moqt_simulator_bin.cc +++ b/quiche/quic/moqt/tools/moqt_simulator_bin.cc
@@ -447,8 +447,8 @@ if (!parameters_.delivery_timeout.IsInfinite()) { subscription_parameters.delivery_timeout = parameters_.delivery_timeout; } - server_session()->JoiningFetch(TrackName(), &receiver_, 0, - subscription_parameters); + server_session()->RelativeJoiningFetch(TrackName(), &receiver_, 0, + subscription_parameters); simulator_.RunFor(parameters_.duration); // At the end, we wait for eight RTTs until the connection settles down.