Update MOQT to the (real) draft-01. PiperOrigin-RevId: 578609842
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index 294b15c..b03293f 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -8,6 +8,7 @@ #include <cstdint> #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "quiche/quic/core/quic_data_writer.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/core/quic_types.h" @@ -18,28 +19,55 @@ namespace { -inline size_t NeededVarIntLen(uint64_t value) { +inline size_t NeededVarIntLen(const uint64_t value) { return static_cast<size_t>(quic::QuicDataWriter::GetVarInt62Len(value)); } -inline size_t NeededVarIntLen(MoqtVersion value) { +inline size_t NeededVarIntLen(const MoqtVersion value) { return static_cast<size_t>( quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value))); } -inline size_t ParameterLen(uint64_t type, uint64_t value_len) { +inline size_t NeededVarIntLen(const MoqtMessageType value) { + return static_cast<size_t>( + quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value))); +} +inline size_t NeededVarIntLen(const MoqtSubscribeLocationMode value) { + return static_cast<size_t>( + quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value))); +} +inline size_t ParameterLen(const uint64_t type, const uint64_t value_len) { return NeededVarIntLen(type) + NeededVarIntLen(value_len) + value_len; } +inline size_t LocationLength(const absl::optional<MoqtSubscribeLocation> loc) { + if (!loc.has_value()) { + return NeededVarIntLen(MoqtSubscribeLocationMode::kNone); + } + if (loc->absolute) { + return NeededVarIntLen(MoqtSubscribeLocationMode::kAbsolute) + + NeededVarIntLen(loc->absolute_value); + } + // It's a relative value + if (loc->relative_value < 0) { + return NeededVarIntLen(MoqtSubscribeLocationMode::kRelativePrevious) + + NeededVarIntLen(static_cast<uint64_t>(loc->relative_value * -1)); + } + return NeededVarIntLen(MoqtSubscribeLocationMode::kRelativeNext) + + NeededVarIntLen(static_cast<uint64_t>(loc->relative_value)); +} +inline size_t LengthPrefixedStringLength(absl::string_view string) { + return NeededVarIntLen(string.length()) + string.length(); +} // This only supports values up to UINT8_MAX, as that's all that exists in the // standard. -inline bool WriteIntParameter(quic::QuicDataWriter& writer, uint64_t type, - uint8_t value) { +inline bool WriteVarIntParameter(quic::QuicDataWriter& writer, uint64_t type, + uint64_t value) { if (!writer.WriteVarInt62(type)) { return false; } - if (!writer.WriteVarInt62(1)) { + if (!writer.WriteVarInt62(NeededVarIntLen(value))) { return false; } - return writer.WriteUInt8(value); + return writer.WriteVarInt62(value); } inline bool WriteStringParameter(quic::QuicDataWriter& writer, uint64_t type, @@ -50,32 +78,65 @@ return writer.WriteStringPieceVarInt62(value); } +inline bool WriteLocation(quic::QuicDataWriter& writer, + absl::optional<MoqtSubscribeLocation> loc) { + if (!loc.has_value()) { + return writer.WriteVarInt62( + static_cast<uint64_t>(MoqtSubscribeLocationMode::kNone)); + } + if (loc->absolute) { + if (!writer.WriteVarInt62( + static_cast<uint64_t>(MoqtSubscribeLocationMode::kAbsolute))) { + return false; + } + return writer.WriteVarInt62(loc->absolute_value); + } + if (loc->relative_value < 0) { + if (!writer.WriteVarInt62(static_cast<uint64_t>( + MoqtSubscribeLocationMode::kRelativePrevious))) { + return false; + } + return writer.WriteVarInt62( + static_cast<uint64_t>(loc->relative_value * -1)); + } + if (!writer.WriteVarInt62( + static_cast<uint64_t>(MoqtSubscribeLocationMode::kRelativeNext))) { + return false; + } + return writer.WriteVarInt62(static_cast<uint64_t>(loc->relative_value)); +} + } // namespace quiche::QuicheBuffer MoqtFramer::SerializeObject( - const MoqtObject& message, const absl::string_view payload, - const size_t known_payload_size) { - if (known_payload_size > 0 && known_payload_size < payload.length()) { + const MoqtObject& message, const absl::string_view payload) { + if (message.payload_length.has_value() && + *message.payload_length < payload.length()) { + QUICHE_DLOG(INFO) << "payload_size is too small for payload"; return quiche::QuicheBuffer(); } - size_t varint_len = NeededVarIntLen(message.track_id) + - NeededVarIntLen(message.group_sequence) + - NeededVarIntLen(message.object_sequence) + - NeededVarIntLen(message.object_send_order); - size_t message_len = - known_payload_size == 0 ? 0 : (known_payload_size + varint_len); + uint64_t message_type = + static_cast<uint64_t>(message.payload_length.has_value() + ? MoqtMessageType::kObjectWithPayloadLength + : MoqtMessageType::kObjectWithoutPayloadLength); size_t buffer_size = - varint_len + payload.length() + - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kObject)) + - NeededVarIntLen(message_len); + NeededVarIntLen(message_type) + NeededVarIntLen(message.track_id) + + NeededVarIntLen(message.group_sequence) + + NeededVarIntLen(message.object_sequence) + + NeededVarIntLen(message.object_send_order) + payload.length(); + if (message.payload_length.has_value()) { + buffer_size += NeededVarIntLen(*message.payload_length); + } quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kObject)); - writer.WriteVarInt62(message_len); + writer.WriteVarInt62(message_type); writer.WriteVarInt62(message.track_id); writer.WriteVarInt62(message.group_sequence); writer.WriteVarInt62(message.object_sequence); writer.WriteVarInt62(message.object_send_order); + if (message.payload_length.has_value()) { + writer.WriteVarInt62(*message.payload_length); + } writer.WriteStringPiece(payload); return buffer; } @@ -88,46 +149,38 @@ return buffer; } -quiche::QuicheBuffer MoqtFramer::SerializeSetup(const MoqtSetup& message) { - size_t message_len; - if (perspective_ == quic::Perspective::IS_CLIENT) { - message_len = NeededVarIntLen(message.supported_versions.size()); - for (MoqtVersion version : message.supported_versions) { - message_len += NeededVarIntLen(version); - } - // TODO: figure out if the role needs to be sent on the client side or on - // both sides. - if (message.role.has_value()) { - message_len += - ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kRole), 1); - } - if (!using_webtrans_ && message.path.has_value()) { - message_len += - ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kPath), - message.path->length()); - } - } else { - message_len = NeededVarIntLen(message.supported_versions[0]); +quiche::QuicheBuffer MoqtFramer::SerializeClientSetup( + const MoqtClientSetup& message) { + size_t buffer_size = NeededVarIntLen(MoqtMessageType::kClientSetup) + + NeededVarIntLen(message.supported_versions.size()); + for (MoqtVersion version : message.supported_versions) { + buffer_size += NeededVarIntLen(version); } - size_t buffer_size = - message_len + - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSetup)) + - NeededVarIntLen(message_len); + uint64_t num_params = 0; + if (message.role.has_value()) { + num_params++; + buffer_size += + ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kRole), 1); + } + if (!using_webtrans_ && message.path.has_value()) { + num_params++; + buffer_size += + ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kPath), + message.path->length()); + } + buffer_size += NeededVarIntLen(num_params); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSetup)); - writer.WriteVarInt62(message_len); - if (perspective_ == quic::Perspective::IS_SERVER) { - writer.WriteVarInt62(static_cast<uint64_t>(message.supported_versions[0])); - return buffer; - } + writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kClientSetup)); writer.WriteVarInt62(message.supported_versions.size()); for (MoqtVersion version : message.supported_versions) { writer.WriteVarInt62(static_cast<uint64_t>(version)); } + writer.WriteVarInt62(num_params); if (message.role.has_value()) { - WriteIntParameter(writer, static_cast<uint64_t>(MoqtSetupParameter::kRole), - static_cast<uint8_t>(message.role.value())); + WriteVarIntParameter(writer, + static_cast<uint64_t>(MoqtSetupParameter::kRole), + static_cast<uint64_t>(message.role.value())); } if (!using_webtrans_ && message.path.has_value()) { WriteStringParameter(writer, @@ -137,45 +190,66 @@ return buffer; } +quiche::QuicheBuffer MoqtFramer::SerializeServerSetup( + const MoqtServerSetup& message) { + size_t buffer_size = NeededVarIntLen(MoqtMessageType::kServerSetup) + + NeededVarIntLen(message.selected_version); + uint64_t num_params = 0; + if (message.role.has_value()) { + num_params++; + buffer_size += + ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kRole), + static_cast<uint64_t>(message.role.value())); + } + buffer_size += NeededVarIntLen(num_params); + quiche::QuicheBuffer buffer(allocator_, buffer_size); + quic::QuicDataWriter writer(buffer.size(), buffer.data()); + writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kServerSetup)); + writer.WriteVarInt62(static_cast<uint64_t>(message.selected_version)); + writer.WriteVarInt62(num_params); + if (message.role.has_value()) { + WriteVarIntParameter(writer, + static_cast<uint64_t>(MoqtSetupParameter::kRole), + static_cast<uint64_t>(message.role.value())); + } + return buffer; +} + quiche::QuicheBuffer MoqtFramer::SerializeSubscribeRequest( const MoqtSubscribeRequest& message) { - size_t message_len = NeededVarIntLen(message.full_track_name.length()) + - message.full_track_name.length(); - if (message.group_sequence.has_value()) { - message_len += ParameterLen( - static_cast<uint64_t>(MoqtTrackRequestParameter::kGroupSequence), 1); + if (!message.start_group.has_value() || !message.start_object.has_value()) { + QUIC_LOG(INFO) << "start_group or start_object is missing"; + return quiche::QuicheBuffer(); } - if (message.object_sequence.has_value()) { - message_len += ParameterLen( - static_cast<uint64_t>(MoqtTrackRequestParameter::kObjectSequence), 1); + if (message.end_group.has_value() != message.end_object.has_value()) { + QUIC_LOG(INFO) << "end_group and end_object must both be None or both " + << "non-None"; + return quiche::QuicheBuffer(); } + size_t buffer_size = NeededVarIntLen(MoqtMessageType::kSubscribeRequest) + + LengthPrefixedStringLength(message.full_track_name) + + LocationLength(message.start_group) + + LocationLength(message.start_object) + + LocationLength(message.end_group) + + LocationLength(message.end_object); + uint64_t num_params = 0; if (message.authorization_info.has_value()) { - message_len += ParameterLen( + num_params++; + buffer_size += ParameterLen( static_cast<uint64_t>(MoqtTrackRequestParameter::kAuthorizationInfo), message.authorization_info->length()); } - size_t buffer_size = - message_len + - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kObject)) + - NeededVarIntLen(message_len); + buffer_size += NeededVarIntLen(num_params); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62( static_cast<uint64_t>(MoqtMessageType::kSubscribeRequest)); - writer.WriteVarInt62(message_len); writer.WriteStringPieceVarInt62(message.full_track_name); - if (message.group_sequence.has_value()) { - WriteIntParameter( - writer, - static_cast<uint64_t>(MoqtTrackRequestParameter::kGroupSequence), - message.group_sequence.value()); - } - if (message.object_sequence.has_value()) { - WriteIntParameter( - writer, - static_cast<uint64_t>(MoqtTrackRequestParameter::kObjectSequence), - message.object_sequence.value()); - } + WriteLocation(writer, message.start_group); + WriteLocation(writer, message.start_object); + WriteLocation(writer, message.end_group); + WriteLocation(writer, message.end_object); + writer.WriteVarInt62(num_params); if (message.authorization_info.has_value()) { WriteStringParameter( writer, @@ -187,19 +261,17 @@ quiche::QuicheBuffer MoqtFramer::SerializeSubscribeOk( const MoqtSubscribeOk& message) { - size_t message_len = NeededVarIntLen(message.full_track_name.length()) + - message.full_track_name.length() + - NeededVarIntLen(message.track_id) + - NeededVarIntLen(message.expires.ToMilliseconds()); size_t buffer_size = - message_len + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeOk)) + - NeededVarIntLen(message_len); + LengthPrefixedStringLength(message.track_namespace) + + LengthPrefixedStringLength(message.track_name) + + NeededVarIntLen(message.track_id) + + NeededVarIntLen(message.expires.ToMilliseconds()); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeOk)); - writer.WriteVarInt62(message_len); - writer.WriteStringPieceVarInt62(message.full_track_name); + writer.WriteStringPieceVarInt62(message.track_namespace); + writer.WriteStringPieceVarInt62(message.track_name); writer.WriteVarInt62(message.track_id); writer.WriteVarInt62(message.expires.ToMilliseconds()); return buffer; @@ -207,20 +279,17 @@ quiche::QuicheBuffer MoqtFramer::SerializeSubscribeError( const MoqtSubscribeError& message) { - size_t message_len = NeededVarIntLen(message.full_track_name.length()) + - message.full_track_name.length() + - NeededVarIntLen(message.error_code) + - NeededVarIntLen(message.reason_phrase.length()) + - message.reason_phrase.length(); size_t buffer_size = - message_len + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeError)) + - NeededVarIntLen(message_len); + LengthPrefixedStringLength(message.track_namespace) + + LengthPrefixedStringLength(message.track_name) + + NeededVarIntLen(message.error_code) + + LengthPrefixedStringLength(message.reason_phrase); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeError)); - writer.WriteVarInt62(message_len); - writer.WriteStringPieceVarInt62(message.full_track_name); + writer.WriteStringPieceVarInt62(message.track_namespace); + writer.WriteStringPieceVarInt62(message.track_name); writer.WriteVarInt62(message.error_code); writer.WriteStringPieceVarInt62(message.reason_phrase); return buffer; @@ -228,38 +297,76 @@ quiche::QuicheBuffer MoqtFramer::SerializeUnsubscribe( const MoqtUnsubscribe& message) { - size_t message_len = NeededVarIntLen(message.full_track_name.length()) + - message.full_track_name.length(); size_t buffer_size = - message_len + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kUnsubscribe)) + - NeededVarIntLen(message_len); + LengthPrefixedStringLength(message.track_namespace) + + LengthPrefixedStringLength(message.track_name); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kUnsubscribe)); - writer.WriteVarInt62(message_len); - writer.WriteStringPieceVarInt62(message.full_track_name); + writer.WriteStringPieceVarInt62(message.track_namespace); + writer.WriteStringPieceVarInt62(message.track_name); + return buffer; +} + +quiche::QuicheBuffer MoqtFramer::SerializeSubscribeFin( + const MoqtSubscribeFin& message) { + size_t buffer_size = + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeFin)) + + LengthPrefixedStringLength(message.track_namespace) + + LengthPrefixedStringLength(message.track_name) + + NeededVarIntLen(message.final_group) + + NeededVarIntLen(message.final_object); + quiche::QuicheBuffer buffer(allocator_, buffer_size); + quic::QuicDataWriter writer(buffer.size(), buffer.data()); + writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeFin)); + writer.WriteStringPieceVarInt62(message.track_namespace); + writer.WriteStringPieceVarInt62(message.track_name); + writer.WriteVarInt62(message.final_group); + writer.WriteVarInt62(message.final_object); + return buffer; +} + +quiche::QuicheBuffer MoqtFramer::SerializeSubscribeRst( + const MoqtSubscribeRst& message) { + size_t buffer_size = + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeRst)) + + LengthPrefixedStringLength(message.track_namespace) + + LengthPrefixedStringLength(message.track_name) + + NeededVarIntLen(message.error_code) + + LengthPrefixedStringLength(message.reason_phrase) + + NeededVarIntLen(message.final_group) + + NeededVarIntLen(message.final_object); + quiche::QuicheBuffer buffer(allocator_, buffer_size); + quic::QuicDataWriter writer(buffer.size(), buffer.data()); + writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeRst)); + writer.WriteStringPieceVarInt62(message.track_namespace); + writer.WriteStringPieceVarInt62(message.track_name); + writer.WriteVarInt62(message.error_code); + writer.WriteStringPieceVarInt62(message.reason_phrase); + writer.WriteVarInt62(message.final_group); + writer.WriteVarInt62(message.final_object); return buffer; } quiche::QuicheBuffer MoqtFramer::SerializeAnnounce( const MoqtAnnounce& message) { - size_t message_len = NeededVarIntLen(message.track_namespace.length()) + - message.track_namespace.length(); + size_t buffer_size = + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounce)) + + LengthPrefixedStringLength(message.track_namespace); + uint64_t num_params = 0; if (message.authorization_info.has_value()) { - message_len += ParameterLen( + num_params++; + buffer_size += ParameterLen( static_cast<uint64_t>(MoqtTrackRequestParameter::kAuthorizationInfo), message.authorization_info->length()); } - size_t buffer_size = - message_len + - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounce)) + - NeededVarIntLen(message_len); + buffer_size += NeededVarIntLen(num_params); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounce)); - writer.WriteVarInt62(message_len); writer.WriteStringPieceVarInt62(message.track_namespace); + writer.WriteVarInt62(num_params); if (message.authorization_info.has_value()) { WriteStringParameter( writer, @@ -271,35 +378,26 @@ quiche::QuicheBuffer MoqtFramer::SerializeAnnounceOk( const MoqtAnnounceOk& message) { - size_t message_len = NeededVarIntLen(message.track_namespace.length()) + - message.track_namespace.length(); size_t buffer_size = - message_len + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounceOk)) + - NeededVarIntLen(message_len); + LengthPrefixedStringLength(message.track_namespace); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounceOk)); - writer.WriteVarInt62(message_len); writer.WriteStringPieceVarInt62(message.track_namespace); return buffer; } quiche::QuicheBuffer MoqtFramer::SerializeAnnounceError( const MoqtAnnounceError& message) { - size_t message_len = NeededVarIntLen(message.track_namespace.length()) + - message.track_namespace.length() + - NeededVarIntLen(message.error_code) + - NeededVarIntLen(message.reason_phrase.length()) + - message.reason_phrase.length(); size_t buffer_size = - message_len + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounceError)) + - NeededVarIntLen(message_len); + LengthPrefixedStringLength(message.track_namespace) + + NeededVarIntLen(message.error_code) + + LengthPrefixedStringLength(message.reason_phrase); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounceError)); - writer.WriteVarInt62(message_len); writer.WriteStringPieceVarInt62(message.track_namespace); writer.WriteVarInt62(message.error_code); writer.WriteStringPieceVarInt62(message.reason_phrase); @@ -308,28 +406,24 @@ quiche::QuicheBuffer MoqtFramer::SerializeUnannounce( const MoqtUnannounce& message) { - size_t message_len = NeededVarIntLen(message.track_namespace.length()) + - message.track_namespace.length(); size_t buffer_size = - message_len + NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kUnannounce)) + - NeededVarIntLen(message_len); + LengthPrefixedStringLength(message.track_namespace); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kUnannounce)); - writer.WriteVarInt62(message_len); writer.WriteStringPieceVarInt62(message.track_namespace); return buffer; } -quiche::QuicheBuffer MoqtFramer::SerializeGoAway() { +quiche::QuicheBuffer MoqtFramer::SerializeGoAway(const MoqtGoAway& message) { size_t buffer_size = NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kGoAway)) + - NeededVarIntLen(0); + LengthPrefixedStringLength(message.new_session_uri); quiche::QuicheBuffer buffer(allocator_, buffer_size); quic::QuicDataWriter writer(buffer.size(), buffer.data()); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kGoAway)); - writer.WriteVarInt62(0); + writer.WriteStringPieceVarInt62(message.new_session_uri); return buffer; }
diff --git a/quiche/quic/moqt/moqt_framer.h b/quiche/quic/moqt/moqt_framer.h index d88ecfb..6d0356b 100644 --- a/quiche/quic/moqt/moqt_framer.h +++ b/quiche/quic/moqt/moqt_framer.h
@@ -8,6 +8,7 @@ #include <cstddef> #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_export.h" @@ -25,40 +26,39 @@ // different streams. class QUICHE_EXPORT MoqtFramer { public: - MoqtFramer(quiche::QuicheBufferAllocator* allocator, - quic::Perspective perspective, bool using_webtrans) - : allocator_(allocator), - perspective_(perspective), - using_webtrans_(using_webtrans) {} + MoqtFramer(quiche::QuicheBufferAllocator* allocator, bool using_webtrans) + : allocator_(allocator), using_webtrans_(using_webtrans) {} // Serialize functions. Takes structured data and serializes it into a // QuicheBuffer for delivery to the stream. - // SerializeObject also takes a payload. |known_payload_size| is used in - // encoding the message length. If zero, the message length as also encoded as - // zero to indicate the message ends with the stream. If nonzero, and too - // small to fit the varints and the provided payload, returns an empty buffer. + // SerializeObject also takes a payload. |payload_size| might simply be the + // size of |payload|, or it could be larger if there is more data coming, or + // it could be nullopt if the final length is unknown. If |payload_size| is + // smaller than |payload|, returns an empty buffer. quiche::QuicheBuffer SerializeObject(const MoqtObject& message, - absl::string_view payload, - size_t known_payload_size); + absl::string_view payload); // Build a buffer for additional payload data. quiche::QuicheBuffer SerializeObjectPayload(absl::string_view payload); - quiche::QuicheBuffer SerializeSetup(const MoqtSetup& message); + quiche::QuicheBuffer SerializeClientSetup(const MoqtClientSetup& message); + quiche::QuicheBuffer SerializeServerSetup(const MoqtServerSetup& message); + // Returns an empty buffer if there is an illegal combination of locations. quiche::QuicheBuffer SerializeSubscribeRequest( const MoqtSubscribeRequest& message); quiche::QuicheBuffer SerializeSubscribeOk(const MoqtSubscribeOk& message); quiche::QuicheBuffer SerializeSubscribeError( const MoqtSubscribeError& message); quiche::QuicheBuffer SerializeUnsubscribe(const MoqtUnsubscribe& message); + quiche::QuicheBuffer SerializeSubscribeFin(const MoqtSubscribeFin& message); + quiche::QuicheBuffer SerializeSubscribeRst(const MoqtSubscribeRst& message); quiche::QuicheBuffer SerializeAnnounce(const MoqtAnnounce& message); quiche::QuicheBuffer SerializeAnnounceOk(const MoqtAnnounceOk& message); quiche::QuicheBuffer SerializeAnnounceError(const MoqtAnnounceError& message); quiche::QuicheBuffer SerializeUnannounce(const MoqtUnannounce& message); - quiche::QuicheBuffer SerializeGoAway(); + quiche::QuicheBuffer SerializeGoAway(const MoqtGoAway& message); private: quiche::QuicheBufferAllocator* allocator_; - quic::Perspective perspective_; bool using_webtrans_; };
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index abfac63..f261e1f 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -18,46 +18,45 @@ namespace moqt::test { struct MoqtFramerTestParams { - MoqtFramerTestParams(MoqtMessageType message_type, - quic::Perspective perspective, bool uses_web_transport) - : message_type(message_type), - perspective(perspective), - uses_web_transport(uses_web_transport) {} + MoqtFramerTestParams(MoqtMessageType message_type, bool uses_web_transport) + : message_type(message_type), uses_web_transport(uses_web_transport) {} MoqtMessageType message_type; - quic::Perspective perspective; bool uses_web_transport; }; std::vector<MoqtFramerTestParams> GetMoqtFramerTestParams() { std::vector<MoqtFramerTestParams> params; std::vector<MoqtMessageType> message_types = { - MoqtMessageType::kObject, MoqtMessageType::kSetup, - MoqtMessageType::kSubscribeRequest, MoqtMessageType::kSubscribeOk, - MoqtMessageType::kSubscribeError, MoqtMessageType::kAnnounce, - MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, + MoqtMessageType::kObjectWithPayloadLength, + MoqtMessageType::kObjectWithoutPayloadLength, + MoqtMessageType::kClientSetup, + MoqtMessageType::kServerSetup, + MoqtMessageType::kSubscribeRequest, + MoqtMessageType::kSubscribeOk, + MoqtMessageType::kSubscribeError, + MoqtMessageType::kUnsubscribe, + MoqtMessageType::kSubscribeFin, + MoqtMessageType::kSubscribeRst, + MoqtMessageType::kAnnounce, + MoqtMessageType::kAnnounceOk, + MoqtMessageType::kAnnounceError, + MoqtMessageType::kUnannounce, MoqtMessageType::kGoAway, }; - std::vector<quic::Perspective> perspectives = { - quic::Perspective::IS_SERVER, - quic::Perspective::IS_CLIENT, - }; std::vector<bool> uses_web_transport_bool = { false, true, }; for (const MoqtMessageType message_type : message_types) { - if (message_type == MoqtMessageType::kSetup) { - for (const quic::Perspective perspective : perspectives) { - for (const bool uses_web_transport : uses_web_transport_bool) { - params.push_back(MoqtFramerTestParams(message_type, perspective, - uses_web_transport)); - } + if (message_type == MoqtMessageType::kClientSetup) { + for (const bool uses_web_transport : uses_web_transport_bool) { + params.push_back( + MoqtFramerTestParams(message_type, uses_web_transport)); } } else { // All other types are processed the same for either perspective or // transport. - params.push_back(MoqtFramerTestParams( - message_type, quic::Perspective::IS_SERVER, true)); + params.push_back(MoqtFramerTestParams(message_type, true)); } } return params; @@ -66,9 +65,7 @@ std::string ParamNameFormatter( const testing::TestParamInfo<MoqtFramerTestParams>& info) { return MoqtMessageTypeToString(info.param.message_type) + "_" + - (info.param.perspective == quic::Perspective::IS_SERVER ? "Server" - : "Client") + - "_" + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); } class MoqtFramerTest @@ -76,30 +73,40 @@ public: MoqtFramerTest() : message_type_(GetParam().message_type), - is_client_(GetParam().perspective == quic::Perspective::IS_CLIENT), webtrans_(GetParam().uses_web_transport), buffer_allocator_(quiche::SimpleBufferAllocator::Get()), - framer_(buffer_allocator_, GetParam().perspective, - GetParam().uses_web_transport) {} + framer_(buffer_allocator_, GetParam().uses_web_transport) {} std::unique_ptr<TestMessageBase> MakeMessage(MoqtMessageType message_type) { switch (message_type) { - case MoqtMessageType::kObject: - return std::make_unique<ObjectMessage>(); - case MoqtMessageType::kSetup: - return std::make_unique<SetupMessage>(!is_client_, webtrans_); + case MoqtMessageType::kObjectWithPayloadLength: + return std::make_unique<ObjectMessageWithLength>(); + case MoqtMessageType::kObjectWithoutPayloadLength: + return std::make_unique<ObjectMessageWithoutLength>(); + case MoqtMessageType::kClientSetup: + return std::make_unique<ClientSetupMessage>(webtrans_); + case MoqtMessageType::kServerSetup: + return std::make_unique<ServerSetupMessage>(); case MoqtMessageType::kSubscribeRequest: return std::make_unique<SubscribeRequestMessage>(); case MoqtMessageType::kSubscribeOk: return std::make_unique<SubscribeOkMessage>(); case MoqtMessageType::kSubscribeError: return std::make_unique<SubscribeErrorMessage>(); + case MoqtMessageType::kUnsubscribe: + return std::make_unique<UnsubscribeMessage>(); + case MoqtMessageType::kSubscribeFin: + return std::make_unique<SubscribeFinMessage>(); + case MoqtMessageType::kSubscribeRst: + return std::make_unique<SubscribeRstMessage>(); case MoqtMessageType::kAnnounce: return std::make_unique<AnnounceMessage>(); case moqt::MoqtMessageType::kAnnounceOk: return std::make_unique<AnnounceOkMessage>(); case moqt::MoqtMessageType::kAnnounceError: return std::make_unique<AnnounceErrorMessage>(); + case moqt::MoqtMessageType::kUnannounce: + return std::make_unique<UnannounceMessage>(); case moqt::MoqtMessageType::kGoAway: return std::make_unique<GoAwayMessage>(); default: @@ -110,13 +117,18 @@ quiche::QuicheBuffer SerializeMessage( TestMessageBase::MessageStructuredData& structured_data) { switch (message_type_) { - case MoqtMessageType::kObject: { + case MoqtMessageType::kObjectWithPayloadLength: + case MoqtMessageType::kObjectWithoutPayloadLength: { auto data = std::get<MoqtObject>(structured_data); - return framer_.SerializeObject(data, "foo", 3); + return framer_.SerializeObject(data, "foo"); } - case MoqtMessageType::kSetup: { - auto data = std::get<MoqtSetup>(structured_data); - return framer_.SerializeSetup(data); + case MoqtMessageType::kClientSetup: { + auto data = std::get<MoqtClientSetup>(structured_data); + return framer_.SerializeClientSetup(data); + } + case MoqtMessageType::kServerSetup: { + auto data = std::get<MoqtServerSetup>(structured_data); + return framer_.SerializeServerSetup(data); } case MoqtMessageType::kSubscribeRequest: { auto data = std::get<MoqtSubscribeRequest>(structured_data); @@ -134,6 +146,14 @@ auto data = std::get<MoqtUnsubscribe>(structured_data); return framer_.SerializeUnsubscribe(data); } + case MoqtMessageType::kSubscribeFin: { + auto data = std::get<MoqtSubscribeFin>(structured_data); + return framer_.SerializeSubscribeFin(data); + } + case MoqtMessageType::kSubscribeRst: { + auto data = std::get<MoqtSubscribeRst>(structured_data); + return framer_.SerializeSubscribeRst(data); + } case MoqtMessageType::kAnnounce: { auto data = std::get<MoqtAnnounce>(structured_data); return framer_.SerializeAnnounce(data); @@ -151,13 +171,13 @@ return framer_.SerializeUnannounce(data); } case moqt::MoqtMessageType::kGoAway: { - return framer_.SerializeGoAway(); + auto data = std::get<MoqtGoAway>(structured_data); + return framer_.SerializeGoAway(data); } } } MoqtMessageType message_type_; - bool is_client_; bool webtrans_; quiche::SimpleBufferAllocator* buffer_allocator_; MoqtFramer framer_;
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc index 93a4b95..9c9021c 100644 --- a/quiche/quic/moqt/moqt_messages.cc +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -10,10 +10,14 @@ std::string MoqtMessageTypeToString(const MoqtMessageType message_type) { switch (message_type) { - case MoqtMessageType::kObject: - return "OBJECT"; - case MoqtMessageType::kSetup: - return "SETUP"; + case MoqtMessageType::kObjectWithPayloadLength: + return "OBJECT_WITH_LENGTH"; + case MoqtMessageType::kObjectWithoutPayloadLength: + return "OBJECT_WITHOUT_LENGTH"; + case MoqtMessageType::kClientSetup: + return "CLIENT_SETUP"; + case MoqtMessageType::kServerSetup: + return "SERVER_SETUP"; case MoqtMessageType::kSubscribeRequest: return "SUBSCRIBE_REQUEST"; case MoqtMessageType::kSubscribeOk: @@ -22,6 +26,10 @@ return "SUBSCRIBE_ERROR"; case MoqtMessageType::kUnsubscribe: return "UNSUBSCRIBE"; + case MoqtMessageType::kSubscribeFin: + return "SUBSCRIBE_FIN"; + case MoqtMessageType::kSubscribeRst: + return "SUBSCRIBE_RST"; case MoqtMessageType::kAnnounce: return "ANNOUNCE"; case MoqtMessageType::kAnnounceOk:
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index b137e59..7203fb3 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -36,11 +36,11 @@ // The maximum length of a message, excluding any OBJECT payload. This prevents // DoS attack via forcing the parser to buffer a large message (OBJECT payloads // are not buffered by the parser). -inline constexpr size_t kMaxMessageHeaderSize = 4096; +inline constexpr size_t kMaxMessageHeaderSize = 2048; enum class QUICHE_EXPORT MoqtMessageType : uint64_t { - kObject = 0x00, - kSetup = 0x01, + kObjectWithPayloadLength = 0x00, + kObjectWithoutPayloadLength = 0x02, kSubscribeRequest = 0x03, kSubscribeOk = 0x04, kSubscribeError = 0x05, @@ -49,7 +49,11 @@ kAnnounceError = 0x08, kUnannounce = 0x09, kUnsubscribe = 0x0a, + kSubscribeFin = 0x0b, + kSubscribeRst = 0x0c, kGoAway = 0x10, + kClientSetup = 0x40, + kServerSetup = 0x41, }; enum class QUICHE_EXPORT MoqtRole : uint64_t { @@ -64,47 +68,104 @@ }; enum class QUICHE_EXPORT MoqtTrackRequestParameter : uint64_t { - kGroupSequence = 0x0, - kObjectSequence = 0x1, + // These two should have been deleted in draft-01. + // kGroupSequence = 0x0, + // kObjectSequence = 0x1, kAuthorizationInfo = 0x2, }; -struct QUICHE_EXPORT MoqtSetup { +struct QUICHE_EXPORT MoqtClientSetup { std::vector<MoqtVersion> supported_versions; absl::optional<MoqtRole> role; absl::optional<absl::string_view> path; }; +struct QUICHE_EXPORT MoqtServerSetup { + MoqtVersion selected_version; + absl::optional<MoqtRole> role; +}; + struct QUICHE_EXPORT MoqtObject { uint64_t track_id; uint64_t group_sequence; uint64_t object_sequence; uint64_t object_send_order; + absl::optional<uint64_t> payload_length; // Message also includes the object payload. }; +enum class QUICHE_EXPORT MoqtSubscribeLocationMode : uint64_t { + kNone = 0x0, + kAbsolute = 0x1, + kRelativePrevious = 0x2, + kRelativeNext = 0x3, +}; + +// kNone: absl::optional<MoqtSubscribeLocation> is nullopt. +// kAbsolute: absolute = true +// kRelativePrevious: absolute is false; relative_value is negative +// kRelativeNext: absolute is true; relative_value is positive +struct QUICHE_EXPORT MoqtSubscribeLocation { + MoqtSubscribeLocation(bool is_absolute, uint64_t abs) + : absolute(is_absolute), absolute_value(abs) {} + MoqtSubscribeLocation(bool is_absolute, int64_t rel) + : absolute(is_absolute), relative_value(rel) {} + bool absolute; + union { + uint64_t absolute_value; + int64_t relative_value; + }; + bool operator==(const MoqtSubscribeLocation& other) const { + return absolute == other.absolute && + ((absolute && absolute_value == other.absolute_value) || + (!absolute && relative_value == other.relative_value)); + } +}; + struct QUICHE_EXPORT MoqtSubscribeRequest { absl::string_view full_track_name; - absl::optional<uint64_t> group_sequence; - absl::optional<uint64_t> object_sequence; + // If the mode is kNone, the these are absl::nullopt. + absl::optional<MoqtSubscribeLocation> start_group; + absl::optional<MoqtSubscribeLocation> start_object; + absl::optional<MoqtSubscribeLocation> end_group; + absl::optional<MoqtSubscribeLocation> end_object; absl::optional<absl::string_view> authorization_info; }; struct QUICHE_EXPORT MoqtSubscribeOk { - absl::string_view full_track_name; + absl::string_view track_namespace; + absl::string_view track_name; uint64_t track_id; // The message uses ms, but expires is in us. quic::QuicTimeDelta expires = quic::QuicTimeDelta::FromMilliseconds(0); }; struct QUICHE_EXPORT MoqtSubscribeError { - absl::string_view full_track_name; + absl::string_view track_namespace; + absl::string_view track_name; uint64_t error_code; absl::string_view reason_phrase; }; struct QUICHE_EXPORT MoqtUnsubscribe { - absl::string_view full_track_name; + absl::string_view track_namespace; + absl::string_view track_name; +}; + +struct QUICHE_EXPORT MoqtSubscribeFin { + absl::string_view track_namespace; + absl::string_view track_name; + uint64_t final_group; + uint64_t final_object; +}; + +struct QUICHE_EXPORT MoqtSubscribeRst { + absl::string_view track_namespace; + absl::string_view track_name; + uint64_t error_code; + absl::string_view reason_phrase; + uint64_t final_group; + uint64_t final_object; }; struct QUICHE_EXPORT MoqtAnnounce { @@ -126,7 +187,9 @@ absl::string_view track_namespace; }; -struct QUICHE_EXPORT MoqtGoAway {}; +struct QUICHE_EXPORT MoqtGoAway { + absl::string_view new_session_uri; +}; std::string MoqtMessageTypeToString(MoqtMessageType message_type);
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 9ed3bc7..12050af 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -4,10 +4,10 @@ #include "quiche/quic/moqt/moqt_parser.h" -#include <algorithm> #include <cstddef> #include <cstdint> #include <cstring> +#include <memory> #include <string> #include "absl/cleanup/cleanup.h" @@ -16,672 +16,519 @@ #include "absl/types/optional.h" #include "quiche/quic/core/quic_data_reader.h" #include "quiche/quic/core/quic_time.h" -#include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_logging.h" -#include "quiche/common/quiche_endian.h" namespace moqt { -namespace { - -// Minus the type, length, and payload, an OBJECT consists of 4 Varints. -constexpr size_t kMaxObjectHeaderSize = 32; - -} // namespace - // The buffering philosophy is complicated, to minimize copying. Here is an // overview: -// If the message type is present, this is stored in message_type_. If part of -// the message type varint is partially present, that is buffered (requiring a -// copy). -// Same for message length. // If the entire message body is present (except for OBJECT payload), it is // parsed and delivered. If not, the partial body is buffered. (requiring a // copy). // Any OBJECT payload is always delivered to the application without copying. // If something has been buffered, when more data arrives copy just enough of it // to finish parsing that thing, then resume normal processing. -void MoqtParser::ProcessData(absl::string_view data, bool end_of_stream) { +void MoqtParser::ProcessData(absl::string_view data, bool fin) { if (no_more_data_) { - if (!data.empty() || !end_of_stream) { - ParseError("Data after end of stream"); - } - return; + ParseError("Data after end of stream"); } if (processing_) { return; } processing_ = true; auto on_return = absl::MakeCleanup([&] { processing_ = false; }); - no_more_data_ = end_of_stream; - quic::QuicDataReader reader(data); - if (!MaybeMergeDataWithBuffer(reader, end_of_stream)) { - return; - } - if (end_of_stream && reader.IsDoneReading() && object_metadata_.has_value()) { - // A FIN arrives while delivering OBJECT payload. - visitor_.OnObjectMessage(object_metadata_.value(), data, true); - EndOfMessage(); - } - while (!reader.IsDoneReading()) { - absl::optional<size_t> processed; - if (!GetMessageTypeAndLength(reader)) { - absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); - break; - } - // Cursor is at start of the message. - if (end_of_stream && NoMessageLength()) { - *message_length_ = reader.BytesRemaining(); - } - if (*message_type_ != MoqtMessageType::kObject && - *message_type_ != MoqtMessageType::kGoAway) { - // Parse OBJECT in case the message is very large. GOAWAY is length zero, - // so always process. - if (NoMessageLength()) { - // Can't parse it yet. - absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); - break; - } - if (*message_length_ > kMaxMessageHeaderSize) { - ParseError("Message too long"); - return; - } - if (*message_length_ > reader.BytesRemaining()) { - // There definitely isn't enough to process the message. - absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); - break; - } - } - processed = ProcessMessage(FetchMessage(reader)); - if (!processed.has_value()) { - if (*message_type_ == MoqtMessageType::kObject && - (NoMessageLength() || reader.BytesRemaining() < *message_length_)) { - // The parser can attempt to process OBJECT before receiving the whole - // message length. If it doesn't parse the varints, it will buffer the - // message. - absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); - break; - } - // Non-OBJECT or OBJECT with the complete specified length, but the data - // was not parseable. - ParseError("Not able to parse message given specified length"); + // Check for early fin + if (fin) { + no_more_data_ = true; + if (object_metadata_.has_value() && + object_metadata_->payload_length.has_value() && + *object_metadata_->payload_length > data.length()) { + ParseError("End of stream before complete OBJECT PAYLOAD"); return; } - if (*processed == *message_length_) { - EndOfMessage(); - } else { - if (*message_type_ != MoqtMessageType::kObject) { - // Partial processing of non-OBJECT is not allowed. - ParseError("Specified message length too long"); + if (!buffered_message_.empty() && data.empty()) { + ParseError("End of stream before complete message"); + return; + } + } + absl::optional<quic::QuicDataReader> reader = absl::nullopt; + size_t original_buffer_size = buffered_message_.size(); + // There are three cases: the parser has already delivered an OBJECT header + // and is now delivering payload; part of a message is in the buffer; or + // no message is in progress. + if (object_metadata_.has_value()) { + // This is additional payload for an OBJECT. + QUICHE_DCHECK(buffered_message_.empty()); + if (!object_metadata_->payload_length.has_value()) { + // Deliver the data and exit. + visitor_.OnObjectMessage(object_metadata_.value(), data, fin); + if (fin) { + object_metadata_.reset(); + } + return; + } + if (data.length() < payload_length_remaining_) { + // Does not finish the payload; deliver and exit. + visitor_.OnObjectMessage(object_metadata_.value(), data, false); + payload_length_remaining_ -= data.length(); + return; + } + // Finishes the payload. Deliver and continue. + reader.emplace(data); + visitor_.OnObjectMessage(object_metadata_.value(), + data.substr(0, payload_length_remaining_), true); + reader->Seek(payload_length_remaining_); + object_metadata_.reset(); + } else if (!buffered_message_.empty()) { + absl::StrAppend(&buffered_message_, data); + reader.emplace(buffered_message_); + } else { + // No message in progress. + reader.emplace(data); + } + size_t total_processed = 0; + while (!reader->IsDoneReading()) { + size_t message_len = ProcessMessage(reader->PeekRemainingPayload(), fin); + if (message_len == 0) { + if (reader->BytesRemaining() > kMaxMessageHeaderSize) { + ParseError("Cannot parse non-OBJECT messages > 2KB"); return; } - // This is a partially processed OBJECT payload. - if (!NoMessageLength()) { - *message_length_ -= *processed; + if (fin) { + ParseError("FIN after incomplete message"); + return; } - } - if (!reader.Seek(*processed)) { - QUICHE_DCHECK(false); - ParseError("Internal Error"); - } - } - if (end_of_stream && - (!buffered_message_.empty() || object_metadata_.has_value() || - message_type_.has_value() || message_length_.has_value())) { - // If the stream is ending, there should be no message in progress. - ParseError("Incomplete message at end of stream"); - } -} - -bool MoqtParser::MaybeMergeDataWithBuffer(quic::QuicDataReader& reader, - bool end_of_stream) { - // Copy as much information as necessary from |data| to complete the - // message or OBJECT header. Minimize unnecessary copying! - if (buffered_message_.empty()) { - return true; - } - quic::QuicDataReader buffer(buffered_message_); - if (!message_length_.has_value()) { - // The buffer contains part of the message type or length. - if (buffer.BytesRemaining() > buffer.PeekVarInt62Length()) { - ParseError("Internal Error"); - QUICHE_DCHECK(false); - return false; - } - size_t bytes_needed = buffer.PeekVarInt62Length() - buffer.BytesRemaining(); - if (bytes_needed > reader.BytesRemaining()) { - // Not enough to complete! - absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); - return false; - } - absl::StrAppend(&buffered_message_, - reader.PeekRemainingPayload().substr(0, bytes_needed)); - if (!reader.Seek(bytes_needed)) { - QUICHE_DCHECK(false); - ParseError("Internal Error"); - return false; - } - quic::QuicDataReader new_buffer(buffered_message_); - uint64_t value; - if (!new_buffer.ReadVarInt62(&value)) { - QUICHE_DCHECK(false); - ParseError("Internal Error"); - return false; - } - if (message_type_.has_value()) { - message_length_ = value; - } else { - message_type_ = static_cast<MoqtMessageType>(value); - } - // GOAWAY is special. Report the message as soon as the type and length - // are complete. - if (message_type_.has_value() && message_length_.has_value() && - *message_type_ == MoqtMessageType::kGoAway) { - ProcessGoAway(new_buffer.PeekRemainingPayload()); - EndOfMessage(); - return false; - } - // Proceed to normal parsing. - buffered_message_.clear(); - return true; - } - // It's a partially buffered message - if (NoMessageLength()) { - if (end_of_stream) { - message_length_ = buffer.BytesRemaining() + reader.BytesRemaining(); - } else if (*message_type_ != MoqtMessageType::kObject) { - absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); - return false; - } - } - if (*message_type_ == MoqtMessageType::kObject) { - // OBJECT is a special case. Append up to KMaxObjectHeaderSize bytes to the - // buffer and see if that allows parsing. - QUICHE_DCHECK(!object_metadata_.has_value()); - size_t original_buffer_size = buffer.BytesRemaining(); - size_t bytes_to_pull = reader.BytesRemaining(); - // No check for *message_length_ == 0 below! Mutants will complain if there - // is a check. If message_length_ < original_buffer_size, the second - // argument will be a very large unsigned integer, which will be irrelevant - // due to std::min. - bytes_to_pull = std::min(reader.BytesRemaining(), - *message_length_ - original_buffer_size); - // Mutants complains that the line below doesn't fail any tests. This is a - // performance optimization to avoid copying large amounts of object payload - // into the buffer when only the OBJECT header will be processed. There is - // no observable behavior change if this line is removed. - bytes_to_pull = std::min(bytes_to_pull, kMaxObjectHeaderSize); - absl::StrAppend(&buffered_message_, - reader.PeekRemainingPayload().substr(0, bytes_to_pull)); - absl::optional<size_t> processed = - ProcessObjectVarints(absl::string_view(buffered_message_)); - if (!processed.has_value()) { - if ((!NoMessageLength() && - buffered_message_.length() == *message_length_) || - buffered_message_.length() > kMaxObjectHeaderSize) { - ParseError("Not able to parse buffered message given specified length"); + if (buffered_message_.empty()) { + // If the buffer is not empty, |data| has already been copied there. + absl::StrAppend(&buffered_message_, reader->PeekRemainingPayload()); } - return false; + break; } - if (*processed > 0 && !reader.Seek(*processed - original_buffer_size)) { - ParseError("Internal Error"); - return false; - } - if (*processed == *message_length_) { - // This covers an edge case where the peer has sent an OBJECT message with - // no content. - visitor_.OnObjectMessage(object_metadata_.value(), absl::string_view(), - true); - EndOfMessage(); - return true; - } - if (!NoMessageLength()) { - *message_length_ -= *processed; - } - // Object payload is never processed in the buffer. - buffered_message_.clear(); - return true; + // A message was successfully processed. + total_processed += message_len; + reader->Seek(message_len); } - size_t bytes_to_pull = - (buffer.BytesRemaining() + reader.BytesRemaining() < *message_length_) - ? reader.BytesRemaining() - : *message_length_ - buffer.BytesRemaining(); - absl::StrAppend(&buffered_message_, - reader.PeekRemainingPayload().substr(0, bytes_to_pull)); - if (!reader.Seek(bytes_to_pull)) { - QUICHE_DCHECK(false); - ParseError("Internal Error"); - return false; + if (original_buffer_size > 0) { + buffered_message_.erase(0, total_processed); } - if (buffered_message_.length() < *message_length_) { - // Not enough bytes present. - return false; - } - absl::optional<size_t> processed = - ProcessMessage(absl::string_view(buffered_message_)); - if (!processed.has_value()) { - ParseError("Not able to parse buffered message given specified length"); - return false; - } - if (*processed != *message_length_) { - ParseError("Buffered message length too long for message contents"); - return false; - } - EndOfMessage(); - return true; -} - -absl::optional<size_t> MoqtParser::ProcessMessage(absl::string_view data) { - switch (*message_type_) { - case MoqtMessageType::kObject: - return ProcessObject(data); - case MoqtMessageType::kSetup: - return ProcessSetup(data); - case MoqtMessageType::kSubscribeRequest: - return ProcessSubscribeRequest(data); - case MoqtMessageType::kSubscribeOk: - return ProcessSubscribeOk(data); - case MoqtMessageType::kSubscribeError: - return ProcessSubscribeError(data); - case MoqtMessageType::kUnsubscribe: - return ProcessUnsubscribe(data); - case MoqtMessageType::kAnnounce: - return ProcessAnnounce(data); - case MoqtMessageType::kAnnounceOk: - return ProcessAnnounceOk(data); - case MoqtMessageType::kAnnounceError: - return ProcessAnnounceError(data); - case MoqtMessageType::kUnannounce: - return ProcessUnannounce(data); - case MoqtMessageType::kGoAway: - return ProcessGoAway(data); - default: - ParseError("Unknown message type"); - return absl::nullopt; + if (fin && object_metadata_.has_value()) { + ParseError("Received FIN mid-payload"); } } -absl::optional<size_t> MoqtParser::ProcessObjectVarints( - absl::string_view data) { - if (object_metadata_.has_value()) { +size_t MoqtParser::ProcessMessage(absl::string_view data, bool fin) { + uint64_t value; + quic::QuicDataReader reader(data); + if (!reader.ReadVarInt62(&value)) { return 0; } + auto type = static_cast<MoqtMessageType>(value); + switch (type) { + case MoqtMessageType::kObjectWithPayloadLength: + return ProcessObject(reader, true, fin); + case MoqtMessageType::kObjectWithoutPayloadLength: + return ProcessObject(reader, false, fin); + case MoqtMessageType::kClientSetup: + return ProcessClientSetup(reader); + case MoqtMessageType::kServerSetup: + return ProcessServerSetup(reader); + case MoqtMessageType::kSubscribeRequest: + return ProcessSubscribeRequest(reader); + case MoqtMessageType::kSubscribeOk: + return ProcessSubscribeOk(reader); + case MoqtMessageType::kSubscribeError: + return ProcessSubscribeError(reader); + case MoqtMessageType::kUnsubscribe: + return ProcessUnsubscribe(reader); + case MoqtMessageType::kSubscribeFin: + return ProcessSubscribeFin(reader); + case MoqtMessageType::kSubscribeRst: + return ProcessSubscribeRst(reader); + case MoqtMessageType::kAnnounce: + return ProcessAnnounce(reader); + case MoqtMessageType::kAnnounceOk: + return ProcessAnnounceOk(reader); + case MoqtMessageType::kAnnounceError: + return ProcessAnnounceError(reader); + case MoqtMessageType::kUnannounce: + return ProcessUnannounce(reader); + case MoqtMessageType::kGoAway: + return ProcessGoAway(reader); + default: + ParseError("Unknown message type"); + return 0; + } +} + +size_t MoqtParser::ProcessObject(quic::QuicDataReader& reader, bool has_length, + bool fin) { + QUICHE_DCHECK(!object_metadata_.has_value()); object_metadata_ = MoqtObject(); - quic::QuicDataReader reader(data); - if (reader.ReadVarInt62(&object_metadata_->track_id) && - reader.ReadVarInt62(&object_metadata_->group_sequence) && - reader.ReadVarInt62(&object_metadata_->object_sequence) && - reader.ReadVarInt62(&object_metadata_->object_send_order)) { - return reader.PreviouslyReadPayload().length(); + uint64_t length; + if (!reader.ReadVarInt62(&object_metadata_->track_id) || + !reader.ReadVarInt62(&object_metadata_->group_sequence) || + !reader.ReadVarInt62(&object_metadata_->object_sequence) || + !reader.ReadVarInt62(&object_metadata_->object_send_order) || + (has_length && !reader.ReadVarInt62(&length))) { + object_metadata_.reset(); + return 0; } - object_metadata_ = absl::nullopt; - QUICHE_DCHECK(reader.PreviouslyReadPayload().length() < kMaxObjectHeaderSize); - return absl::nullopt; -} - -absl::optional<size_t> MoqtParser::ProcessObject(absl::string_view data) { - quic::QuicDataReader reader(data); - size_t payload_length = *message_length_; - absl::optional<size_t> processed = ProcessObjectVarints(data); - if (!processed.has_value() && !object_metadata_.has_value()) { - // Could not obtain the whole object header. - return absl::nullopt; + if (has_length) { + object_metadata_->payload_length = length; } - if (!reader.Seek(*processed)) { - ParseError("Internal Error"); - return absl::nullopt; - } - if (payload_length != 0) { - payload_length -= *processed; - } - QUICHE_DCHECK(NoMessageLength() || reader.BytesRemaining() <= payload_length); + bool received_complete_message = + (fin && !has_length) || + (has_length && + *object_metadata_->payload_length <= reader.BytesRemaining()); + size_t payload_to_draw = (!has_length || *object_metadata_->payload_length >= + reader.BytesRemaining()) + ? reader.BytesRemaining() + : *object_metadata_->payload_length; visitor_.OnObjectMessage( - object_metadata_.value(), reader.PeekRemainingPayload(), - reader.BytesRemaining() == payload_length && !NoMessageLength()); - return data.length(); + object_metadata_.value(), + reader.PeekRemainingPayload().substr(0, payload_to_draw), + received_complete_message); + if (received_complete_message) { + object_metadata_.reset(); + } + reader.Seek(payload_to_draw); + payload_length_remaining_ = length - payload_to_draw; + return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessSetup(absl::string_view data) { - MoqtSetup setup; - quic::QuicDataReader reader(data); +size_t MoqtParser::ProcessClientSetup(quic::QuicDataReader& reader) { + MoqtClientSetup setup; uint64_t number_of_supported_versions; - if (perspective_ == quic::Perspective::IS_SERVER) { - if (!reader.ReadVarInt62(&number_of_supported_versions)) { - return absl::nullopt; - } - } else { - number_of_supported_versions = 1; + if (!reader.ReadVarInt62(&number_of_supported_versions)) { + return 0; } - uint64_t value; + uint64_t version; for (uint64_t i = 0; i < number_of_supported_versions; ++i) { - if (!reader.ReadVarInt62(&value)) { - return absl::nullopt; + if (!reader.ReadVarInt62(&version)) { + return 0; } - setup.supported_versions.push_back(static_cast<MoqtVersion>(value)); + setup.supported_versions.push_back(static_cast<MoqtVersion>(version)); + } + uint64_t num_params; + if (!reader.ReadVarInt62(&num_params)) { + return 0; } // Parse parameters - while (!reader.IsDoneReading()) { - if (!reader.ReadVarInt62(&value)) { - return absl::nullopt; + for (uint64_t i = 0; i < num_params; ++i) { + uint64_t type; + absl::string_view value; + if (!ReadParameter(reader, type, value)) { + return 0; } - auto parameter_key = static_cast<MoqtSetupParameter>(value); - absl::string_view field; - switch (parameter_key) { + auto key = static_cast<MoqtSetupParameter>(type); + switch (key) { case MoqtSetupParameter::kRole: if (setup.role.has_value()) { ParseError("ROLE parameter appears twice in SETUP"); - return absl::nullopt; + return 0; } - if (!ReadIntegerPieceVarInt62(reader, value)) { - return absl::nullopt; - } - setup.role = static_cast<MoqtRole>(value); + uint64_t index; + StringViewToVarInt(value, index); + setup.role = static_cast<MoqtRole>(index); break; case MoqtSetupParameter::kPath: if (uses_web_transport_) { ParseError( "WebTransport connection is using PATH parameter in SETUP"); - return absl::nullopt; - } - if (perspective_ == quic::Perspective::IS_CLIENT) { - ParseError("PATH parameter sent by server in SETUP"); - return absl::nullopt; + return 0; } if (setup.path.has_value()) { - ParseError("PATH parameter appears twice in SETUP"); - return absl::nullopt; + ParseError("PATH parameter appears twice in CLIENT_SETUP"); + return 0; } - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } - setup.path = field; + setup.path = value; break; default: // Skip over the parameter. - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } break; } } - if (perspective_ == quic::Perspective::IS_SERVER) { - if (!setup.role.has_value()) { - ParseError("ROLE SETUP parameter missing from Client message"); - return absl::nullopt; - } - if (!uses_web_transport_ && !setup.path.has_value()) { - ParseError("PATH SETUP parameter missing from Client message over QUIC"); - return absl::nullopt; - } + if (!setup.role.has_value()) { + ParseError("ROLE parameter missing from CLIENT_SETUP message"); + return 0; } - visitor_.OnSetupMessage(setup); + if (!uses_web_transport_ && !setup.path.has_value()) { + ParseError("PATH SETUP parameter missing from Client message over QUIC"); + return 0; + } + visitor_.OnClientSetupMessage(setup); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessSubscribeRequest( - absl::string_view data) { - MoqtSubscribeRequest subscribe_request; - quic::QuicDataReader reader(data); - absl::string_view field; - if (!reader.ReadStringPieceVarInt62(&subscribe_request.full_track_name)) { - return absl::nullopt; +size_t MoqtParser::ProcessServerSetup(quic::QuicDataReader& reader) { + MoqtServerSetup setup; + uint64_t version; + if (!reader.ReadVarInt62(&version)) { + return 0; } - uint64_t value; - while (!reader.IsDoneReading()) { - if (!reader.ReadVarInt62(&value)) { - return absl::nullopt; + setup.selected_version = static_cast<MoqtVersion>(version); + uint64_t num_params; + if (!reader.ReadVarInt62(&num_params)) { + return 0; + } + // Parse parameters + for (uint64_t i = 0; i < num_params; ++i) { + uint64_t type; + absl::string_view value; + if (!ReadParameter(reader, type, value)) { + return 0; } - auto parameter_key = static_cast<MoqtTrackRequestParameter>(value); - switch (parameter_key) { - case MoqtTrackRequestParameter::kGroupSequence: - if (subscribe_request.group_sequence.has_value()) { - ParseError( - "GROUP_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); - return absl::nullopt; + auto key = static_cast<MoqtSetupParameter>(type); + switch (key) { + case MoqtSetupParameter::kRole: + if (setup.role.has_value()) { + ParseError("ROLE parameter appears twice in SETUP"); + return 0; } - if (!ReadIntegerPieceVarInt62(reader, value)) { - return absl::nullopt; - } - subscribe_request.group_sequence = value; + uint64_t index; + StringViewToVarInt(value, index); + setup.role = static_cast<MoqtRole>(index); break; - case MoqtTrackRequestParameter::kObjectSequence: - if (subscribe_request.object_sequence.has_value()) { - ParseError( - "OBJECT_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); - return absl::nullopt; - } - if (!ReadIntegerPieceVarInt62(reader, value)) { - return absl::nullopt; - } - subscribe_request.object_sequence = value; + case MoqtSetupParameter::kPath: + ParseError("PATH parameter in SERVER_SETUP"); + return 0; + default: + // Skip over the parameter. break; + } + } + visitor_.OnServerSetupMessage(setup); + return reader.PreviouslyReadPayload().length(); +} + +size_t MoqtParser::ProcessSubscribeRequest(quic::QuicDataReader& reader) { + MoqtSubscribeRequest subscribe_request; + if (!reader.ReadStringPieceVarInt62(&subscribe_request.full_track_name)) { + return 0; + } + if (!ReadLocation(reader, subscribe_request.start_group)) { + return 0; + } + if (!subscribe_request.start_group.has_value()) { + ParseError("START_GROUP must not be None in SUBSCRIBE_REQUEST"); + return 0; + } + if (!ReadLocation(reader, subscribe_request.start_object)) { + return 0; + } + if (!subscribe_request.start_object.has_value()) { + ParseError("START_OBJECT must not be None in SUBSCRIBE_REQUEST"); + return 0; + } + if (!ReadLocation(reader, subscribe_request.end_group)) { + return 0; + } + if (!ReadLocation(reader, subscribe_request.end_object)) { + return 0; + } + if (subscribe_request.end_group.has_value() != + subscribe_request.end_object.has_value()) { + ParseError( + "SUBSCRIBE_REQUEST end_group and end_object must be both None " + "or both non_None"); + return 0; + } + uint64_t num_params; + if (!reader.ReadVarInt62(&num_params)) { + return 0; + } + for (uint64_t i = 0; i < num_params; ++i) { + uint64_t type; + absl::string_view value; + if (!ReadParameter(reader, type, value)) { + return 0; + } + auto key = static_cast<MoqtTrackRequestParameter>(type); + switch (key) { case MoqtTrackRequestParameter::kAuthorizationInfo: if (subscribe_request.authorization_info.has_value()) { ParseError( "AUTHORIZATION_INFO parameter appears twice in " "SUBSCRIBE_REQUEST"); - return absl::nullopt; + return 0; } - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } - subscribe_request.authorization_info = field; + subscribe_request.authorization_info = value; break; default: // Skip over the parameter. - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } break; } } - if (reader.IsDoneReading()) { - visitor_.OnSubscribeRequestMessage(subscribe_request); - } + visitor_.OnSubscribeRequestMessage(subscribe_request); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessSubscribeOk(absl::string_view data) { +size_t MoqtParser::ProcessSubscribeOk(quic::QuicDataReader& reader) { MoqtSubscribeOk subscribe_ok; - quic::QuicDataReader reader(data); - if (!reader.ReadStringPieceVarInt62(&subscribe_ok.full_track_name)) { - return absl::nullopt; + if (!reader.ReadStringPieceVarInt62(&subscribe_ok.track_namespace)) { + return 0; + } + if (!reader.ReadStringPieceVarInt62(&subscribe_ok.track_name)) { + return 0; } if (!reader.ReadVarInt62(&subscribe_ok.track_id)) { - return absl::nullopt; + return 0; } uint64_t milliseconds; if (!reader.ReadVarInt62(&milliseconds)) { - return absl::nullopt; + return 0; } subscribe_ok.expires = quic::QuicTimeDelta::FromMilliseconds(milliseconds); - if (reader.IsDoneReading()) { - visitor_.OnSubscribeOkMessage(subscribe_ok); - } + visitor_.OnSubscribeOkMessage(subscribe_ok); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessSubscribeError( - absl::string_view data) { +size_t MoqtParser::ProcessSubscribeError(quic::QuicDataReader& reader) { MoqtSubscribeError subscribe_error; - quic::QuicDataReader reader(data); - if (!reader.ReadStringPieceVarInt62(&subscribe_error.full_track_name)) { - return absl::nullopt; + if (!reader.ReadStringPieceVarInt62(&subscribe_error.track_namespace)) { + return 0; + } + if (!reader.ReadStringPieceVarInt62(&subscribe_error.track_name)) { + return 0; } if (!reader.ReadVarInt62(&subscribe_error.error_code)) { - return absl::nullopt; + return 0; } if (!reader.ReadStringPieceVarInt62(&subscribe_error.reason_phrase)) { - return absl::nullopt; + return 0; } - if (reader.IsDoneReading()) { - visitor_.OnSubscribeErrorMessage(subscribe_error); - } + visitor_.OnSubscribeErrorMessage(subscribe_error); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessUnsubscribe(absl::string_view data) { +size_t MoqtParser::ProcessUnsubscribe(quic::QuicDataReader& reader) { MoqtUnsubscribe unsubscribe; - quic::QuicDataReader reader(data); - if (!reader.ReadStringPieceVarInt62(&unsubscribe.full_track_name)) { - return absl::nullopt; + if (!reader.ReadStringPieceVarInt62(&unsubscribe.track_namespace)) { + return 0; } - if (reader.IsDoneReading()) { - visitor_.OnUnsubscribeMessage(unsubscribe); + if (!reader.ReadStringPieceVarInt62(&unsubscribe.track_name)) { + return 0; } + visitor_.OnUnsubscribeMessage(unsubscribe); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessAnnounce(absl::string_view data) { - MoqtAnnounce announce; - quic::QuicDataReader reader(data); - absl::string_view field; - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; +size_t MoqtParser::ProcessSubscribeFin(quic::QuicDataReader& reader) { + MoqtSubscribeFin subscribe_fin; + if (!reader.ReadStringPieceVarInt62(&subscribe_fin.track_namespace)) { + return 0; } - announce.track_namespace = field; - bool saw_group_sequence = false, saw_object_sequence = false; - while (!reader.IsDoneReading()) { - uint64_t value; - if (!reader.ReadVarInt62(&value)) { - return absl::nullopt; + if (!reader.ReadStringPieceVarInt62(&subscribe_fin.track_name)) { + return 0; + } + if (!reader.ReadVarInt62(&subscribe_fin.final_group)) { + return 0; + } + if (!reader.ReadVarInt62(&subscribe_fin.final_object)) { + return 0; + } + visitor_.OnSubscribeFinMessage(subscribe_fin); + return reader.PreviouslyReadPayload().length(); +} + +size_t MoqtParser::ProcessSubscribeRst(quic::QuicDataReader& reader) { + MoqtSubscribeRst subscribe_rst; + if (!reader.ReadStringPieceVarInt62(&subscribe_rst.track_namespace)) { + return 0; + } + if (!reader.ReadStringPieceVarInt62(&subscribe_rst.track_name)) { + return 0; + } + if (!reader.ReadVarInt62(&subscribe_rst.error_code)) { + return 0; + } + if (!reader.ReadStringPieceVarInt62(&subscribe_rst.reason_phrase)) { + return 0; + } + if (!reader.ReadVarInt62(&subscribe_rst.final_group)) { + return 0; + } + if (!reader.ReadVarInt62(&subscribe_rst.final_object)) { + return 0; + } + visitor_.OnSubscribeRstMessage(subscribe_rst); + return reader.PreviouslyReadPayload().length(); +} + +size_t MoqtParser::ProcessAnnounce(quic::QuicDataReader& reader) { + MoqtAnnounce announce; + if (!reader.ReadStringPieceVarInt62(&announce.track_namespace)) { + return 0; + } + uint64_t num_params; + if (!reader.ReadVarInt62(&num_params)) { + return 0; + } + for (uint64_t i = 0; i < num_params; ++i) { + uint64_t type; + absl::string_view value; + if (!ReadParameter(reader, type, value)) { + return 0; } - auto parameter_key = static_cast<MoqtTrackRequestParameter>(value); - switch (parameter_key) { - case MoqtTrackRequestParameter::kGroupSequence: - // Not used, but check for duplicates. - if (saw_group_sequence) { - ParseError("GROUP_SEQUENCE parameter appears twice in ANNOUNCE"); - return absl::nullopt; - } - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } - saw_group_sequence = true; - break; - case MoqtTrackRequestParameter::kObjectSequence: - // Not used, but check for duplicates. - if (saw_object_sequence) { - ParseError("OBJECT_SEQUENCE parameter appears twice in ANNOUNCE"); - return absl::nullopt; - } - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } - saw_object_sequence = true; - break; + auto key = static_cast<MoqtTrackRequestParameter>(type); + switch (key) { case MoqtTrackRequestParameter::kAuthorizationInfo: if (announce.authorization_info.has_value()) { ParseError("AUTHORIZATION_INFO parameter appears twice in ANNOUNCE"); - return absl::nullopt; + return 0; } - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } - announce.authorization_info = field; + announce.authorization_info = value; break; default: // Skip over the parameter. - if (!reader.ReadStringPieceVarInt62(&field)) { - return absl::nullopt; - } break; } } - if (reader.IsDoneReading()) { - visitor_.OnAnnounceMessage(announce); - } + visitor_.OnAnnounceMessage(announce); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessAnnounceOk(absl::string_view data) { +size_t MoqtParser::ProcessAnnounceOk(quic::QuicDataReader& reader) { MoqtAnnounceOk announce_ok; - quic::QuicDataReader reader(data); if (!reader.ReadStringPieceVarInt62(&announce_ok.track_namespace)) { - return absl::nullopt; + return 0; } - if (reader.IsDoneReading()) { - visitor_.OnAnnounceOkMessage(announce_ok); - } + visitor_.OnAnnounceOkMessage(announce_ok); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessAnnounceError( - absl::string_view data) { +size_t MoqtParser::ProcessAnnounceError(quic::QuicDataReader& reader) { MoqtAnnounceError announce_error; - quic::QuicDataReader reader(data); if (!reader.ReadStringPieceVarInt62(&announce_error.track_namespace)) { - return absl::nullopt; + return 0; } if (!reader.ReadVarInt62(&announce_error.error_code)) { - return absl::nullopt; + return 0; } if (!reader.ReadStringPieceVarInt62(&announce_error.reason_phrase)) { - return absl::nullopt; + return 0; } - if (reader.IsDoneReading()) { - visitor_.OnAnnounceErrorMessage(announce_error); - } + visitor_.OnAnnounceErrorMessage(announce_error); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessUnannounce(absl::string_view data) { +size_t MoqtParser::ProcessUnannounce(quic::QuicDataReader& reader) { MoqtUnannounce unannounce; - quic::QuicDataReader reader(data); if (!reader.ReadStringPieceVarInt62(&unannounce.track_namespace)) { - return absl::nullopt; + return 0; } - if (reader.IsDoneReading()) { - visitor_.OnUnannounceMessage(unannounce); - } + visitor_.OnUnannounceMessage(unannounce); return reader.PreviouslyReadPayload().length(); } -absl::optional<size_t> MoqtParser::ProcessGoAway(absl::string_view data) { - if (!data.empty()) { - // GOAWAY can only be followed by end_of_stream. Anything else is an error. - ParseError("GOAWAY has data following"); - return absl::nullopt; +size_t MoqtParser::ProcessGoAway(quic::QuicDataReader& reader) { + MoqtGoAway goaway; + if (!reader.ReadStringPieceVarInt62(&goaway.new_session_uri)) { + return 0; } - visitor_.OnGoAwayMessage(); - return 0; -} - -bool MoqtParser::GetMessageTypeAndLength(quic::QuicDataReader& reader) { - if (!message_type_.has_value()) { - uint64_t value; - if (!reader.ReadVarInt62(&value)) { - return false; - } - message_type_ = static_cast<MoqtMessageType>(value); - } - if (!message_length_.has_value()) { - uint64_t value; - if (!reader.ReadVarInt62(&value)) { - return false; - } - message_length_ = value; - } - return true; -} - -void MoqtParser::EndOfMessage() { - buffered_message_.clear(); - message_type_ = absl::nullopt; - message_length_ = absl::nullopt; - object_metadata_ = absl::nullopt; -} - -absl::string_view MoqtParser::FetchMessage(quic::QuicDataReader& reader) { - if (message_length_ == 0) { - return reader.PeekRemainingPayload(); - } - if (message_length_ > reader.BytesRemaining()) { - QUICHE_DCHECK(message_type_ == MoqtMessageType::kObject); - return reader.PeekRemainingPayload(); - } - return reader.PeekRemainingPayload().substr(0, *message_length_); + visitor_.OnGoAwayMessage(goaway); + return reader.PreviouslyReadPayload().length(); } void MoqtParser::ParseError(absl::string_view reason) { @@ -693,20 +540,69 @@ visitor_.OnParsingError(reason); } -bool MoqtParser::ReadIntegerPieceVarInt62(quic::QuicDataReader& reader, - uint64_t& result) { - absl::string_view field; - if (!reader.ReadStringPieceVarInt62(&field)) { +bool MoqtParser::ReadVarIntPieceVarInt62(quic::QuicDataReader& reader, + uint64_t& result) { + uint64_t length; + if (!reader.ReadVarInt62(&length)) { return false; } - if (field.size() > sizeof(uint64_t)) { - ParseError("Cannot parse explicit length integers longer than 8 bytes"); + uint64_t actual_length = static_cast<uint64_t>(reader.PeekVarInt62Length()); + if (length != actual_length) { + ParseError("Parameter VarInt has length field mismatch"); return false; } - result = 0; - memcpy((uint8_t*)&result + sizeof(result) - field.size(), field.data(), - field.size()); - result = quiche::QuicheEndian::NetToHost64(result); + if (!reader.ReadVarInt62(&result)) { + return false; + } + return true; +} + +bool MoqtParser::ReadLocation(quic::QuicDataReader& reader, + absl::optional<MoqtSubscribeLocation>& loc) { + uint64_t ui64; + if (!reader.ReadVarInt62(&ui64)) { + return false; + } + auto mode = static_cast<MoqtSubscribeLocationMode>(ui64); + if (mode == MoqtSubscribeLocationMode::kNone) { + loc = absl::nullopt; + return true; + } + if (!reader.ReadVarInt62(&ui64)) { + return false; + } + switch (mode) { + case MoqtSubscribeLocationMode::kAbsolute: + loc = MoqtSubscribeLocation(true, ui64); + break; + case MoqtSubscribeLocationMode::kRelativePrevious: + loc = MoqtSubscribeLocation(false, -1 * static_cast<int64_t>(ui64)); + break; + case MoqtSubscribeLocationMode::kRelativeNext: + loc = MoqtSubscribeLocation(false, static_cast<int64_t>(ui64)); + break; + default: + ParseError("Unknown location mode"); + return false; + } + return true; +} + +bool MoqtParser::ReadParameter(quic::QuicDataReader& reader, uint64_t& type, + absl::string_view& value) { + if (!reader.ReadVarInt62(&type)) { + return false; + } + return reader.ReadStringPieceVarInt62(&value); +} + +bool MoqtParser::StringViewToVarInt(absl::string_view& sv, uint64_t& vi) { + quic::QuicDataReader reader(sv); + if (static_cast<size_t>(reader.PeekVarInt62Length()) != sv.length()) { + ParseError("Parameter length does not match varint encoding"); + return false; + } + reader.ReadVarInt62(&vi); return true; }
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index bb23546..b45a707 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -9,13 +9,11 @@ #include <cstddef> #include <cstdint> -#include <memory> #include <string> #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "quiche/quic/core/quic_data_reader.h" -#include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_export.h" @@ -32,105 +30,92 @@ virtual void OnObjectMessage(const MoqtObject& message, absl::string_view payload, bool end_of_message) = 0; - // All of these are called only when the entire specified message length has - // arrived, which requires a stream FIN if the length is zero. The parser - // retains ownership of the memory. - virtual void OnSetupMessage(const MoqtSetup& message) = 0; + // All of these are called only when the entire message has arrived. The + // parser retains ownership of the memory. + virtual void OnClientSetupMessage(const MoqtClientSetup& message) = 0; + virtual void OnServerSetupMessage(const MoqtServerSetup& message) = 0; virtual void OnSubscribeRequestMessage( const MoqtSubscribeRequest& message) = 0; virtual void OnSubscribeOkMessage(const MoqtSubscribeOk& message) = 0; virtual void OnSubscribeErrorMessage(const MoqtSubscribeError& message) = 0; virtual void OnUnsubscribeMessage(const MoqtUnsubscribe& message) = 0; + virtual void OnSubscribeFinMessage(const MoqtSubscribeFin& message) = 0; + virtual void OnSubscribeRstMessage(const MoqtSubscribeRst& message) = 0; virtual void OnAnnounceMessage(const MoqtAnnounce& message) = 0; virtual void OnAnnounceOkMessage(const MoqtAnnounceOk& message) = 0; virtual void OnAnnounceErrorMessage(const MoqtAnnounceError& message) = 0; virtual void OnUnannounceMessage(const MoqtUnannounce& message) = 0; - // In an exception to the above, the parser calls this when it gets two bytes, - // whether or not it includes stream FIN. When a zero-length message has - // special meaning, a message with an actual length of zero is tricky! - virtual void OnGoAwayMessage() = 0; + virtual void OnGoAwayMessage(const MoqtGoAway& message) = 0; virtual void OnParsingError(absl::string_view reason) = 0; }; class QUICHE_EXPORT MoqtParser { public: - MoqtParser(quic::Perspective perspective, bool uses_web_transport, - MoqtParserVisitor& visitor) - : visitor_(visitor), - perspective_(perspective), - uses_web_transport_(uses_web_transport) {} + MoqtParser(bool uses_web_transport, MoqtParserVisitor& visitor) + : visitor_(visitor), uses_web_transport_(uses_web_transport) {} ~MoqtParser() = default; // Take a buffer from the transport in |data|. Parse each complete message and - // call the appropriate visitor function. If |end_of_stream| is true, there + // call the appropriate visitor function. If |fin| is true, there // is no more data arriving on the stream, so the parser will deliver any // message encoded as to run to the end of the stream. // All bytes can be freed. Calls OnParsingError() when there is a parsing // error. - // Any calls after sending |end_of_stream| = true will be ignored. - void ProcessData(absl::string_view data, bool end_of_stream); + // Any calls after sending |fin| = true will be ignored. + void ProcessData(absl::string_view data, bool fin); private: - // Copies the minimum amount of data in |reader| to buffered_message_ in order - // to process what is in there, and does the processing. Returns true if - // additional processing can occur, false otherwise. - bool MaybeMergeDataWithBuffer(quic::QuicDataReader& reader, - bool end_of_stream); - // The central switch statement to dispatch a message to the correct - // Process* function. Returns nullopt if it could not parse the full messsage + // Process* function. Returns 0 if it could not parse the full messsage // (except for object payload). Otherwise, returns the number of bytes // processed. - absl::optional<size_t> ProcessMessage(absl::string_view data); - // A helper function to parse just the varints in an OBJECT. - absl::optional<size_t> ProcessObjectVarints(absl::string_view data); + size_t ProcessMessage(absl::string_view data, bool fin); // The Process* functions parse the serialized data into the appropriate // structs, and call the relevant visitor function for further action. Returns - // the number of bytes consumed if the message is complete; returns nullopt - // otherwise. These functions can throw a fatal error if the message length - // is insufficient. - absl::optional<size_t> ProcessObject(absl::string_view data); - absl::optional<size_t> ProcessSetup(absl::string_view data); - absl::optional<size_t> ProcessSubscribeRequest(absl::string_view data); - absl::optional<size_t> ProcessSubscribeOk(absl::string_view data); - absl::optional<size_t> ProcessSubscribeError(absl::string_view data); - absl::optional<size_t> ProcessUnsubscribe(absl::string_view data); - absl::optional<size_t> ProcessAnnounce(absl::string_view data); - absl::optional<size_t> ProcessAnnounceOk(absl::string_view data); - absl::optional<size_t> ProcessAnnounceError(absl::string_view data); - absl::optional<size_t> ProcessUnannounce(absl::string_view data); - absl::optional<size_t> ProcessGoAway(absl::string_view data); + // the number of bytes consumed if the message is complete; returns 0 + // otherwise. + size_t ProcessObject(quic::QuicDataReader& reader, bool has_length, bool fin); + size_t ProcessClientSetup(quic::QuicDataReader& reader); + size_t ProcessServerSetup(quic::QuicDataReader& reader); + size_t ProcessSubscribeRequest(quic::QuicDataReader& reader); + size_t ProcessSubscribeOk(quic::QuicDataReader& reader); + size_t ProcessSubscribeError(quic::QuicDataReader& reader); + size_t ProcessUnsubscribe(quic::QuicDataReader& reader); + size_t ProcessSubscribeFin(quic::QuicDataReader& reader); + size_t ProcessSubscribeRst(quic::QuicDataReader& reader); + size_t ProcessAnnounce(quic::QuicDataReader& reader); + size_t ProcessAnnounceOk(quic::QuicDataReader& reader); + size_t ProcessAnnounceError(quic::QuicDataReader& reader); + size_t ProcessUnannounce(quic::QuicDataReader& reader); + size_t ProcessGoAway(quic::QuicDataReader& reader); - // If the message length field is zero, it runs to the end of the stream. - bool NoMessageLength() { return *message_length_ == 0; } - // If type and or length are not already stored for this message, reads it out - // of the data in |reader| and stores it in the appropriate members. Returns - // false if length is not available. - bool GetMessageTypeAndLength(quic::QuicDataReader& reader); - void EndOfMessage(); - // Get a string_view of the part of the reader covered by message_length_, - // with exceptions for OBJECT messages. - absl::string_view FetchMessage(quic::QuicDataReader& reader); void ParseError(absl::string_view reason); // Reads an integer whose length is specified by a preceding VarInt62 and // returns it in |result|. Returns false if parsing fails. - bool ReadIntegerPieceVarInt62(quic::QuicDataReader& reader, uint64_t& result); + bool ReadVarIntPieceVarInt62(quic::QuicDataReader& reader, uint64_t& result); + // Read a Location field from SUBSCRIBE REQUEST + bool ReadLocation(quic::QuicDataReader& reader, + absl::optional<MoqtSubscribeLocation>& loc); + // Read a parameter and return the value as a string_view. Returns false if + // |reader| does not have enough data. + bool ReadParameter(quic::QuicDataReader& reader, uint64_t& type, + absl::string_view& value); + // Convert a string view to a varint. Throws an error and returns false if the + // string_view is not exactly the right length. + bool StringViewToVarInt(absl::string_view& sv, uint64_t& vi); MoqtParserVisitor& visitor_; - // Client or server? - quic::Perspective perspective_; bool uses_web_transport_; - bool no_more_data_ = false; // Fatal error or end_of_stream. No more parsing. + bool no_more_data_ = false; // Fatal error or fin. No more parsing. bool parsing_error_ = false; std::string buffered_message_; - absl::optional<MoqtMessageType> message_type_ = absl::nullopt; - absl::optional<size_t> message_length_ = absl::nullopt; // Metadata for an object which is delivered in parts. absl::optional<MoqtObject> object_metadata_ = absl::nullopt; + size_t payload_length_remaining_; bool processing_ = false; // True if currently in ProcessData(), to prevent // re-entrancy.
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index cd4c40b..ae6d545 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -13,54 +13,64 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_data_writer.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/test_tools/moqt_test_message.h" #include "quiche/quic/platform/api/quic_test.h" namespace moqt::test { +namespace { + +bool IsObjectMessage(MoqtMessageType type) { + return (type == MoqtMessageType::kObjectWithPayloadLength || + type == MoqtMessageType::kObjectWithoutPayloadLength); +} + +std::vector<MoqtMessageType> message_types = { + MoqtMessageType::kObjectWithPayloadLength, + MoqtMessageType::kObjectWithoutPayloadLength, + MoqtMessageType::kClientSetup, + MoqtMessageType::kServerSetup, + MoqtMessageType::kSubscribeRequest, + MoqtMessageType::kSubscribeOk, + MoqtMessageType::kSubscribeError, + MoqtMessageType::kUnsubscribe, + MoqtMessageType::kSubscribeFin, + MoqtMessageType::kSubscribeRst, + MoqtMessageType::kAnnounce, + MoqtMessageType::kAnnounceOk, + MoqtMessageType::kAnnounceError, + MoqtMessageType::kUnannounce, + MoqtMessageType::kGoAway, +}; + +} // namespace + struct MoqtParserTestParams { - MoqtParserTestParams(MoqtMessageType message_type, - quic::Perspective perspective, bool uses_web_transport) - : message_type(message_type), - perspective(perspective), - uses_web_transport(uses_web_transport) {} + MoqtParserTestParams(MoqtMessageType message_type, bool uses_web_transport) + : message_type(message_type), uses_web_transport(uses_web_transport) {} MoqtMessageType message_type; - quic::Perspective perspective; bool uses_web_transport; }; std::vector<MoqtParserTestParams> GetMoqtParserTestParams() { std::vector<MoqtParserTestParams> params; - std::vector<MoqtMessageType> message_types = { - MoqtMessageType::kObject, MoqtMessageType::kSetup, - MoqtMessageType::kSubscribeRequest, MoqtMessageType::kSubscribeOk, - MoqtMessageType::kSubscribeError, MoqtMessageType::kAnnounce, - MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, - MoqtMessageType::kGoAway, - }; - std::vector<quic::Perspective> perspectives = { - quic::Perspective::IS_SERVER, - quic::Perspective::IS_CLIENT, - }; + std::vector<bool> uses_web_transport_bool = { false, true, }; for (const MoqtMessageType message_type : message_types) { - if (message_type == MoqtMessageType::kSetup) { - for (const quic::Perspective perspective : perspectives) { - for (const bool uses_web_transport : uses_web_transport_bool) { - params.push_back(MoqtParserTestParams(message_type, perspective, - uses_web_transport)); - } + if (message_type == MoqtMessageType::kClientSetup) { + for (const bool uses_web_transport : uses_web_transport_bool) { + params.push_back( + MoqtParserTestParams(message_type, uses_web_transport)); } } else { // All other types are processed the same for either perspective or // transport. - params.push_back(MoqtParserTestParams( - message_type, quic::Perspective::IS_SERVER, true)); + params.push_back(MoqtParserTestParams(message_type, true)); } } return params; @@ -69,9 +79,7 @@ std::string ParamNameFormatter( const testing::TestParamInfo<MoqtParserTestParams>& info) { return MoqtMessageTypeToString(info.param.message_type) + "_" + - (info.param.perspective == quic::Perspective::IS_SERVER ? "Server" - : "Client") + - "_" + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); } class MoqtParserTestVisitor : public MoqtParserVisitor { @@ -80,20 +88,27 @@ void OnObjectMessage(const MoqtObject& message, absl::string_view payload, bool end_of_message) override { + MoqtObject object = message; object_payload_ = payload; end_of_message_ = end_of_message; messages_received_++; - last_message_ = TestMessageBase::MessageStructuredData(message); + last_message_ = TestMessageBase::MessageStructuredData(object); } - void OnSetupMessage(const MoqtSetup& message) override { + void OnClientSetupMessage(const MoqtClientSetup& message) override { end_of_message_ = true; messages_received_++; - MoqtSetup setup = message; - if (setup.path.has_value()) { - string0_ = std::string(setup.path.value()); - setup.path = absl::string_view(string0_); + MoqtClientSetup client_setup = message; + if (client_setup.path.has_value()) { + string0_ = std::string(client_setup.path.value()); + client_setup.path = absl::string_view(string0_); } - last_message_ = TestMessageBase::MessageStructuredData(setup); + last_message_ = TestMessageBase::MessageStructuredData(client_setup); + } + void OnServerSetupMessage(const MoqtServerSetup& message) override { + end_of_message_ = true; + messages_received_++; + MoqtServerSetup server_setup = message; + last_message_ = TestMessageBase::MessageStructuredData(server_setup); } void OnSubscribeRequestMessage(const MoqtSubscribeRequest& message) override { end_of_message_ = true; @@ -111,16 +126,20 @@ end_of_message_ = true; messages_received_++; MoqtSubscribeOk subscribe_ok = message; - string0_ = std::string(subscribe_ok.full_track_name); - subscribe_ok.full_track_name = absl::string_view(string0_); + string0_ = std::string(subscribe_ok.track_namespace); + subscribe_ok.track_namespace = absl::string_view(string0_); + string1_ = std::string(subscribe_ok.track_name); + subscribe_ok.track_name = absl::string_view(string1_); last_message_ = TestMessageBase::MessageStructuredData(subscribe_ok); } void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override { end_of_message_ = true; messages_received_++; MoqtSubscribeError subscribe_error = message; - string0_ = std::string(subscribe_error.full_track_name); - subscribe_error.full_track_name = absl::string_view(string0_); + string0_ = std::string(subscribe_error.track_namespace); + subscribe_error.track_namespace = absl::string_view(string0_); + string1_ = std::string(subscribe_error.track_name); + subscribe_error.track_name = absl::string_view(string1_); string1_ = std::string(subscribe_error.reason_phrase); subscribe_error.reason_phrase = absl::string_view(string1_); last_message_ = TestMessageBase::MessageStructuredData(subscribe_error); @@ -129,10 +148,34 @@ end_of_message_ = true; messages_received_++; MoqtUnsubscribe unsubscribe = message; - string0_ = std::string(unsubscribe.full_track_name); - unsubscribe.full_track_name = absl::string_view(string0_); + string0_ = std::string(unsubscribe.track_namespace); + unsubscribe.track_namespace = absl::string_view(string0_); + string1_ = std::string(unsubscribe.track_name); + unsubscribe.track_name = absl::string_view(string1_); last_message_ = TestMessageBase::MessageStructuredData(unsubscribe); } + void OnSubscribeFinMessage(const MoqtSubscribeFin& message) override { + end_of_message_ = true; + messages_received_++; + MoqtSubscribeFin subscribe_fin = message; + string0_ = std::string(subscribe_fin.track_namespace); + subscribe_fin.track_namespace = absl::string_view(string0_); + string1_ = std::string(subscribe_fin.track_name); + subscribe_fin.track_name = absl::string_view(string1_); + last_message_ = TestMessageBase::MessageStructuredData(subscribe_fin); + } + void OnSubscribeRstMessage(const MoqtSubscribeRst& message) override { + end_of_message_ = true; + messages_received_++; + MoqtSubscribeRst subscribe_rst = message; + string0_ = std::string(subscribe_rst.track_namespace); + subscribe_rst.track_namespace = absl::string_view(string0_); + string1_ = std::string(subscribe_rst.track_name); + subscribe_rst.track_name = absl::string_view(string1_); + string2_ = std::string(subscribe_rst.reason_phrase); + subscribe_rst.reason_phrase = absl::string_view(string2_); + last_message_ = TestMessageBase::MessageStructuredData(subscribe_rst); + } void OnAnnounceMessage(const MoqtAnnounce& message) override { end_of_message_ = true; messages_received_++; @@ -171,11 +214,14 @@ unannounce.track_namespace = absl::string_view(string0_); last_message_ = TestMessageBase::MessageStructuredData(unannounce); } - void OnGoAwayMessage() override { + void OnGoAwayMessage(const MoqtGoAway& message) override { got_goaway_ = true; end_of_message_ = true; messages_received_++; - last_message_ = TestMessageBase::MessageStructuredData(); + MoqtGoAway goaway = message; + string0_ = std::string(goaway.new_session_uri); + goaway.new_session_uri = absl::string_view(string0_); + last_message_ = TestMessageBase::MessageStructuredData(goaway); } void OnParsingError(absl::string_view reason) override { QUIC_LOG(INFO) << "Parsing error: " << reason; @@ -190,7 +236,7 @@ absl::optional<TestMessageBase::MessageStructuredData> last_message_; // Stored strings for last_message_. The visitor API does not promise the // memory pointed to by string_views is persistent. - std::string string0_, string1_; + std::string string0_, string1_, string2_; }; class MoqtParserTest @@ -198,29 +244,39 @@ public: MoqtParserTest() : message_type_(GetParam().message_type), - is_client_(GetParam().perspective == quic::Perspective::IS_CLIENT), webtrans_(GetParam().uses_web_transport), - parser_(GetParam().perspective, GetParam().uses_web_transport, - visitor_) {} + parser_(GetParam().uses_web_transport, visitor_) {} std::unique_ptr<TestMessageBase> MakeMessage(MoqtMessageType message_type) { switch (message_type) { - case MoqtMessageType::kObject: - return std::make_unique<ObjectMessage>(); - case MoqtMessageType::kSetup: - return std::make_unique<SetupMessage>(is_client_, webtrans_); + case MoqtMessageType::kObjectWithPayloadLength: + return std::make_unique<ObjectMessageWithLength>(); + case MoqtMessageType::kObjectWithoutPayloadLength: + return std::make_unique<ObjectMessageWithoutLength>(); + case MoqtMessageType::kClientSetup: + return std::make_unique<ClientSetupMessage>(webtrans_); + case MoqtMessageType::kServerSetup: + return std::make_unique<ClientSetupMessage>(webtrans_); case MoqtMessageType::kSubscribeRequest: return std::make_unique<SubscribeRequestMessage>(); case MoqtMessageType::kSubscribeOk: return std::make_unique<SubscribeOkMessage>(); case MoqtMessageType::kSubscribeError: return std::make_unique<SubscribeErrorMessage>(); + case MoqtMessageType::kUnsubscribe: + return std::make_unique<UnsubscribeMessage>(); + case MoqtMessageType::kSubscribeFin: + return std::make_unique<SubscribeFinMessage>(); + case MoqtMessageType::kSubscribeRst: + return std::make_unique<SubscribeRstMessage>(); case MoqtMessageType::kAnnounce: return std::make_unique<AnnounceMessage>(); case moqt::MoqtMessageType::kAnnounceOk: return std::make_unique<AnnounceOkMessage>(); case moqt::MoqtMessageType::kAnnounceError: return std::make_unique<AnnounceErrorMessage>(); + case moqt::MoqtMessageType::kUnannounce: + return std::make_unique<UnannounceMessage>(); case moqt::MoqtMessageType::kGoAway: return std::make_unique<GoAwayMessage>(); default: @@ -230,7 +286,6 @@ MoqtParserTestVisitor visitor_; MoqtMessageType message_type_; - bool is_client_; bool webtrans_; MoqtParser parser_; }; @@ -241,11 +296,11 @@ TEST_P(MoqtParserTest, OneMessage) { std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - parser_.ProcessData(message->PacketSample(), false); + parser_.ProcessData(message->PacketSample(), true); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); EXPECT_TRUE(visitor_.end_of_message_); - if (message_type_ == MoqtMessageType::kObject) { + if (IsObjectMessage(message_type_)) { // Check payload message. EXPECT_TRUE(visitor_.object_payload_.has_value()); EXPECT_EQ(*(visitor_.object_payload_), "foo"); @@ -255,79 +310,17 @@ TEST_P(MoqtParserTest, OneMessageWithLongVarints) { std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); message->ExpandVarints(); - parser_.ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - if (message_type_ == MoqtMessageType::kObject) { - // Check payload message. - EXPECT_EQ(visitor_.object_payload_, "foo"); - } - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - -TEST_P(MoqtParserTest, MessageNoLengthWithFin) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - message->set_message_size(0); parser_.ProcessData(message->PacketSample(), true); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); EXPECT_TRUE(visitor_.end_of_message_); - if (message_type_ == MoqtMessageType::kObject) { + if (IsObjectMessage(message_type_)) { // Check payload message. - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_EQ(visitor_.object_payload_, "foo"); } EXPECT_FALSE(visitor_.parsing_error_.has_value()); } -TEST_P(MoqtParserTest, MessageNoLengthSeparateFinObjectOrGoAway) { - // OBJECT and GOAWAY can return on a zero-length message even without - // receiving a FIN. - if (message_type_ != MoqtMessageType::kObject && - message_type_ != MoqtMessageType::kGoAway) { - return; - } - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - message->set_message_size(0); - parser_.ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 1); - if (message_type_ == MoqtMessageType::kGoAway) { - EXPECT_TRUE(visitor_.got_goaway_); - EXPECT_TRUE(visitor_.end_of_message_); - return; - } - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); - EXPECT_FALSE(visitor_.end_of_message_); - - parser_.ProcessData(absl::string_view(), true); // send the FIN - EXPECT_EQ(visitor_.messages_received_, 2); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), ""); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - -TEST_P(MoqtParserTest, MessageNoLengthSeparateFinOtherTypes) { - if (message_type_ == MoqtMessageType::kObject || - message_type_ == MoqtMessageType::kGoAway) { - return; - } - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - message->set_message_size(0); - parser_.ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 0); - parser_.ProcessData(absl::string_view(), true); // send the FIN - EXPECT_EQ(visitor_.messages_received_, 1); - - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - TEST_P(MoqtParserTest, TwoPartMessage) { std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); // The test Object message has payload for less then half the message length, @@ -340,127 +333,32 @@ parser_.ProcessData( message->PacketSample().substr( first_data_size, message->total_message_size() - first_data_size), - false); + true); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - if (message_type_ == MoqtMessageType::kObject) { + if (IsObjectMessage(message_type_)) { EXPECT_EQ(visitor_.object_payload_, "foo"); } - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - -// Send the header + some payload, pure payload, then pure payload to end the -// message. -TEST_P(MoqtParserTest, ThreePartObject) { - if (message_type_ != MoqtMessageType::kObject) { - return; - } - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - message->set_message_size(0); - // The test Object message has payload for less then half the message length, - // so splitting the message in half will prevent the first half from being - // processed. - parser_.ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); - - // second part - parser_.ProcessData("bar", false); - EXPECT_EQ(visitor_.messages_received_, 2); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "bar"); - - // third part includes FIN - parser_.ProcessData("deadbeef", true); - EXPECT_EQ(visitor_.messages_received_, 3); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "deadbeef"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - -// Send the part of header, rest of header + payload, plus payload. -TEST_P(MoqtParserTest, ThreePartObjectFirstIncomplete) { - if (message_type_ != MoqtMessageType::kObject) { - return; - } - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - message->set_message_size(0); - - // first part - parser_.ProcessData(message->PacketSample().substr(0, 4), false); - EXPECT_EQ(visitor_.messages_received_, 0); - - // second part. Add padding to it. - message->set_wire_image_size(100); - parser_.ProcessData( - message->PacketSample().substr(4, message->total_message_size() - 4), - false); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(visitor_.object_payload_->length(), 94); - - // third part includes FIN - parser_.ProcessData("bar", true); - EXPECT_EQ(visitor_.messages_received_, 2); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "bar"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } TEST_P(MoqtParserTest, OneByteAtATime) { std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - message->set_message_size(0); - constexpr size_t kObjectPrePayloadSize = 6; + size_t kObjectPayloadSize = 3; for (size_t i = 0; i < message->total_message_size(); ++i) { - parser_.ProcessData(message->PacketSample().substr(i, 1), false); - if (message_type_ == MoqtMessageType::kGoAway && - i == message->total_message_size() - 1) { - // OnGoAway() is called before FIN. - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - break; - } - if (message_type_ != MoqtMessageType::kObject || - i < kObjectPrePayloadSize) { - // OBJECTs will have to buffer for the first 5 bytes (until the varints - // are done). The sixth byte is a bare OBJECT header, so the parser does - // not notify the visitor. + if (!IsObjectMessage(message_type_)) { EXPECT_EQ(visitor_.messages_received_, 0); - } else { - // OBJECT payload processing. - EXPECT_EQ(visitor_.messages_received_, i - kObjectPrePayloadSize + 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - if (i == 5) { - EXPECT_EQ(visitor_.object_payload_->length(), 0); - } else { - EXPECT_EQ(visitor_.object_payload_->length(), 1); - EXPECT_EQ((*visitor_.object_payload_)[0], - message->PacketSample().substr(i, 1)[0]); - } } EXPECT_FALSE(visitor_.end_of_message_); + parser_.ProcessData(message->PacketSample().substr(i, 1), false); } - // Send FIN - parser_.ProcessData(absl::string_view(), true); - if (message_type_ == MoqtMessageType::kObject) { - EXPECT_EQ(visitor_.messages_received_, - message->total_message_size() - kObjectPrePayloadSize + 1); - } else { - EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.messages_received_, + (IsObjectMessage(message_type_) ? (kObjectPayloadSize + 1) : 1)); + if (message_type_ == MoqtMessageType::kObjectWithoutPayloadLength) { + EXPECT_FALSE(visitor_.end_of_message_); + parser_.ProcessData(absl::string_view(), true); // Needs the FIN + EXPECT_EQ(visitor_.messages_received_, kObjectPayloadSize + 2); } EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); EXPECT_TRUE(visitor_.end_of_message_); @@ -470,196 +368,52 @@ TEST_P(MoqtParserTest, OneByteAtATimeLongerVarints) { std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); message->ExpandVarints(); - message->set_message_size(0); - constexpr size_t kObjectPrePayloadSize = 28; + size_t kObjectPayloadSize = 3; for (size_t i = 0; i < message->total_message_size(); ++i) { - parser_.ProcessData(message->PacketSample().substr(i, 1), false); - if (message_type_ == MoqtMessageType::kGoAway && - i == message->total_message_size() - 1) { - // OnGoAway() is called before FIN. - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - break; - } - if (message_type_ != MoqtMessageType::kObject || - i < kObjectPrePayloadSize) { - // OBJECTs will have to buffer for the first 5 bytes (until the varints - // are done). The sixth byte is a bare OBJECT header, so the parser does - // not notify the visitor. + if (!IsObjectMessage(message_type_)) { EXPECT_EQ(visitor_.messages_received_, 0); - } else { - // OBJECT payload processing. - EXPECT_EQ(visitor_.messages_received_, i - kObjectPrePayloadSize + 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - if (i == 5) { - EXPECT_EQ(visitor_.object_payload_->length(), 0); - } else { - EXPECT_EQ(visitor_.object_payload_->length(), 1); - EXPECT_EQ((*visitor_.object_payload_)[0], - message->PacketSample().substr(i, 1)[0]); - } } EXPECT_FALSE(visitor_.end_of_message_); + parser_.ProcessData(message->PacketSample().substr(i, 1), false); } - // Send FIN - parser_.ProcessData(absl::string_view(), true); - if (message_type_ == MoqtMessageType::kObject) { - EXPECT_EQ(visitor_.messages_received_, - message->total_message_size() - kObjectPrePayloadSize + 1); - } else { - EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.messages_received_, + (IsObjectMessage(message_type_) ? (kObjectPayloadSize + 1) : 1)); + if (message_type_ == MoqtMessageType::kObjectWithoutPayloadLength) { + EXPECT_FALSE(visitor_.end_of_message_); + parser_.ProcessData(absl::string_view(), true); // Needs the FIN + EXPECT_EQ(visitor_.messages_received_, kObjectPayloadSize + 2); } EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); EXPECT_TRUE(visitor_.end_of_message_); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } -TEST_P(MoqtParserTest, OneByteAtATimeKnownLength) { +TEST_P(MoqtParserTest, EarlyFin) { std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - constexpr size_t kObjectPrePayloadSize = 6; - // Send all but the last byte - for (size_t i = 0; i < message->total_message_size() - 1; ++i) { - parser_.ProcessData(message->PacketSample().substr(i, 1), false); - if (message_type_ != MoqtMessageType::kObject || - i < kObjectPrePayloadSize) { - // OBJECTs will have to buffer for the first 5 bytes (until the varints - // are done). The sixth byte is a bare OBJECT header, so the parser does - // not notify the visitor. - EXPECT_EQ(visitor_.messages_received_, 0); - } else { - // OBJECT payload processing. - EXPECT_EQ(visitor_.messages_received_, i - kObjectPrePayloadSize + 1); - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - if (i == 5) { - EXPECT_EQ(visitor_.object_payload_->length(), 0); - } else { - EXPECT_EQ(visitor_.object_payload_->length(), 1); - EXPECT_EQ((*visitor_.object_payload_)[0], - message->PacketSample().substr(i, 1)[0]); - } - } - EXPECT_FALSE(visitor_.end_of_message_); - } - // Send last byte parser_.ProcessData( - message->PacketSample().substr(message->total_message_size() - 1, 1), - false); - if (message_type_ == MoqtMessageType::kObject) { - EXPECT_EQ(visitor_.messages_received_, - message->total_message_size() - kObjectPrePayloadSize); - EXPECT_EQ(visitor_.object_payload_->length(), 1); - EXPECT_EQ((*visitor_.object_payload_)[0], - message->PacketSample().substr(message->total_message_size() - 1, - 1)[0]); - } else { - EXPECT_EQ(visitor_.messages_received_, 1); - } - EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - -TEST_P(MoqtParserTest, LengthTooShort) { - if (message_type_ == MoqtMessageType::kGoAway || - message_type_ == MoqtMessageType::kAnnounceOk) { - // GOAWAY already has length zero. ANNOUNCE_OK works for any message length. - return; - } - auto message = MakeMessage(message_type_); - if (message_type_ == MoqtMessageType::kSetup && - GetParam().perspective == quic::Perspective::IS_CLIENT) { - // Unless varints are longer than necessary, the message is only one byte - // long. - message->ExpandVarints(); - } - size_t truncate = (message_type_ == MoqtMessageType::kObject) ? 4 : 1; - message->set_message_size(message->message_size() - truncate); - parser_.ProcessData(message->PacketSample(), false); + message->PacketSample().substr(0, message->total_message_size() / 2), + true); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "Not able to parse message given specified length"); + EXPECT_EQ(*visitor_.parsing_error_, "FIN after incomplete message"); } -// Buffered packets are a different code path, so test them separately. -TEST_P(MoqtParserTest, LengthTooShortInBufferedPacket) { - if (message_type_ == MoqtMessageType::kGoAway || - message_type_ == MoqtMessageType::kAnnounceOk) { - // GOAWAY already has length zero. ANNOUNCE_OK works for any message length. - return; - } - auto message = MakeMessage(message_type_); - if (message_type_ == MoqtMessageType::kSetup && - GetParam().perspective == quic::Perspective::IS_CLIENT) { - // Unless varints are longer than necessary, the message is only one byte - // long. - message->ExpandVarints(); - } - EXPECT_EQ(visitor_.messages_received_, 0); - size_t truncate = (message_type_ == MoqtMessageType::kObject) ? 5 : 2; - message->set_message_size(message->message_size() - truncate + 1); +TEST_P(MoqtParserTest, SeparateEarlyFin) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); parser_.ProcessData( - message->PacketSample().substr(0, message->total_message_size() - 1), + message->PacketSample().substr(0, message->total_message_size() / 2), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); - // send the last byte - parser_.ProcessData( - message->PacketSample().substr(message->total_message_size() - 1, 1), - false); + parser_.ProcessData(absl::string_view(), true); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "Not able to parse buffered message given specified length"); + EXPECT_EQ(*visitor_.parsing_error_, "End of stream before complete message"); } -TEST_P(MoqtParserTest, LengthTooLong) { - if (message_type_ == MoqtMessageType::kAnnounceOk || - message_type_ == MoqtMessageType::kObject || - message_type_ == MoqtMessageType::kSetup || - message_type_ == MoqtMessageType::kSubscribeRequest || - message_type_ == MoqtMessageType::kAnnounce) { - // OBJECT and ANNOUNCE_OK work for any message length. - // SETUP, SUBSCRIBE_REQUEST, and ANNOUNCE have parameters, so an additional - // byte will cause the message to be interpreted as being too short. - return; - } - auto message = MakeMessage(message_type_); - message->set_message_size(message->message_size() + 1); - parser_.ProcessData(message->PacketSample(), false); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(visitor_.messages_received_, 0); - if (message_type_ == MoqtMessageType::kGoAway) { - EXPECT_EQ(*visitor_.parsing_error_, "GOAWAY has data following"); - } else { - EXPECT_EQ(*visitor_.parsing_error_, "Specified message length too long"); - } -} - -TEST_P(MoqtParserTest, LengthExceedsBufferSize) { - if (message_type_ == MoqtMessageType::kObject) { - // OBJECT works for any length. - return; - } - auto message = MakeMessage(message_type_); - message->set_message_size(kMaxMessageHeaderSize + 1); - parser_.ProcessData(message->PacketSample(), false); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(visitor_.messages_received_, 0); - if (message_type_ == MoqtMessageType::kGoAway) { - EXPECT_EQ(*visitor_.parsing_error_, "GOAWAY has data following"); - } else { - EXPECT_EQ(*visitor_.parsing_error_, "Message too long"); - } -} - -// Tests for message-specific error cases. -class MoqtParserErrorTest : public quic::test::QuicTest { +// Tests for message-specific error cases, and behaviors for a single message +// type. +class MoqtMessageSpecificTest : public quic::test::QuicTest { public: - MoqtParserErrorTest() {} + MoqtMessageSpecificTest() {} MoqtParserTestVisitor visitor_; @@ -667,10 +421,92 @@ static constexpr bool kRawQuic = false; }; -TEST_F(MoqtParserErrorTest, SetupRoleAppearsTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); +TEST_F(MoqtMessageSpecificTest, ObjectNoLengthSeparateFin) { + // OBJECT can return on an unknown-length message even without receiving a + // FIN. + MoqtParser parser(kRawQuic, visitor_); + auto message = std::make_unique<ObjectMessageWithoutLength>(); + parser.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_FALSE(visitor_.end_of_message_); + + parser.ProcessData(absl::string_view(), true); // send the FIN + EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), ""); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +// Send the header + some payload, pure payload, then pure payload to end the +// message. +TEST_F(MoqtMessageSpecificTest, ThreePartObject) { + MoqtParser parser(kRawQuic, visitor_); + auto message = std::make_unique<ObjectMessageWithoutLength>(); + parser.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "foo"); + + // second part + parser.ProcessData("bar", false); + EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "bar"); + + // third part includes FIN + parser.ProcessData("deadbeef", true); + EXPECT_EQ(visitor_.messages_received_, 3); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "deadbeef"); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +// Send the part of header, rest of header + payload, plus payload. +TEST_F(MoqtMessageSpecificTest, ThreePartObjectFirstIncomplete) { + MoqtParser parser(kRawQuic, visitor_); + auto message = std::make_unique<ObjectMessageWithoutLength>(); + + // first part + parser.ProcessData(message->PacketSample().substr(0, 4), false); + EXPECT_EQ(visitor_.messages_received_, 0); + + // second part. Add padding to it. + message->set_wire_image_size(100); + parser.ProcessData( + message->PacketSample().substr(4, message->total_message_size() - 4), + false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(visitor_.object_payload_->length(), 95); + + // third part includes FIN + parser.ProcessData("bar", true); + EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "bar"); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_F(MoqtMessageSpecificTest, SetupRoleAppearsTwice) { + MoqtParser parser(kRawQuic, visitor_); char setup[] = { - 0x01, 0x0e, 0x02, 0x01, 0x02, // versions + 0x40, 0x40, 0x02, 0x01, 0x02, // versions + 0x03, // 3 params 0x00, 0x01, 0x03, // role = both 0x00, 0x01, 0x03, // role = both 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo" @@ -681,36 +517,39 @@ EXPECT_EQ(*visitor_.parsing_error_, "ROLE parameter appears twice in SETUP"); } -TEST_F(MoqtParserErrorTest, SetupRoleIsMissing) { - MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); +TEST_F(MoqtMessageSpecificTest, SetupRoleIsMissing) { + MoqtParser parser(kRawQuic, visitor_); char setup[] = { - 0x01, 0x08, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x01, // 1 param 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); EXPECT_EQ(*visitor_.parsing_error_, - "ROLE SETUP parameter missing from Client message"); + "ROLE parameter missing from CLIENT_SETUP message"); } -TEST_F(MoqtParserErrorTest, SetupPathFromServer) { - MoqtParser parser(quic::Perspective::IS_CLIENT, kRawQuic, visitor_); +TEST_F(MoqtMessageSpecificTest, SetupPathFromServer) { + MoqtParser parser(kRawQuic, visitor_); char setup[] = { - 0x01, 0x06, + 0x40, 0x41, 0x01, // version = 1 + 0x01, // 1 param 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "PATH parameter sent by server in SETUP"); + EXPECT_EQ(*visitor_.parsing_error_, "PATH parameter in SERVER_SETUP"); } -TEST_F(MoqtParserErrorTest, SetupPathAppearsTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); +TEST_F(MoqtMessageSpecificTest, SetupPathAppearsTwice) { + MoqtParser parser(kRawQuic, visitor_); char setup[] = { - 0x01, 0x10, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x03, // 3 params 0x00, 0x01, 0x03, // role = both 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" @@ -718,13 +557,15 @@ parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "PATH parameter appears twice in SETUP"); + EXPECT_EQ(*visitor_.parsing_error_, + "PATH parameter appears twice in CLIENT_SETUP"); } -TEST_F(MoqtParserErrorTest, SetupPathOverWebtrans) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); +TEST_F(MoqtMessageSpecificTest, SetupPathOverWebtrans) { + MoqtParser parser(kWebTrans, visitor_); char setup[] = { - 0x01, 0x0b, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x02, // 2 params 0x00, 0x01, 0x03, // role = both 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; @@ -735,10 +576,11 @@ "WebTransport connection is using PATH parameter in SETUP"); } -TEST_F(MoqtParserErrorTest, SetupPathMissing) { - MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); +TEST_F(MoqtMessageSpecificTest, SetupPathMissing) { + MoqtParser parser(kRawQuic, visitor_); char setup[] = { - 0x01, 0x06, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x01, // 1 param 0x00, 0x01, 0x03, // role = both }; parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); @@ -748,63 +590,17 @@ "PATH SETUP parameter missing from Client message over QUIC"); } -TEST_F(MoqtParserErrorTest, SetupRoleTooLong) { - MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); - char setup[] = { - 0x01, 0x0e, 0x02, 0x01, 0x02, // versions - // role = both - 0x00, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x01, - 0x03, 0x66, 0x6f, 0x6f // path = "foo" - }; - parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "Cannot parse explicit length integers longer than 8 bytes"); -} - -TEST_F(MoqtParserErrorTest, SubscribeRequestGroupSequenceTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); +TEST_F(MoqtMessageSpecificTest, SubscribeRequestAuthorizationInfoTwice) { + MoqtParser parser(kWebTrans, visitor_); char subscribe_request[] = { - 0x03, 0x12, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" - 0x00, 0x01, 0x01, // group_sequence = 1 - 0x00, 0x01, 0x01, // group_sequence = 1 - 0x01, 0x01, 0x02, // object_sequence = 2 - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - }; - parser.ProcessData( - absl::string_view(subscribe_request, sizeof(subscribe_request)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "GROUP_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); -} - -TEST_F(MoqtParserErrorTest, SubscribeRequestObjectSequenceTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); - char subscribe_request[] = { - 0x03, 0x12, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" - 0x00, 0x01, 0x01, // group_sequence = 1 - 0x01, 0x01, 0x02, // object_sequence = 2 - 0x01, 0x01, 0x02, // object_sequence = 2 - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - }; - parser.ProcessData( - absl::string_view(subscribe_request, sizeof(subscribe_request)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "OBJECT_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); -} - -TEST_F(MoqtParserErrorTest, SubscribeRequestAuthorizationInfoTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); - char subscribe_request[] = { - 0x03, 0x14, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" - 0x00, 0x01, 0x01, // group_sequence = 1 - 0x01, 0x01, 0x02, // object_sequence = 2 - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x03, 0x03, 0x66, 0x6f, 0x6f, // full_track_name = "foo" + 0x02, 0x04, // start_group = 4 (relative previous) + 0x01, 0x01, // start_object = 1 (absolute) + 0x00, // end_group = none + 0x00, // end_object = none + 0x02, // two params + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; parser.ProcessData( absl::string_view(subscribe_request, sizeof(subscribe_request)), false); @@ -814,42 +610,13 @@ "AUTHORIZATION_INFO parameter appears twice in SUBSCRIBE_REQUEST"); } -TEST_F(MoqtParserErrorTest, AnnounceGroupSequenceTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); +TEST_F(MoqtMessageSpecificTest, AnnounceAuthorizationInfoTwice) { + MoqtParser parser(kWebTrans, visitor_); char announce[] = { - 0x06, 0x0f, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - 0x00, 0x01, 0x01, // group_sequence = 1 - 0x00, 0x01, 0x01, // group_sequence = 1 - }; - parser.ProcessData(absl::string_view(announce, sizeof(announce)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "GROUP_SEQUENCE parameter appears twice in ANNOUNCE"); -} - -TEST_F(MoqtParserErrorTest, AnnounceObjectSequenceTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); - char announce[] = { - 0x06, 0x0e, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x01, 0x01, 0x02, // object_sequence = 2 - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - 0x01, 0x01, 0x02, // object_sequence = 2 - }; - parser.ProcessData(absl::string_view(announce, sizeof(announce)), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "OBJECT_SEQUENCE parameter appears twice in ANNOUNCE"); -} - -TEST_F(MoqtParserErrorTest, AnnounceAuthorizationInfoTwice) { - MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); - char announce[] = { - 0x06, 0x0e, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x06, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x02, // 2 params + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; parser.ProcessData(absl::string_view(announce, sizeof(announce)), false); EXPECT_EQ(visitor_.messages_received_, 0); @@ -858,4 +625,208 @@ "AUTHORIZATION_INFO parameter appears twice in ANNOUNCE"); } +TEST_F(MoqtMessageSpecificTest, FinMidPayload) { + MoqtParser parser(kRawQuic, visitor_); + auto message = std::make_unique<ObjectMessageWithLength>(); + parser.ProcessData( + message->PacketSample().substr(0, message->total_message_size() - 1), + true); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "Received FIN mid-payload"); +} + +TEST_F(MoqtMessageSpecificTest, PartialPayloadThenFin) { + MoqtParser parser(kRawQuic, visitor_); + auto message = std::make_unique<ObjectMessageWithLength>(); + parser.ProcessData( + message->PacketSample().substr(0, message->total_message_size() - 1), + false); + parser.ProcessData(absl::string_view(), true); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "End of stream before complete OBJECT PAYLOAD"); +} + +TEST_F(MoqtMessageSpecificTest, DataAfterFin) { + MoqtParser parser(kRawQuic, visitor_); + parser.ProcessData(absl::string_view(), true); // Find FIN + parser.ProcessData("foo", false); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "Data after end of stream"); +} + +TEST_F(MoqtMessageSpecificTest, Setup2KB) { + MoqtParser parser(kRawQuic, visitor_); + char big_message[2 * kMaxMessageHeaderSize]; + quic::QuicDataWriter writer(sizeof(big_message), big_message); + writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kServerSetup)); + writer.WriteVarInt62(0x1); // version + writer.WriteVarInt62(0x1); // num_params + writer.WriteVarInt62(0xbeef); // unknown param + writer.WriteVarInt62(kMaxMessageHeaderSize); // very long parameter + writer.WriteRepeatedByte(0x04, kMaxMessageHeaderSize); + // Send incomplete message + parser.ProcessData(absl::string_view(big_message, writer.length() - 1), + false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "Cannot parse non-OBJECT messages > 2KB"); +} + +TEST_F(MoqtMessageSpecificTest, UnknownMessageType) { + MoqtParser parser(kRawQuic, visitor_); + char message[4]; + quic::QuicDataWriter writer(sizeof(message), message); + writer.WriteVarInt62(0xbeef); // unknown message type + parser.ProcessData(absl::string_view(message, writer.length()), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "Unknown message type"); +} + +TEST_F(MoqtMessageSpecificTest, StartGroupIsNone) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_request[] = { + 0x03, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x00, // start_group = none + 0x01, 0x01, // start_object = 1 (absolute) + 0x00, // end_group = none + 0x00, // end_object = none + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_request, sizeof(subscribe_request)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "START_GROUP must not be None in SUBSCRIBE_REQUEST"); +} + +TEST_F(MoqtMessageSpecificTest, StartObjectIsNone) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_request[] = { + 0x03, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x02, 0x04, // start_group = 4 (relative previous) + 0x00, // start_object = none + 0x00, // end_group = none + 0x00, // end_object = none + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_request, sizeof(subscribe_request)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "START_OBJECT must not be None in SUBSCRIBE_REQUEST"); +} + +TEST_F(MoqtMessageSpecificTest, EndGroupIsNoneEndObjectIsNoNone) { + MoqtParser parser(kRawQuic, visitor_); + char subscribe_request[] = { + 0x03, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x02, 0x04, // start_group = 4 (relative previous) + 0x01, 0x01, // start_object = 1 (absolute) + 0x00, // end_group = none + 0x01, 0x01, // end_object = 1 (absolute) + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_request, sizeof(subscribe_request)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "SUBSCRIBE_REQUEST end_group and end_object must be both None " + "or both non_None"); +} + +TEST_F(MoqtMessageSpecificTest, AllMessagesTogether) { + char buffer[5000]; + MoqtParser parser(kRawQuic, visitor_); + size_t write = 0; + size_t read = 0; + int fully_received = 0; + std::unique_ptr<TestMessageBase> prev_message = nullptr; + for (MoqtMessageType type : message_types) { + // Each iteration, process from the halfway point of one message to the + // halfway point of the next. + if (type == MoqtMessageType::kObjectWithoutPayloadLength) { + continue; // Cannot be followed with another message. + } + std::unique_ptr<TestMessageBase> message; + switch (type) { + case MoqtMessageType::kObjectWithPayloadLength: + message = std::make_unique<ObjectMessageWithLength>(); + break; + case MoqtMessageType::kObjectWithoutPayloadLength: + continue; // Cannot be followed with another message; + case MoqtMessageType::kClientSetup: + message = std::make_unique<ClientSetupMessage>(kRawQuic); + break; + case MoqtMessageType::kServerSetup: + message = std::make_unique<ClientSetupMessage>(kRawQuic); + break; + case MoqtMessageType::kSubscribeRequest: + message = std::make_unique<SubscribeRequestMessage>(); + break; + case MoqtMessageType::kSubscribeOk: + message = std::make_unique<SubscribeOkMessage>(); + break; + case MoqtMessageType::kSubscribeError: + message = std::make_unique<SubscribeErrorMessage>(); + break; + case MoqtMessageType::kUnsubscribe: + message = std::make_unique<UnsubscribeMessage>(); + break; + case MoqtMessageType::kSubscribeFin: + message = std::make_unique<SubscribeFinMessage>(); + break; + case MoqtMessageType::kSubscribeRst: + message = std::make_unique<SubscribeRstMessage>(); + break; + case MoqtMessageType::kAnnounce: + message = std::make_unique<AnnounceMessage>(); + break; + case moqt::MoqtMessageType::kAnnounceOk: + message = std::make_unique<AnnounceOkMessage>(); + break; + case moqt::MoqtMessageType::kAnnounceError: + message = std::make_unique<AnnounceErrorMessage>(); + break; + case moqt::MoqtMessageType::kUnannounce: + message = std::make_unique<UnannounceMessage>(); + break; + case moqt::MoqtMessageType::kGoAway: + message = std::make_unique<GoAwayMessage>(); + break; + default: + message = nullptr; + break; + } + memcpy(buffer + write, message->PacketSample().data(), + message->total_message_size()); + size_t new_read = write + message->total_message_size() / 2; + parser.ProcessData(absl::string_view(buffer + read, new_read - read), + false); + EXPECT_EQ(visitor_.messages_received_, fully_received); + if (prev_message != nullptr) { + EXPECT_TRUE( + prev_message->EqualFieldValues(visitor_.last_message_.value())); + } + fully_received++; + read = new_read; + write += message->total_message_size(); + prev_message = std::move(message); + } + // Deliver the rest + parser.ProcessData(absl::string_view(buffer + read, write - read), true); + EXPECT_EQ(visitor_.messages_received_, fully_received); + EXPECT_TRUE(prev_message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + } // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index f8b89ea..9530082 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -9,6 +9,7 @@ #include <utility> #include <vector> +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_types.h" @@ -40,13 +41,14 @@ control_stream->SetVisitor(std::make_unique<Stream>( this, control_stream, /*is_control_stream=*/true)); control_stream_ = control_stream->GetStreamId(); - MoqtSetup setup = MoqtSetup{ + MoqtClientSetup setup = MoqtClientSetup{ .supported_versions = std::vector<MoqtVersion>{parameters_.version}, - .role = MoqtRole::kBoth}; + .role = MoqtRole::kBoth, + }; if (!parameters_.using_webtrans) { setup.path = parameters_.path; } - quiche::QuicheBuffer serialized_setup = framer_.SerializeSetup(setup); + quiche::QuicheBuffer serialized_setup = framer_.SerializeClientSetup(setup); bool success = control_stream->Write(serialized_setup.AsStringView()); if (!success) { Error("Failed to write client SETUP message"); @@ -120,7 +122,7 @@ } } -void MoqtSession::Stream::OnSetupMessage(const MoqtSetup& message) { +void MoqtSession::Stream::OnClientSetupMessage(const MoqtClientSetup& message) { if (is_control_stream_.has_value()) { if (!*is_control_stream_) { session_->Error("Received SETUP on non-control stream"); @@ -129,6 +131,10 @@ } else { is_control_stream_ = true; } + if (perspective() == Perspective::IS_CLIENT) { + session_->Error("Received CLIENT_SETUP from server"); + return; + } if (absl::c_find(message.supported_versions, session_->parameters_.version) == message.supported_versions.end()) { session_->Error(absl::StrCat("Version mismatch: expected 0x", @@ -137,12 +143,11 @@ } QUICHE_DLOG(INFO) << ENDPOINT << "Received the SETUP message"; if (session_->parameters_.perspective == Perspective::IS_SERVER) { - MoqtSetup response = - MoqtSetup{.supported_versions = - std::vector<MoqtVersion>{session_->parameters_.version}, - .role = MoqtRole::kBoth}; + MoqtServerSetup response; + response.selected_version = session_->parameters_.version; + response.role = MoqtRole::kBoth; bool success = stream_->Write( - session_->framer_.SerializeSetup(response).AsStringView()); + session_->framer_.SerializeServerSetup(response).AsStringView()); if (!success) { session_->Error("Failed to write server SETUP message"); return; @@ -153,6 +158,29 @@ std::move(session_->session_established_callback_)(); } +void MoqtSession::Stream::OnServerSetupMessage(const MoqtServerSetup& message) { + if (is_control_stream_.has_value()) { + if (!*is_control_stream_) { + session_->Error("Received SETUP on non-control stream"); + return; + } + } else { + is_control_stream_ = true; + } + if (perspective() == Perspective::IS_SERVER) { + session_->Error("Received SERVER_SETUP from client"); + return; + } + if (message.selected_version != session_->parameters_.version) { + session_->Error(absl::StrCat("Version mismatch: expected 0x", + absl::Hex(session_->parameters_.version))); + return; + } + QUICHE_DLOG(INFO) << ENDPOINT << "Received the SETUP message"; + // TODO: handle role and path. + std::move(session_->session_established_callback_)(); +} + void MoqtSession::Stream::OnParsingError(absl::string_view reason) { session_->Error(absl::StrCat("Parse error: ", reason)); }
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 0faa8c0..dc98b2c 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -34,7 +34,7 @@ parameters_(parameters), session_established_callback_(std::move(session_established_callback)), session_terminated_callback_(std::move(session_terminated_callback)), - framer_(quiche::SimpleBufferAllocator::Get(), parameters.perspective, + framer_(quiche::SimpleBufferAllocator::Get(), parameters.using_webtrans) {} // webtransport::SessionVisitor implementation. @@ -58,14 +58,12 @@ Stream(MoqtSession* session, webtransport::Stream* stream) : session_(session), stream_(stream), - parser_(session->parameters_.perspective, - session->parameters_.using_webtrans, *this) {} + parser_(session->parameters_.using_webtrans, *this) {} Stream(MoqtSession* session, webtransport::Stream* stream, bool is_control_stream) : session_(session), stream_(stream), - parser_(session->parameters_.perspective, - session->parameters_.using_webtrans, *this), + parser_(session->parameters_.using_webtrans, *this), is_control_stream_(is_control_stream) {} // webtransport::StreamVisitor implementation. @@ -78,17 +76,20 @@ // MoqtParserVisitor implementation. void OnObjectMessage(const MoqtObject& message, absl::string_view payload, bool end_of_message) override {} - void OnSetupMessage(const MoqtSetup& message) override; + void OnClientSetupMessage(const MoqtClientSetup& message) override; + void OnServerSetupMessage(const MoqtServerSetup& message) override; void OnSubscribeRequestMessage( const MoqtSubscribeRequest& message) override {} void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override {} void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override {} void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override {} + void OnSubscribeFinMessage(const MoqtSubscribeFin& message) override {} + void OnSubscribeRstMessage(const MoqtSubscribeRst& message) override {} void OnAnnounceMessage(const MoqtAnnounce& message) override {} void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override {} void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override {} void OnUnannounceMessage(const MoqtUnannounce& message) override {} - void OnGoAwayMessage() override {} + void OnGoAwayMessage(const MoqtGoAway& message) override {} void OnParsingError(absl::string_view reason) override; quic::Perspective perspective() const {
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 7cf4803..b5d1020 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -5,7 +5,6 @@ #ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_ #define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_ -#include <algorithm> #include <cstddef> #include <cstdint> #include <cstring> @@ -34,28 +33,16 @@ virtual ~TestMessageBase() = default; MoqtMessageType message_type() const { return message_type_; } - typedef absl::variant<MoqtSetup, MoqtObject, MoqtSubscribeRequest, - MoqtSubscribeOk, MoqtSubscribeError, MoqtUnsubscribe, - MoqtAnnounce, MoqtAnnounceOk, MoqtAnnounceError, - MoqtUnannounce, MoqtGoAway> + typedef absl::variant<MoqtClientSetup, MoqtServerSetup, MoqtObject, + MoqtSubscribeRequest, MoqtSubscribeOk, + MoqtSubscribeError, MoqtUnsubscribe, MoqtSubscribeFin, + MoqtSubscribeRst, MoqtAnnounce, MoqtAnnounceOk, + MoqtAnnounceError, MoqtUnannounce, MoqtGoAway> MessageStructuredData; // The total actual size of the message. size_t total_message_size() const { return wire_image_size_; } - // The message size indicated in the second varint in every message. - size_t message_size() const { - quic::QuicDataReader reader(PacketSample()); - uint64_t value; - if (!reader.ReadVarInt62(&value)) { - return 0; - } - if (!reader.ReadVarInt62(&value)) { - return 0; - } - return value; - } - absl::string_view PacketSample() const { return absl::string_view(wire_image_, wire_image_size_); } @@ -67,37 +54,6 @@ // Returns a copy of the structured data for the message. virtual MessageStructuredData structured_data() const = 0; - // Sets the message length field. If |message_size| == 0, just change the - // field in the wire image. If another value, this will either truncate the - // message or increase its length (which adds uninitialized bytes). This can - // be useful for playing with different Object Payload lengths, for example. - void set_message_size(uint64_t message_size) { - char new_wire_image[sizeof(wire_image_)]; - quic::QuicDataReader reader(PacketSample()); - quic::QuicDataWriter writer(sizeof(new_wire_image), new_wire_image); - uint64_t type; - auto field_size = reader.PeekVarInt62Length(); - reader.ReadVarInt62(&type); - writer.WriteVarInt62WithForcedLength( - type, std::max(field_size, writer.GetVarInt62Len(type))); - uint64_t original_length; - field_size = reader.PeekVarInt62Length(); - reader.ReadVarInt62(&original_length); - // Try to preserve the original field length, unless it's too small. - writer.WriteVarInt62WithForcedLength( - message_size, - std::max(field_size, writer.GetVarInt62Len(message_size))); - writer.WriteStringPiece(reader.PeekRemainingPayload()); - memcpy(wire_image_, new_wire_image, writer.length()); - wire_image_size_ = writer.length(); - if (message_size > original_length) { - wire_image_size_ += (message_size - original_length); - } - if (message_size > 0 && message_size < original_length) { - wire_image_size_ -= (original_length - message_size); - } - } - // Compares |values| to the derived class's structured data to make sure // they are equal. virtual bool EqualFieldValues(MessageStructuredData& values) const = 0; @@ -123,8 +79,6 @@ quic::QuicDataReader reader( absl::string_view(wire_image_, wire_image_size_)); quic::QuicDataWriter writer(sizeof(new_wire_image), new_wire_image); - size_t message_length = 0; - int item = 0; size_t i = 0; while (!reader.IsDoneReading()) { if (i >= varints.length() || varints[i++] == '-') { @@ -134,44 +88,31 @@ continue; } uint64_t value; - item++; reader.ReadVarInt62(&value); writer.WriteVarInt62WithForcedLength( value, static_cast<quiche::QuicheVariableLengthIntegerLength>( next_varint_len)); - if (item == 2) { - // this is the message length field. - message_length = value; - } next_varint_len *= 2; if (next_varint_len == 16) { next_varint_len = 2; } } - if (message_length > 0) { - // Update message length. Based on the progression of next_varint_len, - // the message_type is 2 bytes and message_length is 4 bytes. - message_length = writer.length() - 6; - auto new_writer = quic::QuicDataWriter(4, (char*)&new_wire_image[2]); - new_writer.WriteVarInt62WithForcedLength( - message_length, - static_cast<quiche::QuicheVariableLengthIntegerLength>(4)); - } memcpy(wire_image_, new_wire_image, writer.length()); wire_image_size_ = writer.length(); } - private: + protected: MoqtMessageType message_type_; + + private: char wire_image_[kMaxMessageHeaderSize + 20]; size_t wire_image_size_; }; +// Base class for the two subtypes of Object Message. class QUICHE_NO_EXPORT ObjectMessage : public TestMessageBase { public: - ObjectMessage() : TestMessageBase(MoqtMessageType::kObject) { - SetWireImage(raw_packet_, sizeof(raw_packet_)); - } + ObjectMessage(MoqtMessageType type) : TestMessageBase(type) {} bool EqualFieldValues(MessageStructuredData& values) const override { auto cast = std::get<MoqtObject>(values); @@ -202,97 +143,185 @@ return TestMessageBase::MessageStructuredData(object_); } - private: - uint8_t raw_packet_[9] = { - 0x00, 0x07, 0x04, 0x05, 0x06, 0x07, // varints - 0x66, 0x6f, 0x6f, // payload = "foo" - }; + protected: MoqtObject object_ = { /*track_id=*/4, /*group_sequence=*/5, /*object_sequence=*/6, /*object_send_order=*/7, + /*payload_length=*/absl::nullopt, }; }; -class QUICHE_NO_EXPORT SetupMessage : public TestMessageBase { +class QUICHE_NO_EXPORT ObjectMessageWithLength : public ObjectMessage { public: - explicit SetupMessage(bool client_parser, bool webtrans) - : TestMessageBase(MoqtMessageType::kSetup), client_(client_parser) { - if (client_parser) { - SetWireImage(server_raw_packet_, sizeof(server_raw_packet_)); + ObjectMessageWithLength() + : ObjectMessage(MoqtMessageType::kObjectWithPayloadLength) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + object_.payload_length = payload_length_; + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtObject>(values); + if (cast.payload_length != payload_length_) { + QUIC_LOG(INFO) << "OBJECT Payload Length mismatch"; + return false; + } + return ObjectMessage::EqualFieldValues(values); + } + + void ExpandVarints() override { + ExpandVarintsImpl("vvvvvv"); // first six fields are varints + } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(object_); + } + + private: + uint8_t raw_packet_[9] = { + 0x00, 0x04, 0x05, 0x06, 0x07, // varints + 0x03, 0x66, 0x6f, 0x6f, // payload = "foo" + }; + absl::optional<uint64_t> payload_length_ = 3; +}; + +class QUICHE_NO_EXPORT ObjectMessageWithoutLength : public ObjectMessage { + public: + ObjectMessageWithoutLength() + : ObjectMessage(MoqtMessageType::kObjectWithoutPayloadLength) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtObject>(values); + if (cast.payload_length != absl::nullopt) { + QUIC_LOG(INFO) << "OBJECT Payload Length mismatch"; + return false; + } + return ObjectMessage::EqualFieldValues(values); + } + + void ExpandVarints() override { + ExpandVarintsImpl("vvvvv"); // first six fields are varints + } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(object_); + } + + private: + uint8_t raw_packet_[8] = { + 0x02, 0x04, 0x05, 0x06, 0x07, // varints + 0x66, 0x6f, 0x6f, // payload = "foo" + }; +}; + +class QUICHE_NO_EXPORT ClientSetupMessage : public TestMessageBase { + public: + explicit ClientSetupMessage(bool webtrans) + : TestMessageBase(MoqtMessageType::kClientSetup) { + if (webtrans) { + // Should not send PATH. + client_setup_.path = absl::nullopt; + raw_packet_[5] = 0x01; // only one parameter + SetWireImage(raw_packet_, sizeof(raw_packet_) - 5); } else { - SetWireImage(client_raw_packet_, sizeof(client_raw_packet_)); - if (webtrans) { - // Should not send PATH. - set_message_size(message_size() - 5); - client_setup_.path = absl::nullopt; - } + SetWireImage(raw_packet_, sizeof(raw_packet_)); } } bool EqualFieldValues(MessageStructuredData& values) const override { - auto cast = std::get<MoqtSetup>(values); - const MoqtSetup* compare = client_ ? &server_setup_ : &client_setup_; - if (cast.supported_versions.size() != compare->supported_versions.size()) { - QUIC_LOG(INFO) << "SETUP number of supported versions mismatch"; + auto cast = std::get<MoqtClientSetup>(values); + if (cast.supported_versions.size() != + client_setup_.supported_versions.size()) { + QUIC_LOG(INFO) << "CLIENT_SETUP number of supported versions mismatch"; return false; } for (uint64_t i = 0; i < cast.supported_versions.size(); ++i) { // Listed versions are 1 and 2, in that order. - if (cast.supported_versions[i] != compare->supported_versions[i]) { - QUIC_LOG(INFO) << "SETUP supported version mismatch"; + if (cast.supported_versions[i] != client_setup_.supported_versions[i]) { + QUIC_LOG(INFO) << "CLIENT_SETUP supported version mismatch"; return false; } } - if (cast.role != compare->role) { - QUIC_LOG(INFO) << "SETUP role mismatch"; + if (cast.role != client_setup_.role) { + QUIC_LOG(INFO) << "CLIENT_SETUP role mismatch"; return false; } - if (cast.path != compare->path) { - QUIC_LOG(INFO) << "SETUP path mismatch"; + if (cast.path != client_setup_.path) { + QUIC_LOG(INFO) << "CLIENT_SETUP path mismatch"; return false; } return true; } void ExpandVarints() override { - if (client_) { - ExpandVarintsImpl("vvvvvvv-vv---"); // skip one byte for Role value + if (client_setup_.path.has_value()) { + ExpandVarintsImpl("--vvvvvv-vv---"); + // first two bytes are already a 2B varint. Also, don't expand parameter + // varints because that messes up the parameter length field. } else { - ExpandVarintsImpl("vvv"); // all three are varints + ExpandVarintsImpl("--vvvvvv-"); } } MessageStructuredData structured_data() const override { - if (client_) { - return TestMessageBase::MessageStructuredData(server_setup_); - } return TestMessageBase::MessageStructuredData(client_setup_); } private: - bool client_; - uint8_t client_raw_packet_[13] = { - 0x01, 0x0b, 0x02, 0x01, 0x02, // versions - 0x00, 0x01, 0x03, // role = both - 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo" + uint8_t raw_packet_[14] = { + 0x40, 0x40, // type + 0x02, 0x01, 0x02, // versions + 0x02, // 2 parameters + 0x00, 0x01, 0x03, // role = both + 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo" }; - uint8_t server_raw_packet_[3] = { - 0x01, 0x01, - 0x01, // version - }; - MoqtSetup client_setup_ = { + MoqtClientSetup client_setup_ = { /*supported_versions=*/std::vector<MoqtVersion>( {static_cast<MoqtVersion>(1), static_cast<MoqtVersion>(2)}), /*role=*/MoqtRole::kBoth, /*path=*/"foo", }; - MoqtSetup server_setup_ = { - /*supported_versions=*/std::vector<MoqtVersion>( - {static_cast<MoqtVersion>(1)}), +}; + +class QUICHE_NO_EXPORT ServerSetupMessage : public TestMessageBase { + public: + explicit ServerSetupMessage() + : TestMessageBase(MoqtMessageType::kServerSetup) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtServerSetup>(values); + if (cast.selected_version != server_setup_.selected_version) { + QUIC_LOG(INFO) << "SERVER_SETUP selected version mismatch"; + return false; + } + if (cast.role != server_setup_.role) { + QUIC_LOG(INFO) << "SERVER_SETUP role mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { + ExpandVarintsImpl("--v"); // first two bytes are already a 2b varint + } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(server_setup_); + } + + private: + uint8_t raw_packet_[4] = { + 0x40, 0x41, // type + 0x01, 0x00, // version, zero params + }; + MoqtServerSetup server_setup_ = { + /*selected_version=*/static_cast<MoqtVersion>(1), /*role=*/absl::nullopt, - /*path=*/absl::nullopt, }; }; @@ -309,12 +338,20 @@ QUIC_LOG(INFO) << "SUBSCRIBE REQUEST full track name mismatch"; return false; } - if (cast.group_sequence != subscribe_request_.group_sequence) { - QUIC_LOG(INFO) << "SUBSCRIBE REQUEST group sequence mismatch"; + if (cast.start_group != subscribe_request_.start_group) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST start group mismatch"; return false; } - if (cast.object_sequence != subscribe_request_.object_sequence) { - QUIC_LOG(INFO) << "SUBSCRIBE REQUEST object sequence mismatch"; + if (cast.start_object != subscribe_request_.start_object) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST start object mismatch"; + return false; + } + if (cast.end_group != subscribe_request_.end_group) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST end group mismatch"; + return false; + } + if (cast.end_object != subscribe_request_.end_object) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST end object mismatch"; return false; } if (cast.authorization_info != subscribe_request_.authorization_info) { @@ -324,7 +361,7 @@ return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---vv-vv-vv"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---vvvvvvvvv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(subscribe_request_); @@ -332,16 +369,21 @@ private: uint8_t raw_packet_[17] = { - 0x03, 0x0f, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" - 0x00, 0x01, 0x01, // group_sequence = 1 - 0x01, 0x01, 0x02, // object_sequence = 2 - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x03, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x02, 0x04, // start_group = 4 (relative previous) + 0x01, 0x01, // start_object = 1 (absolute) + 0x00, // end_group = none + 0x00, // end_object = none + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; MoqtSubscribeRequest subscribe_request_ = { /*full_track_name=*/"foo", - /*group_sequence=*/1, - /*object_sequence=*/2, + /*start_group=*/MoqtSubscribeLocation(false, (int64_t)(-4)), + /*start_object=*/MoqtSubscribeLocation(true, (uint64_t)1), + /*end_group=*/absl::nullopt, + /*end_object=*/absl::nullopt, /*authorization_info=*/"bar", }; }; @@ -354,33 +396,42 @@ bool EqualFieldValues(MessageStructuredData& values) const override { auto cast = std::get<MoqtSubscribeOk>(values); - if (cast.full_track_name != subscribe_ok_.full_track_name) { + if (cast.track_namespace != subscribe_ok_.track_namespace) { + QUIC_LOG(INFO) << "SUBSCRIBE OK track namespace mismatch"; + return false; + } + if (cast.track_name != subscribe_ok_.track_name) { + QUIC_LOG(INFO) << "SUBSCRIBE OK track name mismatch"; return false; } if (cast.track_id != subscribe_ok_.track_id) { + QUIC_LOG(INFO) << "SUBSCRIBE OK track ID mismatch"; return false; } if (cast.expires != subscribe_ok_.expires) { + QUIC_LOG(INFO) << "SUBSCRIBE OK expiration mismatch"; return false; } return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---vv"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---v---vv"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(subscribe_ok_); } private: - uint8_t raw_packet_[8] = { - 0x04, 0x06, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" - 0x01, // track_id = 1 - 0x02, // expires = 2 + uint8_t raw_packet_[11] = { + 0x04, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x03, 0x62, 0x61, 0x72, // track_namespace = "bar" + 0x01, // track_id = 1 + 0x02, // expires = 2 }; MoqtSubscribeOk subscribe_ok_ = { - /*full_track_name=*/"foo", + /*track_namespace=*/"foo", + /*track_name=*/"bar", /*track_id=*/1, /*expires=*/quic::QuicTimeDelta::FromMilliseconds(2), }; @@ -394,8 +445,12 @@ bool EqualFieldValues(MessageStructuredData& values) const override { auto cast = std::get<MoqtSubscribeError>(values); - if (cast.full_track_name != subscribe_error_.full_track_name) { - QUIC_LOG(INFO) << "SUBSCRIBE ERROR full track name mismatch"; + if (cast.track_namespace != subscribe_error_.track_namespace) { + QUIC_LOG(INFO) << "SUBSCRIBE ERROR track namespace mismatch"; + return false; + } + if (cast.track_name != subscribe_error_.track_name) { + QUIC_LOG(INFO) << "SUBSCRIBE ERROR track name mismatch"; return false; } if (cast.error_code != subscribe_error_.error_code) { @@ -409,21 +464,23 @@ return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---vv---"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---v---vv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(subscribe_error_); } private: - uint8_t raw_packet_[11] = { - 0x05, 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" - 0x01, // error_code = 1 - 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" + uint8_t raw_packet_[14] = { + 0x05, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x03, 0x62, 0x61, 0x72, // track_namespace = "bar" + 0x01, // error_code = 1 + 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" }; MoqtSubscribeError subscribe_error_ = { - /*full_track_name=*/"foo", + /*track_namespace=*/"foo", + /*track_name=*/"bar", /*subscribe=*/1, /*reason_phrase=*/"bar", }; @@ -437,26 +494,142 @@ bool EqualFieldValues(MessageStructuredData& values) const override { auto cast = std::get<MoqtUnsubscribe>(values); - if (cast.full_track_name != unsubscribe_.full_track_name) { - QUIC_LOG(INFO) << "UNSUBSCRIBE full track name mismatch"; + if (cast.track_namespace != unsubscribe_.track_namespace) { + QUIC_LOG(INFO) << "UNSUBSCRIBE track name mismatch"; + return false; + } + if (cast.track_name != unsubscribe_.track_name) { + QUIC_LOG(INFO) << "UNSUBSCRIBE track name mismatch"; return false; } return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---v---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(unsubscribe_); } private: - uint8_t raw_packet_[6] = { - 0x0a, 0x04, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + uint8_t raw_packet_[9] = { + 0x0a, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x03, 0x62, 0x61, 0x72, // track_namespace = "bar" }; MoqtUnsubscribe unsubscribe_ = { - /*full_track_name=*/"foo", + /*track_namespace=*/"foo", + /*track_name=*/"bar", + }; +}; + +class QUICHE_NO_EXPORT SubscribeFinMessage : public TestMessageBase { + public: + SubscribeFinMessage() : TestMessageBase(MoqtMessageType::kSubscribeFin) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSubscribeFin>(values); + if (cast.track_namespace != subscribe_fin_.track_namespace) { + QUIC_LOG(INFO) << "SUBSCRIBE_FIN track name mismatch"; + return false; + } + if (cast.track_name != subscribe_fin_.track_name) { + QUIC_LOG(INFO) << "SUBSCRIBE_FIN track name mismatch"; + return false; + } + if (cast.final_group != subscribe_fin_.final_group) { + QUIC_LOG(INFO) << "SUBSCRIBE_FIN final group mismatch"; + return false; + } + if (cast.final_object != subscribe_fin_.final_object) { + QUIC_LOG(INFO) << "SUBSCRIBE_FIN final object mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv---v---vv"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(subscribe_fin_); + } + + private: + uint8_t raw_packet_[11] = { + 0x0b, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x03, 0x62, 0x61, 0x72, // track_namespace = "bar" + 0x08, // final_group = 8 + 0x0c, // final_object = 12 + }; + + MoqtSubscribeFin subscribe_fin_ = { + /*track_namespace=*/"foo", + /*track_name=*/"bar", + /*final_group=*/8, + /*final_object=*/12, + }; +}; + +class QUICHE_NO_EXPORT SubscribeRstMessage : public TestMessageBase { + public: + SubscribeRstMessage() : TestMessageBase(MoqtMessageType::kSubscribeRst) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSubscribeRst>(values); + if (cast.track_namespace != subscribe_rst_.track_namespace) { + QUIC_LOG(INFO) << "SUBSCRIBE_RST track name mismatch"; + return false; + } + if (cast.track_name != subscribe_rst_.track_name) { + QUIC_LOG(INFO) << "SUBSCRIBE_RST track name mismatch"; + return false; + } + if (cast.error_code != subscribe_rst_.error_code) { + QUIC_LOG(INFO) << "SUBSCRIBE_RST error code mismatch"; + return false; + } + if (cast.reason_phrase != subscribe_rst_.reason_phrase) { + QUIC_LOG(INFO) << "SUBSCRIBE_RST reason phrase mismatch"; + return false; + } + if (cast.final_group != subscribe_rst_.final_group) { + QUIC_LOG(INFO) << "SUBSCRIBE_RST final group mismatch"; + return false; + } + if (cast.final_object != subscribe_rst_.final_object) { + QUIC_LOG(INFO) << "SUBSCRIBE_RST final object mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv---v---vv--vv"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(subscribe_rst_); + } + + private: + uint8_t raw_packet_[15] = { + 0x0c, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x03, 0x62, 0x61, 0x72, // track_namespace = "bar" + 0x03, // error_code = 3 + 0x02, 0x68, 0x69, // reason_phrase = "hi" + 0x08, // final_group = 8 + 0x0c, // final_object = 12 + }; + + MoqtSubscribeRst subscribe_rst_ = { + /*track_namespace=*/"foo", + /*track_name=*/"bar", + /*error_code=*/3, + /*reason_phrase=*/"hi", + /*final_group=*/8, + /*final_object=*/12, }; }; @@ -479,7 +652,7 @@ return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---vv---"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---vvv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(announce_); @@ -487,8 +660,9 @@ private: uint8_t raw_packet_[11] = { - 0x06, 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x06, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x01, // 1 parameter + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" }; MoqtAnnounce announce_ = { @@ -512,15 +686,15 @@ return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(announce_ok_); } private: - uint8_t raw_packet_[6] = { - 0x07, 0x04, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + uint8_t raw_packet_[5] = { + 0x07, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" }; MoqtAnnounceOk announce_ok_ = { @@ -551,17 +725,17 @@ return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---vv---"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---vv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(announce_error_); } private: - uint8_t raw_packet_[11] = { - 0x08, 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x01, // error__code = 1 - 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" + uint8_t raw_packet_[10] = { + 0x08, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x01, // error_code = 1 + 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" }; MoqtAnnounceError announce_error_ = { @@ -586,15 +760,15 @@ return true; } - void ExpandVarints() override { ExpandVarintsImpl("vvv---"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(unannounce_); } private: - uint8_t raw_packet_[6] = { - 0x09, 0x04, 0x03, 0x66, 0x6f, 0x6f, // track_namespace + uint8_t raw_packet_[5] = { + 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_namespace }; MoqtUnannounce unannounce_ = { @@ -608,23 +782,29 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& /*values*/) const override { + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtGoAway>(values); + if (cast.new_session_uri != goaway_.new_session_uri) { + QUIC_LOG(INFO) << "UNSUBSCRIBE full track name mismatch"; + return false; + } return true; } - void ExpandVarints() override { ExpandVarintsImpl("vv"); } + void ExpandVarints() override { ExpandVarintsImpl("vv---"); } MessageStructuredData structured_data() const override { return TestMessageBase::MessageStructuredData(goaway_); } private: - uint8_t raw_packet_[2] = { - 0x10, - 0x00, + uint8_t raw_packet_[5] = { + 0x10, 0x03, 0x66, 0x6f, 0x6f, }; - MoqtGoAway goaway_ = {}; + MoqtGoAway goaway_ = { + /*new_session_uri=*/"foo", + }; }; } // namespace moqt::test