Update MoQT to draft-04 - roll the version number - new frames: SUBSCRIBE_UPDATE, ANNOUNCE_CANCEL, TRACK_STATUS_REQUEST, TRACK_STATUS - new SUBSCRIBE format - Session sends SUBSCRIBE_DONE after UNSUBSCRIBE or some SUBSCRIBE_UPDATEs. Not in production. PiperOrigin-RevId: 641898167
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index 05a0440..4bb3a04 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -37,42 +37,6 @@ using ::quiche::WireUint8; using ::quiche::WireVarInt62; -// Encoding for MOQT Locations: -// https://moq-wg.github.io/moq-transport/draft-ietf-moq-transport.html#name-subscribe-locations -class WireLocation { - public: - using DataType = std::optional<MoqtSubscribeLocation>; - explicit WireLocation(const DataType& location) : location_(location) {} - - size_t GetLengthOnWire() { - return quiche::ComputeLengthOnWire( - WireVarInt62(GetModeForSubscribeLocation(location_)), - WireOptional<WireVarInt62>(LocationOffsetOnTheWire(location_))); - } - absl::Status SerializeIntoWriter(quiche::QuicheDataWriter& writer) { - return quiche::SerializeIntoWriter( - writer, WireVarInt62(GetModeForSubscribeLocation(location_)), - WireOptional<WireVarInt62>(LocationOffsetOnTheWire(location_))); - } - - private: - // For all location types other than None, we record a single varint after the - // type; this function computes the value of that varint. - static std::optional<uint64_t> LocationOffsetOnTheWire( - std::optional<MoqtSubscribeLocation> location) { - if (!location.has_value()) { - return std::nullopt; - } - if (location->absolute) { - return location->absolute_value; - } - return location->relative_value <= 0 ? -location->relative_value - : location->relative_value + 1; - } - - const DataType& location_; -}; - // Encoding for string parameters as described in // https://moq-wg.github.io/moq-transport/draft-ietf-moq-transport.html#name-parameters struct StringParameter { @@ -260,14 +224,9 @@ quiche::QuicheBuffer MoqtFramer::SerializeSubscribe( const MoqtSubscribe& message) { - if (!message.start_group.has_value() || !message.start_object.has_value()) { - QUICHE_BUG(MoqtFramer_start_group_missing) - << "start_group or start_object is missing"; - return quiche::QuicheBuffer(); - } - if (message.end_group.has_value() != message.end_object.has_value()) { - QUICHE_BUG(MoqtFramer_end_mismatch) - << "end_group and end_object must both be None or both non-None"; + MoqtFilterType filter_type = GetFilterType(message); + if (filter_type == MoqtFilterType::kNone) { + QUICHE_BUG(MoqtFramer_invalid_subscribe) << "Invalid object range"; return quiche::QuicheBuffer(); } absl::InlinedVector<StringParameter, 1> string_params; @@ -276,15 +235,42 @@ StringParameter(MoqtTrackRequestParameter::kAuthorizationInfo, *message.authorization_info)); } - return Serialize( - WireVarInt62(MoqtMessageType::kSubscribe), - WireVarInt62(message.subscribe_id), WireVarInt62(message.track_alias), - WireStringWithVarInt62Length(message.track_namespace), - WireStringWithVarInt62Length(message.track_name), - WireLocation(message.start_group), WireLocation(message.start_object), - WireLocation(message.end_group), WireLocation(message.end_object), - WireVarInt62(string_params.size()), - WireSpan<WireStringParameter>(string_params)); + switch (filter_type) { + case MoqtFilterType::kLatestGroup: + case MoqtFilterType::kLatestObject: + return Serialize( + WireVarInt62(MoqtMessageType::kSubscribe), + WireVarInt62(message.subscribe_id), WireVarInt62(message.track_alias), + WireStringWithVarInt62Length(message.track_namespace), + WireStringWithVarInt62Length(message.track_name), + WireVarInt62(filter_type), WireVarInt62(string_params.size()), + WireSpan<WireStringParameter>(string_params)); + case MoqtFilterType::kAbsoluteStart: + return Serialize( + WireVarInt62(MoqtMessageType::kSubscribe), + WireVarInt62(message.subscribe_id), WireVarInt62(message.track_alias), + WireStringWithVarInt62Length(message.track_namespace), + WireStringWithVarInt62Length(message.track_name), + WireVarInt62(filter_type), WireVarInt62(*message.start_group), + WireVarInt62(*message.start_object), + WireVarInt62(string_params.size()), + WireSpan<WireStringParameter>(string_params)); + case MoqtFilterType::kAbsoluteRange: + return Serialize( + WireVarInt62(MoqtMessageType::kSubscribe), + WireVarInt62(message.subscribe_id), WireVarInt62(message.track_alias), + WireStringWithVarInt62Length(message.track_namespace), + WireStringWithVarInt62Length(message.track_name), + WireVarInt62(filter_type), WireVarInt62(*message.start_group), + WireVarInt62(*message.start_object), WireVarInt62(*message.end_group), + WireVarInt62(message.end_object.has_value() ? *message.end_object + 1 + : 0), + WireVarInt62(string_params.size()), + WireSpan<WireStringParameter>(string_params)); + default: + QUICHE_BUG(MoqtFramer_end_group_missing) << "Subscribe framing error."; + return quiche::QuicheBuffer(); + } } quiche::QuicheBuffer MoqtFramer::SerializeSubscribeOk( @@ -333,6 +319,29 @@ WireStringWithVarInt62Length(message.reason_phrase), WireUint8(0)); } +quiche::QuicheBuffer MoqtFramer::SerializeSubscribeUpdate( + const MoqtSubscribeUpdate& message) { + uint64_t end_group = + message.end_group.has_value() ? *message.end_group + 1 : 0; + uint64_t end_object = + message.end_object.has_value() ? *message.end_object + 1 : 0; + if (end_group == 0 && end_object != 0) { + QUICHE_BUG(MoqtFramer_invalid_subscribe_update) << "Invalid object range"; + return quiche::QuicheBuffer(); + } + absl::InlinedVector<StringParameter, 1> string_params; + if (message.authorization_info.has_value()) { + string_params.push_back( + StringParameter(MoqtTrackRequestParameter::kAuthorizationInfo, + *message.authorization_info)); + } + return Serialize( + WireVarInt62(MoqtMessageType::kSubscribeUpdate), + WireVarInt62(message.subscribe_id), WireVarInt62(message.start_group), + WireVarInt62(message.start_object), WireVarInt62(end_group), + WireVarInt62(end_object), WireSpan<WireStringParameter>(string_params)); +} + quiche::QuicheBuffer MoqtFramer::SerializeAnnounce( const MoqtAnnounce& message) { absl::InlinedVector<StringParameter, 1> string_params; @@ -362,12 +371,35 @@ WireStringWithVarInt62Length(message.reason_phrase)); } +quiche::QuicheBuffer MoqtFramer::SerializeAnnounceCancel( + const MoqtAnnounceCancel& message) { + return Serialize(WireVarInt62(MoqtMessageType::kAnnounceCancel), + WireStringWithVarInt62Length(message.track_namespace)); +} + +quiche::QuicheBuffer MoqtFramer::SerializeTrackStatusRequest( + const MoqtTrackStatusRequest& message) { + return Serialize(WireVarInt62(MoqtMessageType::kTrackStatusRequest), + WireStringWithVarInt62Length(message.track_namespace), + WireStringWithVarInt62Length(message.track_name)); +} + quiche::QuicheBuffer MoqtFramer::SerializeUnannounce( const MoqtUnannounce& message) { return Serialize(WireVarInt62(MoqtMessageType::kUnannounce), WireStringWithVarInt62Length(message.track_namespace)); } +quiche::QuicheBuffer MoqtFramer::SerializeTrackStatus( + const MoqtTrackStatus& message) { + return Serialize(WireVarInt62(MoqtMessageType::kTrackStatus), + WireStringWithVarInt62Length(message.track_namespace), + WireStringWithVarInt62Length(message.track_name), + WireVarInt62(message.status_code), + WireVarInt62(message.last_group), + WireVarInt62(message.last_object)); +} + quiche::QuicheBuffer MoqtFramer::SerializeGoAway(const MoqtGoAway& message) { return Serialize(WireVarInt62(MoqtMessageType::kGoAway), WireStringWithVarInt62Length(message.new_session_uri));
diff --git a/quiche/quic/moqt/moqt_framer.h b/quiche/quic/moqt/moqt_framer.h index 7f19fa8..6aef8b6 100644 --- a/quiche/quic/moqt/moqt_framer.h +++ b/quiche/quic/moqt/moqt_framer.h
@@ -5,11 +5,7 @@ #ifndef QUICHE_QUIC_MOQT_MOQT_FRAMER_H_ #define QUICHE_QUIC_MOQT_MOQT_FRAMER_H_ -#include <cstddef> -#include <optional> - #include "absl/strings/string_view.h" -#include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_buffer_allocator.h" @@ -47,10 +43,17 @@ const MoqtSubscribeError& message); quiche::QuicheBuffer SerializeUnsubscribe(const MoqtUnsubscribe& message); quiche::QuicheBuffer SerializeSubscribeDone(const MoqtSubscribeDone& message); + quiche::QuicheBuffer SerializeSubscribeUpdate( + const MoqtSubscribeUpdate& message); quiche::QuicheBuffer SerializeAnnounce(const MoqtAnnounce& message); quiche::QuicheBuffer SerializeAnnounceOk(const MoqtAnnounceOk& message); quiche::QuicheBuffer SerializeAnnounceError(const MoqtAnnounceError& message); + quiche::QuicheBuffer SerializeAnnounceCancel( + const MoqtAnnounceCancel& message); + quiche::QuicheBuffer SerializeTrackStatusRequest( + const MoqtTrackStatusRequest& message); quiche::QuicheBuffer SerializeUnannounce(const MoqtUnannounce& message); + quiche::QuicheBuffer SerializeTrackStatus(const MoqtTrackStatus& message); quiche::QuicheBuffer SerializeGoAway(const MoqtGoAway& message); private:
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index 607e1db..7907d30 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -4,6 +4,9 @@ #include "quiche/quic/moqt/moqt_framer.h" +#include <cerrno> +#include <cstddef> +#include <cstdint> #include <memory> #include <optional> #include <string> @@ -30,21 +33,15 @@ std::vector<MoqtFramerTestParams> GetMoqtFramerTestParams() { std::vector<MoqtFramerTestParams> params; std::vector<MoqtMessageType> message_types = { - MoqtMessageType::kObjectStream, - MoqtMessageType::kSubscribe, - MoqtMessageType::kSubscribeOk, - MoqtMessageType::kSubscribeError, - MoqtMessageType::kUnsubscribe, - MoqtMessageType::kSubscribeDone, - MoqtMessageType::kAnnounce, - MoqtMessageType::kAnnounceOk, - MoqtMessageType::kAnnounceError, - MoqtMessageType::kUnannounce, - MoqtMessageType::kGoAway, - MoqtMessageType::kClientSetup, - MoqtMessageType::kServerSetup, - MoqtMessageType::kStreamHeaderTrack, - MoqtMessageType::kStreamHeaderGroup, + MoqtMessageType::kObjectStream, MoqtMessageType::kSubscribe, + MoqtMessageType::kSubscribeOk, MoqtMessageType::kSubscribeError, + MoqtMessageType::kUnsubscribe, MoqtMessageType::kSubscribeDone, + MoqtMessageType::kAnnounceCancel, MoqtMessageType::kTrackStatusRequest, + MoqtMessageType::kTrackStatus, MoqtMessageType::kAnnounce, + MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, + MoqtMessageType::kUnannounce, MoqtMessageType::kGoAway, + MoqtMessageType::kClientSetup, MoqtMessageType::kServerSetup, + MoqtMessageType::kStreamHeaderTrack, MoqtMessageType::kStreamHeaderGroup, }; std::vector<bool> uses_web_transport_bool = { false, @@ -141,10 +138,22 @@ auto data = std::get<MoqtAnnounceError>(structured_data); return framer_.SerializeAnnounceError(data); } + case moqt::MoqtMessageType::kAnnounceCancel: { + auto data = std::get<MoqtAnnounceCancel>(structured_data); + return framer_.SerializeAnnounceCancel(data); + } + case moqt::MoqtMessageType::kTrackStatusRequest: { + auto data = std::get<MoqtTrackStatusRequest>(structured_data); + return framer_.SerializeTrackStatusRequest(data); + } case MoqtMessageType::kUnannounce: { auto data = std::get<MoqtUnannounce>(structured_data); return framer_.SerializeUnannounce(data); } + case moqt::MoqtMessageType::kTrackStatus: { + auto data = std::get<MoqtTrackStatus>(structured_data); + return framer_.SerializeTrackStatus(data); + } case moqt::MoqtMessageType::kGoAway: { auto data = std::get<MoqtGoAway>(structured_data); return framer_.SerializeGoAway(data); @@ -189,6 +198,12 @@ quiche::SimpleBufferAllocator* buffer_allocator_; MoqtFramer framer_; + + // Obtain a pointer to an arbitrary offset in a serialized buffer. + const uint8_t* BufferAtOffset(quiche::QuicheBuffer& buffer, size_t offset) { + const char* data = buffer.data(); + return reinterpret_cast<const uint8_t*>(data + offset); + } }; TEST_F(MoqtFramerSimpleTest, GroupMiddler) { @@ -258,4 +273,154 @@ EXPECT_EQ(buffer.AsStringView(), datagram->PacketSample()); } +TEST_F(MoqtFramerSimpleTest, AllSubscribeInputs) { + for (std::optional<uint64_t> start_group : + {std::optional<uint64_t>(), std::optional<uint64_t>(4)}) { + for (std::optional<uint64_t> start_object : + {std::optional<uint64_t>(), std::optional<uint64_t>(0)}) { + for (std::optional<uint64_t> end_group : + {std::optional<uint64_t>(), std::optional<uint64_t>(7)}) { + for (std::optional<uint64_t> end_object : + {std::optional<uint64_t>(), std::optional<uint64_t>(3)}) { + MoqtSubscribe subscribe = { + /*subscribe_id=*/3, + /*track_alias=*/4, + /*track_namespace=*/"foo", + /*track_name=*/"abcd", + start_group, + start_object, + end_group, + end_object, + /*authorization_info=*/"bar", + }; + quiche::QuicheBuffer buffer; + MoqtFilterType expected_filter_type = MoqtFilterType::kNone; + if (!start_group.has_value() && !start_object.has_value() && + !end_group.has_value() && !end_object.has_value()) { + expected_filter_type = MoqtFilterType::kLatestObject; + } else if (!start_group.has_value() && start_object.has_value() && + *start_object == 0 && !end_group.has_value() && + !end_object.has_value()) { + expected_filter_type = MoqtFilterType::kLatestGroup; + } else if (start_group.has_value() && start_object.has_value() && + !end_group.has_value() && !end_object.has_value()) { + expected_filter_type = MoqtFilterType::kAbsoluteStart; + } else if (start_group.has_value() && start_object.has_value() && + end_group.has_value()) { + expected_filter_type = MoqtFilterType::kAbsoluteRange; + } + if (expected_filter_type == MoqtFilterType::kNone) { + EXPECT_QUIC_BUG(buffer = framer_.SerializeSubscribe(subscribe), + "Invalid object range"); + EXPECT_EQ(buffer.size(), 0); + continue; + } + buffer = framer_.SerializeSubscribe(subscribe); + // Go to the filter type. + const uint8_t* read = BufferAtOffset(buffer, 12); + EXPECT_EQ(static_cast<MoqtFilterType>(*read), expected_filter_type); + EXPECT_GT(buffer.size(), 0); + if (expected_filter_type == MoqtFilterType::kAbsoluteRange && + end_object.has_value()) { + const uint8_t* object_id = read + 4; + EXPECT_EQ(*object_id, *end_object + 1); + } + } + } + } + } +} + +TEST_F(MoqtFramerSimpleTest, SubscribeEndBeforeStart) { + MoqtSubscribe subscribe = { + /*subscribe_id=*/3, + /*track_alias=*/4, + /*track_namespace=*/"foo", + /*track_name=*/"abcd", + /*start_group=*/std::optional<uint64_t>(4), + /*start_object=*/std::optional<uint64_t>(3), + /*end_group=*/std::optional<uint64_t>(3), + /*end_object=*/std::nullopt, + /*authorization_info=*/"bar", + }; + quiche::QuicheBuffer buffer; + EXPECT_QUIC_BUG(buffer = framer_.SerializeSubscribe(subscribe), + "Invalid object range"); + EXPECT_EQ(buffer.size(), 0); + subscribe.end_group = 4; + subscribe.end_object = 1; + EXPECT_QUIC_BUG(buffer = framer_.SerializeSubscribe(subscribe), + "Invalid object range"); + EXPECT_EQ(buffer.size(), 0); +} + +TEST_F(MoqtFramerSimpleTest, SubscribeLatestGroupNonzeroObject) { + MoqtSubscribe subscribe = { + /*subscribe_id=*/3, + /*track_alias=*/4, + /*track_namespace=*/"foo", + /*track_name=*/"abcd", + /*start_group=*/std::nullopt, + /*start_object=*/std::optional<uint64_t>(3), + /*end_group=*/std::nullopt, + /*end_object=*/std::nullopt, + /*authorization_info=*/"bar", + }; + quiche::QuicheBuffer buffer; + EXPECT_QUIC_BUG(buffer = framer_.SerializeSubscribe(subscribe), + "Invalid object range"); + EXPECT_EQ(buffer.size(), 0); +} + +TEST_F(MoqtFramerSimpleTest, SubscribeUpdateEndGroupOnly) { + MoqtSubscribeUpdate subscribe_update = { + /*subscribe_id=*/3, + /*start_group=*/4, + /*start_object=*/3, + /*end_group=*/4, + /*end_object=*/std::nullopt, + /*authorization_info=*/"bar", + }; + quiche::QuicheBuffer buffer; + buffer = framer_.SerializeSubscribeUpdate(subscribe_update); + EXPECT_GT(buffer.size(), 0); + const uint8_t* end_group = BufferAtOffset(buffer, 4); + EXPECT_EQ(*end_group, 5); + const uint8_t* end_object = end_group + 1; + EXPECT_EQ(*end_object, 0); +} + +TEST_F(MoqtFramerSimpleTest, SubscribeUpdateIncrementsEnd) { + MoqtSubscribeUpdate subscribe_update = { + /*subscribe_id=*/3, + /*start_group=*/4, + /*start_object=*/3, + /*end_group=*/4, + /*end_object=*/6, + /*authorization_info=*/"bar", + }; + quiche::QuicheBuffer buffer; + buffer = framer_.SerializeSubscribeUpdate(subscribe_update); + EXPECT_GT(buffer.size(), 0); + const uint8_t* end_group = BufferAtOffset(buffer, 4); + EXPECT_EQ(*end_group, 5); + const uint8_t* end_object = end_group + 1; + EXPECT_EQ(*end_object, 7); +} + +TEST_F(MoqtFramerSimpleTest, SubscribeUpdateInvalidRange) { + MoqtSubscribeUpdate subscribe_update = { + /*subscribe_id=*/3, + /*start_group=*/4, + /*start_object=*/3, + /*end_group=*/std::nullopt, + /*end_object=*/6, + /*authorization_info=*/"bar", + }; + quiche::QuicheBuffer buffer; + EXPECT_QUIC_BUG(buffer = framer_.SerializeSubscribeUpdate(subscribe_update), + "Invalid object range"); + EXPECT_EQ(buffer.size(), 0); +} + } // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index 56286fc..3f71bb4 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -18,7 +18,6 @@ #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_outgoing_queue.h" #include "quiche/quic/moqt/moqt_session.h" -#include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" #include "quiche/quic/test_tools/crypto_test_utils.h" #include "quiche/quic/test_tools/quic_test_utils.h" @@ -123,9 +122,9 @@ public: void CreateDefaultEndpoints() { client_ = std::make_unique<ClientEndpoint>( - &test_harness_.simulator(), "Client", "Server", MoqtVersion::kDraft03); + &test_harness_.simulator(), "Client", "Server", MoqtVersion::kDraft04); server_ = std::make_unique<ServerEndpoint>( - &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft03); + &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft04); test_harness_.set_client(client_.get()); test_harness_.set_server(server_.get()); } @@ -176,7 +175,7 @@ &test_harness_.simulator(), "Client", "Server", MoqtVersion::kUnrecognizedVersionForTests); server_ = std::make_unique<ServerEndpoint>( - &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft03); + &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft04); test_harness_.set_client(client_.get()); test_harness_.set_server(server_.get()); WireUpEndpoints(); @@ -327,7 +326,7 @@ EXPECT_TRUE(success); } -TEST_F(MoqtIntegrationTest, SubscribeRelativeOk) { +TEST_F(MoqtIntegrationTest, SubscribeCurrentObjectOk) { EstablishSession(); FullTrackName full_track_name("foo", "bar"); MockLocalTrackVisitor server_visitor; @@ -338,9 +337,9 @@ bool received_ok = false; EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason)) .WillOnce([&]() { received_ok = true; }); - client_->session()->SubscribeRelative(full_track_name.track_namespace, - full_track_name.track_name, 10, 10, - &client_visitor); + client_->session()->SubscribeCurrentObject(full_track_name.track_namespace, + full_track_name.track_name, + &client_visitor); bool success = test_harness_.RunUntilWithDefaultTimeout([&]() { return received_ok; }); EXPECT_TRUE(success); @@ -373,9 +372,9 @@ bool received_ok = false; EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason)) .WillOnce([&]() { received_ok = true; }); - client_->session()->SubscribeRelative(full_track_name.track_namespace, - full_track_name.track_name, 10, 10, - &client_visitor); + client_->session()->SubscribeCurrentObject(full_track_name.track_namespace, + full_track_name.track_name, + &client_visitor); bool success = test_harness_.RunUntilWithDefaultTimeout([&]() { return received_ok; }); EXPECT_TRUE(success);
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc index fba312b..18a73ad 100644 --- a/quiche/quic/moqt/moqt_messages.cc +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -10,6 +10,42 @@ namespace moqt { +MoqtFilterType GetFilterType(const MoqtSubscribe& message) { + if (!message.end_group.has_value() && message.end_object.has_value()) { + return MoqtFilterType::kNone; + } + bool has_start = + message.start_group.has_value() && message.start_object.has_value(); + if (message.end_group.has_value()) { + if (has_start) { + if (*message.end_group < *message.start_group) { + return MoqtFilterType::kNone; + } else if (*message.end_group == *message.start_group && + *message.end_object <= *message.start_object) { + if (*message.end_object < *message.start_object) { + return MoqtFilterType::kNone; + } else if (*message.end_object == *message.start_object) { + return MoqtFilterType::kAbsoluteStart; + } + } + return MoqtFilterType::kAbsoluteRange; + } + } else { + if (has_start) { + return MoqtFilterType::kAbsoluteStart; + } else if (!message.start_group.has_value()) { + if (message.start_object.has_value()) { + if (message.start_object.value() == 0) { + return MoqtFilterType::kLatestGroup; + } + } else { + return MoqtFilterType::kLatestObject; + } + } + } + return MoqtFilterType::kNone; +} + std::string MoqtMessageTypeToString(const MoqtMessageType message_type) { switch (message_type) { case MoqtMessageType::kObjectStream: @@ -30,6 +66,14 @@ return "UNSUBSCRIBE"; case MoqtMessageType::kSubscribeDone: return "SUBSCRIBE_DONE"; + case MoqtMessageType::kSubscribeUpdate: + return "SUBSCRIBE_UPDATE"; + case MoqtMessageType::kAnnounceCancel: + return "ANNOUNCE_CANCEL"; + case MoqtMessageType::kTrackStatusRequest: + return "TRACK_STATUS_REQUEST"; + case MoqtMessageType::kTrackStatus: + return "TRACK_STATUS"; case MoqtMessageType::kAnnounce: return "ANNOUNCE"; case MoqtMessageType::kAnnounceOk:
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index eef04cc..da72fea 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -27,7 +27,7 @@ } enum class MoqtVersion : uint64_t { - kDraft03 = 0xff000003, + kDraft04 = 0xff000004, kUnrecognizedVersionForTests = 0xfe0000ff, }; @@ -49,6 +49,7 @@ enum class QUICHE_EXPORT MoqtMessageType : uint64_t { kObjectStream = 0x00, kObjectDatagram = 0x01, + kSubscribeUpdate = 0x02, kSubscribe = 0x03, kSubscribeOk = 0x04, kSubscribeError = 0x05, @@ -58,6 +59,9 @@ kUnannounce = 0x09, kUnsubscribe = 0x0a, kSubscribeDone = 0x0b, + kAnnounceCancel = 0x0c, + kTrackStatusRequest = 0x0d, + kTrackStatus = 0x0e, kGoAway = 0x10, kClientSetup = 0x40, kServerSetup = 0x41, @@ -117,7 +121,7 @@ (track_namespace == other.track_namespace && track_name < other.track_name); } - FullTrackName& operator=(FullTrackName other) { + FullTrackName& operator=(const FullTrackName& other) { track_namespace = other.track_namespace; track_name = other.track_name; return *this; @@ -192,60 +196,39 @@ std::optional<uint64_t> payload_length; }; -enum class QUICHE_EXPORT MoqtSubscribeLocationMode : uint64_t { +enum class QUICHE_EXPORT MoqtFilterType : uint64_t { kNone = 0x0, - kAbsolute = 0x1, - kRelativePrevious = 0x2, - kRelativeNext = 0x3, + kLatestGroup = 0x1, + kLatestObject = 0x2, + kAbsoluteStart = 0x3, + kAbsoluteRange = 0x4, }; -// kNone: std::optional<MoqtSubscribeLocation> is nullopt. -// kAbsolute: absolute = true -// kRelativePrevious: absolute is false; relative_value is negative -// kRelativeNext: absolute is true; relative_value is positive -struct QUICHE_EXPORT MoqtSubscribeLocation { - MoqtSubscribeLocation(bool is_absolute, uint64_t abs) - : absolute(is_absolute), absolute_value(abs) {} - MoqtSubscribeLocation(bool is_absolute, int64_t rel) - : absolute(is_absolute), relative_value(rel) {} - bool absolute; - union { - uint64_t absolute_value; - int64_t relative_value; - }; - bool operator==(const MoqtSubscribeLocation& other) const { - return absolute == other.absolute && - ((absolute && absolute_value == other.absolute_value) || - (!absolute && relative_value == other.relative_value)); - } -}; - -inline MoqtSubscribeLocationMode GetModeForSubscribeLocation( - const std::optional<MoqtSubscribeLocation>& location) { - if (!location.has_value()) { - return MoqtSubscribeLocationMode::kNone; - } - if (location->absolute) { - return MoqtSubscribeLocationMode::kAbsolute; - } - return location->relative_value >= 0 - ? MoqtSubscribeLocationMode::kRelativeNext - : MoqtSubscribeLocationMode::kRelativePrevious; -} - struct QUICHE_EXPORT MoqtSubscribe { uint64_t subscribe_id; uint64_t track_alias; std::string track_namespace; std::string track_name; + // The combinations of these that have values indicate the filter type. + // SG: Start Group; SO: Start Object; EG: End Group; EO: End Object; + // (none): KLatestObject + // SO: kLatestGroup (must be zero) + // SG, SO: kAbsoluteStart + // SG, SO, EG, EO: kAbsoluteRange + // SG, SO, EG: kAbsoluteRange (request whole last group) + // All other combinations are invalid. + std::optional<uint64_t> start_group; + std::optional<uint64_t> start_object; + std::optional<uint64_t> end_group; + std::optional<uint64_t> end_object; // If the mode is kNone, the these are std::nullopt. - std::optional<MoqtSubscribeLocation> start_group; - std::optional<MoqtSubscribeLocation> start_object; - std::optional<MoqtSubscribeLocation> end_group; - std::optional<MoqtSubscribeLocation> end_object; std::optional<std::string> authorization_info; }; +// Deduce the filter type from the combination of group and object IDs. Returns +// kNone if the state of the subscribe is invalid. +MoqtFilterType GetFilterType(const MoqtSubscribe& message); + struct QUICHE_EXPORT MoqtSubscribeOk { uint64_t subscribe_id; // The message uses ms, but expires is in us. @@ -283,11 +266,20 @@ struct QUICHE_EXPORT MoqtSubscribeDone { uint64_t subscribe_id; - uint64_t status_code; + SubscribeDoneCode status_code; std::string reason_phrase; std::optional<FullSequence> final_id; }; +struct QUICHE_EXPORT MoqtSubscribeUpdate { + uint64_t subscribe_id; + uint64_t start_group; + uint64_t start_object; + std::optional<uint64_t> end_group; + std::optional<uint64_t> end_object; + std::optional<std::string> authorization_info; +}; + struct QUICHE_EXPORT MoqtAnnounce { std::string track_namespace; std::optional<std::string> authorization_info; @@ -307,6 +299,31 @@ std::string track_namespace; }; +enum class QUICHE_EXPORT MoqtTrackStatusCode : uint64_t { + kInProgress = 0x0, + kDoesNotExist = 0x1, + kNotYetBegun = 0x2, + kFinished = 0x3, + kStatusNotAvailable = 0x4, +}; + +struct QUICHE_EXPORT MoqtTrackStatus { + std::string track_namespace; + std::string track_name; + MoqtTrackStatusCode status_code; + uint64_t last_group; + uint64_t last_object; +}; + +struct QUICHE_EXPORT MoqtAnnounceCancel { + std::string track_namespace; +}; + +struct QUICHE_EXPORT MoqtTrackStatusRequest { + std::string track_namespace; + std::string track_name; +}; + struct QUICHE_EXPORT MoqtGoAway { std::string new_session_uri; };
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 1e8cf8a..e1eba0b 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -168,14 +168,22 @@ return ProcessUnsubscribe(reader); case MoqtMessageType::kSubscribeDone: return ProcessSubscribeDone(reader); + case MoqtMessageType::kSubscribeUpdate: + return ProcessSubscribeUpdate(reader); case MoqtMessageType::kAnnounce: return ProcessAnnounce(reader); case MoqtMessageType::kAnnounceOk: return ProcessAnnounceOk(reader); case MoqtMessageType::kAnnounceError: return ProcessAnnounceError(reader); + case MoqtMessageType::kAnnounceCancel: + return ProcessAnnounceCancel(reader); + case MoqtMessageType::kTrackStatusRequest: + return ProcessTrackStatusRequest(reader); case MoqtMessageType::kUnannounce: return ProcessUnannounce(reader); + case MoqtMessageType::kTrackStatus: + return ProcessTrackStatus(reader); case MoqtMessageType::kGoAway: return ProcessGoAway(reader); default: @@ -370,34 +378,53 @@ size_t MoqtParser::ProcessSubscribe(quic::QuicDataReader& reader) { MoqtSubscribe subscribe_request; + uint64_t filter, group, object; if (!reader.ReadVarInt62(&subscribe_request.subscribe_id) || !reader.ReadVarInt62(&subscribe_request.track_alias) || !reader.ReadStringVarInt62(subscribe_request.track_namespace) || !reader.ReadStringVarInt62(subscribe_request.track_name) || - !ReadLocation(reader, subscribe_request.start_group)) { + !reader.ReadVarInt62(&filter)) { return 0; } - if (!subscribe_request.start_group.has_value()) { - ParseError("START_GROUP must not be None in SUBSCRIBE"); - return 0; - } - if (!ReadLocation(reader, subscribe_request.start_object)) { - return 0; - } - if (!subscribe_request.start_object.has_value()) { - ParseError("START_OBJECT must not be None in SUBSCRIBE"); - return 0; - } - if (!ReadLocation(reader, subscribe_request.end_group) || - !ReadLocation(reader, subscribe_request.end_object)) { - return 0; - } - if (subscribe_request.end_group.has_value() != - subscribe_request.end_object.has_value()) { - ParseError( - "SUBSCRIBE end_group and end_object must be both None " - "or both non_None"); - return 0; + MoqtFilterType filter_type = static_cast<MoqtFilterType>(filter); + switch (filter_type) { + case MoqtFilterType::kLatestGroup: + subscribe_request.start_object = 0; + break; + case MoqtFilterType::kLatestObject: + break; + case MoqtFilterType::kAbsoluteStart: + case MoqtFilterType::kAbsoluteRange: + if (!reader.ReadVarInt62(&group) || !reader.ReadVarInt62(&object)) { + return 0; + } + subscribe_request.start_group = group; + subscribe_request.start_object = object; + if (filter_type == MoqtFilterType::kAbsoluteStart) { + break; + } + if (!reader.ReadVarInt62(&group) || !reader.ReadVarInt62(&object)) { + return 0; + } + subscribe_request.end_group = group; + if (subscribe_request.end_group < subscribe_request.start_group) { + ParseError("End group is less than start group"); + return 0; + } + if (object == 0) { + subscribe_request.end_object = std::nullopt; + } else { + subscribe_request.end_object = object - 1; + if (subscribe_request.start_group == subscribe_request.end_group && + subscribe_request.end_object < subscribe_request.start_object) { + ParseError("End object comes before start object"); + return 0; + } + } + break; + default: + ParseError("Invalid filter type"); + return 0; } uint64_t num_params; if (!reader.ReadVarInt62(&num_params)) { @@ -415,7 +442,7 @@ if (subscribe_request.authorization_info.has_value()) { ParseError( "AUTHORIZATION_INFO parameter appears twice in " - "SUBSCRIBE_REQUEST"); + "SUBSCRIBE"); return 0; } subscribe_request.authorization_info = value; @@ -480,12 +507,14 @@ size_t MoqtParser::ProcessSubscribeDone(quic::QuicDataReader& reader) { MoqtSubscribeDone subscribe_done; uint8_t content_exists; + uint64_t value; if (!reader.ReadVarInt62(&subscribe_done.subscribe_id) || - !reader.ReadVarInt62(&subscribe_done.status_code) || + !reader.ReadVarInt62(&value) || !reader.ReadStringVarInt62(subscribe_done.reason_phrase) || !reader.ReadUInt8(&content_exists)) { return 0; } + subscribe_done.status_code = static_cast<SubscribeDoneCode>(value); if (content_exists > 1) { ParseError("SUBSCRIBE_DONE ContentExists has invalid value"); return 0; @@ -501,6 +530,66 @@ return reader.PreviouslyReadPayload().length(); } +size_t MoqtParser::ProcessSubscribeUpdate(quic::QuicDataReader& reader) { + MoqtSubscribeUpdate subscribe_update; + uint64_t end_group, end_object, num_params; + if (!reader.ReadVarInt62(&subscribe_update.subscribe_id) || + !reader.ReadVarInt62(&subscribe_update.start_group) || + !reader.ReadVarInt62(&subscribe_update.start_object) || + !reader.ReadVarInt62(&end_group) || !reader.ReadVarInt62(&end_object) || + !reader.ReadVarInt62(&num_params)) { + return 0; + } + if (end_group == 0) { + // end_group remains nullopt. + if (end_object > 0) { + ParseError("SUBSCRIBE_UPDATE has end_object but no end_group"); + return 0; + } + } else { + subscribe_update.end_group = end_group - 1; + if (subscribe_update.end_group < subscribe_update.start_group) { + ParseError("End group is less than start group"); + return 0; + } + } + if (end_object > 0) { + subscribe_update.end_object = end_object - 1; + if (subscribe_update.end_object.has_value() && + subscribe_update.start_group == *subscribe_update.end_group && + *subscribe_update.end_object < subscribe_update.start_object) { + ParseError("End object comes before start object"); + return 0; + } + } else { + subscribe_update.end_object = std::nullopt; + } + for (uint64_t i = 0; i < num_params; ++i) { + uint64_t type; + absl::string_view value; + if (!ReadParameter(reader, type, value)) { + return 0; + } + auto key = static_cast<MoqtTrackRequestParameter>(type); + switch (key) { + case MoqtTrackRequestParameter::kAuthorizationInfo: + if (subscribe_update.authorization_info.has_value()) { + ParseError( + "AUTHORIZATION_INFO parameter appears twice in " + "SUBSCRIBE_UPDATE"); + return 0; + } + subscribe_update.authorization_info = value; + break; + default: + // Skip over the parameter. + break; + } + } + visitor_.OnSubscribeUpdateMessage(subscribe_update); + return reader.PreviouslyReadPayload().length(); +} + size_t MoqtParser::ProcessAnnounce(quic::QuicDataReader& reader) { MoqtAnnounce announce; if (!reader.ReadStringVarInt62(announce.track_namespace)) { @@ -560,6 +649,27 @@ return reader.PreviouslyReadPayload().length(); } +size_t MoqtParser::ProcessAnnounceCancel(quic::QuicDataReader& reader) { + MoqtAnnounceCancel announce_cancel; + if (!reader.ReadStringVarInt62(announce_cancel.track_namespace)) { + return 0; + } + visitor_.OnAnnounceCancelMessage(announce_cancel); + return reader.PreviouslyReadPayload().length(); +} + +size_t MoqtParser::ProcessTrackStatusRequest(quic::QuicDataReader& reader) { + MoqtTrackStatusRequest track_status_request; + if (!reader.ReadStringVarInt62(track_status_request.track_namespace)) { + return 0; + } + if (!reader.ReadStringVarInt62(track_status_request.track_name)) { + return 0; + } + visitor_.OnTrackStatusRequestMessage(track_status_request); + return reader.PreviouslyReadPayload().length(); +} + size_t MoqtParser::ProcessUnannounce(quic::QuicDataReader& reader) { MoqtUnannounce unannounce; if (!reader.ReadStringVarInt62(unannounce.track_namespace)) { @@ -569,6 +679,21 @@ return reader.PreviouslyReadPayload().length(); } +size_t MoqtParser::ProcessTrackStatus(quic::QuicDataReader& reader) { + MoqtTrackStatus track_status; + uint64_t value; + if (!reader.ReadStringVarInt62(track_status.track_namespace) || + !reader.ReadStringVarInt62(track_status.track_name) || + !reader.ReadVarInt62(&value) || + !reader.ReadVarInt62(&track_status.last_group) || + !reader.ReadVarInt62(&track_status.last_object)) { + return 0; + } + track_status.status_code = static_cast<MoqtTrackStatusCode>(value); + visitor_.OnTrackStatusMessage(track_status); + return reader.PreviouslyReadPayload().length(); +} + size_t MoqtParser::ProcessGoAway(quic::QuicDataReader& reader) { MoqtGoAway goaway; if (!reader.ReadStringVarInt62(goaway.new_session_uri)) { @@ -631,37 +756,6 @@ return true; } -bool MoqtParser::ReadLocation(quic::QuicDataReader& reader, - std::optional<MoqtSubscribeLocation>& loc) { - uint64_t ui64; - if (!reader.ReadVarInt62(&ui64)) { - return false; - } - auto mode = static_cast<MoqtSubscribeLocationMode>(ui64); - if (mode == MoqtSubscribeLocationMode::kNone) { - loc = std::nullopt; - return true; - } - if (!reader.ReadVarInt62(&ui64)) { - return false; - } - switch (mode) { - case MoqtSubscribeLocationMode::kAbsolute: - loc = MoqtSubscribeLocation(true, ui64); - break; - case MoqtSubscribeLocationMode::kRelativePrevious: - loc = MoqtSubscribeLocation(false, -1 * static_cast<int64_t>(ui64)); - break; - case MoqtSubscribeLocationMode::kRelativeNext: - loc = MoqtSubscribeLocation(false, static_cast<int64_t>(ui64) + 1); - break; - default: - ParseError("Unknown location mode"); - return false; - } - return true; -} - bool MoqtParser::ReadParameter(quic::QuicDataReader& reader, uint64_t& type, absl::string_view& value) { if (!reader.ReadVarInt62(&type)) {
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index 084e6fe..04a6ef7 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -39,10 +39,15 @@ virtual void OnSubscribeErrorMessage(const MoqtSubscribeError& message) = 0; virtual void OnUnsubscribeMessage(const MoqtUnsubscribe& message) = 0; virtual void OnSubscribeDoneMessage(const MoqtSubscribeDone& message) = 0; + virtual void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) = 0; virtual void OnAnnounceMessage(const MoqtAnnounce& message) = 0; virtual void OnAnnounceOkMessage(const MoqtAnnounceOk& message) = 0; virtual void OnAnnounceErrorMessage(const MoqtAnnounceError& message) = 0; + virtual void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) = 0; + virtual void OnTrackStatusRequestMessage( + const MoqtTrackStatusRequest& message) = 0; virtual void OnUnannounceMessage(const MoqtUnannounce& message) = 0; + virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) = 0; virtual void OnGoAwayMessage(const MoqtGoAway& message) = 0; virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; @@ -90,10 +95,14 @@ size_t ProcessSubscribeError(quic::QuicDataReader& reader); size_t ProcessUnsubscribe(quic::QuicDataReader& reader); size_t ProcessSubscribeDone(quic::QuicDataReader& reader); + size_t ProcessSubscribeUpdate(quic::QuicDataReader& reader); size_t ProcessAnnounce(quic::QuicDataReader& reader); size_t ProcessAnnounceOk(quic::QuicDataReader& reader); size_t ProcessAnnounceError(quic::QuicDataReader& reader); + size_t ProcessAnnounceCancel(quic::QuicDataReader& reader); + size_t ProcessTrackStatusRequest(quic::QuicDataReader& reader); size_t ProcessUnannounce(quic::QuicDataReader& reader); + size_t ProcessTrackStatus(quic::QuicDataReader& reader); size_t ProcessGoAway(quic::QuicDataReader& reader); static size_t ParseObjectHeader(quic::QuicDataReader& reader, @@ -106,9 +115,6 @@ // Reads an integer whose length is specified by a preceding VarInt62 and // returns it in |result|. Returns false if parsing fails. bool ReadVarIntPieceVarInt62(quic::QuicDataReader& reader, uint64_t& result); - // Read a Location field from SUBSCRIBE REQUEST - bool ReadLocation(quic::QuicDataReader& reader, - std::optional<MoqtSubscribeLocation>& loc); // Read a parameter and return the value as a string_view. Returns false if // |reader| does not have enough data. bool ReadParameter(quic::QuicDataReader& reader, uint64_t& type,
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 9d0f992..4794f0e 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -41,8 +41,12 @@ MoqtMessageType::kSubscribe, MoqtMessageType::kSubscribeOk, MoqtMessageType::kSubscribeError, + MoqtMessageType::kSubscribeUpdate, MoqtMessageType::kUnsubscribe, MoqtMessageType::kSubscribeDone, + MoqtMessageType::kAnnounceCancel, + MoqtMessageType::kTrackStatusRequest, + MoqtMessageType::kTrackStatus, MoqtMessageType::kAnnounce, MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, @@ -125,6 +129,9 @@ void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override { OnControlMessage(message); } + void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) override { + OnControlMessage(message); + } void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override { OnControlMessage(message); } @@ -140,9 +147,19 @@ void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override { OnControlMessage(message); } + void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override { + OnControlMessage(message); + } + void OnTrackStatusRequestMessage( + const MoqtTrackStatusRequest& message) override { + OnControlMessage(message); + } void OnUnannounceMessage(const MoqtUnannounce& message) override { OnControlMessage(message); } + void OnTrackStatusMessage(const MoqtTrackStatus& message) override { + OnControlMessage(message); + } void OnGoAwayMessage(const MoqtGoAway& message) override { OnControlMessage(message); } @@ -606,19 +623,33 @@ char subscribe[] = { 0x03, 0x01, 0x02, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" - 0x02, 0x04, // start_group = 4 (relative previous) - 0x01, 0x01, // start_object = 1 (absolute) - 0x00, // end_group = none - 0x00, // end_object = none - 0x02, // two params - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, // filter_type = kLatestObject + 0x02, // two params + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); EXPECT_EQ(*visitor_.parsing_error_, - "AUTHORIZATION_INFO parameter appears twice in SUBSCRIBE_REQUEST"); + "AUTHORIZATION_INFO parameter appears twice in SUBSCRIBE"); + EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); +} + +TEST_F(MoqtMessageSpecificTest, SubscribeUpdateAuthorizationInfoTwice) { + MoqtParser parser(kWebTrans, visitor_); + char subscribe_update[] = { + 0x02, 0x02, 0x03, 0x01, 0x05, 0x06, // start and end sequences + 0x02, // 2 parameters + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_update, sizeof(subscribe_update)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "AUTHORIZATION_INFO parameter appears twice in SUBSCRIBE_UPDATE"); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } @@ -703,65 +734,230 @@ EXPECT_EQ(*visitor_.parsing_error_, "Unknown message type"); } -TEST_F(MoqtMessageSpecificTest, StartGroupIsNone) { +TEST_F(MoqtMessageSpecificTest, LatestGroup) { MoqtParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" - 0x00, // start_group = none - 0x01, 0x01, // start_object = 1 (absolute) - 0x00, // end_group = none - 0x00, // end_object = none + 0x01, // filter_type = kLatestGroup 0x01, // 1 parameter 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "START_GROUP must not be None in SUBSCRIBE"); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(visitor_.last_message_.has_value()); + MoqtSubscribe message = + std::get<MoqtSubscribe>(visitor_.last_message_.value()); + EXPECT_FALSE(message.start_group.has_value()); + EXPECT_EQ(message.start_object, 0); + EXPECT_FALSE(message.end_group.has_value()); + EXPECT_FALSE(message.end_object.has_value()); } -TEST_F(MoqtMessageSpecificTest, StartObjectIsNone) { +TEST_F(MoqtMessageSpecificTest, LatestObject) { MoqtParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" - 0x02, 0x04, // start_group = 4 (relative previous) - 0x00, // start_object = none - 0x00, // end_group = none - 0x00, // end_object = none + 0x02, // filter_type = kLatestObject 0x01, // 1 parameter 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "START_OBJECT must not be None in SUBSCRIBE"); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); + MoqtSubscribe message = + std::get<MoqtSubscribe>(visitor_.last_message_.value()); + EXPECT_FALSE(message.start_group.has_value()); + EXPECT_FALSE(message.start_object.has_value()); + EXPECT_FALSE(message.end_group.has_value()); + EXPECT_FALSE(message.end_object.has_value()); } -TEST_F(MoqtMessageSpecificTest, EndGroupIsNoneEndObjectIsNoNone) { +TEST_F(MoqtMessageSpecificTest, AbsoluteStart) { MoqtParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" - 0x02, 0x04, // start_group = 4 (relative previous) - 0x01, 0x01, // start_object = 1 (absolute) - 0x00, // end_group = none - 0x01, 0x01, // end_object = 1 (absolute) + 0x03, // filter_type = kAbsoluteStart + 0x04, // start_group = 4 + 0x01, // start_object = 1 + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); + MoqtSubscribe message = + std::get<MoqtSubscribe>(visitor_.last_message_.value()); + EXPECT_EQ(message.start_group.value(), 4); + EXPECT_EQ(message.start_object.value(), 1); + EXPECT_FALSE(message.end_group.has_value()); + EXPECT_FALSE(message.end_object.has_value()); +} + +TEST_F(MoqtMessageSpecificTest, AbsoluteRangeExplicitEndObject) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe[] = { + 0x03, 0x01, 0x02, // id and alias + 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + 0x04, // filter_type = kAbsoluteStart + 0x04, // start_group = 4 + 0x01, // start_object = 1 + 0x07, // end_group = 7 + 0x03, // end_object = 2 + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); + MoqtSubscribe message = + std::get<MoqtSubscribe>(visitor_.last_message_.value()); + EXPECT_EQ(message.start_group.value(), 4); + EXPECT_EQ(message.start_object.value(), 1); + EXPECT_EQ(message.end_group.value(), 7); + EXPECT_EQ(message.end_object.value(), 2); +} + +TEST_F(MoqtMessageSpecificTest, AbsoluteRangeWholeEndGroup) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe[] = { + 0x03, 0x01, 0x02, // id and alias + 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + 0x04, // filter_type = kAbsoluteRange + 0x04, // start_group = 4 + 0x01, // start_object = 1 + 0x07, // end_group = 7 + 0x00, // end whole group + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); + MoqtSubscribe message = + std::get<MoqtSubscribe>(visitor_.last_message_.value()); + EXPECT_EQ(message.start_group.value(), 4); + EXPECT_EQ(message.start_object.value(), 1); + EXPECT_EQ(message.end_group.value(), 7); + EXPECT_FALSE(message.end_object.has_value()); +} + +TEST_F(MoqtMessageSpecificTest, AbsoluteRangeEndGroupTooLow) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe[] = { + 0x03, 0x01, 0x02, // id and alias + 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + 0x04, // filter_type = kAbsoluteRange + 0x04, // start_group = 4 + 0x01, // start_object = 1 + 0x03, // end_group = 3 + 0x00, // end whole group 0x01, // 1 parameter 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "End group is less than start group"); +} + +TEST_F(MoqtMessageSpecificTest, AbsoluteRangeExactlyOneObject) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe[] = { + 0x03, 0x01, 0x02, // id and alias + 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + 0x04, // filter_type = kAbsoluteRange + 0x04, // start_group = 4 + 0x01, // start_object = 1 + 0x04, // end_group = 4 + 0x02, // end object = 1 + 0x00, // no parameters + }; + parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); + EXPECT_EQ(visitor_.messages_received_, 1); +} + +TEST_F(MoqtMessageSpecificTest, SubscribeUpdateExactlyOneObject) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_update[] = { + 0x02, 0x02, 0x03, 0x01, 0x04, 0x07, // start and end sequences + 0x00, // No parameters + }; + parser.ProcessData( + absl::string_view(subscribe_update, sizeof(subscribe_update)), false); + EXPECT_EQ(visitor_.messages_received_, 1); +} + +TEST_F(MoqtMessageSpecificTest, SubscribeUpdateEndGroupTooLow) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_update[] = { + 0x02, 0x02, 0x03, 0x01, 0x03, 0x06, // start and end sequences + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_update, sizeof(subscribe_update)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "End group is less than start group"); +} + +TEST_F(MoqtMessageSpecificTest, AbsoluteRangeEndObjectTooLow) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe[] = { + 0x03, 0x01, 0x02, // id and alias + 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + 0x04, // filter_type = kAbsoluteRange + 0x04, // start_group = 4 + 0x01, // start_object = 1 + 0x04, // end_group = 4 + 0x01, // end_object = 0 + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "End object comes before start object"); +} + +TEST_F(MoqtMessageSpecificTest, SubscribeUpdateEndObjectTooLow) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_update[] = { + 0x02, 0x02, 0x03, 0x02, 0x04, 0x01, // start and end sequences + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_update, sizeof(subscribe_update)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "End object comes before start object"); +} + +TEST_F(MoqtMessageSpecificTest, SubscribeUpdateNoEndGroup) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_update[] = { + 0x02, 0x02, 0x03, 0x02, 0x00, 0x01, // start and end sequences + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_update, sizeof(subscribe_update)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); EXPECT_EQ(*visitor_.parsing_error_, - "SUBSCRIBE end_group and end_object must be both None " - "or both non_None"); + "SUBSCRIBE_UPDATE has end_object but no end_group"); } TEST_F(MoqtMessageSpecificTest, AllMessagesTogether) { @@ -775,7 +971,7 @@ // Each iteration, process from the halfway point of one message to the // halfway point of the next. if (IsObjectMessage(type)) { - continue; // Objects cannot share a stream with other meessages. + continue; // Objects cannot share a stream with other messages. } std::unique_ptr<TestMessageBase> message = CreateTestMessage(type, kRawQuic); @@ -800,31 +996,6 @@ EXPECT_FALSE(visitor_.parsing_error_.has_value()); } -TEST_F(MoqtMessageSpecificTest, RelativeLocation) { - MoqtParser parser(kRawQuic, visitor_); - char subscribe[] = { - 0x03, 0x01, 0x02, // id and alias - 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" - 0x02, 0x00, // start_group = 0 (relative previous) - 0x03, 0x00, // start_object = 1 (relative next) - 0x00, // end_group = none - 0x00, // end_object = none - 0x01, // 1 parameter - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - }; - parser.ProcessData(absl::string_view(subscribe, sizeof(subscribe)), false); - EXPECT_EQ(visitor_.messages_received_, 1); - MoqtSubscribe message = std::get<MoqtSubscribe>(*visitor_.last_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); - ASSERT_TRUE(message.start_group.has_value()); - ASSERT_FALSE(message.start_group->absolute); - EXPECT_EQ(message.start_group->relative_value, 0); - ASSERT_TRUE(message.start_object.has_value()); - ASSERT_FALSE(message.start_object->absolute); - EXPECT_EQ(message.start_object->relative_value, 1); -} - TEST_F(MoqtMessageSpecificTest, DatagramSuccessful) { ObjectDatagramMessage message; MoqtObject object;
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 8781b26..06d2808 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -197,8 +197,8 @@ MoqtSubscribe message; message.track_namespace = track_namespace; message.track_name = name; - message.start_group = MoqtSubscribeLocation(true, start_group); - message.start_object = MoqtSubscribeLocation(true, start_object); + message.start_group = start_group; + message.start_object = start_object; message.end_group = std::nullopt; message.end_object = std::nullopt; if (!auth_info.empty()) { @@ -210,6 +210,29 @@ bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace, absl::string_view name, uint64_t start_group, uint64_t start_object, + uint64_t end_group, + RemoteTrack::Visitor* visitor, + absl::string_view auth_info) { + if (end_group < start_group) { + QUIC_DLOG(ERROR) << "Subscription end is before beginning"; + return false; + } + MoqtSubscribe message; + message.track_namespace = track_namespace; + message.track_name = name; + message.start_group = start_group; + message.start_object = start_object; + message.end_group = end_group; + message.end_object = std::nullopt; + if (!auth_info.empty()) { + message.authorization_info = std::move(auth_info); + } + return Subscribe(message, visitor); +} + +bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace, + absl::string_view name, + uint64_t start_group, uint64_t start_object, uint64_t end_group, uint64_t end_object, RemoteTrack::Visitor* visitor, absl::string_view auth_info) { @@ -224,26 +247,25 @@ MoqtSubscribe message; message.track_namespace = track_namespace; message.track_name = name; - message.start_group = MoqtSubscribeLocation(true, start_group); - message.start_object = MoqtSubscribeLocation(true, start_object); - message.end_group = MoqtSubscribeLocation(true, end_group); - message.end_object = MoqtSubscribeLocation(true, end_object); + message.start_group = start_group; + message.start_object = start_object; + message.end_group = end_group; + message.end_object = end_object; if (!auth_info.empty()) { message.authorization_info = std::move(auth_info); } return Subscribe(message, visitor); } -bool MoqtSession::SubscribeRelative(absl::string_view track_namespace, - absl::string_view name, int64_t start_group, - int64_t start_object, - RemoteTrack::Visitor* visitor, - absl::string_view auth_info) { +bool MoqtSession::SubscribeCurrentObject(absl::string_view track_namespace, + absl::string_view name, + RemoteTrack::Visitor* visitor, + absl::string_view auth_info) { MoqtSubscribe message; message.track_namespace = track_namespace; message.track_name = name; - message.start_group = MoqtSubscribeLocation(false, start_group); - message.start_object = MoqtSubscribeLocation(false, start_object); + message.start_group = std::nullopt; + message.start_object = std::nullopt; message.end_group = std::nullopt; message.end_object = std::nullopt; if (!auth_info.empty()) { @@ -260,8 +282,8 @@ message.track_namespace = track_namespace; message.track_name = name; // First object of current group. - message.start_group = MoqtSubscribeLocation(false, (uint64_t)0); - message.start_object = MoqtSubscribeLocation(true, (int64_t)0); + message.start_group = std::nullopt; + message.start_object = 0; message.end_group = std::nullopt; message.end_object = std::nullopt; if (!auth_info.empty()) { @@ -270,6 +292,36 @@ return Subscribe(message, visitor); } +bool MoqtSession::SubscribeIsDone(uint64_t subscribe_id, SubscribeDoneCode code, + absl::string_view reason_phrase) { + // Search all the tracks to find the subscribe ID. + auto name_it = local_track_by_subscribe_id_.find(subscribe_id); + if (name_it == local_track_by_subscribe_id_.end()) { + return false; + } + auto track_it = local_tracks_.find(name_it->second); + if (track_it == local_tracks_.end()) { + return false; + } + LocalTrack& track = track_it->second; + MoqtSubscribeDone subscribe_done; + subscribe_done.subscribe_id = subscribe_id; + subscribe_done.status_code = code; + subscribe_done.reason_phrase = reason_phrase; + SubscribeWindow* window = track.GetWindow(subscribe_id); + if (window == nullptr) { + return false; + } + subscribe_done.final_id = window->largest_delivered(); + SendControlMessage(framer_.SerializeSubscribeDone(subscribe_done)); + QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_DONE message for " + << subscribe_id; + // Clean up the subscription + track.DeleteWindow(subscribe_id); + local_track_by_subscribe_id_.erase(name_it); + return true; +} + bool MoqtSession::Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor) { if (peer_role_ == MoqtRole::kSubscriber) { @@ -388,6 +440,7 @@ quiche::StreamWriteOptions write_options; write_options.set_send_fin(end_of_stream); for (auto subscription : subscriptions) { + subscription->OnObjectDelivered(FullSequence(group_id, object_id)); if (forwarding_preference == MoqtForwardingPreference::kDatagram) { object.subscribe_id = subscription->subscribe_id(); quiche::QuicheBuffer datagram = @@ -662,19 +715,37 @@ } session_->used_track_aliases_.insert(message.track_alias); } - std::optional<FullSequence> start = session_->LocationToAbsoluteNumber( - track, message.start_group, message.start_object); - QUICHE_DCHECK(start.has_value()); // Parser enforces this. - std::optional<FullSequence> end = session_->LocationToAbsoluteNumber( - track, message.end_group, message.end_object); + FullSequence start; + std::optional<FullSequence> end; + if (message.start_group.has_value()) { + // The filter is AbsoluteStart or AbsoluteRange. + QUIC_BUG_IF(quic_bug_invalid_subscribe, !message.start_object.has_value()) + << "Start group without start object"; + start = FullSequence(*message.start_group, *message.start_object); + } else { + // The filter is LatestObject or LatestGroup. + start = track.next_sequence(); + if (message.start_object.has_value()) { + // The filter is LatestGroup. + QUIC_BUG_IF(quic_bug_invalid_subscribe, *message.start_object != 0) + << "LatestGroup does not start with zero"; + start.object = 0; + } else { + --start.object; + } + } + if (message.end_group.has_value()) { + end = FullSequence(*message.end_group, message.end_object.has_value() + ? *message.end_object + : UINT64_MAX); + } LocalTrack::Visitor::PublishPastObjectsCallback publish_past_objects; SubscribeWindow window = end.has_value() ? SubscribeWindow(message.subscribe_id, track.forwarding_preference(), - start->group, start->object, end->group, - end->object) + start.group, start.object, end->group, end->object) : SubscribeWindow(message.subscribe_id, track.forwarding_preference(), - start->group, start->object); + start.group, start.object); if (start < track.next_sequence() && track.visitor() != nullptr) { absl::StatusOr<LocalTrack::Visitor::PublishPastObjectsCallback> past_objects_available = track.visitor()->OnSubscribeForPast(window); @@ -692,11 +763,13 @@ QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for " << message.track_namespace << ":" << message.track_name; if (end.has_value()) { - track.AddWindow(message.subscribe_id, start->group, start->object, - end->group, end->object); + track.AddWindow(message.subscribe_id, start.group, start.object, end->group, + end->object); } else { - track.AddWindow(message.subscribe_id, start->group, start->object); + track.AddWindow(message.subscribe_id, start.group, start.object); } + session_->local_track_by_subscribe_id_.emplace(message.subscribe_id, + track.full_track_name()); if (publish_past_objects) { std::move(publish_past_objects)(); } @@ -772,11 +845,43 @@ } void MoqtSession::Stream::OnUnsubscribeMessage(const MoqtUnsubscribe& message) { + session_->SubscribeIsDone(message.subscribe_id, + SubscribeDoneCode::kUnsubscribed, ""); +} + +void MoqtSession::Stream::OnSubscribeUpdateMessage( + const MoqtSubscribeUpdate& message) { // Search all the tracks to find the subscribe ID. - for (auto& [name, track] : session_->local_tracks_) { - track.DeleteWindow(message.subscribe_id); + auto name_it = + session_->local_track_by_subscribe_id_.find(message.subscribe_id); + if (name_it == session_->local_track_by_subscribe_id_.end()) { + return; } - // TODO(martinduke): Send SUBSCRIBE_DONE in response. + auto track_it = session_->local_tracks_.find(name_it->second); + if (track_it == session_->local_tracks_.end()) { + return; + } + LocalTrack& track = track_it->second; + SubscribeWindow* window = track.GetWindow(message.subscribe_id); + if (window == nullptr) { + return; + } + FullSequence start(message.start_group, message.start_object); + std::optional<FullSequence> end; + if (message.end_group.has_value()) { + end = FullSequence(*message.end_group, message.end_object.has_value() + ? *message.end_object + : UINT64_MAX); + } + // TODO(martinduke): Handle the case where the update range is invalid. + if (window->UpdateStartEnd(start, end)) { + std::optional<FullSequence> largest_delivered = window->largest_delivered(); + if (largest_delivered.has_value() && end <= *largest_delivered) { + session_->SubscribeIsDone(message.subscribe_id, + SubscribeDoneCode::kSubscriptionEnded, + "SUBSCRIBE_UPDATE moved subscription end"); + } + } } void MoqtSession::Stream::OnAnnounceMessage(const MoqtAnnounce& message) { @@ -855,28 +960,6 @@ return true; } -std::optional<FullSequence> MoqtSession::LocationToAbsoluteNumber( - const LocalTrack& track, const std::optional<MoqtSubscribeLocation>& group, - const std::optional<MoqtSubscribeLocation>& object) { - FullSequence sequence; - if (!group.has_value() || !object.has_value()) { - return std::nullopt; - } - if (group->absolute) { - sequence.group = group->absolute_value; - } else { - sequence.group = track.next_sequence().group + group->relative_value; - } - if (object->absolute) { - sequence.object = object->absolute_value; - } else { - // Subtract 1 because the relative value is computed from the largest sent - // sequence number, not the next one. - sequence.object = track.next_sequence().object + object->relative_value - 1; - } - return sequence; -} - void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message, bool fin) { quiche::StreamWriteOptions options;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index fab05df..7e5fb7f 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -102,23 +102,34 @@ // Returns true if SUBSCRIBE was sent. If there is already a subscription to // the track, the message will still be sent. However, the visitor will be // ignored. + // Subscribe from (start_group, start_object) to the end of the track. bool SubscribeAbsolute(absl::string_view track_namespace, absl::string_view name, uint64_t start_group, uint64_t start_object, RemoteTrack::Visitor* visitor, absl::string_view auth_info = ""); + // Subscribe from (start_group, start_object) to the end of end_group. + bool SubscribeAbsolute(absl::string_view track_namespace, + absl::string_view name, uint64_t start_group, + uint64_t start_object, uint64_t end_group, + RemoteTrack::Visitor* visitor, + absl::string_view auth_info = ""); + // Subscribe from (start_group, start_object) to (end_group, end_object). bool SubscribeAbsolute(absl::string_view track_namespace, absl::string_view name, uint64_t start_group, uint64_t start_object, uint64_t end_group, uint64_t end_object, RemoteTrack::Visitor* visitor, absl::string_view auth_info = ""); - bool SubscribeRelative(absl::string_view track_namespace, - absl::string_view name, int64_t start_group, - int64_t start_object, RemoteTrack::Visitor* visitor, - absl::string_view auth_info = ""); + bool SubscribeCurrentObject(absl::string_view track_namespace, + absl::string_view name, + RemoteTrack::Visitor* visitor, + absl::string_view auth_info = ""); bool SubscribeCurrentGroup(absl::string_view track_namespace, absl::string_view name, RemoteTrack::Visitor* visitor, absl::string_view auth_info = ""); + // Returns true if SUBSCRIBE_DONE was sent. + bool SubscribeIsDone(uint64_t subscribe_id, SubscribeDoneCode code, + absl::string_view reason_phrase); // Returns false if it could not open a stream when necessary, or if the // track does not exist (there was no call to AddLocalTrack). Will still @@ -171,10 +182,15 @@ void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override; void OnSubscribeDoneMessage(const MoqtSubscribeDone& /*message*/) override { } + void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) override; void OnAnnounceMessage(const MoqtAnnounce& message) override; void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override; void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override; + void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override {}; + void OnTrackStatusRequestMessage( + const MoqtTrackStatusRequest& message) override {}; void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override {} + void OnTrackStatusMessage(const MoqtTrackStatus& message) override {} void OnGoAwayMessage(const MoqtGoAway& /*message*/) override {} void OnParsingError(MoqtError error_code, absl::string_view reason) override; @@ -214,11 +230,7 @@ // Returns false if the SUBSCRIBE isn't sent. bool Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor); - // converts two MoqtLocations into absolute sequences. - std::optional<FullSequence> LocationToAbsoluteNumber( - const LocalTrack& track, - const std::optional<MoqtSubscribeLocation>& group, - const std::optional<MoqtSubscribeLocation>& object); + // Returns the stream ID if successful, nullopt if not. // TODO: Add a callback if stream creation is delayed. std::optional<webtransport::StreamId> OpenUnidirectionalStream(); @@ -246,6 +258,7 @@ // All the tracks the peer can subscribe to. absl::flat_hash_map<FullTrackName, LocalTrack> local_tracks_; + absl::flat_hash_map<uint64_t, FullTrackName> local_track_by_subscribe_id_; // This is only used to check for track_alias collisions. absl::flat_hash_set<uint64_t> used_track_aliases_; uint64_t next_local_track_alias_ = 0;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 9f9370a..0b40053 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -42,7 +42,7 @@ constexpr webtransport::StreamId kOutgoingUniStreamId = 14; constexpr MoqtSessionParameters default_parameters = { - /*version=*/MoqtVersion::kDraft03, + /*version=*/MoqtVersion::kDraft04, /*perspective=*/quic::Perspective::IS_CLIENT, /*using_webtrans=*/true, /*path=*/std::string(), @@ -108,15 +108,19 @@ session->active_subscribes_[subscribe_id] = {subscribe, visitor}; } + static LocalTrack& local_track(MoqtSession* session, FullTrackName& name) { + return session->local_tracks_.find(name)->second; + } + static void AddSubscription(MoqtSession* session, FullTrackName& name, uint64_t subscribe_id, uint64_t track_alias, uint64_t start_group, uint64_t start_object) { - auto it = session->local_tracks_.find(name); - ASSERT_NE(it, session->local_tracks_.end()); - LocalTrack& track = it->second; + LocalTrack& track = local_track(session, name); track.set_track_alias(track_alias); track.AddWindow(subscribe_id, start_group, start_object); session->used_track_aliases_.emplace(track_alias); + session->local_track_by_subscribe_id_.emplace(subscribe_id, + track.full_track_name()); } static FullSequence next_sequence(MoqtSession* session, FullTrackName& name) { @@ -185,7 +189,7 @@ &session_, visitor.get()); // Handle the server setup MoqtServerSetup setup = { - MoqtVersion::kDraft03, + MoqtVersion::kDraft04, MoqtRole::kPubSub, }; EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); @@ -194,7 +198,7 @@ TEST_F(MoqtSessionTest, OnClientSetup) { MoqtSessionParameters server_parameters = { - /*version=*/MoqtVersion::kDraft03, + /*version=*/MoqtVersion::kDraft04, /*perspective=*/quic::Perspective::IS_SERVER, /*using_webtrans=*/true, /*path=*/"", @@ -206,7 +210,7 @@ std::unique_ptr<MoqtParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); MoqtClientSetup setup = { - /*supported_versions=*/{MoqtVersion::kDraft03}, + /*supported_versions=*/{MoqtVersion::kDraft04}, /*role=*/MoqtRole::kPubSub, /*path=*/std::nullopt, }; @@ -283,8 +287,8 @@ /*track_alias=*/2, /*track_namespace=*/"foo", /*track_name=*/"bar", - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, /*authorization_info=*/std::nullopt, @@ -410,8 +414,8 @@ /*track_alias=*/2, /*track_namespace=*/"foo", /*track_name=*/"bar", - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, /*authorization_info=*/std::nullopt, @@ -446,8 +450,8 @@ /*track_alias=*/2, /*track_namespace=*/"foo", /*track_name=*/"bar", - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, /*authorization_info=*/std::nullopt, @@ -607,7 +611,7 @@ TEST_F(MoqtSessionTest, IncomingPartialObjectNoBuffer) { MoqtSessionParameters parameters = { - /*version=*/MoqtVersion::kDraft03, + /*version=*/MoqtVersion::kDraft04, /*perspective=*/quic::Perspective::IS_CLIENT, /*using_webtrans=*/true, /*path=*/"", @@ -648,8 +652,8 @@ /*track_alias=*/2, /*track_namespace=*/ftn.track_namespace, /*track_name=*/ftn.track_name, - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, }; @@ -703,8 +707,8 @@ /*track_alias=*/2, /*track_namespace=*/ftn.track_namespace, /*track_name=*/ftn.track_name, - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, }; @@ -762,8 +766,8 @@ /*track_alias=*/2, /*track_namespace=*/ftn.track_namespace, /*track_name=*/ftn.track_name, - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, }; @@ -812,8 +816,8 @@ /*track_alias=*/2, /*track_namespace=*/ftn.track_namespace, /*track_name=*/ftn.track_name, - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, }; @@ -953,8 +957,8 @@ /*track_alias=*/3, // Doesn't match 2. /*track_namespace=*/"foo", /*track_name=*/"bar", - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, /*authorization_info=*/std::nullopt, @@ -1018,7 +1022,7 @@ TEST_F(MoqtSessionTest, OneBidirectionalStreamServer) { MoqtSessionParameters server_parameters = { - /*version=*/MoqtVersion::kDraft03, + /*version=*/MoqtVersion::kDraft04, /*perspective=*/quic::Perspective::IS_SERVER, /*using_webtrans=*/true, /*path=*/"", @@ -1030,7 +1034,7 @@ std::unique_ptr<MoqtParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); MoqtClientSetup setup = { - /*supported_versions*/ {MoqtVersion::kDraft03}, + /*supported_versions*/ {MoqtVersion::kDraft04}, /*role=*/MoqtRole::kPubSub, /*path=*/std::nullopt, }; @@ -1074,7 +1078,18 @@ MoqtUnsubscribe unsubscribe = { /*subscribe_id=*/0, }; + EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); + bool correct_message = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), + MoqtMessageType::kSubscribeDone); + return absl::OkStatus(); + }); stream_input->OnUnsubscribeMessage(unsubscribe); + EXPECT_TRUE(correct_message); EXPECT_FALSE(session_.HasSubscribers(ftn)); } @@ -1176,8 +1191,8 @@ /*track_alias=*/2, /*track_namespace=*/"foo", /*track_name=*/"bar", - /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), - /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)), + /*start_group=*/0, + /*start_object=*/0, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, /*authorization_info=*/std::nullopt, @@ -1210,6 +1225,41 @@ stream_input->OnAnnounceMessage(announce); } +TEST_F(MoqtSessionTest, SubscribeUpdateClosesSubscription) { + MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber); + FullTrackName ftn("foo", "bar"); + MockLocalTrackVisitor track_visitor; + session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &track_visitor); + MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0); + // Get the window, set the maximum delivered. + LocalTrack& track = MoqtSessionPeer::local_track(&session_, ftn); + track.GetWindow(0)->OnObjectDelivered(FullSequence(7, 3)); + // Update the end to fall at the last delivered object. + MoqtSubscribeUpdate update = { + /*subscribe_id=*/0, + /*start_group=*/5, + /*start_object=*/0, + /*end_group=*/7, + /*end_object=*/3, + }; + StrictMock<webtransport::test::MockStream> mock_stream; + std::unique_ptr<MoqtParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); + bool correct_message = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), + MoqtMessageType::kSubscribeDone); + return absl::OkStatus(); + }); + stream_input->OnSubscribeUpdateMessage(update); + EXPECT_TRUE(correct_message); + EXPECT_FALSE(session_.HasSubscribers(ftn)); +} + // TODO: Cover more error cases in the above } // namespace test
diff --git a/quiche/quic/moqt/moqt_subscribe_windows.h b/quiche/quic/moqt/moqt_subscribe_windows.h index 75a204f..40c2564 100644 --- a/quiche/quic/moqt/moqt_subscribe_windows.h +++ b/quiche/quic/moqt/moqt_subscribe_windows.h
@@ -59,14 +59,39 @@ return forwarding_preference_; } + void OnObjectDelivered(FullSequence sequence) { + if (!largest_delivered_.has_value() || *largest_delivered_ < sequence) { + largest_delivered_ = sequence; + } + } + + std::optional<FullSequence>& largest_delivered() { + return largest_delivered_; + } + + // Returns true if the updated values are valid. + bool UpdateStartEnd(FullSequence start, std::optional<FullSequence> end) { + // Can't make the subscription window bigger. + if (!InWindow(start)) { + return false; + } + if (end_.has_value() && (!end.has_value() || *end_ < *end)) { + return false; + } + start_ = start; + end_ = end; + return true; + } + private: // Converts an object sequence number into one that matches the way that // stream IDs are being mapped. (See the comment for send_streams_ below.) FullSequence SequenceToIndex(FullSequence sequence) const; const uint64_t subscribe_id_; - const FullSequence start_; - const std::optional<FullSequence> end_ = std::nullopt; + FullSequence start_; + std::optional<FullSequence> end_ = std::nullopt; + std::optional<FullSequence> largest_delivered_; // Store open streams for this subscription. If the forwarding preference is // kTrack, there is one entry under sequence (0, 0). If kGroup, each entry is // under (group, 0). If kObject, it's tracked under the full sequence. If
diff --git a/quiche/quic/moqt/moqt_subscribe_windows_test.cc b/quiche/quic/moqt/moqt_subscribe_windows_test.cc index c509be0..a144474 100644 --- a/quiche/quic/moqt/moqt_subscribe_windows_test.cc +++ b/quiche/quic/moqt/moqt_subscribe_windows_test.cc
@@ -82,6 +82,45 @@ EXPECT_QUIC_BUG(window.AddStream(4, 0, 2), "Adding a stream for datagram"); } +TEST_F(SubscribeWindowTest, OnObjectDelivered) { + SubscribeWindow window(subscribe_id_, MoqtForwardingPreference::kObject, + start_group_, start_object_, end_group_, end_object_); + EXPECT_FALSE(window.largest_delivered().has_value()); + window.OnObjectDelivered(FullSequence(4, 1)); + EXPECT_TRUE(window.largest_delivered().has_value()); + EXPECT_EQ(window.largest_delivered().value(), FullSequence(4, 1)); + window.OnObjectDelivered(FullSequence(4, 2)); + EXPECT_EQ(window.largest_delivered().value(), FullSequence(4, 2)); + window.OnObjectDelivered(FullSequence(4, 0)); + EXPECT_EQ(window.largest_delivered().value(), FullSequence(4, 2)); +} + +TEST_F(SubscribeWindowTest, UpdateStartEnd) { + SubscribeWindow window(subscribe_id_, MoqtForwardingPreference::kObject, + start_group_, start_object_, end_group_, end_object_); + EXPECT_TRUE( + window.UpdateStartEnd(FullSequence(start_group_, start_object_ + 1), + FullSequence(end_group_, end_object_ - 1))); + EXPECT_FALSE(window.InWindow(FullSequence(start_group_, start_object_))); + EXPECT_FALSE(window.InWindow(FullSequence(end_group_, end_object_))); + EXPECT_FALSE( + window.UpdateStartEnd(FullSequence(start_group_, start_object_), + FullSequence(end_group_, end_object_ - 1))); + EXPECT_FALSE( + window.UpdateStartEnd(FullSequence(start_group_, start_object_ + 1), + FullSequence(end_group_, end_object_))); +} + +TEST_F(SubscribeWindowTest, UpdateStartEndOpenEnded) { + SubscribeWindow window(subscribe_id_, MoqtForwardingPreference::kObject, + start_group_, start_object_); + EXPECT_TRUE(window.UpdateStartEnd(FullSequence(start_group_, start_object_), + FullSequence(end_group_, end_object_))); + EXPECT_FALSE(window.InWindow(FullSequence(end_group_, end_object_ + 1))); + EXPECT_FALSE(window.UpdateStartEnd(FullSequence(start_group_, start_object_), + std::nullopt)); +} + class QUICHE_EXPORT MoqtSubscribeWindowsTest : public quic::test::QuicTest { public: MoqtSubscribeWindowsTest() : windows_(MoqtForwardingPreference::kObject) {}
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 98dee0f..227d298 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -36,9 +36,10 @@ typedef absl::variant<MoqtClientSetup, MoqtServerSetup, MoqtObject, MoqtSubscribe, MoqtSubscribeOk, MoqtSubscribeError, - MoqtUnsubscribe, MoqtSubscribeDone, MoqtAnnounce, - MoqtAnnounceOk, MoqtAnnounceError, MoqtUnannounce, - MoqtGoAway> + MoqtUnsubscribe, MoqtSubscribeDone, MoqtSubscribeUpdate, + MoqtAnnounce, MoqtAnnounceOk, MoqtAnnounceError, + MoqtAnnounceCancel, MoqtTrackStatusRequest, + MoqtUnannounce, MoqtTrackStatus, MoqtGoAway> MessageStructuredData; // The total actual size of the message. @@ -446,42 +447,27 @@ return true; } - void ExpandVarints() override { - ExpandVarintsImpl("vvvv---v----vvvvvvvvv"); - } + void ExpandVarints() override { ExpandVarintsImpl("vvvv---v----vvvvvv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(subscribe_); } private: - uint8_t raw_packet_[24] = { - 0x03, - 0x01, - 0x02, // id and alias - 0x03, - 0x66, - 0x6f, - 0x6f, // track_namespace = "foo" - 0x04, - 0x61, - 0x62, - 0x63, - 0x64, // track_name = "abcd" - 0x02, - 0x04, // start_group = 4 (relative previous) - 0x01, - 0x01, // start_object = 1 (absolute) - 0x00, // end_group = none - 0x00, // end_object = none - // TODO(martinduke): figure out what to do about the missing num - // parameters field. - 0x01, // 1 parameter - 0x02, - 0x03, - 0x62, - 0x61, - 0x72, // authorization_info = "bar" + uint8_t raw_packet_[21] = { + 0x03, 0x01, + 0x02, // id and alias + 0x03, 0x66, 0x6f, + 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, + 0x64, // track_name = "abcd" + 0x03, // Filter type: Absolute Start + 0x04, // start_group = 4 (relative previous) + 0x01, // start_object = 1 (absolute) + // No EndGroup or EndObject + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, + 0x72, // authorization_info = "bar" }; MoqtSubscribe subscribe_ = { @@ -489,8 +475,8 @@ /*track_alias=*/2, /*track_namespace=*/"foo", /*track_name=*/"abcd", - /*start_group=*/MoqtSubscribeLocation(false, (int64_t)(-4)), - /*start_object=*/MoqtSubscribeLocation(true, (uint64_t)1), + /*start_group=*/4, + /*start_object=*/1, /*end_group=*/std::nullopt, /*end_object=*/std::nullopt, /*authorization_info=*/"bar", @@ -671,12 +657,71 @@ MoqtSubscribeDone subscribe_done_ = { /*subscribe_id=*/2, - /*error_code=*/3, + /*error_code=*/SubscribeDoneCode::kTrackEnded, /*reason_phrase=*/"hi", /*final_id=*/FullSequence(8, 12), }; }; +class QUICHE_NO_EXPORT SubscribeUpdateMessage : public TestMessageBase { + public: + SubscribeUpdateMessage() + : TestMessageBase(MoqtMessageType::kSubscribeUpdate) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSubscribeUpdate>(values); + if (cast.subscribe_id != subscribe_update_.subscribe_id) { + QUIC_LOG(INFO) << "SUBSCRIBE_UPDATE subscribe ID mismatch"; + return false; + } + if (cast.start_group != subscribe_update_.start_group) { + QUIC_LOG(INFO) << "SUBSCRIBE_UPDATE start group mismatch"; + return false; + } + if (cast.start_object != subscribe_update_.start_object) { + QUIC_LOG(INFO) << "SUBSCRIBE_UPDATE start group mismatch"; + return false; + } + if (cast.end_group != subscribe_update_.end_group) { + QUIC_LOG(INFO) << "SUBSCRIBE_UPDATE end group mismatch"; + return false; + } + if (cast.end_object != subscribe_update_.end_object) { + QUIC_LOG(INFO) << "SUBSCRIBE_UPDATE end group mismatch"; + return false; + } + if (cast.authorization_info != subscribe_update_.authorization_info) { + QUIC_LOG(INFO) << "SUBSCRIBE_UPDATE authorization info mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vvvvvvvvv---"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(subscribe_update_); + } + + private: + uint8_t raw_packet_[12] = { + 0x02, 0x02, 0x03, 0x01, 0x05, 0x06, // start and end sequences + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + + MoqtSubscribeUpdate subscribe_update_ = { + /*subscribe_id=*/2, + /*start_group=*/3, + /*start_object=*/1, + /*end_group=*/4, + /*end_object=*/5, + /*authorization_info=*/"bar", + }; +}; + class QUICHE_NO_EXPORT AnnounceMessage : public TestMessageBase { public: AnnounceMessage() : TestMessageBase(MoqtMessageType::kAnnounce) { @@ -789,6 +834,75 @@ }; }; +class QUICHE_NO_EXPORT AnnounceCancelMessage : public TestMessageBase { + public: + AnnounceCancelMessage() : TestMessageBase(MoqtMessageType::kAnnounceCancel) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtAnnounceCancel>(values); + if (cast.track_namespace != announce_cancel_.track_namespace) { + QUIC_LOG(INFO) << "ANNOUNCE CANCEL track namespace mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv---"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(announce_cancel_); + } + + private: + uint8_t raw_packet_[5] = { + 0x0c, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + }; + + MoqtAnnounceCancel announce_cancel_ = { + /*track_namespace=*/"foo", + }; +}; + +class QUICHE_NO_EXPORT TrackStatusRequestMessage : public TestMessageBase { + public: + TrackStatusRequestMessage() + : TestMessageBase(MoqtMessageType::kTrackStatusRequest) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtTrackStatusRequest>(values); + if (cast.track_namespace != track_status_request_.track_namespace) { + QUIC_LOG(INFO) << "TRACK STATUS REQUEST track namespace mismatch"; + return false; + } + if (cast.track_name != track_status_request_.track_name) { + QUIC_LOG(INFO) << "TRACK STATUS REQUEST track name mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv---v----"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(track_status_request_); + } + + private: + uint8_t raw_packet_[10] = { + 0x0d, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + }; + + MoqtTrackStatusRequest track_status_request_ = { + /*track_namespace=*/"foo", + /*track_name=*/"abcd", + }; +}; + class QUICHE_NO_EXPORT UnannounceMessage : public TestMessageBase { public: UnannounceMessage() : TestMessageBase(MoqtMessageType::kUnannounce) { @@ -820,6 +934,59 @@ }; }; +class QUICHE_NO_EXPORT TrackStatusMessage : public TestMessageBase { + public: + TrackStatusMessage() : TestMessageBase(MoqtMessageType::kTrackStatus) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtTrackStatus>(values); + if (cast.track_namespace != track_status_.track_namespace) { + QUIC_LOG(INFO) << "TRACK STATUS track namespace mismatch"; + return false; + } + if (cast.track_name != track_status_.track_name) { + QUIC_LOG(INFO) << "TRACK STATUS track name mismatch"; + return false; + } + if (cast.status_code != track_status_.status_code) { + QUIC_LOG(INFO) << "TRACK STATUS code mismatch"; + return false; + } + if (cast.last_group != track_status_.last_group) { + QUIC_LOG(INFO) << "TRACK STATUS last group mismatch"; + return false; + } + if (cast.last_object != track_status_.last_object) { + QUIC_LOG(INFO) << "TRACK STATUS last object mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv---v----vvv"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(track_status_); + } + + private: + uint8_t raw_packet_[13] = { + 0x0e, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" + 0x00, 0x0c, 0x14, // status, last_group, last_object + }; + + MoqtTrackStatus track_status_ = { + /*track_namespace=*/"foo", + /*track_name=*/"abcd", + /*status_code=*/MoqtTrackStatusCode::kInProgress, + /*last_group=*/12, + /*last_object=*/20, + }; +}; + class QUICHE_NO_EXPORT GoAwayMessage : public TestMessageBase { public: GoAwayMessage() : TestMessageBase(MoqtMessageType::kGoAway) { @@ -869,14 +1036,22 @@ return std::make_unique<UnsubscribeMessage>(); case MoqtMessageType::kSubscribeDone: return std::make_unique<SubscribeDoneMessage>(); + case MoqtMessageType::kSubscribeUpdate: + return std::make_unique<SubscribeUpdateMessage>(); case MoqtMessageType::kAnnounce: return std::make_unique<AnnounceMessage>(); case MoqtMessageType::kAnnounceOk: return std::make_unique<AnnounceOkMessage>(); case MoqtMessageType::kAnnounceError: return std::make_unique<AnnounceErrorMessage>(); + case MoqtMessageType::kAnnounceCancel: + return std::make_unique<AnnounceCancelMessage>(); + case MoqtMessageType::kTrackStatusRequest: + return std::make_unique<TrackStatusRequestMessage>(); case MoqtMessageType::kUnannounce: return std::make_unique<UnannounceMessage>(); + case MoqtMessageType::kTrackStatus: + return std::make_unique<TrackStatusMessage>(); case MoqtMessageType::kGoAway: return std::make_unique<GoAwayMessage>(); case MoqtMessageType::kClientSetup:
diff --git a/quiche/quic/moqt/tools/chat_client_bin.cc b/quiche/quic/moqt/tools/chat_client_bin.cc index d557137..b7f0f75 100644 --- a/quiche/quic/moqt/tools/chat_client_bin.cc +++ b/quiche/quic/moqt/tools/chat_client_bin.cc
@@ -293,9 +293,9 @@ auto new_user = other_users_.emplace( std::make_pair(user, ChatUser(to_subscribe, group_sequence))); ChatUser& user_record = new_user.first->second; - session_->SubscribeRelative(user_record.full_track_name.track_namespace, - user_record.full_track_name.track_name, 0, - 0, visitor); + session_->SubscribeCurrentGroup( + user_record.full_track_name.track_namespace, + user_record.full_track_name.track_name, visitor); subscribes_to_make_++; } else { if (it->second.from_group == group_sequence) {
diff --git a/quiche/quic/moqt/tools/moqt_client.cc b/quiche/quic/moqt/tools/moqt_client.cc index 4920c16..a8cc392 100644 --- a/quiche/quic/moqt/tools/moqt_client.cc +++ b/quiche/quic/moqt/tools/moqt_client.cc
@@ -89,7 +89,7 @@ } MoqtSessionParameters parameters; - parameters.version = MoqtVersion::kDraft03; + parameters.version = MoqtVersion::kDraft04; parameters.perspective = quic::Perspective::IS_CLIENT, parameters.using_webtrans = true; parameters.path = "";
diff --git a/quiche/quic/moqt/tools/moqt_server.cc b/quiche/quic/moqt/tools/moqt_server.cc index 9a60977..869c3ec 100644 --- a/quiche/quic/moqt/tools/moqt_server.cc +++ b/quiche/quic/moqt/tools/moqt_server.cc
@@ -33,7 +33,7 @@ parameters.perspective = quic::Perspective::IS_SERVER; parameters.path = path; parameters.using_webtrans = true; - parameters.version = MoqtVersion::kDraft03; + parameters.version = MoqtVersion::kDraft04; parameters.deliver_partial_objects = false; auto moqt_session = std::make_unique<MoqtSession>(session, parameters); std::move (*configurator)(moqt_session.get());