MoQ Transport Parser. PiperOrigin-RevId: 568955764
diff --git a/build/source_list.bzl b/build/source_list.bzl index 9ca439c..c9c0224 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1476,8 +1476,14 @@ "quic/load_balancer/load_balancer_server_id_test.cc", ] moqt_hdrs = [ + "quic/moqt/moqt_messages.h", + "quic/moqt/moqt_parser.h", + "quic/moqt/test_tools/moqt_test_message.h", ] moqt_srcs = [ + "quic/moqt/moqt_messages.cc", + "quic/moqt/moqt_parser.cc", + "quic/moqt/moqt_parser_test.cc", ] binary_http_hdrs = [ "binary_http/binary_http_message.h",
diff --git a/build/source_list.gni b/build/source_list.gni index d8978d9..e70f003 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1480,10 +1480,14 @@ "src/quiche/quic/load_balancer/load_balancer_server_id_test.cc", ] moqt_hdrs = [ - + "src/quiche/quic/moqt/moqt_messages.h", + "src/quiche/quic/moqt/moqt_parser.h", + "src/quiche/quic/moqt/test_tools/moqt_test_message.h", ] moqt_srcs = [ - + "src/quiche/quic/moqt/moqt_messages.cc", + "src/quiche/quic/moqt/moqt_parser.cc", + "src/quiche/quic/moqt/moqt_parser_test.cc", ] binary_http_hdrs = [ "src/quiche/binary_http/binary_http_message.h",
diff --git a/build/source_list.json b/build/source_list.json index 6dbfde3..07bf1a6 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1479,10 +1479,14 @@ "quiche/quic/load_balancer/load_balancer_server_id_test.cc" ], "moqt_hdrs": [ - + "quiche/quic/moqt/moqt_messages.h", + "quiche/quic/moqt/moqt_parser.h", + "quiche/quic/moqt/test_tools/moqt_test_message.h" ], "moqt_srcs": [ - + "quiche/quic/moqt/moqt_messages.cc", + "quiche/quic/moqt/moqt_parser.cc", + "quiche/quic/moqt/moqt_parser_test.cc" ], "binary_http_hdrs": [ "quiche/binary_http/binary_http_message.h"
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc new file mode 100644 index 0000000..9b37dc7 --- /dev/null +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -0,0 +1,34 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_messages.h" + +#include <string> + +namespace moqt { + +std::string MoqtMessageTypeToString(const MoqtMessageType message_type) { + switch (message_type) { + case MoqtMessageType::kObject: + return "OBJECT"; + case MoqtMessageType::kSetup: + return "SETUP"; + case MoqtMessageType::kSubscribeRequest: + return "SUBSCRIBE_REQUEST"; + case MoqtMessageType::kSubscribeOk: + return "SUBSCRIBE_OK"; + case MoqtMessageType::kSubscribeError: + return "SUBSCRIBE_ERROR"; + case MoqtMessageType::kAnnounce: + return "ANNOUNCE"; + case MoqtMessageType::kAnnounceOk: + return "ANNOUNCE_OK"; + case MoqtMessageType::kAnnounceError: + return "ANNOUNCE_ERROR"; + case MoqtMessageType::kGoAway: + return "GOAWAY"; + } +} + +} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h new file mode 100644 index 0000000..30c1d97 --- /dev/null +++ b/quiche/quic/moqt/moqt_messages.h
@@ -0,0 +1,112 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Structured data for message types in draft-ietf-moq-transport-00. + +#ifndef QUICHE_QUIC_MOQT_MOQT_MESSAGES_H_ +#define QUICHE_QUIC_MOQT_MOQT_MESSAGES_H_ + +#include <cstddef> +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace moqt { + +// The maximum length of a message, excluding any OBJECT payload. This prevents +// DoS attack via forcing the parser to buffer a large message (OBJECT payloads +// are not buffered by the parser). +inline constexpr size_t kMaxMessageHeaderSize = 4096; + +enum class QUICHE_EXPORT MoqtMessageType : uint64_t { + kObject = 0x00, + kSetup = 0x01, + kSubscribeRequest = 0x03, + kSubscribeOk = 0x04, + kSubscribeError = 0x05, + kAnnounce = 0x06, + kAnnounceOk = 0x7, + kAnnounceError = 0x08, + kGoAway = 0x10, +}; + +enum class QUICHE_EXPORT MoqtRole : uint64_t { + kIngestion = 0x1, + kDelivery = 0x2, + kBoth = 0x3, +}; + +enum class QUICHE_EXPORT MoqtSetupParameter : uint64_t { + kRole = 0x0, + kPath = 0x1, +}; + +enum class QUICHE_EXPORT MoqtTrackRequestParameter : uint64_t { + kGroupSequence = 0x0, + kObjectSequence = 0x1, + kAuthorizationInfo = 0x2, +}; + +struct QUICHE_EXPORT MoqtSetup { + uint64_t number_of_supported_versions; + std::vector<uint64_t> supported_versions; + absl::optional<MoqtRole> role; + absl::optional<absl::string_view> path; +}; + +struct QUICHE_EXPORT MoqtObject { + uint64_t track_id; + uint64_t group_sequence; + uint64_t object_sequence; + uint64_t object_send_order; + // Message also includes the object payload. +}; + +struct QUICHE_EXPORT MoqtSubscribeRequest { + absl::string_view full_track_name; + absl::optional<uint64_t> group_sequence; + absl::optional<uint64_t> object_sequence; + absl::optional<absl::string_view> authorization_info; +}; + +struct QUICHE_EXPORT MoqtSubscribeOk { + absl::string_view full_track_name; + uint64_t track_id; + // The message uses ms, but expires is in us. + quic::QuicTimeDelta expires = quic::QuicTimeDelta::FromMilliseconds(0); +}; + +struct QUICHE_EXPORT MoqtSubscribeError { + absl::string_view full_track_name; + uint64_t error_code; + absl::string_view reason_phrase; +}; + +struct QUICHE_EXPORT MoqtAnnounce { + absl::string_view track_namespace; + absl::optional<absl::string_view> authorization_info; +}; + +struct QUICHE_EXPORT MoqtAnnounceOk { + absl::string_view track_namespace; +}; + +struct QUICHE_EXPORT MoqtAnnounceError { + absl::string_view track_namespace; + uint64_t error_code; + absl::string_view reason_phrase; +}; + +struct QUICHE_EXPORT MoqtGoAway {}; + +std::string MoqtMessageTypeToString(MoqtMessageType message_type); + +} // namespace moqt + +#endif // QUICHE_QUIC_MOQT_MOQT_MESSAGES_H_
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc new file mode 100644 index 0000000..d90af04 --- /dev/null +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -0,0 +1,688 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_parser.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <string> + +#include "absl/cleanup/cleanup.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace moqt { + +namespace { + +// Minus the type, length, and payload, an OBJECT consists of 4 Varints. +constexpr size_t kMaxObjectHeaderSize = 32; + +} // namespace + +// The buffering philosophy is complicated, to minimize copying. Here is an +// overview: +// If the message type is present, this is stored in message_type_. If part of +// the message type varint is partially present, that is buffered (requiring a +// copy). +// Same for message length. +// If the entire message body is present (except for OBJECT payload), it is +// parsed and delivered. If not, the partial body is buffered. (requiring a +// copy). +// Any OBJECT payload is always delivered to the application without copying. +// If something has been buffered, when more data arrives copy just enough of it +// to finish parsing that thing, then resume normal processing. +void MoqtParser::ProcessData(absl::string_view data, bool end_of_stream) { + if (no_more_data_) { + if (!data.empty() || !end_of_stream) { + ParseError("Data after end of stream"); + } + return; + } + if (processing_) { + return; + } + processing_ = true; + auto on_return = absl::MakeCleanup([&] { processing_ = false; }); + no_more_data_ = end_of_stream; + quic::QuicDataReader reader(data); + if (!MaybeMergeDataWithBuffer(reader, end_of_stream)) { + return; + } + if (end_of_stream && reader.IsDoneReading() && object_metadata_.has_value()) { + // A FIN arrives while delivering OBJECT payload. + visitor_.OnObjectMessage(object_metadata_.value(), data, true); + EndOfMessage(); + } + while (!reader.IsDoneReading()) { + absl::optional<size_t> processed; + if (!GetMessageTypeAndLength(reader)) { + absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); + break; + } + // Cursor is at start of the message. + if (end_of_stream && NoMessageLength()) { + *message_length_ = reader.BytesRemaining(); + } + if (*message_type_ != MoqtMessageType::kObject && + *message_type_ != MoqtMessageType::kGoAway) { + // Parse OBJECT in case the message is very large. GOAWAY is length zero, + // so always process. + if (NoMessageLength()) { + // Can't parse it yet. + absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); + break; + } + if (*message_length_ > kMaxMessageHeaderSize) { + ParseError("Message too long"); + return; + } + if (*message_length_ > reader.BytesRemaining()) { + // There definitely isn't enough to process the message. + absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); + break; + } + } + processed = ProcessMessage(FetchMessage(reader)); + if (!processed.has_value()) { + if (*message_type_ == MoqtMessageType::kObject && + (NoMessageLength() || reader.BytesRemaining() < *message_length_)) { + // The parser can attempt to process OBJECT before receiving the whole + // message length. If it doesn't parse the varints, it will buffer the + // message. + absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); + break; + } + // Non-OBJECT or OBJECT with the complete specified length, but the data + // was not parseable. + ParseError("Not able to parse message given specified length"); + return; + } + if (*processed == *message_length_) { + EndOfMessage(); + } else { + if (*message_type_ != MoqtMessageType::kObject) { + // Partial processing of non-OBJECT is not allowed. + ParseError("Specified message length too long"); + return; + } + // This is a partially processed OBJECT payload. + if (!NoMessageLength()) { + *message_length_ -= *processed; + } + } + if (!reader.Seek(*processed)) { + QUICHE_DCHECK(false); + ParseError("Internal Error"); + } + } + if (end_of_stream && + (!buffered_message_.empty() || object_metadata_.has_value() || + message_type_.has_value() || message_length_.has_value())) { + // If the stream is ending, there should be no message in progress. + ParseError("Incomplete message at end of stream"); + } +} + +bool MoqtParser::MaybeMergeDataWithBuffer(quic::QuicDataReader& reader, + bool end_of_stream) { + // Copy as much information as necessary from |data| to complete the + // message or OBJECT header. Minimize unnecessary copying! + if (buffered_message_.empty()) { + return true; + } + quic::QuicDataReader buffer(buffered_message_); + if (!message_length_.has_value()) { + // The buffer contains part of the message type or length. + if (buffer.BytesRemaining() > buffer.PeekVarInt62Length()) { + ParseError("Internal Error"); + QUICHE_DCHECK(false); + return false; + } + size_t bytes_needed = buffer.PeekVarInt62Length() - buffer.BytesRemaining(); + if (bytes_needed > reader.BytesRemaining()) { + // Not enough to complete! + absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); + return false; + } + absl::StrAppend(&buffered_message_, + reader.PeekRemainingPayload().substr(0, bytes_needed)); + if (!reader.Seek(bytes_needed)) { + QUICHE_DCHECK(false); + ParseError("Internal Error"); + return false; + } + quic::QuicDataReader new_buffer(buffered_message_); + uint64_t value; + if (!new_buffer.ReadVarInt62(&value)) { + QUICHE_DCHECK(false); + ParseError("Internal Error"); + return false; + } + if (message_type_.has_value()) { + message_length_ = value; + } else { + message_type_ = static_cast<MoqtMessageType>(value); + } + // GOAWAY is special. Report the message as soon as the type and length + // are complete. + if (message_type_.has_value() && message_length_.has_value() && + *message_type_ == MoqtMessageType::kGoAway) { + ProcessGoAway(new_buffer.PeekRemainingPayload()); + EndOfMessage(); + return false; + } + // Proceed to normal parsing. + buffered_message_.clear(); + return true; + } + // It's a partially buffered message + if (NoMessageLength()) { + if (end_of_stream) { + message_length_ = buffer.BytesRemaining() + reader.BytesRemaining(); + } else if (*message_type_ != MoqtMessageType::kObject) { + absl::StrAppend(&buffered_message_, reader.PeekRemainingPayload()); + return false; + } + } + if (*message_type_ == MoqtMessageType::kObject) { + // OBJECT is a special case. Append up to KMaxObjectHeaderSize bytes to the + // buffer and see if that allows parsing. + QUICHE_DCHECK(!object_metadata_.has_value()); + size_t original_buffer_size = buffer.BytesRemaining(); + size_t bytes_to_pull = reader.BytesRemaining(); + // No check for *message_length_ == 0 below! Mutants will complain if there + // is a check. If message_length_ < original_buffer_size, the second + // argument will be a very large unsigned integer, which will be irrelevant + // due to std::min. + bytes_to_pull = std::min(reader.BytesRemaining(), + *message_length_ - original_buffer_size); + // Mutants complains that the line below doesn't fail any tests. This is a + // performance optimization to avoid copying large amounts of object payload + // into the buffer when only the OBJECT header will be processed. There is + // no observable behavior change if this line is removed. + bytes_to_pull = std::min(bytes_to_pull, kMaxObjectHeaderSize); + absl::StrAppend(&buffered_message_, + reader.PeekRemainingPayload().substr(0, bytes_to_pull)); + absl::optional<size_t> processed = + ProcessObjectVarints(absl::string_view(buffered_message_)); + if (!processed.has_value()) { + if ((!NoMessageLength() && + buffered_message_.length() == *message_length_) || + buffered_message_.length() > kMaxObjectHeaderSize) { + ParseError("Not able to parse buffered message given specified length"); + } + return false; + } + if (*processed > 0 && !reader.Seek(*processed - original_buffer_size)) { + ParseError("Internal Error"); + return false; + } + if (*processed == *message_length_) { + // This covers an edge case where the peer has sent an OBJECT message with + // no content. + visitor_.OnObjectMessage(object_metadata_.value(), absl::string_view(), + true); + EndOfMessage(); + return true; + } + if (!NoMessageLength()) { + *message_length_ -= *processed; + } + // Object payload is never processed in the buffer. + buffered_message_.clear(); + return true; + } + size_t bytes_to_pull = + (buffer.BytesRemaining() + reader.BytesRemaining() < *message_length_) + ? reader.BytesRemaining() + : *message_length_ - buffer.BytesRemaining(); + absl::StrAppend(&buffered_message_, + reader.PeekRemainingPayload().substr(0, bytes_to_pull)); + if (!reader.Seek(bytes_to_pull)) { + QUICHE_DCHECK(false); + ParseError("Internal Error"); + return false; + } + if (buffered_message_.length() < *message_length_) { + // Not enough bytes present. + return false; + } + absl::optional<size_t> processed = + ProcessMessage(absl::string_view(buffered_message_)); + if (!processed.has_value()) { + ParseError("Not able to parse buffered message given specified length"); + return false; + } + if (*processed != *message_length_) { + ParseError("Buffered message length too long for message contents"); + return false; + } + EndOfMessage(); + return true; +} + +absl::optional<size_t> MoqtParser::ProcessMessage(absl::string_view data) { + switch (*message_type_) { + case MoqtMessageType::kObject: + return ProcessObject(data); + case MoqtMessageType::kSetup: + return ProcessSetup(data); + case MoqtMessageType::kSubscribeRequest: + return ProcessSubscribeRequest(data); + case MoqtMessageType::kSubscribeOk: + return ProcessSubscribeOk(data); + case MoqtMessageType::kSubscribeError: + return ProcessSubscribeError(data); + case MoqtMessageType::kAnnounce: + return ProcessAnnounce(data); + case MoqtMessageType::kAnnounceOk: + return ProcessAnnounceOk(data); + case MoqtMessageType::kAnnounceError: + return ProcessAnnounceError(data); + case MoqtMessageType::kGoAway: + return ProcessGoAway(data); + default: + ParseError("Unknown message type"); + return absl::nullopt; + } +} + +absl::optional<size_t> MoqtParser::ProcessObjectVarints( + absl::string_view data) { + if (object_metadata_.has_value()) { + return 0; + } + object_metadata_ = MoqtObject(); + quic::QuicDataReader reader(data); + if (reader.ReadVarInt62(&object_metadata_->track_id) && + reader.ReadVarInt62(&object_metadata_->group_sequence) && + reader.ReadVarInt62(&object_metadata_->object_sequence) && + reader.ReadVarInt62(&object_metadata_->object_send_order)) { + return reader.PreviouslyReadPayload().length(); + } + object_metadata_ = absl::nullopt; + QUICHE_DCHECK(reader.PreviouslyReadPayload().length() < kMaxObjectHeaderSize); + return absl::nullopt; +} + +absl::optional<size_t> MoqtParser::ProcessObject(absl::string_view data) { + quic::QuicDataReader reader(data); + size_t payload_length = *message_length_; + absl::optional<size_t> processed = ProcessObjectVarints(data); + if (!processed.has_value() && !object_metadata_.has_value()) { + // Could not obtain the whole object header. + return absl::nullopt; + } + if (!reader.Seek(*processed)) { + ParseError("Internal Error"); + return absl::nullopt; + } + if (payload_length != 0) { + payload_length -= *processed; + } + QUICHE_DCHECK(NoMessageLength() || reader.BytesRemaining() <= payload_length); + visitor_.OnObjectMessage( + object_metadata_.value(), reader.PeekRemainingPayload(), + reader.BytesRemaining() == payload_length && !NoMessageLength()); + return data.length(); +} + +absl::optional<size_t> MoqtParser::ProcessSetup(absl::string_view data) { + MoqtSetup setup; + quic::QuicDataReader reader(data); + if (perspective_ == quic::Perspective::IS_SERVER) { + if (!reader.ReadVarInt62(&setup.number_of_supported_versions)) { + return absl::nullopt; + } + } else { + setup.number_of_supported_versions = 1; + } + uint64_t value; + for (uint64_t i = 0; i < setup.number_of_supported_versions; ++i) { + if (!reader.ReadVarInt62(&value)) { + return absl::nullopt; + } + setup.supported_versions.push_back(value); + } + // Parse parameters + while (!reader.IsDoneReading()) { + if (!reader.ReadVarInt62(&value)) { + return absl::nullopt; + } + auto parameter_key = static_cast<MoqtSetupParameter>(value); + absl::string_view field; + switch (parameter_key) { + case MoqtSetupParameter::kRole: + if (setup.role.has_value()) { + ParseError("ROLE parameter appears twice in SETUP"); + return absl::nullopt; + } + if (perspective_ == quic::Perspective::IS_CLIENT) { + ParseError("ROLE parameter sent by server in SETUP"); + return absl::nullopt; + } + if (!ReadIntegerPieceVarInt62(reader, value)) { + return absl::nullopt; + } + setup.role = static_cast<MoqtRole>(value); + break; + case MoqtSetupParameter::kPath: + if (uses_web_transport_) { + ParseError( + "WebTransport connection is using PATH parameter in SETUP"); + return absl::nullopt; + } + if (perspective_ == quic::Perspective::IS_CLIENT) { + ParseError("PATH parameter sent by server in SETUP"); + return absl::nullopt; + } + if (setup.path.has_value()) { + ParseError("PATH parameter appears twice in SETUP"); + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + setup.path = field; + break; + default: + // Skip over the parameter. + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + break; + } + } + if (perspective_ == quic::Perspective::IS_SERVER) { + if (!setup.role.has_value()) { + ParseError("ROLE SETUP parameter missing from Client message"); + return absl::nullopt; + } + if (!uses_web_transport_ && !setup.path.has_value()) { + ParseError("PATH SETUP parameter missing from Client message over QUIC"); + return absl::nullopt; + } + } + visitor_.OnSetupMessage(setup); + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessSubscribeRequest( + absl::string_view data) { + MoqtSubscribeRequest subscribe_request; + quic::QuicDataReader reader(data); + absl::string_view field; + if (!reader.ReadStringPieceVarInt62(&subscribe_request.full_track_name)) { + return absl::nullopt; + } + uint64_t value; + while (!reader.IsDoneReading()) { + if (!reader.ReadVarInt62(&value)) { + return absl::nullopt; + } + auto parameter_key = static_cast<MoqtTrackRequestParameter>(value); + switch (parameter_key) { + case MoqtTrackRequestParameter::kGroupSequence: + if (subscribe_request.group_sequence.has_value()) { + ParseError( + "GROUP_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); + return absl::nullopt; + } + if (!ReadIntegerPieceVarInt62(reader, value)) { + return absl::nullopt; + } + subscribe_request.group_sequence = value; + break; + case MoqtTrackRequestParameter::kObjectSequence: + if (subscribe_request.object_sequence.has_value()) { + ParseError( + "OBJECT_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); + return absl::nullopt; + } + if (!ReadIntegerPieceVarInt62(reader, value)) { + return absl::nullopt; + } + subscribe_request.object_sequence = value; + break; + case MoqtTrackRequestParameter::kAuthorizationInfo: + if (subscribe_request.authorization_info.has_value()) { + ParseError( + "AUTHORIZATION_INFO parameter appears twice in " + "SUBSCRIBE_REQUEST"); + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + subscribe_request.authorization_info = field; + break; + default: + // Skip over the parameter. + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + break; + } + } + if (reader.IsDoneReading()) { + visitor_.OnSubscribeRequestMessage(subscribe_request); + } + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessSubscribeOk(absl::string_view data) { + MoqtSubscribeOk subscribe_ok; + quic::QuicDataReader reader(data); + if (!reader.ReadStringPieceVarInt62(&subscribe_ok.full_track_name)) { + return absl::nullopt; + } + if (!reader.ReadVarInt62(&subscribe_ok.track_id)) { + return absl::nullopt; + } + uint64_t milliseconds; + if (!reader.ReadVarInt62(&milliseconds)) { + return absl::nullopt; + } + subscribe_ok.expires = quic::QuicTimeDelta::FromMilliseconds(milliseconds); + if (reader.IsDoneReading()) { + visitor_.OnSubscribeOkMessage(subscribe_ok); + } + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessSubscribeError( + absl::string_view data) { + MoqtSubscribeError subscribe_error; + quic::QuicDataReader reader(data); + if (!reader.ReadStringPieceVarInt62(&subscribe_error.full_track_name)) { + return absl::nullopt; + } + if (!reader.ReadVarInt62(&subscribe_error.error_code)) { + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&subscribe_error.reason_phrase)) { + return absl::nullopt; + } + if (reader.IsDoneReading()) { + visitor_.OnSubscribeErrorMessage(subscribe_error); + } + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessAnnounce(absl::string_view data) { + MoqtAnnounce announce; + quic::QuicDataReader reader(data); + absl::string_view field; + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + announce.track_namespace = field; + bool saw_group_sequence = false, saw_object_sequence = false; + while (!reader.IsDoneReading()) { + uint64_t value; + if (!reader.ReadVarInt62(&value)) { + return absl::nullopt; + } + auto parameter_key = static_cast<MoqtTrackRequestParameter>(value); + switch (parameter_key) { + case MoqtTrackRequestParameter::kGroupSequence: + // Not used, but check for duplicates. + if (saw_group_sequence) { + ParseError("GROUP_SEQUENCE parameter appears twice in ANNOUNCE"); + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + saw_group_sequence = true; + break; + case MoqtTrackRequestParameter::kObjectSequence: + // Not used, but check for duplicates. + if (saw_object_sequence) { + ParseError("OBJECT_SEQUENCE parameter appears twice in ANNOUNCE"); + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + saw_object_sequence = true; + break; + case MoqtTrackRequestParameter::kAuthorizationInfo: + if (announce.authorization_info.has_value()) { + ParseError("AUTHORIZATION_INFO parameter appears twice in ANNOUNCE"); + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + announce.authorization_info = field; + break; + default: + // Skip over the parameter. + if (!reader.ReadStringPieceVarInt62(&field)) { + return absl::nullopt; + } + break; + } + } + if (reader.IsDoneReading()) { + visitor_.OnAnnounceMessage(announce); + } + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessAnnounceOk(absl::string_view data) { + MoqtAnnounceOk announce_ok; + quic::QuicDataReader reader(data); + if (!reader.ReadStringPiece(&announce_ok.track_namespace, data.length())) { + return absl::nullopt; + } + if (reader.IsDoneReading()) { + visitor_.OnAnnounceOkMessage(announce_ok); + } + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessAnnounceError( + absl::string_view data) { + MoqtAnnounceError announce_error; + quic::QuicDataReader reader(data); + if (!reader.ReadStringPieceVarInt62(&announce_error.track_namespace)) { + return absl::nullopt; + } + if (!reader.ReadVarInt62(&announce_error.error_code)) { + return absl::nullopt; + } + if (!reader.ReadStringPieceVarInt62(&announce_error.reason_phrase)) { + return absl::nullopt; + } + if (reader.IsDoneReading()) { + visitor_.OnAnnounceErrorMessage(announce_error); + } + return reader.PreviouslyReadPayload().length(); +} + +absl::optional<size_t> MoqtParser::ProcessGoAway(absl::string_view data) { + if (!data.empty()) { + // GOAWAY can only be followed by end_of_stream. Anything else is an error. + ParseError("GOAWAY has data following"); + return absl::nullopt; + } + visitor_.OnGoAwayMessage(); + return 0; +} + +bool MoqtParser::GetMessageTypeAndLength(quic::QuicDataReader& reader) { + if (!message_type_.has_value()) { + uint64_t value; + if (!reader.ReadVarInt62(&value)) { + return false; + } + message_type_ = static_cast<MoqtMessageType>(value); + } + if (!message_length_.has_value()) { + uint64_t value; + if (!reader.ReadVarInt62(&value)) { + return false; + } + message_length_ = value; + } + return true; +} + +void MoqtParser::EndOfMessage() { + buffered_message_.clear(); + message_type_ = absl::nullopt; + message_length_ = absl::nullopt; + object_metadata_ = absl::nullopt; +} + +absl::string_view MoqtParser::FetchMessage(quic::QuicDataReader& reader) { + if (message_length_ == 0) { + return reader.PeekRemainingPayload(); + } + if (message_length_ > reader.BytesRemaining()) { + QUICHE_DCHECK(message_type_ == MoqtMessageType::kObject); + return reader.PeekRemainingPayload(); + } + return reader.PeekRemainingPayload().substr(0, *message_length_); +} + +void MoqtParser::ParseError(absl::string_view reason) { + if (parsing_error_) { + return; // Don't send multiple parse errors. + } + no_more_data_ = true; + parsing_error_ = true; + visitor_.OnParsingError(reason); +} + +bool MoqtParser::ReadIntegerPieceVarInt62(quic::QuicDataReader& reader, + uint64_t& result) { + absl::string_view field; + if (!reader.ReadStringPieceVarInt62(&field)) { + return false; + } + if (field.size() > sizeof(uint64_t)) { + ParseError("Cannot parse explicit length integers longer than 8 bytes"); + return false; + } + result = 0; + memcpy((uint8_t*)&result + sizeof(result) - field.size(), field.data(), + field.size()); + result = quiche::QuicheEndian::NetToHost64(result); + return true; +} + +} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h new file mode 100644 index 0000000..d931922 --- /dev/null +++ b/quiche/quic/moqt/moqt_parser.h
@@ -0,0 +1,137 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A parser for draft-ietf-moq-transport-00. + +#ifndef QUICHE_QUIC_MOQT_MOQT_PARSER_H_ +#define QUICHE_QUIC_MOQT_MOQT_PARSER_H_ + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace moqt { + +class QUICHE_EXPORT MoqtParserVisitor { + public: + virtual ~MoqtParserVisitor() = default; + + // If |end_of_message| is true, |payload| contains the last bytes of the + // OBJECT payload. If not, there will be subsequent calls with further payload + // data. The parser retains ownership of |message| and |payload|, so the + // visitor needs to copy anything it wants to retain. + virtual void OnObjectMessage(const MoqtObject& message, + absl::string_view payload, + bool end_of_message) = 0; + // All of these are called only when the entire specified message length has + // arrived, which requires a stream FIN if the length is zero. The parser + // retains ownership of the memory. + virtual void OnSetupMessage(const MoqtSetup& message) = 0; + virtual void OnSubscribeRequestMessage( + const MoqtSubscribeRequest& message) = 0; + virtual void OnSubscribeOkMessage(const MoqtSubscribeOk& message) = 0; + virtual void OnSubscribeErrorMessage(const MoqtSubscribeError& message) = 0; + virtual void OnAnnounceMessage(const MoqtAnnounce& message) = 0; + virtual void OnAnnounceOkMessage(const MoqtAnnounceOk& message) = 0; + virtual void OnAnnounceErrorMessage(const MoqtAnnounceError& message) = 0; + // In an exception to the above, the parser calls this when it gets two bytes, + // whether or not it includes stream FIN. When a zero-length message has + // special meaning, a message with an actual length of zero is tricky! + virtual void OnGoAwayMessage() = 0; + + virtual void OnParsingError(absl::string_view reason) = 0; +}; + +class QUICHE_EXPORT MoqtParser { + public: + MoqtParser(quic::Perspective perspective, bool uses_web_transport, + MoqtParserVisitor& visitor) + : visitor_(visitor), + perspective_(perspective), + uses_web_transport_(uses_web_transport) {} + ~MoqtParser() = default; + + // Take a buffer from the transport in |data|. Parse each complete message and + // call the appropriate visitor function. If |end_of_stream| is true, there + // is no more data arriving on the stream, so the parser will deliver any + // message encoded as to run to the end of the stream. + // All bytes can be freed. Calls OnParsingError() when there is a parsing + // error. + // Any calls after sending |end_of_stream| = true will be ignored. + void ProcessData(absl::string_view data, bool end_of_stream); + + private: + // Copies the minimum amount of data in |reader| to buffered_message_ in order + // to process what is in there, and does the processing. Returns true if + // additional processing can occur, false otherwise. + bool MaybeMergeDataWithBuffer(quic::QuicDataReader& reader, + bool end_of_stream); + + // The central switch statement to dispatch a message to the correct + // Process* function. Returns nullopt if it could not parse the full messsage + // (except for object payload). Otherwise, returns the number of bytes + // processed. + absl::optional<size_t> ProcessMessage(absl::string_view data); + // A helper function to parse just the varints in an OBJECT. + absl::optional<size_t> ProcessObjectVarints(absl::string_view data); + // The Process* functions parse the serialized data into the appropriate + // structs, and call the relevant visitor function for further action. Returns + // the number of bytes consumed if the message is complete; returns nullopt + // otherwise. These functions can throw a fatal error if the message length + // is insufficient. + absl::optional<size_t> ProcessObject(absl::string_view data); + absl::optional<size_t> ProcessSetup(absl::string_view data); + absl::optional<size_t> ProcessSubscribeRequest(absl::string_view data); + absl::optional<size_t> ProcessSubscribeOk(absl::string_view data); + absl::optional<size_t> ProcessSubscribeError(absl::string_view data); + absl::optional<size_t> ProcessAnnounce(absl::string_view data); + absl::optional<size_t> ProcessAnnounceOk(absl::string_view data); + absl::optional<size_t> ProcessAnnounceError(absl::string_view data); + absl::optional<size_t> ProcessGoAway(absl::string_view data); + + // If the message length field is zero, it runs to the end of the stream. + bool NoMessageLength() { return *message_length_ == 0; } + // If type and or length are not already stored for this message, reads it out + // of the data in |reader| and stores it in the appropriate members. Returns + // false if length is not available. + bool GetMessageTypeAndLength(quic::QuicDataReader& reader); + void EndOfMessage(); + // Get a string_view of the part of the reader covered by message_length_, + // with exceptions for OBJECT messages. + absl::string_view FetchMessage(quic::QuicDataReader& reader); + void ParseError(absl::string_view reason); + + // Reads an integer whose length is specified by a preceding VarInt62 and + // returns it in |result|. Returns false if parsing fails. + bool ReadIntegerPieceVarInt62(quic::QuicDataReader& reader, uint64_t& result); + + MoqtParserVisitor& visitor_; + // Client or server? + quic::Perspective perspective_; + bool uses_web_transport_; + bool no_more_data_ = false; // Fatal error or end_of_stream. No more parsing. + bool parsing_error_ = false; + + std::string buffered_message_; + absl::optional<MoqtMessageType> message_type_ = absl::nullopt; + absl::optional<size_t> message_length_ = absl::nullopt; + + // Metadata for an object which is delivered in parts. + absl::optional<MoqtObject> object_metadata_ = absl::nullopt; + + bool processing_ = false; // True if currently in ProcessData(), to prevent + // re-entrancy. +}; + +} // namespace moqt + +#endif // QUICHE_QUIC_MOQT_MOQT_PARSER_H_
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc new file mode 100644 index 0000000..357be26 --- /dev/null +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -0,0 +1,858 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_parser.h" + +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <memory> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/test_tools/moqt_test_message.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace moqt::test { + +struct MoqtParserTestParams { + MoqtParserTestParams(MoqtMessageType message_type, + quic::Perspective perspective, bool uses_web_transport) + : message_type(message_type), + perspective(perspective), + uses_web_transport(uses_web_transport) {} + MoqtMessageType message_type; + quic::Perspective perspective; + bool uses_web_transport; +}; + +std::vector<MoqtParserTestParams> GetMoqtParserTestParams() { + std::vector<MoqtParserTestParams> params; + std::vector<MoqtMessageType> message_types = { + MoqtMessageType::kObject, MoqtMessageType::kSetup, + MoqtMessageType::kSubscribeRequest, MoqtMessageType::kSubscribeOk, + MoqtMessageType::kSubscribeError, MoqtMessageType::kAnnounce, + MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, + MoqtMessageType::kGoAway, + }; + std::vector<quic::Perspective> perspectives = { + quic::Perspective::IS_SERVER, + quic::Perspective::IS_CLIENT, + }; + std::vector<bool> uses_web_transport_bool = { + false, + true, + }; + for (const MoqtMessageType message_type : message_types) { + if (message_type == MoqtMessageType::kSetup) { + for (const quic::Perspective perspective : perspectives) { + for (const bool uses_web_transport : uses_web_transport_bool) { + params.push_back(MoqtParserTestParams(message_type, perspective, + uses_web_transport)); + } + } + } else { + // All other types are processed the same for either perspective or + // transport. + params.push_back(MoqtParserTestParams( + message_type, quic::Perspective::IS_SERVER, true)); + } + } + return params; +} + +std::string ParamNameFormatter( + const testing::TestParamInfo<MoqtParserTestParams>& info) { + return MoqtMessageTypeToString(info.param.message_type) + "_" + + (info.param.perspective == quic::Perspective::IS_SERVER ? "Server" + : "Client") + + "_" + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); +} + +class MoqtParserTestVisitor : public MoqtParserVisitor { + public: + ~MoqtParserTestVisitor() = default; + + void OnObjectMessage(const MoqtObject& message, absl::string_view payload, + bool end_of_message) override { + object_payload_ = payload; + end_of_message_ = end_of_message; + messages_received_++; + last_message_ = TestMessageBase::MessageStructuredData(message); + } + void OnSetupMessage(const MoqtSetup& message) override { + end_of_message_ = true; + messages_received_++; + MoqtSetup setup = message; + if (setup.path.has_value()) { + string0_ = std::string(setup.path.value()); + setup.path = absl::string_view(string0_); + } + last_message_ = TestMessageBase::MessageStructuredData(setup); + } + void OnSubscribeRequestMessage(const MoqtSubscribeRequest& message) override { + end_of_message_ = true; + messages_received_++; + MoqtSubscribeRequest subscribe_request = message; + string0_ = std::string(subscribe_request.full_track_name); + subscribe_request.full_track_name = absl::string_view(string0_); + if (subscribe_request.authorization_info.has_value()) { + string1_ = std::string(subscribe_request.authorization_info.value()); + subscribe_request.authorization_info = absl::string_view(string1_); + } + last_message_ = TestMessageBase::MessageStructuredData(subscribe_request); + } + void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override { + end_of_message_ = true; + messages_received_++; + MoqtSubscribeOk subscribe_ok = message; + string0_ = std::string(subscribe_ok.full_track_name); + subscribe_ok.full_track_name = absl::string_view(string0_); + last_message_ = TestMessageBase::MessageStructuredData(subscribe_ok); + } + void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override { + end_of_message_ = true; + messages_received_++; + MoqtSubscribeError subscribe_error = message; + string0_ = std::string(subscribe_error.full_track_name); + subscribe_error.full_track_name = absl::string_view(string0_); + string1_ = std::string(subscribe_error.reason_phrase); + subscribe_error.reason_phrase = absl::string_view(string1_); + last_message_ = TestMessageBase::MessageStructuredData(subscribe_error); + } + void OnAnnounceMessage(const MoqtAnnounce& message) override { + end_of_message_ = true; + messages_received_++; + MoqtAnnounce announce = message; + string0_ = std::string(announce.track_namespace); + announce.track_namespace = absl::string_view(string0_); + if (announce.authorization_info.has_value()) { + string1_ = std::string(announce.authorization_info.value()); + announce.authorization_info = absl::string_view(string1_); + } + last_message_ = TestMessageBase::MessageStructuredData(announce); + } + void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override { + end_of_message_ = true; + messages_received_++; + MoqtAnnounceOk announce_ok = message; + string0_ = std::string(announce_ok.track_namespace); + announce_ok.track_namespace = absl::string_view(string0_); + last_message_ = TestMessageBase::MessageStructuredData(announce_ok); + } + void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override { + end_of_message_ = true; + messages_received_++; + MoqtAnnounceError announce_error = message; + string0_ = std::string(announce_error.track_namespace); + announce_error.track_namespace = absl::string_view(string0_); + string1_ = std::string(announce_error.reason_phrase); + announce_error.reason_phrase = absl::string_view(string1_); + last_message_ = TestMessageBase::MessageStructuredData(announce_error); + } + void OnGoAwayMessage() override { + got_goaway_ = true; + end_of_message_ = true; + messages_received_++; + last_message_ = TestMessageBase::MessageStructuredData(); + } + void OnParsingError(absl::string_view reason) override { + QUIC_LOG(INFO) << "Parsing error: " << reason; + parsing_error_ = reason; + } + + absl::optional<absl::string_view> object_payload_; + bool end_of_message_ = false; + bool got_goaway_ = false; + absl::optional<absl::string_view> parsing_error_; + uint64_t messages_received_ = 0; + absl::optional<TestMessageBase::MessageStructuredData> last_message_; + // Stored strings for last_message_. The visitor API does not promise the + // memory pointed to by string_views is persistent. + std::string string0_, string1_; +}; + +class MoqtParserTest + : public quic::test::QuicTestWithParam<MoqtParserTestParams> { + public: + MoqtParserTest() + : message_type_(GetParam().message_type), + is_client_(GetParam().perspective == quic::Perspective::IS_CLIENT), + webtrans_(GetParam().uses_web_transport), + parser_(GetParam().perspective, GetParam().uses_web_transport, + visitor_) {} + + std::unique_ptr<TestMessageBase> MakeMessage(MoqtMessageType message_type) { + switch (message_type) { + case MoqtMessageType::kObject: + return std::make_unique<ObjectMessage>(); + case MoqtMessageType::kSetup: + return std::make_unique<SetupMessage>(is_client_, webtrans_); + case MoqtMessageType::kSubscribeRequest: + return std::make_unique<SubscribeRequestMessage>(); + case MoqtMessageType::kSubscribeOk: + return std::make_unique<SubscribeOkMessage>(); + case MoqtMessageType::kSubscribeError: + return std::make_unique<SubscribeErrorMessage>(); + case MoqtMessageType::kAnnounce: + return std::make_unique<AnnounceMessage>(); + case moqt::MoqtMessageType::kAnnounceOk: + return std::make_unique<AnnounceOkMessage>(); + case moqt::MoqtMessageType::kAnnounceError: + return std::make_unique<AnnounceErrorMessage>(); + case moqt::MoqtMessageType::kGoAway: + return std::make_unique<GoAwayMessage>(); + default: + return nullptr; + } + } + + MoqtParserTestVisitor visitor_; + MoqtMessageType message_type_; + bool is_client_; + bool webtrans_; + MoqtParser parser_; +}; + +INSTANTIATE_TEST_SUITE_P(MoqtParserTests, MoqtParserTest, + testing::ValuesIn(GetMoqtParserTestParams()), + ParamNameFormatter); + +TEST_P(MoqtParserTest, OneMessage) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + if (message_type_ == MoqtMessageType::kObject) { + // Check payload message. + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "foo"); + } +} + +TEST_P(MoqtParserTest, OneMessageWithLongVarints) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->ExpandVarints(); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + if (message_type_ == MoqtMessageType::kObject) { + // Check payload message. + EXPECT_EQ(visitor_.object_payload_, "foo"); + } + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, MessageNoLengthWithFin) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->set_message_size(0); + parser_.ProcessData(message->PacketSample(), true); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + if (message_type_ == MoqtMessageType::kObject) { + // Check payload message. + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "foo"); + } + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, MessageNoLengthSeparateFinObjectOrGoAway) { + // OBJECT and GOAWAY can return on a zero-length message even without + // receiving a FIN. + if (message_type_ != MoqtMessageType::kObject && + message_type_ != MoqtMessageType::kGoAway) { + return; + } + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->set_message_size(0); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 1); + if (message_type_ == MoqtMessageType::kGoAway) { + EXPECT_TRUE(visitor_.got_goaway_); + EXPECT_TRUE(visitor_.end_of_message_); + return; + } + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_FALSE(visitor_.end_of_message_); + + parser_.ProcessData(absl::string_view(), true); // send the FIN + EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), ""); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, MessageNoLengthSeparateFinOtherTypes) { + if (message_type_ == MoqtMessageType::kObject || + message_type_ == MoqtMessageType::kGoAway) { + return; + } + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->set_message_size(0); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 0); + parser_.ProcessData(absl::string_view(), true); // send the FIN + EXPECT_EQ(visitor_.messages_received_, 1); + + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, TwoPartMessage) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + // The test Object message has payload for less then half the message length, + // so splitting the message in half will prevent the first half from being + // processed. + size_t first_data_size = message->total_message_size() / 2; + parser_.ProcessData(message->PacketSample().substr(0, first_data_size), + false); + EXPECT_EQ(visitor_.messages_received_, 0); + parser_.ProcessData( + message->PacketSample().substr( + first_data_size, message->total_message_size() - first_data_size), + false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + if (message_type_ == MoqtMessageType::kObject) { + EXPECT_EQ(visitor_.object_payload_, "foo"); + } + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +// Send the header + some payload, pure payload, then pure payload to end the +// message. +TEST_P(MoqtParserTest, ThreePartObject) { + if (message_type_ != MoqtMessageType::kObject) { + return; + } + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->set_message_size(0); + // The test Object message has payload for less then half the message length, + // so splitting the message in half will prevent the first half from being + // processed. + parser_.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "foo"); + + // second part + parser_.ProcessData("bar", false); + EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "bar"); + + // third part includes FIN + parser_.ProcessData("deadbeef", true); + EXPECT_EQ(visitor_.messages_received_, 3); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "deadbeef"); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +// Send the part of header, rest of header + payload, plus payload. +TEST_P(MoqtParserTest, ThreePartObjectFirstIncomplete) { + if (message_type_ != MoqtMessageType::kObject) { + return; + } + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->set_message_size(0); + + // first part + parser_.ProcessData(message->PacketSample().substr(0, 4), false); + EXPECT_EQ(visitor_.messages_received_, 0); + + // second part. Add padding to it. + message->set_wire_image_size(100); + parser_.ProcessData( + message->PacketSample().substr(4, message->total_message_size() - 4), + false); + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(visitor_.object_payload_->length(), 94); + + // third part includes FIN + parser_.ProcessData("bar", true); + EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + EXPECT_EQ(*(visitor_.object_payload_), "bar"); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, OneByteAtATime) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->set_message_size(0); + constexpr size_t kObjectPrePayloadSize = 6; + for (size_t i = 0; i < message->total_message_size(); ++i) { + parser_.ProcessData(message->PacketSample().substr(i, 1), false); + if (message_type_ == MoqtMessageType::kGoAway && + i == message->total_message_size() - 1) { + // OnGoAway() is called before FIN. + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + break; + } + if (message_type_ != MoqtMessageType::kObject || + i < kObjectPrePayloadSize) { + // OBJECTs will have to buffer for the first 5 bytes (until the varints + // are done). The sixth byte is a bare OBJECT header, so the parser does + // not notify the visitor. + EXPECT_EQ(visitor_.messages_received_, 0); + } else { + // OBJECT payload processing. + EXPECT_EQ(visitor_.messages_received_, i - kObjectPrePayloadSize + 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + if (i == 5) { + EXPECT_EQ(visitor_.object_payload_->length(), 0); + } else { + EXPECT_EQ(visitor_.object_payload_->length(), 1); + EXPECT_EQ((*visitor_.object_payload_)[0], + message->PacketSample().substr(i, 1)[0]); + } + } + EXPECT_FALSE(visitor_.end_of_message_); + } + // Send FIN + parser_.ProcessData(absl::string_view(), true); + if (message_type_ == MoqtMessageType::kObject) { + EXPECT_EQ(visitor_.messages_received_, + message->total_message_size() - kObjectPrePayloadSize + 1); + } else { + EXPECT_EQ(visitor_.messages_received_, 1); + } + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, OneByteAtATimeLongerVarints) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + message->ExpandVarints(); + message->set_message_size(0); + constexpr size_t kObjectPrePayloadSize = 28; + for (size_t i = 0; i < message->total_message_size(); ++i) { + parser_.ProcessData(message->PacketSample().substr(i, 1), false); + if (message_type_ == MoqtMessageType::kGoAway && + i == message->total_message_size() - 1) { + // OnGoAway() is called before FIN. + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + break; + } + if (message_type_ != MoqtMessageType::kObject || + i < kObjectPrePayloadSize) { + // OBJECTs will have to buffer for the first 5 bytes (until the varints + // are done). The sixth byte is a bare OBJECT header, so the parser does + // not notify the visitor. + EXPECT_EQ(visitor_.messages_received_, 0); + } else { + // OBJECT payload processing. + EXPECT_EQ(visitor_.messages_received_, i - kObjectPrePayloadSize + 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + if (i == 5) { + EXPECT_EQ(visitor_.object_payload_->length(), 0); + } else { + EXPECT_EQ(visitor_.object_payload_->length(), 1); + EXPECT_EQ((*visitor_.object_payload_)[0], + message->PacketSample().substr(i, 1)[0]); + } + } + EXPECT_FALSE(visitor_.end_of_message_); + } + // Send FIN + parser_.ProcessData(absl::string_view(), true); + if (message_type_ == MoqtMessageType::kObject) { + EXPECT_EQ(visitor_.messages_received_, + message->total_message_size() - kObjectPrePayloadSize + 1); + } else { + EXPECT_EQ(visitor_.messages_received_, 1); + } + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, OneByteAtATimeKnownLength) { + std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + constexpr size_t kObjectPrePayloadSize = 6; + // Send all but the last byte + for (size_t i = 0; i < message->total_message_size() - 1; ++i) { + parser_.ProcessData(message->PacketSample().substr(i, 1), false); + if (message_type_ != MoqtMessageType::kObject || + i < kObjectPrePayloadSize) { + // OBJECTs will have to buffer for the first 5 bytes (until the varints + // are done). The sixth byte is a bare OBJECT header, so the parser does + // not notify the visitor. + EXPECT_EQ(visitor_.messages_received_, 0); + } else { + // OBJECT payload processing. + EXPECT_EQ(visitor_.messages_received_, i - kObjectPrePayloadSize + 1); + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.object_payload_.has_value()); + if (i == 5) { + EXPECT_EQ(visitor_.object_payload_->length(), 0); + } else { + EXPECT_EQ(visitor_.object_payload_->length(), 1); + EXPECT_EQ((*visitor_.object_payload_)[0], + message->PacketSample().substr(i, 1)[0]); + } + } + EXPECT_FALSE(visitor_.end_of_message_); + } + // Send last byte + parser_.ProcessData( + message->PacketSample().substr(message->total_message_size() - 1, 1), + false); + if (message_type_ == MoqtMessageType::kObject) { + EXPECT_EQ(visitor_.messages_received_, + message->total_message_size() - kObjectPrePayloadSize); + EXPECT_EQ(visitor_.object_payload_->length(), 1); + EXPECT_EQ((*visitor_.object_payload_)[0], + message->PacketSample().substr(message->total_message_size() - 1, + 1)[0]); + } else { + EXPECT_EQ(visitor_.messages_received_, 1); + } + EXPECT_TRUE(message->EqualFieldValues(visitor_.last_message_.value())); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); +} + +TEST_P(MoqtParserTest, LengthTooShort) { + if (message_type_ == MoqtMessageType::kGoAway || + message_type_ == MoqtMessageType::kAnnounceOk) { + // GOAWAY already has length zero. ANNOUNCE_OK works for any message length. + return; + } + auto message = MakeMessage(message_type_); + if (message_type_ == MoqtMessageType::kSetup && + GetParam().perspective == quic::Perspective::IS_CLIENT) { + // Unless varints are longer than necessary, the message is only one byte + // long. + message->ExpandVarints(); + } + size_t truncate = (message_type_ == MoqtMessageType::kObject) ? 4 : 1; + message->set_message_size(message->message_size() - truncate); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "Not able to parse message given specified length"); +} + +// Buffered packets are a different code path, so test them separately. +TEST_P(MoqtParserTest, LengthTooShortInBufferedPacket) { + if (message_type_ == MoqtMessageType::kGoAway || + message_type_ == MoqtMessageType::kAnnounceOk) { + // GOAWAY already has length zero. ANNOUNCE_OK works for any message length. + return; + } + auto message = MakeMessage(message_type_); + if (message_type_ == MoqtMessageType::kSetup && + GetParam().perspective == quic::Perspective::IS_CLIENT) { + // Unless varints are longer than necessary, the message is only one byte + // long. + message->ExpandVarints(); + } + EXPECT_EQ(visitor_.messages_received_, 0); + size_t truncate = (message_type_ == MoqtMessageType::kObject) ? 5 : 2; + message->set_message_size(message->message_size() - truncate + 1); + parser_.ProcessData( + message->PacketSample().substr(0, message->total_message_size() - 1), + false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); + // send the last byte + parser_.ProcessData( + message->PacketSample().substr(message->total_message_size() - 1, 1), + false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "Not able to parse buffered message given specified length"); +} + +TEST_P(MoqtParserTest, LengthTooLong) { + if (message_type_ == MoqtMessageType::kAnnounceOk || + message_type_ == MoqtMessageType::kObject || + message_type_ == MoqtMessageType::kSetup || + message_type_ == MoqtMessageType::kSubscribeRequest || + message_type_ == MoqtMessageType::kAnnounce) { + // OBJECT and ANNOUNCE_OK work for any message length. + // SETUP, SUBSCRIBE_REQUEST, and ANNOUNCE have parameters, so an additional + // byte will cause the message to be interpreted as being too short. + return; + } + auto message = MakeMessage(message_type_); + message->set_message_size(message->message_size() + 1); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(visitor_.messages_received_, 0); + if (message_type_ == MoqtMessageType::kGoAway) { + EXPECT_EQ(*visitor_.parsing_error_, "GOAWAY has data following"); + } else { + EXPECT_EQ(*visitor_.parsing_error_, "Specified message length too long"); + } +} + +TEST_P(MoqtParserTest, LengthExceedsBufferSize) { + if (message_type_ == MoqtMessageType::kObject) { + // OBJECT works for any length. + return; + } + auto message = MakeMessage(message_type_); + message->set_message_size(kMaxMessageHeaderSize + 1); + parser_.ProcessData(message->PacketSample(), false); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(visitor_.messages_received_, 0); + if (message_type_ == MoqtMessageType::kGoAway) { + EXPECT_EQ(*visitor_.parsing_error_, "GOAWAY has data following"); + } else { + EXPECT_EQ(*visitor_.parsing_error_, "Message too long"); + } +} + +// Tests for message-specific error cases. +class MoqtParserErrorTest : public quic::test::QuicTest { + public: + MoqtParserErrorTest() {} + + MoqtParserTestVisitor visitor_; + + static constexpr bool kWebTrans = true; + static constexpr bool kRawQuic = false; +}; + +TEST_F(MoqtParserErrorTest, SetupRoleAppearsTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); + char setup[] = { + 0x01, 0x0e, 0x02, 0x01, 0x02, // versions + 0x00, 0x01, 0x03, // role = both + 0x00, 0x01, 0x03, // role = both + 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo" + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "ROLE parameter appears twice in SETUP"); +} + +TEST_F(MoqtParserErrorTest, SetupRoleFromServer) { + MoqtParser parser(quic::Perspective::IS_CLIENT, kWebTrans, visitor_); + char setup[] = { + 0x01, 0x04, + 0x01, // version = 1 + 0x00, 0x01, 0x03, // role = both + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "ROLE parameter sent by server in SETUP"); +} + +TEST_F(MoqtParserErrorTest, SetupRoleIsMissing) { + MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); + char setup[] = { + 0x01, 0x08, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "ROLE SETUP parameter missing from Client message"); +} + +TEST_F(MoqtParserErrorTest, SetupPathFromServer) { + MoqtParser parser(quic::Perspective::IS_CLIENT, kRawQuic, visitor_); + char setup[] = { + 0x01, 0x06, + 0x01, // version = 1 + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "PATH parameter sent by server in SETUP"); +} + +TEST_F(MoqtParserErrorTest, SetupPathAppearsTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); + char setup[] = { + 0x01, 0x10, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x00, 0x01, 0x03, // role = both + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, "PATH parameter appears twice in SETUP"); +} + +TEST_F(MoqtParserErrorTest, SetupPathOverWebtrans) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char setup[] = { + 0x01, 0x0b, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x00, 0x01, 0x03, // role = both + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "WebTransport connection is using PATH parameter in SETUP"); +} + +TEST_F(MoqtParserErrorTest, SetupPathMissing) { + MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); + char setup[] = { + 0x01, 0x06, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x00, 0x01, 0x03, // role = both + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "PATH SETUP parameter missing from Client message over QUIC"); +} + +TEST_F(MoqtParserErrorTest, SetupRoleTooLong) { + MoqtParser parser(quic::Perspective::IS_SERVER, kRawQuic, visitor_); + char setup[] = { + 0x01, 0x0e, 0x02, 0x01, 0x02, // versions + // role = both + 0x00, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x01, + 0x03, 0x66, 0x6f, 0x6f // path = "foo" + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "Cannot parse explicit length integers longer than 8 bytes"); +} + +TEST_F(MoqtParserErrorTest, SubscribeRequestGroupSequenceTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char subscribe_request[] = { + 0x03, 0x12, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x00, 0x01, 0x01, // group_sequence = 1 + 0x00, 0x01, 0x01, // group_sequence = 1 + 0x01, 0x01, 0x02, // object_sequence = 2 + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_request, sizeof(subscribe_request)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "GROUP_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); +} + +TEST_F(MoqtParserErrorTest, SubscribeRequestObjectSequenceTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char subscribe_request[] = { + 0x03, 0x12, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x00, 0x01, 0x01, // group_sequence = 1 + 0x01, 0x01, 0x02, // object_sequence = 2 + 0x01, 0x01, 0x02, // object_sequence = 2 + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_request, sizeof(subscribe_request)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "OBJECT_SEQUENCE parameter appears twice in SUBSCRIBE_REQUEST"); +} + +TEST_F(MoqtParserErrorTest, SubscribeRequestAuthorizationInfoTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char subscribe_request[] = { + 0x03, 0x14, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x00, 0x01, 0x01, // group_sequence = 1 + 0x01, 0x01, 0x02, // object_sequence = 2 + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData( + absl::string_view(subscribe_request, sizeof(subscribe_request)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "AUTHORIZATION_INFO parameter appears twice in SUBSCRIBE_REQUEST"); +} + +TEST_F(MoqtParserErrorTest, AnnounceGroupSequenceTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char announce[] = { + 0x06, 0x0f, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x00, 0x01, 0x01, // group_sequence = 1 + 0x00, 0x01, 0x01, // group_sequence = 1 + }; + parser.ProcessData(absl::string_view(announce, sizeof(announce)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "GROUP_SEQUENCE parameter appears twice in ANNOUNCE"); +} + +TEST_F(MoqtParserErrorTest, AnnounceObjectSequenceTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char announce[] = { + 0x06, 0x0e, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x01, 0x01, 0x02, // object_sequence = 2 + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x01, 0x01, 0x02, // object_sequence = 2 + }; + parser.ProcessData(absl::string_view(announce, sizeof(announce)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "OBJECT_SEQUENCE parameter appears twice in ANNOUNCE"); +} + +TEST_F(MoqtParserErrorTest, AnnounceAuthorizationInfoTwice) { + MoqtParser parser(quic::Perspective::IS_SERVER, kWebTrans, visitor_); + char announce[] = { + 0x06, 0x0e, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + parser.ProcessData(absl::string_view(announce, sizeof(announce)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "AUTHORIZATION_INFO parameter appears twice in ANNOUNCE"); +} + +} // namespace moqt::test
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h new file mode 100644 index 0000000..aa4c8c0 --- /dev/null +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -0,0 +1,526 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_ +#define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_ + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_endian.h" + +namespace moqt::test { + +// Base class containing a wire image and the corresponding structured +// representation of an example of each message. It allows parser and framer +// tests to iterate through all message types without much specialized code. +class QUICHE_NO_EXPORT TestMessageBase { + public: + TestMessageBase(MoqtMessageType message_type) : message_type_(message_type) {} + virtual ~TestMessageBase() = default; + MoqtMessageType message_type() const { return message_type_; } + + typedef absl::variant<MoqtSetup, MoqtObject, MoqtSubscribeRequest, + MoqtSubscribeOk, MoqtSubscribeError, MoqtAnnounce, + MoqtAnnounceOk, MoqtAnnounceError, MoqtGoAway> + MessageStructuredData; + + // The total actual size of the message. + size_t total_message_size() const { return wire_image_size_; } + + // The message size indicated in the second varint in every message. + size_t message_size() const { + quic::QuicDataReader reader(PacketSample()); + uint64_t value; + if (!reader.ReadVarInt62(&value)) { + return 0; + } + if (!reader.ReadVarInt62(&value)) { + return 0; + } + return value; + } + + absl::string_view PacketSample() const { + return absl::string_view(wire_image_, wire_image_size_); + } + + void set_wire_image_size(size_t wire_image_size) { + wire_image_size_ = wire_image_size; + } + + // Sets the message length field. If |message_size| == 0, just change the + // field in the wire image. If another value, this will either truncate the + // message or increase its length (which adds uninitialized bytes). This can + // be useful for playing with different Object Payload lengths, for example. + void set_message_size(uint64_t message_size) { + char new_wire_image[sizeof(wire_image_)]; + quic::QuicDataReader reader(PacketSample()); + quic::QuicDataWriter writer(sizeof(new_wire_image), new_wire_image); + uint64_t type; + auto field_size = reader.PeekVarInt62Length(); + reader.ReadVarInt62(&type); + writer.WriteVarInt62WithForcedLength( + type, std::max(field_size, writer.GetVarInt62Len(type))); + uint64_t original_length; + field_size = reader.PeekVarInt62Length(); + reader.ReadVarInt62(&original_length); + // Try to preserve the original field length, unless it's too small. + writer.WriteVarInt62WithForcedLength( + message_size, + std::max(field_size, writer.GetVarInt62Len(message_size))); + writer.WriteStringPiece(reader.PeekRemainingPayload()); + memcpy(wire_image_, new_wire_image, writer.length()); + wire_image_size_ = writer.length(); + if (message_size > original_length) { + wire_image_size_ += (message_size - original_length); + } + if (message_size > 0 && message_size < original_length) { + wire_image_size_ -= (original_length - message_size); + } + } + + // Compares |values| to the derived class's structured data to make sure + // they are equal. + virtual bool EqualFieldValues(MessageStructuredData& values) const = 0; + + // Expand all varints in the message. This is pure virtual because each + // message has a different layout of varints. + virtual void ExpandVarints() = 0; + + protected: + void SetWireImage(uint8_t* wire_image, size_t wire_image_size) { + memcpy(wire_image_, wire_image, wire_image_size); + wire_image_size_ = wire_image_size; + } + + // Expands all the varints in the message, alternating between making them 2, + // 4, and 8 bytes long. Updates length fields accordingly. + // Each character in |varints| corresponds to a byte in the original message. + // If there is a 'v', it is a varint that should be expanded. If '-', skip + // to the next byte. + void ExpandVarintsImpl(absl::string_view varints) { + int next_varint_len = 2; + char new_wire_image[kMaxMessageHeaderSize + 1]; + quic::QuicDataReader reader( + absl::string_view(wire_image_, wire_image_size_)); + quic::QuicDataWriter writer(sizeof(new_wire_image), new_wire_image); + size_t message_length = 0; + int item = 0; + size_t i = 0; + while (!reader.IsDoneReading()) { + if (i >= varints.length() || varints[i++] == '-') { + uint8_t byte; + reader.ReadUInt8(&byte); + writer.WriteUInt8(byte); + continue; + } + uint64_t value; + item++; + reader.ReadVarInt62(&value); + writer.WriteVarInt62WithForcedLength( + value, static_cast<quiche::QuicheVariableLengthIntegerLength>( + next_varint_len)); + if (item == 2) { + // this is the message length field. + message_length = value; + } + next_varint_len *= 2; + if (next_varint_len == 16) { + next_varint_len = 2; + } + } + if (message_length > 0) { + // Update message length. Based on the progression of next_varint_len, + // the message_type is 2 bytes and message_length is 4 bytes. + message_length = writer.length() - 6; + auto new_writer = quic::QuicDataWriter(4, (char*)&new_wire_image[2]); + new_writer.WriteVarInt62WithForcedLength( + message_length, + static_cast<quiche::QuicheVariableLengthIntegerLength>(4)); + } + memcpy(wire_image_, new_wire_image, writer.length()); + wire_image_size_ = writer.length(); + } + + private: + MoqtMessageType message_type_; + char wire_image_[kMaxMessageHeaderSize + 20]; + size_t wire_image_size_; +}; + +class QUICHE_NO_EXPORT ObjectMessage : public TestMessageBase { + public: + ObjectMessage() : TestMessageBase(MoqtMessageType::kObject) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtObject>(values); + if (cast.track_id != object_.track_id) { + QUIC_LOG(INFO) << "OBJECT Track ID mismatch"; + return false; + } + if (cast.group_sequence != object_.group_sequence) { + QUIC_LOG(INFO) << "OBJECT Group Sequence mismatch"; + return false; + } + if (cast.object_sequence != object_.object_sequence) { + QUIC_LOG(INFO) << "OBJECT Object Sequence mismatch"; + return false; + } + if (cast.object_send_order != object_.object_send_order) { + QUIC_LOG(INFO) << "OBJECT Object Send Order mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { + ExpandVarintsImpl("vvvvvv"); // first six fields are varints + } + + private: + uint8_t raw_packet_[9] = { + 0x00, 0x07, 0x04, 0x05, 0x06, 0x07, // varints + 0x66, 0x6f, 0x6f, // payload = "foo" + }; + MoqtObject object_ = { + /*track_id=*/4, + /*group_sequence=*/5, + /*object_sequence=*/6, + /*object_send_order=*/7, + }; +}; + +class QUICHE_NO_EXPORT SetupMessage : public TestMessageBase { + public: + explicit SetupMessage(bool client_parser, bool webtrans) + : TestMessageBase(MoqtMessageType::kSetup), client_(client_parser) { + if (client_parser) { + SetWireImage(server_raw_packet_, sizeof(server_raw_packet_)); + } else { + SetWireImage(client_raw_packet_, sizeof(client_raw_packet_)); + if (webtrans) { + // Should not send PATH. + set_message_size(message_size() - 5); + client_setup_.path = absl::nullopt; + } + } + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSetup>(values); + const MoqtSetup* compare = client_ ? &server_setup_ : &client_setup_; + if (cast.number_of_supported_versions != + compare->number_of_supported_versions) { + QUIC_LOG(INFO) << "SETUP number of supported versions mismatch"; + return false; + } + for (uint64_t i = 0; i < cast.number_of_supported_versions; ++i) { + // Listed versions are 1 and 2, in that order. + if (cast.supported_versions[i] != compare->supported_versions[i]) { + QUIC_LOG(INFO) << "SETUP supported version mismatch"; + return false; + } + } + if (cast.role != compare->role) { + QUIC_LOG(INFO) << "SETUP role mismatch"; + return false; + } + if (cast.path != compare->path) { + QUIC_LOG(INFO) << "SETUP path mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { + if (client_) { + ExpandVarintsImpl("vvvvvvv-vv---"); // skip one byte for Role value + } else { + ExpandVarintsImpl("vvv"); // all three are varints + } + } + + private: + bool client_; + uint8_t client_raw_packet_[13] = { + 0x01, 0x0b, 0x02, 0x01, 0x02, // versions + 0x00, 0x01, 0x03, // role = both + 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo" + }; + uint8_t server_raw_packet_[3] = { + 0x01, 0x01, + 0x01, // version + }; + MoqtSetup client_setup_ = { + /*number_of_supported_versions=*/2, + /*supported_versions=*/std::vector<uint64_t>({1, 2}), + /*role=*/MoqtRole::kBoth, + /*path=*/"foo", + }; + MoqtSetup server_setup_ = { + /*number_of_supported_versions=*/1, + /*supported_versions=*/std::vector<uint64_t>({1}), + /*role=*/absl::nullopt, + /*path=*/absl::nullopt, + }; +}; + +class QUICHE_NO_EXPORT SubscribeRequestMessage : public TestMessageBase { + public: + SubscribeRequestMessage() + : TestMessageBase(MoqtMessageType::kSubscribeRequest) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSubscribeRequest>(values); + if (cast.full_track_name != subscribe_request_.full_track_name) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST full track name mismatch"; + return false; + } + if (cast.group_sequence != subscribe_request_.group_sequence) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST group sequence mismatch"; + return false; + } + if (cast.object_sequence != subscribe_request_.object_sequence) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST object sequence mismatch"; + return false; + } + if (cast.authorization_info != subscribe_request_.authorization_info) { + QUIC_LOG(INFO) << "SUBSCRIBE REQUEST authorization info mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vvv---vv-vv-vv"); } + + private: + uint8_t raw_packet_[17] = { + 0x03, 0x0f, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x00, 0x01, 0x01, // group_sequence = 1 + 0x01, 0x01, 0x02, // object_sequence = 2 + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + + MoqtSubscribeRequest subscribe_request_ = { + /*full_track_name=*/"foo", + /*group_sequence=*/1, + /*object_sequence=*/2, + /*authorization_info=*/"bar", + }; +}; + +class QUICHE_NO_EXPORT SubscribeOkMessage : public TestMessageBase { + public: + SubscribeOkMessage() : TestMessageBase(MoqtMessageType::kSubscribeOk) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSubscribeOk>(values); + if (cast.full_track_name != subscribe_ok_.full_track_name) { + return false; + } + if (cast.track_id != subscribe_ok_.track_id) { + return false; + } + if (cast.expires != subscribe_ok_.expires) { + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vvv---vv"); } + + private: + uint8_t raw_packet_[8] = { + 0x04, 0x06, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x01, // track_id = 1 + 0x02, // expires = 2 + }; + + MoqtSubscribeOk subscribe_ok_ = { + /*full_track_name=*/"foo", + /*track_id=*/1, + /*expires=*/quic::QuicTimeDelta::FromMilliseconds(2), + }; +}; + +class QUICHE_NO_EXPORT SubscribeErrorMessage : public TestMessageBase { + public: + SubscribeErrorMessage() : TestMessageBase(MoqtMessageType::kSubscribeError) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtSubscribeError>(values); + if (cast.full_track_name != subscribe_error_.full_track_name) { + QUIC_LOG(INFO) << "SUBSCRIBE ERROR full track name mismatch"; + return false; + } + if (cast.error_code != subscribe_error_.error_code) { + QUIC_LOG(INFO) << "SUBSCRIBE ERROR error code mismatch"; + return false; + } + if (cast.reason_phrase != subscribe_error_.reason_phrase) { + QUIC_LOG(INFO) << "SUBSCRIBE ERROR reason phrase mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vvv---vv---"); } + + private: + uint8_t raw_packet_[11] = { + 0x05, 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_name = "foo" + 0x01, // error_code = 1 + 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" + }; + + MoqtSubscribeError subscribe_error_ = { + /*full_track_name=*/"foo", + /*subscribe=*/1, + /*reason_phrase=*/"bar", + }; +}; + +class QUICHE_NO_EXPORT AnnounceMessage : public TestMessageBase { + public: + AnnounceMessage() : TestMessageBase(MoqtMessageType::kAnnounce) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtAnnounce>(values); + if (cast.track_namespace != announce_.track_namespace) { + QUIC_LOG(INFO) << "ANNOUNCE MESSAGE track namespace mismatch"; + return false; + } + if (cast.authorization_info != announce_.authorization_info) { + QUIC_LOG(INFO) << "ANNOUNCE MESSAGE authorization info mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vvv---vv---"); } + + private: + uint8_t raw_packet_[11] = { + 0x06, 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar" + }; + + MoqtAnnounce announce_ = { + /*track_namespace=*/"foo", + /*authorization_info=*/"bar", + }; +}; + +class QUICHE_NO_EXPORT AnnounceOkMessage : public TestMessageBase { + public: + AnnounceOkMessage() : TestMessageBase(MoqtMessageType::kAnnounceOk) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtAnnounceOk>(values); + if (cast.track_namespace != announce_ok_.track_namespace) { + QUIC_LOG(INFO) << "ANNOUNCE OK MESSAGE track namespace mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv---"); } + + private: + uint8_t raw_packet_[5] = { + 0x07, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + }; + + MoqtAnnounceOk announce_ok_ = { + /*track_namespace=*/"foo", + }; +}; + +class QUICHE_NO_EXPORT AnnounceErrorMessage : public TestMessageBase { + public: + AnnounceErrorMessage() : TestMessageBase(MoqtMessageType::kAnnounceError) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtAnnounceError>(values); + if (cast.track_namespace != announce_error_.track_namespace) { + QUIC_LOG(INFO) << "ANNOUNCE ERROR track namespace mismatch"; + return false; + } + if (cast.error_code != announce_error_.error_code) { + QUIC_LOG(INFO) << "ANNOUNCE ERROR error code mismatch"; + return false; + } + if (cast.reason_phrase != announce_error_.reason_phrase) { + QUIC_LOG(INFO) << "ANNOUNCE ERROR reason phrase mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vvv---vv---"); } + + private: + uint8_t raw_packet_[11] = { + 0x08, 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" + 0x01, // error__code = 1 + 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" + }; + + MoqtAnnounceError announce_error_ = { + /*track_namespace=*/"foo", + /*error_code=*/1, + /*reason_phrase=*/"bar", + }; +}; + +class QUICHE_NO_EXPORT GoAwayMessage : public TestMessageBase { + public: + GoAwayMessage() : TestMessageBase(MoqtMessageType::kGoAway) { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& /*values*/) const override { + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv"); } + + private: + uint8_t raw_packet_[2] = { + 0x10, + 0x00, + }; +}; + +} // namespace moqt::test + +#endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_