Refactor the framer and some of the publishing pathway for MOQT. Changes to framer: - Use wire_serialization.h API where possible. - Remove unused SerializeObjectPayload method. - Change SerializeObject into SerializeObjectHeader; we never want to copy the object payload here, since it almost always should be passed directly into the underlying stream using Writev. - Simplify some code paths (in case of SerializeObjectHeader at cost of repetition) in order to make them easier to read and modify going forward. Changes to PublishObject: - Always require full payload; we didn't really provide a way to write the remaining payload, and that would require a different API we could add later. - Check that end_of_stream is always set for Object/Datagram. PiperOrigin-RevId: 608134589
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index 61caf1e..b69456a 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -6,466 +6,353 @@ #include <cstddef> #include <cstdint> +#include <cstdlib> #include <optional> +#include <type_traits> +#include <utility> +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_data_writer.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_data_writer.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/common/wire_serialization.h" namespace moqt { namespace { -inline size_t NeededVarIntLen(const uint64_t value) { - return static_cast<size_t>(quic::QuicDataWriter::GetVarInt62Len(value)); -} -inline size_t NeededVarIntLen(const MoqtVersion value) { - return static_cast<size_t>( - quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value))); -} -inline size_t NeededVarIntLen(const MoqtMessageType value) { - return static_cast<size_t>( - quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value))); -} -inline size_t NeededVarIntLen(const MoqtSubscribeLocationMode value) { - return static_cast<size_t>( - quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value))); -} -inline size_t ParameterLen(const uint64_t type, const uint64_t value_len) { - return NeededVarIntLen(type) + NeededVarIntLen(value_len) + value_len; -} -inline size_t LocationLength(const std::optional<MoqtSubscribeLocation> loc) { - if (!loc.has_value()) { - return NeededVarIntLen(MoqtSubscribeLocationMode::kNone); - } - if (loc->absolute) { - return NeededVarIntLen(MoqtSubscribeLocationMode::kAbsolute) + - NeededVarIntLen(loc->absolute_value); - } - // It's a relative value - if (loc->relative_value < 0) { - return NeededVarIntLen(MoqtSubscribeLocationMode::kRelativePrevious) + - NeededVarIntLen(static_cast<uint64_t>(loc->relative_value * -1)); - } - return NeededVarIntLen(MoqtSubscribeLocationMode::kRelativeNext) + - NeededVarIntLen(static_cast<uint64_t>(loc->relative_value)); -} -inline size_t LengthPrefixedStringLength(absl::string_view string) { - return NeededVarIntLen(string.length()) + string.length(); -} +using ::quiche::QuicheBuffer; +using ::quiche::WireOptional; +using ::quiche::WireSpan; +using ::quiche::WireStringWithVarInt62Length; +using ::quiche::WireVarInt62; -// This only supports values up to UINT8_MAX, as that's all that exists in the -// standard. -inline bool WriteVarIntParameter(quic::QuicDataWriter& writer, uint64_t type, - uint64_t value) { - if (!writer.WriteVarInt62(type)) { - return false; - } - if (!writer.WriteVarInt62(NeededVarIntLen(value))) { - return false; - } - return writer.WriteVarInt62(value); -} +// Encoding for MOQT Locations: +// https://moq-wg.github.io/moq-transport/draft-ietf-moq-transport.html#name-subscribe-locations +class WireLocation { + public: + using DataType = std::optional<MoqtSubscribeLocation>; + explicit WireLocation(const DataType& location) : location_(location) {} -inline bool WriteStringParameter(quic::QuicDataWriter& writer, uint64_t type, - absl::string_view value) { - if (!writer.WriteVarInt62(type)) { - return false; + size_t GetLengthOnWire() { + return quiche::ComputeLengthOnWire( + WireVarInt62(GetModeForSubscribeLocation(location_)), + WireOptional<WireVarInt62>(LocationOffsetOnTheWire(location_))); } - return writer.WriteStringPieceVarInt62(value); -} + absl::Status SerializeIntoWriter(quiche::QuicheDataWriter& writer) { + return quiche::SerializeIntoWriter( + writer, WireVarInt62(GetModeForSubscribeLocation(location_)), + WireOptional<WireVarInt62>(LocationOffsetOnTheWire(location_))); + } -inline bool WriteLocation(quic::QuicDataWriter& writer, - std::optional<MoqtSubscribeLocation> loc) { - if (!loc.has_value()) { - return writer.WriteVarInt62( - static_cast<uint64_t>(MoqtSubscribeLocationMode::kNone)); - } - if (loc->absolute) { - if (!writer.WriteVarInt62( - static_cast<uint64_t>(MoqtSubscribeLocationMode::kAbsolute))) { - return false; + private: + // For all location types other than None, we record a single varint after the + // type; this function computes the value of that varint. + static std::optional<uint64_t> LocationOffsetOnTheWire( + std::optional<MoqtSubscribeLocation> location) { + if (!location.has_value()) { + return std::nullopt; } - return writer.WriteVarInt62(loc->absolute_value); - } - if (loc->relative_value <= 0) { - if (!writer.WriteVarInt62(static_cast<uint64_t>( - MoqtSubscribeLocationMode::kRelativePrevious))) { - return false; + if (location->absolute) { + return location->absolute_value; } - return writer.WriteVarInt62( - static_cast<uint64_t>(loc->relative_value * -1)); + return location->relative_value <= 0 ? -location->relative_value + : location->relative_value + 1; } - if (!writer.WriteVarInt62( - static_cast<uint64_t>(MoqtSubscribeLocationMode::kRelativeNext))) { - return false; + + const DataType& location_; +}; + +// Encoding for string parameters as described in +// https://moq-wg.github.io/moq-transport/draft-ietf-moq-transport.html#name-parameters +struct StringParameter { + template <typename Enum> + StringParameter(Enum type, absl::string_view data) + : type(static_cast<uint64_t>(type)), data(data) { + static_assert(std::is_enum_v<Enum>); } - return writer.WriteVarInt62(static_cast<uint64_t>(loc->relative_value - 1)); + + uint64_t type; + absl::string_view data; +}; +class WireStringParameter { + public: + using DataType = StringParameter; + + explicit WireStringParameter(const StringParameter& parameter) + : parameter_(parameter) {} + size_t GetLengthOnWire() { + return quiche::ComputeLengthOnWire( + WireVarInt62(parameter_.type), + WireStringWithVarInt62Length(parameter_.data)); + } + absl::Status SerializeIntoWriter(quiche::QuicheDataWriter& writer) { + return quiche::SerializeIntoWriter( + writer, WireVarInt62(parameter_.type), + WireStringWithVarInt62Length(parameter_.data)); + } + + private: + const StringParameter& parameter_; +}; + +// Encoding for integer parameters as described in +// https://moq-wg.github.io/moq-transport/draft-ietf-moq-transport.html#name-parameters +struct IntParameter { + template <typename Enum, typename Param> + IntParameter(Enum type, Param value) + : type(static_cast<uint64_t>(type)), value(static_cast<uint64_t>(value)) { + static_assert(std::is_enum_v<Enum>); + static_assert(std::is_enum_v<Param> || std::is_unsigned_v<Param>); + } + + uint64_t type; + uint64_t value; +}; +class WireIntParameter { + public: + using DataType = IntParameter; + + explicit WireIntParameter(const IntParameter& parameter) + : parameter_(parameter) {} + size_t GetLengthOnWire() { + return quiche::ComputeLengthOnWire( + WireVarInt62(parameter_.type), + WireVarInt62(NeededVarIntLen(parameter_.value)), + WireVarInt62(parameter_.value)); + } + absl::Status SerializeIntoWriter(quiche::QuicheDataWriter& writer) { + return quiche::SerializeIntoWriter( + writer, WireVarInt62(parameter_.type), + WireVarInt62(NeededVarIntLen(parameter_.value)), + WireVarInt62(parameter_.value)); + } + + private: + size_t NeededVarIntLen(const uint64_t value) { + return static_cast<size_t>(quic::QuicDataWriter::GetVarInt62Len(value)); + } + + const IntParameter& parameter_; +}; + +// Serializes data into buffer using the default allocator. Invokes QUICHE_BUG +// on failure. +template <typename... Ts> +QuicheBuffer Serialize(Ts... data) { + absl::StatusOr<QuicheBuffer> buffer = quiche::SerializeIntoBuffer( + quiche::SimpleBufferAllocator::Get(), data...); + if (!buffer.ok()) { + QUICHE_BUG(moqt_failed_serialization) + << "Failed to serialize MoQT frame: " << buffer.status(); + return QuicheBuffer(); + } + return *std::move(buffer); } } // namespace -quiche::QuicheBuffer MoqtFramer::SerializeObject( - const MoqtObject& message, const absl::string_view payload, - bool is_first_in_stream) { - if (message.payload_length.has_value() && - *message.payload_length < payload.length()) { +quiche::QuicheBuffer MoqtFramer::SerializeObjectHeader( + const MoqtObject& message, bool is_first_in_stream) { + if (!message.payload_length.has_value() && + !(message.forwarding_preference == MoqtForwardingPreference::kObject || + message.forwarding_preference == MoqtForwardingPreference::kDatagram)) { QUIC_BUG(quic_bug_serialize_object_input_01) - << "payload_size is too small for payload"; + << "Track or Group forwarding preference requires knowing the object " + "length in advance"; return quiche::QuicheBuffer(); } - if (!is_first_in_stream && - (message.forwarding_preference == MoqtForwardingPreference::kObject || - message.forwarding_preference == MoqtForwardingPreference::kDatagram)) { - QUIC_BUG(quic_bug_serialize_object_input_02) - << "Object or Datagram forwarding_preference must be first in stream"; - return quiche::QuicheBuffer(); - } - // Figure out the total message size based on message type and payload. - size_t buffer_size = NeededVarIntLen(message.object_id) + payload.length(); - uint64_t message_type = static_cast<uint64_t>( - GetMessageTypeForForwardingPreference(message.forwarding_preference)); - if (is_first_in_stream) { - buffer_size += NeededVarIntLen(message_type) + - NeededVarIntLen(message.subscribe_id) + - NeededVarIntLen(message.track_alias) + - NeededVarIntLen(message.group_id) + - NeededVarIntLen(message.object_send_order); - } else if (message.forwarding_preference == - MoqtForwardingPreference::kTrack) { - buffer_size += NeededVarIntLen(message.group_id); - } - uint64_t reported_payload_length = message.payload_length.has_value() - ? message.payload_length.value() - : payload.length(); - if (message.forwarding_preference == MoqtForwardingPreference::kTrack || - message.forwarding_preference == MoqtForwardingPreference::kGroup) { - buffer_size += NeededVarIntLen(reported_payload_length); - } - // Write to buffer. - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - if (is_first_in_stream) { - writer.WriteVarInt62(message_type); - writer.WriteVarInt62(message.subscribe_id); - writer.WriteVarInt62(message.track_alias); - if (message.forwarding_preference != MoqtForwardingPreference::kTrack) { - writer.WriteVarInt62(message.group_id); - if (message.forwarding_preference != MoqtForwardingPreference::kGroup) { - writer.WriteVarInt62(message.object_id); - } + if (!is_first_in_stream) { + switch (message.forwarding_preference) { + case MoqtForwardingPreference::kTrack: + return Serialize(WireVarInt62(message.group_id), + WireVarInt62(message.object_id), + WireVarInt62(*message.payload_length)); + case MoqtForwardingPreference::kGroup: + return Serialize(WireVarInt62(message.object_id), + WireVarInt62(*message.payload_length)); + default: + QUIC_BUG(quic_bug_serialize_object_input_02) + << "Object or Datagram forwarding_preference must be first in " + "stream"; + return quiche::QuicheBuffer(); } - writer.WriteVarInt62(message.object_send_order); } + MoqtMessageType message_type = + GetMessageTypeForForwardingPreference(message.forwarding_preference); switch (message.forwarding_preference) { case MoqtForwardingPreference::kTrack: - writer.WriteVarInt62(message.group_id); - [[fallthrough]]; + return Serialize( + WireVarInt62(message_type), WireVarInt62(message.subscribe_id), + WireVarInt62(message.track_alias), + WireVarInt62(message.object_send_order), + WireVarInt62(message.group_id), WireVarInt62(message.object_id), + WireVarInt62(*message.payload_length)); case MoqtForwardingPreference::kGroup: - writer.WriteVarInt62(message.object_id); - writer.WriteVarInt62(reported_payload_length); - break; - default: - break; + return Serialize( + WireVarInt62(message_type), WireVarInt62(message.subscribe_id), + WireVarInt62(message.track_alias), WireVarInt62(message.group_id), + WireVarInt62(message.object_send_order), + WireVarInt62(message.object_id), + WireVarInt62(*message.payload_length)); + case MoqtForwardingPreference::kObject: + case MoqtForwardingPreference::kDatagram: + return Serialize( + WireVarInt62(message_type), WireVarInt62(message.subscribe_id), + WireVarInt62(message.track_alias), WireVarInt62(message.group_id), + WireVarInt62(message.object_id), + WireVarInt62(message.object_send_order)); } - writer.WriteStringPiece(payload); - return buffer; -} - -quiche::QuicheBuffer MoqtFramer::SerializeObjectPayload( - const absl::string_view payload) { - quiche::QuicheBuffer buffer(allocator_, payload.length()); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteStringPiece(payload); - return buffer; } quiche::QuicheBuffer MoqtFramer::SerializeClientSetup( const MoqtClientSetup& message) { - size_t buffer_size = NeededVarIntLen(MoqtMessageType::kClientSetup) + - NeededVarIntLen(message.supported_versions.size()); - for (MoqtVersion version : message.supported_versions) { - buffer_size += NeededVarIntLen(version); - } - uint64_t num_params = 0; + absl::InlinedVector<IntParameter, 1> int_parameters; + absl::InlinedVector<StringParameter, 1> string_parameters; if (message.role.has_value()) { - num_params++; - buffer_size += - ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kRole), 1); + int_parameters.push_back( + IntParameter(MoqtSetupParameter::kRole, *message.role)); } if (!using_webtrans_ && message.path.has_value()) { - num_params++; - buffer_size += - ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kPath), - message.path->length()); + string_parameters.push_back( + StringParameter(MoqtSetupParameter::kPath, *message.path)); } - buffer_size += NeededVarIntLen(num_params); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kClientSetup)); - writer.WriteVarInt62(message.supported_versions.size()); - for (MoqtVersion version : message.supported_versions) { - writer.WriteVarInt62(static_cast<uint64_t>(version)); - } - writer.WriteVarInt62(num_params); - if (message.role.has_value()) { - WriteVarIntParameter(writer, - static_cast<uint64_t>(MoqtSetupParameter::kRole), - static_cast<uint64_t>(*message.role)); - } - if (!using_webtrans_ && message.path.has_value()) { - WriteStringParameter(writer, - static_cast<uint64_t>(MoqtSetupParameter::kPath), - *message.path); - } - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize( + WireVarInt62(MoqtMessageType::kClientSetup), + WireVarInt62(message.supported_versions.size()), + WireSpan<WireVarInt62, MoqtVersion>(message.supported_versions), + WireVarInt62(string_parameters.size() + int_parameters.size()), + WireSpan<WireIntParameter>(int_parameters), + WireSpan<WireStringParameter>(string_parameters)); } quiche::QuicheBuffer MoqtFramer::SerializeServerSetup( const MoqtServerSetup& message) { - size_t buffer_size = NeededVarIntLen(MoqtMessageType::kServerSetup) + - NeededVarIntLen(message.selected_version); - uint64_t num_params = 0; + absl::InlinedVector<IntParameter, 1> int_parameters; if (message.role.has_value()) { - num_params++; - buffer_size += - ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kRole), 1); + int_parameters.push_back( + IntParameter(MoqtSetupParameter::kRole, *message.role)); } - buffer_size += NeededVarIntLen(num_params); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kServerSetup)); - writer.WriteVarInt62(static_cast<uint64_t>(message.selected_version)); - writer.WriteVarInt62(num_params); - if (message.role.has_value()) { - WriteVarIntParameter(writer, - static_cast<uint64_t>(MoqtSetupParameter::kRole), - static_cast<uint64_t>(*message.role)); - } - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kServerSetup), + WireVarInt62(message.selected_version), + WireVarInt62(int_parameters.size()), + WireSpan<WireIntParameter>(int_parameters)); } quiche::QuicheBuffer MoqtFramer::SerializeSubscribe( const MoqtSubscribe& message) { if (!message.start_group.has_value() || !message.start_object.has_value()) { - QUIC_LOG(INFO) << "start_group or start_object is missing"; + QUICHE_BUG(MoqtFramer_start_group_missing) + << "start_group or start_object is missing"; return quiche::QuicheBuffer(); } if (message.end_group.has_value() != message.end_object.has_value()) { - QUIC_LOG(INFO) << "end_group and end_object must both be None or both " - << "non-None"; + QUICHE_BUG(MoqtFramer_end_mismatch) + << "end_group and end_object must both be None or both non-None"; return quiche::QuicheBuffer(); } - size_t buffer_size = NeededVarIntLen(MoqtMessageType::kSubscribe) + - NeededVarIntLen(message.subscribe_id) + - NeededVarIntLen(message.track_alias) + - LengthPrefixedStringLength(message.track_namespace) + - LengthPrefixedStringLength(message.track_name) + - LocationLength(message.start_group) + - LocationLength(message.start_object) + - LocationLength(message.end_group) + - LocationLength(message.end_object); - uint64_t num_params = 0; + absl::InlinedVector<StringParameter, 1> string_params; if (message.authorization_info.has_value()) { - num_params++; - buffer_size += ParameterLen( - static_cast<uint64_t>(MoqtTrackRequestParameter::kAuthorizationInfo), - message.authorization_info->length()); + string_params.push_back( + StringParameter(MoqtTrackRequestParameter::kAuthorizationInfo, + *message.authorization_info)); } - buffer_size += NeededVarIntLen(num_params); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribe)); - writer.WriteVarInt62(static_cast<uint64_t>(message.subscribe_id)); - writer.WriteVarInt62(static_cast<uint64_t>(message.track_alias)); - writer.WriteStringPieceVarInt62(message.track_namespace); - writer.WriteStringPieceVarInt62(message.track_name); - WriteLocation(writer, message.start_group); - WriteLocation(writer, message.start_object); - WriteLocation(writer, message.end_group); - WriteLocation(writer, message.end_object); - writer.WriteVarInt62(num_params); - if (message.authorization_info.has_value()) { - WriteStringParameter( - writer, - static_cast<uint64_t>(MoqtTrackRequestParameter::kAuthorizationInfo), - *message.authorization_info); - } - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize( + WireVarInt62(MoqtMessageType::kSubscribe), + WireVarInt62(message.subscribe_id), WireVarInt62(message.track_alias), + WireStringWithVarInt62Length(message.track_namespace), + WireStringWithVarInt62Length(message.track_name), + WireLocation(message.start_group), WireLocation(message.start_object), + WireLocation(message.end_group), WireLocation(message.end_object), + WireVarInt62(string_params.size()), + WireSpan<WireStringParameter>(string_params)); } quiche::QuicheBuffer MoqtFramer::SerializeSubscribeOk( const MoqtSubscribeOk& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeOk)) + - NeededVarIntLen(message.subscribe_id) + - NeededVarIntLen(message.expires.ToMilliseconds()); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeOk)); - writer.WriteVarInt62(message.subscribe_id); - writer.WriteVarInt62(message.expires.ToMilliseconds()); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kSubscribeOk), + WireVarInt62(message.subscribe_id), + WireVarInt62(message.expires.ToMilliseconds())); } quiche::QuicheBuffer MoqtFramer::SerializeSubscribeError( const MoqtSubscribeError& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeError)) + - NeededVarIntLen(message.subscribe_id) + - NeededVarIntLen(static_cast<uint64_t>(message.error_code)) + - LengthPrefixedStringLength(message.reason_phrase) + - NeededVarIntLen(message.track_alias); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeError)); - writer.WriteVarInt62(message.subscribe_id); - writer.WriteVarInt62(static_cast<uint64_t>(message.error_code)); - writer.WriteStringPieceVarInt62(message.reason_phrase); - writer.WriteVarInt62(message.track_alias); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kSubscribeError), + WireVarInt62(message.subscribe_id), + WireVarInt62(message.error_code), + WireStringWithVarInt62Length(message.reason_phrase), + WireVarInt62(message.track_alias)); } quiche::QuicheBuffer MoqtFramer::SerializeUnsubscribe( const MoqtUnsubscribe& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kUnsubscribe)) + - NeededVarIntLen(message.subscribe_id); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kUnsubscribe)); - writer.WriteVarInt62(message.subscribe_id); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kUnsubscribe), + WireVarInt62(message.subscribe_id)); } quiche::QuicheBuffer MoqtFramer::SerializeSubscribeFin( const MoqtSubscribeFin& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeFin)) + - NeededVarIntLen(message.subscribe_id) + - NeededVarIntLen(message.final_group) + - NeededVarIntLen(message.final_object); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeFin)); - writer.WriteVarInt62(message.subscribe_id); - writer.WriteVarInt62(message.final_group); - writer.WriteVarInt62(message.final_object); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kSubscribeFin), + WireVarInt62(message.subscribe_id), + WireVarInt62(message.final_group), + WireVarInt62(message.final_object)); } quiche::QuicheBuffer MoqtFramer::SerializeSubscribeRst( const MoqtSubscribeRst& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kSubscribeRst)) + - NeededVarIntLen(message.subscribe_id) + - NeededVarIntLen(message.error_code) + - LengthPrefixedStringLength(message.reason_phrase) + - NeededVarIntLen(message.final_group) + - NeededVarIntLen(message.final_object); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSubscribeRst)); - writer.WriteVarInt62(message.subscribe_id); - writer.WriteVarInt62(message.error_code); - writer.WriteStringPieceVarInt62(message.reason_phrase); - writer.WriteVarInt62(message.final_group); - writer.WriteVarInt62(message.final_object); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize( + WireVarInt62(MoqtMessageType::kSubscribeRst), + WireVarInt62(message.subscribe_id), WireVarInt62(message.error_code), + WireStringWithVarInt62Length(message.reason_phrase), + WireVarInt62(message.final_group), WireVarInt62(message.final_object)); } quiche::QuicheBuffer MoqtFramer::SerializeAnnounce( const MoqtAnnounce& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounce)) + - LengthPrefixedStringLength(message.track_namespace); - uint64_t num_params = 0; + absl::InlinedVector<StringParameter, 1> string_params; if (message.authorization_info.has_value()) { - num_params++; - buffer_size += ParameterLen( - static_cast<uint64_t>(MoqtTrackRequestParameter::kAuthorizationInfo), - message.authorization_info->length()); + string_params.push_back( + StringParameter(MoqtTrackRequestParameter::kAuthorizationInfo, + *message.authorization_info)); } - buffer_size += NeededVarIntLen(num_params); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounce)); - writer.WriteStringPieceVarInt62(message.track_namespace); - writer.WriteVarInt62(num_params); - if (message.authorization_info.has_value()) { - WriteStringParameter( - writer, - static_cast<uint64_t>(MoqtTrackRequestParameter::kAuthorizationInfo), - *message.authorization_info); - } - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize( + WireVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounce)), + WireStringWithVarInt62Length(message.track_namespace), + WireVarInt62(string_params.size()), + WireSpan<WireStringParameter>(string_params)); } quiche::QuicheBuffer MoqtFramer::SerializeAnnounceOk( const MoqtAnnounceOk& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounceOk)) + - LengthPrefixedStringLength(message.track_namespace); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounceOk)); - writer.WriteStringPieceVarInt62(message.track_namespace); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kAnnounceOk), + WireStringWithVarInt62Length(message.track_namespace)); } quiche::QuicheBuffer MoqtFramer::SerializeAnnounceError( const MoqtAnnounceError& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kAnnounceError)) + - LengthPrefixedStringLength(message.track_namespace) + - NeededVarIntLen(message.error_code) + - LengthPrefixedStringLength(message.reason_phrase); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kAnnounceError)); - writer.WriteStringPieceVarInt62(message.track_namespace); - writer.WriteVarInt62(message.error_code); - writer.WriteStringPieceVarInt62(message.reason_phrase); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kAnnounceError), + WireStringWithVarInt62Length(message.track_namespace), + WireVarInt62(message.error_code), + WireStringWithVarInt62Length(message.reason_phrase)); } quiche::QuicheBuffer MoqtFramer::SerializeUnannounce( const MoqtUnannounce& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kUnannounce)) + - LengthPrefixedStringLength(message.track_namespace); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kUnannounce)); - writer.WriteStringPieceVarInt62(message.track_namespace); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kUnannounce), + WireStringWithVarInt62Length(message.track_namespace)); } quiche::QuicheBuffer MoqtFramer::SerializeGoAway(const MoqtGoAway& message) { - size_t buffer_size = - NeededVarIntLen(static_cast<uint64_t>(MoqtMessageType::kGoAway)) + - LengthPrefixedStringLength(message.new_session_uri); - quiche::QuicheBuffer buffer(allocator_, buffer_size); - quic::QuicDataWriter writer(buffer.size(), buffer.data()); - writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kGoAway)); - writer.WriteStringPieceVarInt62(message.new_session_uri); - QUICHE_DCHECK(writer.remaining() == 0); - return buffer; + return Serialize(WireVarInt62(MoqtMessageType::kGoAway), + WireStringWithVarInt62Length(message.new_session_uri)); } } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_framer.h b/quiche/quic/moqt/moqt_framer.h index 080e76e..e4070c5 100644 --- a/quiche/quic/moqt/moqt_framer.h +++ b/quiche/quic/moqt/moqt_framer.h
@@ -32,19 +32,10 @@ // Serialize functions. Takes structured data and serializes it into a // QuicheBuffer for delivery to the stream. - // SerializeObject also takes a payload. |payload_size| might simply be the - // size of |payload|, or it could be larger if there is more data coming, or - // it could be nullopt if the final length is unknown. - // If |message.payload_size| is nullopt but the forwarding preference requires - // a payload length, will assume that payload.length() is the correct value. - // If |message.payload_size| is smaller than |payload|, or - // |message.forwarding preference| is not consistent with - // |is_first_in_stream|, returns an empty buffer and triggers QUIC_BUG. - quiche::QuicheBuffer SerializeObject(const MoqtObject& message, - absl::string_view payload, - bool is_first_in_stream); - // Build a buffer for additional payload data. - quiche::QuicheBuffer SerializeObjectPayload(absl::string_view payload); + // Serializes the header for an object, including the appropriate stream + // header if `is_first_in_stream` is set to true. + quiche::QuicheBuffer SerializeObjectHeader(const MoqtObject& message, + bool is_first_in_stream); quiche::QuicheBuffer SerializeClientSetup(const MoqtClientSetup& message); quiche::QuicheBuffer SerializeServerSetup(const MoqtServerSetup& message); // Returns an empty buffer if there is an illegal combination of locations.
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index ff28d92..fe152d2 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -5,9 +5,12 @@ #include "quiche/quic/moqt/moqt_framer.h" #include <memory> +#include <optional> #include <string> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/test_tools/moqt_test_message.h" #include "quiche/quic/platform/api/quic_expect_bug.h" @@ -70,6 +73,22 @@ (info.param.uses_web_transport ? "WebTransport" : "QUIC"); } +quiche::QuicheBuffer SerializeObject(MoqtFramer& framer, + const MoqtObject& message, + absl::string_view payload, + bool is_first_in_stream) { + MoqtObject adjusted_message = message; + adjusted_message.payload_length = payload.size(); + quiche::QuicheBuffer header = + framer.SerializeObjectHeader(adjusted_message, is_first_in_stream); + if (header.empty()) { + return quiche::QuicheBuffer(); + } + return quiche::QuicheBuffer::Copy( + quiche::SimpleBufferAllocator::Get(), + absl::StrCat(header.AsStringView(), payload)); +} + class MoqtFramerTest : public quic::test::QuicTestWithParam<MoqtFramerTestParams> { public: @@ -90,8 +109,8 @@ case MoqtMessageType::kObjectPreferDatagram: case MoqtMessageType::kStreamHeaderTrack: case MoqtMessageType::kStreamHeaderGroup: { - auto data = std::get<MoqtObject>(structured_data); - return framer_.SerializeObject(data, "foo", true); + MoqtObject data = std::get<MoqtObject>(structured_data); + return SerializeObject(framer_, data, "foo", true); } case MoqtMessageType::kSubscribe: { auto data = std::get<MoqtSubscribe>(structured_data); @@ -178,28 +197,28 @@ TEST_F(MoqtFramerSimpleTest, GroupMiddler) { auto header = std::make_unique<StreamHeaderGroupMessage>(); - auto buffer1 = framer_.SerializeObject( - std::get<MoqtObject>(header->structured_data()), "foo", true); + auto buffer1 = SerializeObject( + framer_, std::get<MoqtObject>(header->structured_data()), "foo", true); EXPECT_EQ(buffer1.size(), header->total_message_size()); EXPECT_EQ(buffer1.AsStringView(), header->PacketSample()); auto middler = std::make_unique<StreamMiddlerGroupMessage>(); - auto buffer2 = framer_.SerializeObject( - std::get<MoqtObject>(middler->structured_data()), "bar", false); + auto buffer2 = SerializeObject( + framer_, std::get<MoqtObject>(middler->structured_data()), "bar", false); EXPECT_EQ(buffer2.size(), middler->total_message_size()); EXPECT_EQ(buffer2.AsStringView(), middler->PacketSample()); } TEST_F(MoqtFramerSimpleTest, TrackMiddler) { auto header = std::make_unique<StreamHeaderTrackMessage>(); - auto buffer1 = framer_.SerializeObject( - std::get<MoqtObject>(header->structured_data()), "foo", true); + auto buffer1 = SerializeObject( + framer_, std::get<MoqtObject>(header->structured_data()), "foo", true); EXPECT_EQ(buffer1.size(), header->total_message_size()); EXPECT_EQ(buffer1.AsStringView(), header->PacketSample()); auto middler = std::make_unique<StreamMiddlerTrackMessage>(); - auto buffer2 = framer_.SerializeObject( - std::get<MoqtObject>(middler->structured_data()), "bar", false); + auto buffer2 = SerializeObject( + framer_, std::get<MoqtObject>(middler->structured_data()), "bar", false); EXPECT_EQ(buffer2.size(), middler->total_message_size()); EXPECT_EQ(buffer2.AsStringView(), middler->PacketSample()); } @@ -212,21 +231,16 @@ /*object_id=*/6, /*object_send_order=*/7, /*forwarding_preference=*/MoqtForwardingPreference::kObject, - /*payload_length=*/1, + /*payload_length=*/std::nullopt, }; quiche::QuicheBuffer buffer; - EXPECT_QUIC_BUG(buffer = framer_.SerializeObject(object, "foo", true), - "payload_size is too small for payload"); - EXPECT_TRUE(buffer.empty()); - object.payload_length = 3; - EXPECT_QUIC_BUG(buffer = framer_.SerializeObject(object, "foo", false), - "Object or Datagram forwarding_preference must be first " - "in stream"); - EXPECT_TRUE(buffer.empty()); object.forwarding_preference = MoqtForwardingPreference::kDatagram; - EXPECT_QUIC_BUG(buffer = framer_.SerializeObject(object, "foo", false), - "Object or Datagram forwarding_preference must be first " - "in stream"); + EXPECT_QUIC_BUG(buffer = framer_.SerializeObjectHeader(object, false), + "must be first"); + EXPECT_TRUE(buffer.empty()); + object.forwarding_preference = MoqtForwardingPreference::kGroup; + EXPECT_QUIC_BUG(buffer = framer_.SerializeObjectHeader(object, false), + "requires knowing the object length"); EXPECT_TRUE(buffer.empty()); }
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index 9922c04..dcd23f9 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -202,6 +202,19 @@ } }; +inline MoqtSubscribeLocationMode GetModeForSubscribeLocation( + const std::optional<MoqtSubscribeLocation>& location) { + if (!location.has_value()) { + return MoqtSubscribeLocationMode::kNone; + } + if (location->absolute) { + return MoqtSubscribeLocationMode::kAbsolute; + } + return location->relative_value >= 0 + ? MoqtSubscribeLocationMode::kRelativeNext + : MoqtSubscribeLocationMode::kRelativePrevious; +} + struct QUICHE_EXPORT MoqtSubscribe { uint64_t subscribe_id; uint64_t track_alias;
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 40285d5..2aba2b2 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/moqt/moqt_session.h" +#include <array> #include <cstdint> #include <memory> #include <optional> @@ -19,6 +20,7 @@ #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_subscribe_windows.h" #include "quiche/quic/moqt/moqt_track.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_stream.h" @@ -263,11 +265,13 @@ uint64_t group_id, uint64_t object_id, uint64_t object_send_order, MoqtForwardingPreference forwarding_preference, - absl::string_view payload, - std::optional<uint64_t> payload_length, - bool end_of_stream) { - if (payload_length.has_value() && *payload_length < payload.length()) { - QUICHE_DLOG(ERROR) << ENDPOINT << "Payload too short"; + absl::string_view payload, bool end_of_stream) { + if ((forwarding_preference == MoqtForwardingPreference::kObject || + forwarding_preference == MoqtForwardingPreference::kDatagram) && + !end_of_stream) { + QUIC_BUG(MoqtSession_PublishObject_end_of_stream_required) + << "Forwarding preferences of Object or Datagram require stream to be " + "immediately closed"; return false; } auto track_it = local_tracks_.find(full_track_name); @@ -289,7 +293,7 @@ object.object_id = object_id; object.object_send_order = object_send_order; object.forwarding_preference = forwarding_preference; - object.payload_length = payload_length; + object.payload_length = payload.size(); int failures = 0; quiche::StreamWriteOptions write_options; write_options.set_send_fin(end_of_stream); @@ -322,10 +326,10 @@ continue; } object.subscribe_id = subscription->subscribe_id(); - if (quiche::WriteIntoStream( - *stream, - framer_.SerializeObject(object, payload, new_stream).AsStringView(), - write_options) != absl::OkStatus()) { + quiche::QuicheBuffer header = + framer_.SerializeObjectHeader(object, new_stream); + std::array<absl::string_view, 2> views = {header.AsStringView(), payload}; + if (!stream->Writev(views, write_options).ok()) { QUICHE_DLOG(ERROR) << ENDPOINT << "Failed to write OBJECT message"; ++failures; continue;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index ee0720f..0bd0fa6 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -121,9 +121,9 @@ uint64_t object_id, uint64_t object_send_order, MoqtForwardingPreference forwarding_preference, absl::string_view payload, - std::optional<uint64_t> payload_length, bool end_of_stream); // TODO: Add an API to FIN the stream for a particular track/group/object. + // TODO: Add an API to send partial objects. private: friend class test::MoqtSessionPeer;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 158c7a5..cc24b9e 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -416,7 +416,7 @@ // Send Sequence (2, 0) so that next_sequence is set correctly. session_.PublishObject(ftn, 2, 0, 0, MoqtForwardingPreference::kObject, "foo", - std::nullopt, true); + true); // Peer subscribes to (0, 0) MoqtSubscribe request = { /*subscribe_id=*/1, @@ -678,7 +678,7 @@ // No subscription; this is a no-op except to update next_sequence. EXPECT_CALL(mock_stream, Writev(_, _)).Times(0); session_.PublishObject(ftn, 4, 1, 0, MoqtForwardingPreference::kObject, - "deadbeef", std::nullopt, true); + "deadbeef", true); EXPECT_EQ(MoqtSessionPeer::next_sequence(&session_, ftn), FullSequence(4, 2)); // Publish in window. @@ -703,7 +703,7 @@ return absl::OkStatus(); }); session_.PublishObject(ftn, 5, 0, 0, MoqtForwardingPreference::kObject, - "deadbeef", std::nullopt, true); + "deadbeef", true); EXPECT_TRUE(correct_message); } @@ -720,9 +720,8 @@ ; EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) .WillOnce(Return(false)); - EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, - MoqtForwardingPreference::kObject, - "deadbeef", std::nullopt, true)); + EXPECT_FALSE(session_.PublishObject( + ftn, 5, 0, 0, MoqtForwardingPreference::kObject, "deadbeef", true)); } TEST_F(MoqtSessionTest, GetStreamByIdFails) { @@ -740,9 +739,8 @@ .WillRepeatedly(Return(kOutgoingUniStreamId)); EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) .WillOnce(Return(nullptr)); - EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, - MoqtForwardingPreference::kObject, - "deadbeef", std::nullopt, true)); + EXPECT_FALSE(session_.PublishObject( + ftn, 5, 0, 0, MoqtForwardingPreference::kObject, "deadbeef", true)); } TEST_F(MoqtSessionTest, SubscribeProposesBadTrackAlias) {
diff --git a/quiche/quic/moqt/tools/chat_client_bin.cc b/quiche/quic/moqt/tools/chat_client_bin.cc index cab2790..282d2d8 100644 --- a/quiche/quic/moqt/tools/chat_client_bin.cc +++ b/quiche/quic/moqt/tools/chat_client_bin.cc
@@ -364,8 +364,7 @@ client.session()->PublishObject( client.my_track_name(), client.next_sequence().group++, client.next_sequence().object, /*object_send_order=*/0, - moqt::MoqtForwardingPreference::kObject, message_to_send, - /*payload_length=*/std::nullopt, true); + moqt::MoqtForwardingPreference::kObject, message_to_send, true); } return 0; }