Refactor how MOQT control messages are parsed, part 1. Instead of returning zero or length parsed, all message parse functions now return absl::Status. This allows for more detailed error messages. Also delete the obsolete test for MAX_CACHE_DURATION message parameter, and make location parser more strict. PiperOrigin-RevId: 881983396
diff --git a/quiche/quic/moqt/moqt_error.cc b/quiche/quic/moqt/moqt_error.cc index 4e46e5f..3466337 100644 --- a/quiche/quic/moqt/moqt_error.cc +++ b/quiche/quic/moqt/moqt_error.cc
@@ -4,13 +4,23 @@ #include "quiche/quic/moqt/moqt_error.h" +#include <cstring> +#include <optional> + +#include "absl/base/casts.h" #include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/web_transport/web_transport.h" namespace moqt { +namespace { +constexpr absl::string_view kMoqtErrorStatusPayloadUrl = + "quiche.googlesource.com/MoqtError"; +} + RequestErrorCode StatusToRequestErrorCode(absl::Status status) { QUICHE_DCHECK(!status.ok()); switch (status.code()) { @@ -84,4 +94,33 @@ } } +std::optional<MoqtError> GetMoqtErrorForStatus(const absl::Status& status) { + std::optional<absl::Cord> raw_code_cord = + status.GetPayload(kMoqtErrorStatusPayloadUrl); + if (!raw_code_cord.has_value()) { + return std::nullopt; + } + absl::string_view raw_code = raw_code_cord->Flatten(); + if (raw_code.size() != sizeof(MoqtError)) { + QUICHE_LOG(DFATAL) << "MoqtError is incorrect size"; + return std::nullopt; + } + MoqtError error; + memcpy(&error, raw_code.data(), sizeof(MoqtError)); + return error; +} + +void SetMoqtErrorForStatus(absl::Status& status, MoqtError error) { + char buffer[sizeof(error)]; + memcpy(buffer, &error, sizeof(error)); + status.SetPayload(kMoqtErrorStatusPayloadUrl, + absl::Cord(absl::string_view(buffer, sizeof(buffer)))); +} + +absl::Status MoqtErrorStatusWithCode(absl::string_view data, MoqtError error) { + absl::Status status = absl::InvalidArgumentError(data); + SetMoqtErrorForStatus(status, error); + return status; +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_error.h b/quiche/quic/moqt/moqt_error.h index 7542832..2b08442 100644 --- a/quiche/quic/moqt/moqt_error.h +++ b/quiche/quic/moqt/moqt_error.h
@@ -42,6 +42,11 @@ kMalformedAuthority = 0x1a, }; +// Utility functions to attach MoqtError to an absl::Status. +std::optional<MoqtError> GetMoqtErrorForStatus(const absl::Status& status); +void SetMoqtErrorForStatus(absl::Status& status, MoqtError error); +absl::Status MoqtErrorStatusWithCode(absl::string_view data, MoqtError error); + // Error codes used by MoQT to reset streams. inline constexpr webtransport::StreamErrorCode kResetCodeInternalError = 0x00; inline constexpr webtransport::StreamErrorCode kResetCodeCancelled = 0x01;
diff --git a/quiche/quic/moqt/moqt_key_value_pair.h b/quiche/quic/moqt/moqt_key_value_pair.h index b0fc0f6..ca038af 100644 --- a/quiche/quic/moqt/moqt_key_value_pair.h +++ b/quiche/quic/moqt/moqt_key_value_pair.h
@@ -13,6 +13,7 @@ #include <vector> #include "absl/container/btree_map.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" @@ -184,7 +185,7 @@ // Defined in moqt_parser.cc. // If the class is not initialized with the default constructor, it is likely // to return an error if a non-default field duplicates what is in |list|. - MoqtError FromKeyValuePairList(const KeyValuePairList& list); + absl::Status FromKeyValuePairList(const KeyValuePairList& list); }; enum class MessageParameter : uint64_t { @@ -242,7 +243,7 @@ // Defined in moqt_parser.cc. // If the class is not initialized with the default constructor, it is likely // to return an error if a non-default field duplicates what is in |list|. - MoqtError FromKeyValuePairList(const KeyValuePairList& list); + absl::Status FromKeyValuePairList(const KeyValuePairList& list); private: // "if (forward)" is bug-prone because it returns forward_.has_value(). Make
diff --git a/quiche/quic/moqt/moqt_key_value_pair_test.cc b/quiche/quic/moqt/moqt_key_value_pair_test.cc index 135ba0f..56f54bb 100644 --- a/quiche/quic/moqt/moqt_key_value_pair_test.cc +++ b/quiche/quic/moqt/moqt_key_value_pair_test.cc
@@ -7,14 +7,18 @@ #include <cstdint> #include <optional> +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" namespace moqt::test { +using ::quiche::test::StatusIs; + class LocationTest : public quic::test::QuicTest {}; TEST_F(LocationTest, LocationTests) { @@ -137,7 +141,7 @@ list.insert(static_cast<uint64_t>(MessageParameter::kOackWindowSize), 12345678ULL); MessageParameters parameters; - parameters.FromKeyValuePairList(list); + QUICHE_EXPECT_OK(parameters.FromKeyValuePairList(list)); EXPECT_EQ(parameters.delivery_timeout, quic::QuicTimeDelta::FromMilliseconds(1)); EXPECT_FALSE(parameters.forward()); @@ -149,30 +153,30 @@ KeyValuePairList list; MessageParameters parameters; list.insert(static_cast<uint64_t>(MessageParameter::kDeliveryTimeout), 0ULL); - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); list.clear(); list.insert(static_cast<uint64_t>(MessageParameter::kForward), 2ULL); - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); list.clear(); list.insert(static_cast<uint64_t>(MessageParameter::kSubscriberPriority), 256ULL); - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); list.clear(); list.insert(static_cast<uint64_t>(MessageParameter::kGroupOrder), 0ULL); - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); list.clear(); list.insert(static_cast<uint64_t>(MessageParameter::kGroupOrder), 3ULL); - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); // Unknown MessageParameter. list.clear(); list.insert(0x12345678, 12345678ULL); - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(MessageParametersTest, DuplicateParameters) { @@ -224,8 +228,8 @@ break; } } - EXPECT_EQ(parameters.FromKeyValuePairList(list), - MoqtError::kProtocolViolation); + EXPECT_THAT(parameters.FromKeyValuePairList(list), + StatusIs(absl::StatusCode::kInvalidArgument)); } }
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 3a874e1..8b52e6e 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -18,6 +18,7 @@ #include "absl/base/casts.h" #include "absl/cleanup/cleanup.h" #include "absl/container/fixed_array.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -34,6 +35,7 @@ #include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_status_utils.h" #include "quiche/web_transport/web_transport.h" namespace moqt { @@ -47,6 +49,19 @@ return value >> 1; } +absl::Status KeyValueFormatError(absl::string_view message) { + return MoqtErrorStatusWithCode(message, MoqtError::kKeyValueFormattingError); +} + +absl::Status CheckForTrailingData(const quic::QuicDataReader& reader) { + if (!reader.IsDoneReading()) { + return absl::InvalidArgumentError( + absl::StrCat("Control message has excess data of ", + reader.BytesRemaining(), " bytes at the end")); + } + return absl::OkStatus(); +} + // |fin_read| is set to true if there is a FIN anywhere before the end of the // varint. std::optional<uint64_t> ReadVarInt62FromStream(webtransport::Stream& stream, @@ -88,66 +103,73 @@ } // Reads from |reader| to list. Returns false if there is a read error. -bool ParseKeyValuePairList(quic::QuicDataReader& reader, - KeyValuePairList& list) { +absl::Status ParseKeyValuePairList(quic::QuicDataReader& reader, + KeyValuePairList& list) { list.clear(); uint64_t num_params; if (!reader.ReadVarInt62(&num_params)) { - return false; + return absl::InvalidArgumentError( + "Unable to parse key-value pair list element count"); } uint64_t type = 0; for (uint64_t i = 0; i < num_params; ++i) { uint64_t type_diff; if (!reader.ReadVarInt62(&type_diff)) { - return false; + return absl::InvalidArgumentError( + "Unable to parse the key in a key-value pair"); } type += type_diff; if (type % 2 == 1) { absl::string_view bytes; if (!reader.ReadStringPieceVarInt62(&bytes)) { - return false; + return absl::InvalidArgumentError( + "Unable to read the string value in a key-value pair"); } list.insert(type, bytes); continue; } uint64_t value; if (!reader.ReadVarInt62(&value)) { - return false; + return absl::InvalidArgumentError( + "Unable to read the integer value in a key-value pair"); } list.insert(type, value); } - return true; + return absl::OkStatus(); } -bool ParseKeyValuePairListWithNoPrefix(quic::QuicDataReader& reader, - KeyValuePairList& list) { +absl::Status ParseKeyValuePairListWithNoPrefix(quic::QuicDataReader& reader, + KeyValuePairList& list) { list.clear(); uint64_t type = 0; while (reader.BytesRemaining() > 0) { uint64_t type_diff; if (!reader.ReadVarInt62(&type_diff)) { - return false; + return absl::InvalidArgumentError( + "Unable to parse the key in a key-value pair"); } type += type_diff; if (type % 2 == 1) { absl::string_view bytes; if (!reader.ReadStringPieceVarInt62(&bytes)) { - return false; + return absl::InvalidArgumentError( + "Unable to read the string value in a key-value pair"); } list.insert(type, bytes); continue; } uint64_t value; if (!reader.ReadVarInt62(&value)) { - return false; + return absl::InvalidArgumentError( + "Unable to read the integer value in a key-value pair"); } list.insert(type, value); } - return true; + return absl::OkStatus(); } -MoqtError ParseAuthTokenParameter(absl::string_view field, - std::vector<AuthToken>& out) { +bool ParseAuthTokenParameter(absl::string_view field, + std::vector<AuthToken>& out) { quic::QuicDataReader reader(field); AuthTokenAliasType alias_type; uint64_t alias; @@ -155,14 +177,14 @@ absl::string_view token; uint64_t value; if (!reader.ReadVarInt62(&value)) { - return MoqtError::kKeyValueFormattingError; + return false; } alias_type = static_cast<AuthTokenAliasType>(value); switch (alias_type) { case AuthTokenAliasType::kUseValue: if (!reader.ReadVarInt62(&value) || value > AuthTokenType::kMaxAuthTokenType) { - return MoqtError::kKeyValueFormattingError; + return false; } type = static_cast<AuthTokenType>(value); token = reader.PeekRemainingPayload(); @@ -170,13 +192,13 @@ break; case AuthTokenAliasType::kUseAlias: if (!reader.ReadVarInt62(&value)) { - return MoqtError::kKeyValueFormattingError; + return false; } out.push_back(AuthToken(value, alias_type)); break; case AuthTokenAliasType::kRegister: if (!reader.ReadVarInt62(&alias) || !reader.ReadVarInt62(&value)) { - return MoqtError::kKeyValueFormattingError; + return false; } type = static_cast<AuthTokenType>(value); token = reader.PeekRemainingPayload(); @@ -184,30 +206,28 @@ break; case AuthTokenAliasType::kDelete: if (!reader.ReadVarInt62(&alias)) { - return MoqtError::kKeyValueFormattingError; + return false; } out.push_back(AuthToken(alias, alias_type)); break; default: // invalid alias type - return MoqtError::kKeyValueFormattingError; + return false; } - return MoqtError::kNoError; + return true; } -MoqtError ParseLocation(absl::string_view field, Location& out) { +bool ParseLocation(absl::string_view field, Location& out) { quic::QuicDataReader reader(field); - if (!reader.ReadVarInt62(&out.group) || !reader.ReadVarInt62(&out.object)) { - return MoqtError::kKeyValueFormattingError; - } - return MoqtError::kNoError; + return reader.ReadVarInt62(&out.group) && reader.ReadVarInt62(&out.object) && + reader.IsDoneReading(); } -MoqtError ParseSubscriptionFilter(absl::string_view field, - std::optional<SubscriptionFilter>& out) { +absl::Status ParseSubscriptionFilter(absl::string_view field, + std::optional<SubscriptionFilter>& out) { quic::QuicDataReader reader(field); uint64_t value; if (!reader.ReadVarInt62(&value)) { - return MoqtError::kKeyValueFormattingError; + return KeyValueFormatError("Unable to read subscription filter type"); } uint64_t group, object; switch (static_cast<MoqtFilterType>(value)) { @@ -217,76 +237,84 @@ break; case MoqtFilterType::kAbsoluteStart: if (!reader.ReadVarInt62(&group) || !reader.ReadVarInt62(&object)) { - return MoqtError::kKeyValueFormattingError; + return KeyValueFormatError("Invalid AbsoluteStart filter"); } out.emplace(Location(group, object)); break; case MoqtFilterType::kAbsoluteRange: if (!reader.ReadVarInt62(&group) || !reader.ReadVarInt62(&object) || !reader.ReadVarInt62(&value)) { - return MoqtError::kKeyValueFormattingError; + return KeyValueFormatError("Invalid AbsoluteRange filter"); } if (value < group) { // end before start - return MoqtError::kProtocolViolation; + return absl::InvalidArgumentError( + "AbsoluteRange filter specified with a start after the end"); } out.emplace(Location(group, object), value); break; default: // invalid filter type - return MoqtError::kProtocolViolation; + return absl::InvalidArgumentError("Invalid filter type"); } - return MoqtError::kNoError; + return absl::OkStatus(); } } // namespace -MoqtError SetupParameters::FromKeyValuePairList(const KeyValuePairList& list) { - MoqtError error = MoqtError::kNoError; - // If this callback returns false without explicitly setting an error, then - // the error is a kProtocolViolation. +absl::Status SetupParameters::FromKeyValuePairList( + const KeyValuePairList& list) { + absl::Status status = absl::OkStatus(); + uint64_t last_key; bool result = list.ForEach( [&](uint64_t key, std::variant<uint64_t, absl::string_view> value) { + last_key = key; switch (static_cast<SetupParameter>(key)) { case SetupParameter::kMaxRequestId: if (max_request_id.has_value()) { + status = absl::InvalidArgumentError("Duplicate Setup Parameter"); return false; } max_request_id = std::get<uint64_t>(value); break; case SetupParameter::kMaxAuthTokenCacheSize: if (max_auth_token_cache_size.has_value()) { + status = absl::InvalidArgumentError("Duplicate Setup Parameter"); return false; } max_auth_token_cache_size = std::get<uint64_t>(value); break; case SetupParameter::kPath: if (path.has_value()) { + status = absl::InvalidArgumentError("Duplicate Setup Parameter"); return false; } if (!http2::adapter::HeaderValidator::IsValidPath( std::get<absl::string_view>(value), /*allow_fragment=*/false)) { - error = MoqtError::kMalformedPath; + status = MoqtErrorStatusWithCode("Malformed path", + MoqtError::kMalformedPath); return false; } path = std::get<absl::string_view>(value); break; case SetupParameter::kAuthorizationToken: - error = ParseAuthTokenParameter(std::get<absl::string_view>(value), - authorization_tokens); - if (error != MoqtError::kNoError) { + if (!ParseAuthTokenParameter(std::get<absl::string_view>(value), + authorization_tokens)) { + status = KeyValueFormatError("Malformed auth token parameter"); return false; } break; case SetupParameter::kAuthority: if (!http2::adapter::HeaderValidator::IsValidAuthority( std::get<absl::string_view>(value))) { - error = MoqtError::kMalformedAuthority; + status = MoqtErrorStatusWithCode("Invalid authority field", + MoqtError::kMalformedAuthority); return false; } authority = std::get<absl::string_view>(value); break; case SetupParameter::kMoqtImplementation: if (moqt_implementation.has_value()) { + status = absl::InvalidArgumentError("Duplicate Setup Parameter"); return false; } QUICHE_LOG(INFO) << "Peer MOQT implementation: " @@ -295,10 +323,12 @@ break; case SetupParameter::kSupportObjectAcks: if (support_object_acks.has_value()) { + status = absl::InvalidArgumentError("Duplicate Setup Parameter"); return false; } if (std::get<uint64_t>(value) > 1) { - error = MoqtError::kKeyValueFormattingError; + status = + KeyValueFormatError("SUPPORT_OBJECT_ACKS has to be 0 or 1"); return false; } support_object_acks = (std::get<uint64_t>(value) == 1); @@ -308,114 +338,147 @@ } return true; }); - if (!result && error == MoqtError::kNoError) { - return MoqtError::kProtocolViolation; + if (!result && status.ok()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse the value for the setup parameter key 0x", + absl::Hex(static_cast<uint64_t>(last_key)))); } - return error; + return status; } -MoqtError MessageParameters::FromKeyValuePairList( +absl::Status MessageParameters::FromKeyValuePairList( const KeyValuePairList& list) { - MoqtError error = MoqtError::kNoError; - bool error_occurred = !list.ForEach( - [&](uint64_t key, std::variant<uint64_t, absl::string_view> value) { - switch (static_cast<MessageParameter>(key)) { - case MessageParameter::kDeliveryTimeout: - if (delivery_timeout.has_value() || - std::get<uint64_t>(value) == 0) { - return false; - } - delivery_timeout = quic::QuicTimeDelta::TryFromMilliseconds( - std::get<uint64_t>(value)) - .value_or(quic::QuicTimeDelta::Infinite()); - break; - case MessageParameter::kAuthorizationToken: - error = ParseAuthTokenParameter(std::get<absl::string_view>(value), - authorization_tokens); - if (error != MoqtError::kNoError) { - return false; - } - break; - case MessageParameter::kExpires: - if (expires.has_value()) { - return false; - } - expires = quic::QuicTimeDelta::TryFromMilliseconds( - std::get<uint64_t>(value)) - .value_or(quic::QuicTimeDelta::Infinite()); - if (expires->IsZero()) { - expires = quic::QuicTimeDelta::Infinite(); - } - break; - case MessageParameter::kLargestObject: - if (largest_object.has_value()) { - return false; - } - largest_object = Location(); - error = ParseLocation(std::get<absl::string_view>(value), - *largest_object); - if (error != MoqtError::kNoError) { - return false; - } - break; - case MessageParameter::kForward: - if (forward_has_value() || std::get<uint64_t>(value) > 1) { - return false; - } - set_forward(std::get<uint64_t>(value) != 0); - break; - case MessageParameter::kSubscriberPriority: - if (subscriber_priority.has_value() || - std::get<uint64_t>(value) > kMaxPriority) { - return false; - } - subscriber_priority = - static_cast<MoqtPriority>(std::get<uint64_t>(value)); - break; - case MessageParameter::kSubscriptionFilter: - if (subscription_filter.has_value()) { - // TODO(martinduke): Support multiple subscription filters. - return false; - } - error = ParseSubscriptionFilter(std::get<absl::string_view>(value), - subscription_filter); - if (error != MoqtError::kNoError) { - return false; - } - break; - case MessageParameter::kGroupOrder: - if (group_order.has_value() || - std::get<uint64_t>(value) > kMaxMoqtDeliveryOrder || - std::get<uint64_t>(value) < kMinMoqtDeliveryOrder) { - return false; - } - group_order = - static_cast<MoqtDeliveryOrder>(std::get<uint64_t>(value)); - break; - case MessageParameter::kNewGroupRequest: - if (new_group_request.has_value()) { - return false; - } - new_group_request = std::get<uint64_t>(value); - break; - case MessageParameter::kOackWindowSize: - if (oack_window_size.has_value()) { - return false; - } - oack_window_size = quic::QuicTimeDelta::FromMicroseconds( - std::get<uint64_t>(value)); - break; - default: - // Unknown MessageParameters not allowed! - return false; + absl::Status status = absl::OkStatus(); + uint64_t last_key; + bool result = list.ForEach([&](uint64_t key, + std::variant<uint64_t, absl::string_view> + value) { + last_key = key; + switch (static_cast<MessageParameter>(key)) { + case MessageParameter::kDeliveryTimeout: + if (delivery_timeout.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; } - return true; - }); - if (error_occurred && error == MoqtError::kNoError) { - // Illegal duplicate parameter. - return MoqtError::kProtocolViolation; + if (std::get<uint64_t>(value) == 0) { + status = absl::InvalidArgumentError("DELIVERY_TIMEOUT cannot be 0"); + return false; + } + delivery_timeout = + quic::QuicTimeDelta::TryFromMilliseconds(std::get<uint64_t>(value)) + .value_or(quic::QuicTimeDelta::Infinite()); + break; + case MessageParameter::kAuthorizationToken: + if (!ParseAuthTokenParameter(std::get<absl::string_view>(value), + authorization_tokens)) { + status = KeyValueFormatError("Malformed auth token parameter"); + return false; + } + break; + case MessageParameter::kExpires: + if (expires.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + expires = + quic::QuicTimeDelta::TryFromMilliseconds(std::get<uint64_t>(value)) + .value_or(quic::QuicTimeDelta::Infinite()); + if (expires->IsZero()) { + expires = quic::QuicTimeDelta::Infinite(); + } + break; + case MessageParameter::kLargestObject: + if (largest_object.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + largest_object = Location(); + if (!ParseLocation(std::get<absl::string_view>(value), + *largest_object)) { + status = KeyValueFormatError( + "Failed to parse location of the largest object"); + return false; + } + break; + case MessageParameter::kForward: + if (forward_has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + if (std::get<uint64_t>(value) > 1) { + status = absl::InvalidArgumentError("FORWARD must be 0 or 1"); + return false; + } + set_forward(std::get<uint64_t>(value) != 0); + break; + case MessageParameter::kSubscriberPriority: + if (subscriber_priority.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + if (std::get<uint64_t>(value) > kMaxPriority) { + status = + absl::InvalidArgumentError("Subscriber priority exceeds maximum"); + return false; + } + subscriber_priority = + static_cast<MoqtPriority>(std::get<uint64_t>(value)); + break; + case MessageParameter::kSubscriptionFilter: + if (subscription_filter.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + // TODO(martinduke): Support multiple subscription filters. + return false; + } + status = ParseSubscriptionFilter(std::get<absl::string_view>(value), + subscription_filter); + if (!status.ok()) { + return false; + } + break; + case MessageParameter::kGroupOrder: + if (group_order.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + if (std::get<uint64_t>(value) > kMaxMoqtDeliveryOrder || + std::get<uint64_t>(value) < kMinMoqtDeliveryOrder) { + status = absl::InvalidArgumentError( + "GROUP_ORDER is outside the valid range"); + return false; + } + group_order = static_cast<MoqtDeliveryOrder>(std::get<uint64_t>(value)); + break; + case MessageParameter::kNewGroupRequest: + if (new_group_request.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + new_group_request = std::get<uint64_t>(value); + break; + case MessageParameter::kOackWindowSize: + if (oack_window_size.has_value()) { + status = absl::InvalidArgumentError("Duplicate Message Parameter"); + return false; + } + oack_window_size = + quic::QuicTimeDelta::FromMicroseconds(std::get<uint64_t>(value)); + break; + default: + // Unknown MessageParameters not allowed! + status = absl::InvalidArgumentError( + absl::StrCat("Unknown message parameter 0x", + absl::Hex(static_cast<uint64_t>(key)))); + return false; + } + return true; + }); + if (!result && status.ok()) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to parse the value for the message parameter key 0x", + absl::Hex(static_cast<uint64_t>(last_key)))); } - return error; + return status; } bool MoqtMessageTypeParser::ReadUntilMessageTypeKnown() { @@ -508,172 +571,162 @@ } } -size_t MoqtControlParser::ProcessMessage(absl::string_view data, - MoqtMessageType message_type) { - quic::QuicDataReader reader(data); - size_t bytes_read; +void MoqtControlParser::ProcessMessage(absl::string_view data, + MoqtMessageType message_type) { + absl::Status status; switch (message_type) { case MoqtMessageType::kClientSetup: - bytes_read = ProcessClientSetup(reader); + status = ProcessClientSetup(data); break; case MoqtMessageType::kServerSetup: - bytes_read = ProcessServerSetup(reader); + status = ProcessServerSetup(data); break; case MoqtMessageType::kRequestOk: - bytes_read = ProcessRequestOk(reader); + status = ProcessRequestOk(data); break; case MoqtMessageType::kRequestError: - bytes_read = ProcessRequestError(reader); + status = ProcessRequestError(data); break; case MoqtMessageType::kSubscribe: - bytes_read = ProcessSubscribe(reader); + status = ProcessSubscribe(data); break; case MoqtMessageType::kSubscribeOk: - bytes_read = ProcessSubscribeOk(reader); + status = ProcessSubscribeOk(data); break; case MoqtMessageType::kUnsubscribe: - bytes_read = ProcessUnsubscribe(reader); + status = ProcessUnsubscribe(data); break; case MoqtMessageType::kPublishDone: - bytes_read = ProcessPublishDone(reader); + status = ProcessPublishDone(data); break; case MoqtMessageType::kRequestUpdate: - bytes_read = ProcessRequestUpdate(reader); + status = ProcessRequestUpdate(data); break; case MoqtMessageType::kPublishNamespace: - bytes_read = ProcessPublishNamespace(reader); + status = ProcessPublishNamespace(data); break; case MoqtMessageType::kPublishNamespaceDone: - bytes_read = ProcessPublishNamespaceDone(reader); + status = ProcessPublishNamespaceDone(data); break; case MoqtMessageType::kNamespace: - bytes_read = ProcessNamespace(reader); + status = ProcessNamespace(data); break; case MoqtMessageType::kNamespaceDone: - bytes_read = ProcessNamespaceDone(reader); + status = ProcessNamespaceDone(data); break; case MoqtMessageType::kPublishNamespaceCancel: - bytes_read = ProcessPublishNamespaceCancel(reader); + status = ProcessPublishNamespaceCancel(data); break; case MoqtMessageType::kTrackStatus: - bytes_read = ProcessTrackStatus(reader); + status = ProcessTrackStatus(data); break; case MoqtMessageType::kGoAway: - bytes_read = ProcessGoAway(reader); + status = ProcessGoAway(data); break; case MoqtMessageType::kSubscribeNamespace: - bytes_read = ProcessSubscribeNamespace(reader); + status = ProcessSubscribeNamespace(data); break; case MoqtMessageType::kMaxRequestId: - bytes_read = ProcessMaxRequestId(reader); + status = ProcessMaxRequestId(data); break; case MoqtMessageType::kFetch: - bytes_read = ProcessFetch(reader); + status = ProcessFetch(data); break; case MoqtMessageType::kFetchCancel: - bytes_read = ProcessFetchCancel(reader); + status = ProcessFetchCancel(data); break; case MoqtMessageType::kFetchOk: - bytes_read = ProcessFetchOk(reader); + status = ProcessFetchOk(data); break; case MoqtMessageType::kRequestsBlocked: - bytes_read = ProcessRequestsBlocked(reader); + status = ProcessRequestsBlocked(data); break; case MoqtMessageType::kPublish: - bytes_read = ProcessPublish(reader); + status = ProcessPublish(data); break; case MoqtMessageType::kPublishOk: - bytes_read = ProcessPublishOk(reader); + status = ProcessPublishOk(data); break; case moqt::MoqtMessageType::kObjectAck: - bytes_read = ProcessObjectAck(reader); + status = ProcessObjectAck(data); break; default: - ParseError("Unknown message type"); - bytes_read = 0; - break; + ParseError(absl::InvalidArgumentError( + absl::StrCat("Unknown control message type 0x", + absl::Hex(static_cast<uint64_t>(message_type))))); + return; } - if (bytes_read != data.size() || bytes_read == 0) { - ParseError("Message length does not match payload length"); - return 0; + if (!status.ok()) { + ParseError( + quiche::AppendToStatus(status, " while parsing a message of type 0x", + absl::Hex(static_cast<uint64_t>(message_type)))); } - return bytes_read; } -size_t MoqtControlParser::ProcessClientSetup(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessClientSetup(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtClientSetup setup; KeyValuePairList parameters; - if (!ParseKeyValuePairList(reader, parameters)) { - return 0; - } - if (!FillAndValidateSetupParameters(parameters, setup.parameters, - MoqtMessageType::kClientSetup)) { - return 0; - } + QUICHE_RETURN_IF_ERROR(ParseKeyValuePairList(reader, parameters)); + QUICHE_RETURN_IF_ERROR(FillAndValidateSetupParameters( + parameters, setup.parameters, MoqtMessageType::kClientSetup)); // TODO(martinduke): Validate construction of the PATH (Sec 8.3.2.1) visitor_.OnClientSetupMessage(setup); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessServerSetup(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessServerSetup(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtServerSetup setup; KeyValuePairList parameters; - if (!ParseKeyValuePairList(reader, parameters)) { - return 0; - } - if (!FillAndValidateSetupParameters(parameters, setup.parameters, - MoqtMessageType::kServerSetup)) { - return 0; - } + QUICHE_RETURN_IF_ERROR(ParseKeyValuePairList(reader, parameters)); + QUICHE_RETURN_IF_ERROR(FillAndValidateSetupParameters( + parameters, setup.parameters, MoqtMessageType::kServerSetup)); visitor_.OnServerSetupMessage(setup); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessSubscribe(quic::QuicDataReader& reader, - MoqtMessageType message_type) { +absl::Status MoqtControlParser::ProcessSubscribe(absl::string_view data, + MoqtMessageType message_type) { + quic::QuicDataReader reader(data); MoqtSubscribe subscribe; - if (!reader.ReadVarInt62(&subscribe.request_id) || - !ReadFullTrackName(reader, subscribe.full_track_name)) { - return 0; + if (!reader.ReadVarInt62(&subscribe.request_id)) { + return absl::InvalidArgumentError("Failed to read request ID"); } - if (!FillAndValidateMessageParameters(reader, subscribe.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR(ReadFullTrackName(reader, subscribe.full_track_name)); + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, subscribe.parameters)); if (message_type == MoqtMessageType::kTrackStatus) { visitor_.OnTrackStatusMessage(subscribe); } else { visitor_.OnSubscribeMessage(subscribe); } - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessSubscribeOk(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessSubscribeOk(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtSubscribeOk subscribe_ok; - if (!reader.ReadVarInt62(&subscribe_ok.request_id) || - !reader.ReadVarInt62(&subscribe_ok.track_alias)) { - return 0; + if (!reader.ReadVarInt62(&subscribe_ok.request_id)) { + return absl::InvalidArgumentError("Failed to read the request ID"); + } + if (!reader.ReadVarInt62(&subscribe_ok.track_alias)) { + return absl::InvalidArgumentError("Failed to read the track alias"); } KeyValuePairList pairs; - if (!ParseKeyValuePairList(reader, pairs)) { - return 0; - } - MoqtError error = subscribe_ok.parameters.FromKeyValuePairList(pairs); - if (error != MoqtError::kNoError) { - ParseError(error, "Failed to parse SUBSCRIBE_OK message parameters"); - return 0; - } - if (!ParseKeyValuePairListWithNoPrefix(reader, subscribe_ok.extensions)) { - return 0; - } + QUICHE_RETURN_IF_ERROR(ParseKeyValuePairList(reader, pairs)); + QUICHE_RETURN_IF_ERROR(subscribe_ok.parameters.FromKeyValuePairList(pairs)); + QUICHE_RETURN_IF_ERROR( + ParseKeyValuePairListWithNoPrefix(reader, subscribe_ok.extensions)); if (!subscribe_ok.extensions.Validate()) { - ParseError("Invalid SUBSCRIBE_OK track extensions"); - return 0; + return absl::InvalidArgumentError("Invalid SUBSCRIBE_OK track extensions"); } visitor_.OnSubscribeOkMessage(subscribe_ok); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessRequestError(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessRequestError(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtRequestError request_error; uint64_t error_code; uint64_t raw_interval; @@ -681,7 +734,7 @@ !reader.ReadVarInt62(&error_code) || !reader.ReadVarInt62(&raw_interval) || !reader.ReadStringVarInt62(request_error.reason_phrase)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } request_error.error_code = static_cast<RequestErrorCode>(error_code); request_error.retry_interval = @@ -690,164 +743,173 @@ : std::make_optional( quic::QuicTimeDelta::FromMilliseconds(raw_interval - 1)); visitor_.OnRequestErrorMessage(request_error); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessUnsubscribe(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessUnsubscribe(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtUnsubscribe unsubscribe; if (!reader.ReadVarInt62(&unsubscribe.request_id)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } visitor_.OnUnsubscribeMessage(unsubscribe); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessPublishDone(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessPublishDone(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtPublishDone publish_done; uint64_t value; if (!reader.ReadVarInt62(&publish_done.request_id) || !reader.ReadVarInt62(&value) || !reader.ReadVarInt62(&publish_done.stream_count) || !reader.ReadStringVarInt62(publish_done.error_reason)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } publish_done.status_code = static_cast<PublishDoneCode>(value); visitor_.OnPublishDoneMessage(publish_done); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessRequestUpdate(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessRequestUpdate(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtRequestUpdate request_update; if (!reader.ReadVarInt62(&request_update.request_id) || !reader.ReadVarInt62(&request_update.existing_request_id)) { - return 0; + return absl::InvalidArgumentError("Message missing request IDs"); } - if (!FillAndValidateMessageParameters(reader, request_update.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, request_update.parameters)); visitor_.OnRequestUpdateMessage(request_update); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessPublishNamespace( - quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessPublishNamespace( + absl::string_view data) { + quic::QuicDataReader reader(data); MoqtPublishNamespace publish_namespace; - if (!reader.ReadVarInt62(&publish_namespace.request_id) || - !ReadTrackNamespace(reader, publish_namespace.track_namespace)) { - return 0; + if (!reader.ReadVarInt62(&publish_namespace.request_id)) { + return absl::InvalidArgumentError("Request ID missing"); } - if (!FillAndValidateMessageParameters(reader, publish_namespace.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + ReadTrackNamespace(reader, publish_namespace.track_namespace)); + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, publish_namespace.parameters)); visitor_.OnPublishNamespaceMessage(publish_namespace); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessNamespace(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessNamespace(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtNamespace _namespace; - if (!ReadTrackNamespace(reader, _namespace.track_namespace_suffix)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + ReadTrackNamespace(reader, _namespace.track_namespace_suffix)); visitor_.OnNamespaceMessage(_namespace); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessNamespaceDone(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessNamespaceDone(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtNamespaceDone namespace_done; - if (!ReadTrackNamespace(reader, namespace_done.track_namespace_suffix)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + ReadTrackNamespace(reader, namespace_done.track_namespace_suffix)); visitor_.OnNamespaceDoneMessage(namespace_done); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessRequestOk(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessRequestOk(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtRequestOk request_ok; if (!reader.ReadVarInt62(&request_ok.request_id)) { - return 0; + return absl::InvalidArgumentError("Request ID missing"); } - if (!FillAndValidateMessageParameters(reader, request_ok.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, request_ok.parameters)); visitor_.OnRequestOkMessage(request_ok); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessPublishNamespaceDone( - quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessPublishNamespaceDone( + absl::string_view data) { + quic::QuicDataReader reader(data); MoqtPublishNamespaceDone pn_done; if (!reader.ReadVarInt62(&pn_done.request_id)) { - return 0; + return absl::InvalidArgumentError("Request ID missing"); } visitor_.OnPublishNamespaceDoneMessage(pn_done); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessPublishNamespaceCancel( - quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessPublishNamespaceCancel( + absl::string_view data) { + quic::QuicDataReader reader(data); MoqtPublishNamespaceCancel publish_namespace_cancel; uint64_t error_code; if (!reader.ReadVarInt62(&publish_namespace_cancel.request_id) || !reader.ReadVarInt62(&error_code) || !reader.ReadStringVarInt62(publish_namespace_cancel.error_reason)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } publish_namespace_cancel.error_code = static_cast<RequestErrorCode>(error_code); visitor_.OnPublishNamespaceCancelMessage(publish_namespace_cancel); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessTrackStatus(quic::QuicDataReader& reader) { - return ProcessSubscribe(reader, MoqtMessageType::kTrackStatus); +absl::Status MoqtControlParser::ProcessTrackStatus(absl::string_view data) { + return ProcessSubscribe(data, MoqtMessageType::kTrackStatus); } -size_t MoqtControlParser::ProcessGoAway(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessGoAway(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtGoAway goaway; if (!reader.ReadStringVarInt62(goaway.new_session_uri)) { - return 0; + return absl::InvalidArgumentError("Missing new session URI"); } visitor_.OnGoAwayMessage(goaway); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessSubscribeNamespace( - quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessSubscribeNamespace( + absl::string_view data) { + quic::QuicDataReader reader(data); MoqtSubscribeNamespace subscribe_namespace; uint64_t raw_option; - if (!reader.ReadVarInt62(&subscribe_namespace.request_id) || - !ReadTrackNamespace(reader, subscribe_namespace.track_namespace_prefix) || - !reader.ReadVarInt62(&raw_option)) { - return 0; + if (!reader.ReadVarInt62(&subscribe_namespace.request_id)) { + return absl::InvalidArgumentError("Request ID missing"); + } + QUICHE_RETURN_IF_ERROR( + ReadTrackNamespace(reader, subscribe_namespace.track_namespace_prefix)); + if (!reader.ReadVarInt62(&raw_option)) { + return absl::InvalidArgumentError("SUBSCRIBE_NAMESPACE option missing"); } if (raw_option > kMaxSubscribeOption) { - ParseError("Invalid SUBSCRIBE_NAMESPACE option"); - return 0; + return absl::InvalidArgumentError("Invalid SUBSCRIBE_NAMESPACE option"); } subscribe_namespace.subscribe_options = static_cast<SubscribeNamespaceOption>(raw_option); - if (!FillAndValidateMessageParameters(reader, - subscribe_namespace.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, subscribe_namespace.parameters)); visitor_.OnSubscribeNamespaceMessage(subscribe_namespace); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessMaxRequestId(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessMaxRequestId(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtMaxRequestId max_request_id; if (!reader.ReadVarInt62(&max_request_id.max_request_id)) { - return 0; + return absl::InvalidArgumentError("Max request ID missing"); } visitor_.OnMaxRequestIdMessage(max_request_id); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessFetch(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessFetch(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtFetch fetch; uint64_t type; if (!reader.ReadVarInt62(&fetch.request_id) || !reader.ReadVarInt62(&type)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } switch (static_cast<FetchType>(type)) { case FetchType::kAbsoluteJoining: { @@ -855,7 +917,8 @@ uint64_t joining_start; if (!reader.ReadVarInt62(&joining_request_id) || !reader.ReadVarInt62(&joining_start)) { - return 0; + return absl::InvalidArgumentError( + "Absolute joining parameters invalid"); } fetch.fetch = JoiningFetchAbsolute{joining_request_id, joining_start}; break; @@ -865,7 +928,8 @@ uint64_t joining_start; if (!reader.ReadVarInt62(&joining_request_id) || !reader.ReadVarInt62(&joining_start)) { - return 0; + return absl::InvalidArgumentError( + "Relative joining parameters invalid"); } fetch.fetch = JoiningFetchRelative{joining_request_id, joining_start}; break; @@ -874,12 +938,14 @@ fetch.fetch = StandaloneFetch(); StandaloneFetch& standalone_fetch = std::get<StandaloneFetch>(fetch.fetch); - if (!ReadFullTrackName(reader, standalone_fetch.full_track_name) || - !reader.ReadVarInt62(&standalone_fetch.start_location.group) || + QUICHE_RETURN_IF_ERROR( + ReadFullTrackName(reader, standalone_fetch.full_track_name)); + if (!reader.ReadVarInt62(&standalone_fetch.start_location.group) || !reader.ReadVarInt62(&standalone_fetch.start_location.object) || !reader.ReadVarInt62(&standalone_fetch.end_location.group) || !reader.ReadVarInt62(&standalone_fetch.end_location.object)) { - return 0; + return absl::InvalidArgumentError( + "Standalone fetch parameters invalid"); } if (standalone_fetch.end_location.object == 0) { standalone_fetch.end_location.object = kMaxObjectId; @@ -887,34 +953,32 @@ --standalone_fetch.end_location.object; } if (standalone_fetch.end_location < standalone_fetch.start_location) { - ParseError("End object comes before start object in FETCH"); - return 0; + return absl::InvalidArgumentError( + "End object comes before start object in FETCH"); } break; } default: - ParseError("Invalid FETCH type"); - return 0; + return absl::InvalidArgumentError("Invalid FETCH type"); } - if (!FillAndValidateMessageParameters(reader, fetch.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, fetch.parameters)); visitor_.OnFetchMessage(fetch); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessFetchOk(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessFetchOk(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtFetchOk fetch_ok; uint8_t end_of_track; if (!reader.ReadVarInt62(&fetch_ok.request_id) || !reader.ReadUInt8(&end_of_track) || !reader.ReadVarInt62(&fetch_ok.end_location.group) || !reader.ReadVarInt62(&fetch_ok.end_location.object)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } if (end_of_track > 0x01) { - ParseError("Invalid end of track value in FETCH_OK"); - return 0; + return absl::InvalidArgumentError("Invalid end of track value in FETCH_OK"); } if (fetch_ok.end_location.object == 0) { fetch_ok.end_location.object = kMaxObjectId; @@ -922,91 +986,97 @@ --fetch_ok.end_location.object; } fetch_ok.end_of_track = end_of_track == 1; - if (!FillAndValidateMessageParameters(reader, fetch_ok.parameters)) { - return 0; - } - if (!ParseKeyValuePairListWithNoPrefix(reader, fetch_ok.extensions)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, fetch_ok.parameters)); + QUICHE_RETURN_IF_ERROR( + ParseKeyValuePairListWithNoPrefix(reader, fetch_ok.extensions)); if (!fetch_ok.extensions.Validate()) { - ParseError("Invalid FETCH_OK track extensions"); - return 0; + return absl::InvalidArgumentError("Invalid FETCH_OK track extensions"); } visitor_.OnFetchOkMessage(fetch_ok); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessFetchCancel(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessFetchCancel(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtFetchCancel fetch_cancel; if (!reader.ReadVarInt62(&fetch_cancel.request_id)) { - return 0; + return absl::InvalidArgumentError("Request ID missing"); } visitor_.OnFetchCancelMessage(fetch_cancel); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessRequestsBlocked(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessRequestsBlocked(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtRequestsBlocked requests_blocked; if (!reader.ReadVarInt62(&requests_blocked.max_request_id)) { - return 0; + return absl::InvalidArgumentError("Max request ID missing"); } visitor_.OnRequestsBlockedMessage(requests_blocked); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessPublish(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessPublish(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtPublish publish; QUICHE_DCHECK(reader.PreviouslyReadPayload().empty()); - if (!reader.ReadVarInt62(&publish.request_id) || - !ReadFullTrackName(reader, publish.full_track_name) || - !reader.ReadVarInt62(&publish.track_alias)) { - return 0; + if (!reader.ReadVarInt62(&publish.request_id)) { + return absl::InvalidArgumentError("Request ID missing"); } - if (!FillAndValidateMessageParameters(reader, publish.parameters)) { - return 0; + QUICHE_RETURN_IF_ERROR(ReadFullTrackName(reader, publish.full_track_name)); + if (!reader.ReadVarInt62(&publish.track_alias)) { + return absl::InvalidArgumentError("Track alias missing"); } - if (!ParseKeyValuePairListWithNoPrefix(reader, publish.extensions)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, publish.parameters)); + QUICHE_RETURN_IF_ERROR( + ParseKeyValuePairListWithNoPrefix(reader, publish.extensions)); if (!publish.extensions.Validate()) { - ParseError("Invalid PUBLISH track extensions"); - return 0; + return absl::InvalidArgumentError("Invalid PUBLISH track extensions"); } visitor_.OnPublishMessage(publish); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessPublishOk(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessPublishOk(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtPublishOk publish_ok; if (!reader.ReadVarInt62(&publish_ok.request_id)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } - if (!FillAndValidateMessageParameters(reader, publish_ok.parameters)) { - return 0; - } + QUICHE_RETURN_IF_ERROR( + FillAndValidateMessageParameters(reader, publish_ok.parameters)); visitor_.OnPublishOkMessage(publish_ok); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } -size_t MoqtControlParser::ProcessObjectAck(quic::QuicDataReader& reader) { +absl::Status MoqtControlParser::ProcessObjectAck(absl::string_view data) { + quic::QuicDataReader reader(data); MoqtObjectAck object_ack; uint64_t raw_delta; if (!reader.ReadVarInt62(&object_ack.subscribe_id) || !reader.ReadVarInt62(&object_ack.group_id) || !reader.ReadVarInt62(&object_ack.object_id) || !reader.ReadVarInt62(&raw_delta)) { - return 0; + return absl::InvalidArgumentError("Message missing fields"); } object_ack.delta_from_deadline = quic::QuicTimeDelta::FromMicroseconds( SignedVarintUnserializedForm(raw_delta)); visitor_.OnObjectAckMessage(object_ack); - return reader.PreviouslyReadPayload().length(); + return CheckForTrailingData(reader); } void MoqtControlParser::ParseError(absl::string_view reason) { ParseError(MoqtError::kProtocolViolation, reason); } +void MoqtControlParser::ParseError(const absl::Status& status) { + ParseError( + GetMoqtErrorForStatus(status).value_or(MoqtError::kProtocolViolation), + status.message()); +} + void MoqtControlParser::ParseError(MoqtError error_code, absl::string_view reason) { if (parsing_error_) { @@ -1017,89 +1087,65 @@ visitor_.OnParsingError(error_code, reason); } -bool MoqtControlParser::ReadTrackNamespace(quic::QuicDataReader& reader, - TrackNamespace& track_namespace) { +absl::Status MoqtControlParser::ReadTrackNamespace( + quic::QuicDataReader& reader, TrackNamespace& track_namespace) { QUICHE_DCHECK(track_namespace.empty()); uint64_t num_elements; if (!reader.ReadVarInt62(&num_elements)) { - return false; + return absl::InvalidArgumentError( + "Unable to parse the number of namespace elements"); } if (num_elements == 0 || num_elements > kMaxNamespaceElements) { - ParseError(MoqtError::kProtocolViolation, - "Invalid number of namespace elements"); - return false; + return absl::InvalidArgumentError("Invalid number of namespace elements"); } absl::FixedArray<absl::string_view> elements(num_elements); for (uint64_t i = 0; i < num_elements; ++i) { if (!reader.ReadStringPieceVarInt62(&elements[i])) { - return false; + return absl::InvalidArgumentError( + "Namespace element shorter than specified"); } } if (!track_namespace.Append(elements)) { - ParseError(MoqtError::kProtocolViolation, "Track namespace is too large"); - return false; + return absl::InvalidArgumentError("Track namespace is too large"); } - return true; + return absl::OkStatus(); } -bool MoqtControlParser::ReadFullTrackName(quic::QuicDataReader& reader, - FullTrackName& full_track_name) { +absl::Status MoqtControlParser::ReadFullTrackName( + quic::QuicDataReader& reader, FullTrackName& full_track_name) { QUICHE_DCHECK(!full_track_name.IsValid()); TrackNamespace track_namespace; - if (!ReadTrackNamespace(reader, track_namespace)) { - return false; - } + QUICHE_RETURN_IF_ERROR(ReadTrackNamespace(reader, track_namespace)); absl::string_view name; if (!reader.ReadStringPieceVarInt62(&name)) { - return false; + return absl::InvalidArgumentError("Unable to parse track name"); } absl::StatusOr<FullTrackName> full_track_name_or = FullTrackName::Create(std::move(track_namespace), std::string(name)); - if (!full_track_name_or.ok()) { - ParseError(MoqtError::kProtocolViolation, - full_track_name_or.status().message()); - return false; - } + QUICHE_RETURN_IF_ERROR(full_track_name_or.status()); full_track_name = *std::move(full_track_name_or); - return true; + return absl::OkStatus(); } -bool MoqtControlParser::FillAndValidateSetupParameters( +absl::Status MoqtControlParser::FillAndValidateSetupParameters( const KeyValuePairList& in, SetupParameters& out, MoqtMessageType message_type) { - MoqtError error = out.FromKeyValuePairList(in); - if (error != MoqtError::kNoError) { - absl::string_view error_message = (error == MoqtError::kProtocolViolation) - ? "Duplicate Setup Parameter" - : "Setup Parameter parsing error"; - ParseError(error, error_message); - return false; - } - error = + QUICHE_RETURN_IF_ERROR(out.FromKeyValuePairList(in)); + MoqtError error = SetupParametersAllowedByMessage(out, message_type, uses_web_transport_); if (error != MoqtError::kNoError) { - ParseError(error, ""); - return false; + return MoqtErrorStatusWithCode("Setup parameter parsing error", error); } - return true; + return absl::OkStatus(); } -bool MoqtControlParser::FillAndValidateMessageParameters( +absl::Status MoqtControlParser::FillAndValidateMessageParameters( quic::QuicDataReader& reader, MessageParameters& out) { KeyValuePairList pairs; - if (!ParseKeyValuePairList(reader, pairs)) { - return false; - } - MoqtError error = out.FromKeyValuePairList(pairs); - if (error != MoqtError::kNoError) { - absl::string_view error_message = (error == MoqtError::kProtocolViolation) - ? "Duplicate Message Parameter" - : "Message Parameter parsing error"; - ParseError(error, error_message); - return false; - } + QUICHE_RETURN_IF_ERROR(ParseKeyValuePairList(reader, pairs)); // All parameter types are allowed in all messages. - return true; + QUICHE_RETURN_IF_ERROR(out.FromKeyValuePairList(pairs)); + return absl::OkStatus(); } void MoqtDataParser::ParseError(absl::string_view reason) {
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index 49248a7..e04dbcb 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -13,6 +13,7 @@ #include <optional> #include <string> +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_data_reader.h" #include "quiche/quic/moqt/moqt_error.h" @@ -115,66 +116,64 @@ private: // The central switch statement to dispatch a message to the correct - // Process* function. Returns 0 if it could not parse the full messsage - // (except for object payload). Otherwise, returns the number of bytes - // processed. - size_t ProcessMessage(absl::string_view data, MoqtMessageType message_type); + // Process* function. Invokles an error callback if parsing fails. + void ProcessMessage(absl::string_view data, MoqtMessageType message_type); // The Process* functions parse the serialized data into the appropriate // structs, and call the relevant visitor function for further action. Returns // the number of bytes consumed if the message is complete; returns 0 // otherwise. - size_t ProcessClientSetup(quic::QuicDataReader& reader); - size_t ProcessServerSetup(quic::QuicDataReader& reader); - size_t ProcessRequestOk(quic::QuicDataReader& reader); - size_t ProcessRequestError(quic::QuicDataReader& reader); + absl::Status ProcessClientSetup(absl::string_view data); + absl::Status ProcessServerSetup(absl::string_view data); + absl::Status ProcessRequestOk(absl::string_view data); + absl::Status ProcessRequestError(absl::string_view data); // Subscribe formats are used for TrackStatus as well, so take the message // type as an argument, defaulting to the subscribe version. - size_t ProcessSubscribe( - quic::QuicDataReader& reader, + absl::Status ProcessSubscribe( + absl::string_view data, MoqtMessageType message_type = MoqtMessageType::kSubscribe); - size_t ProcessSubscribeOk(quic::QuicDataReader& reader); - size_t ProcessUnsubscribe(quic::QuicDataReader& reader); - size_t ProcessPublishDone(quic::QuicDataReader& reader); - size_t ProcessRequestUpdate(quic::QuicDataReader& reader); - size_t ProcessPublishNamespace(quic::QuicDataReader& reader); - size_t ProcessPublishNamespaceDone(quic::QuicDataReader& reader); - size_t ProcessNamespace(quic::QuicDataReader& reader); - size_t ProcessNamespaceDone(quic::QuicDataReader& reader); - size_t ProcessPublishNamespaceCancel(quic::QuicDataReader& reader); - size_t ProcessTrackStatus(quic::QuicDataReader& reader); - size_t ProcessGoAway(quic::QuicDataReader& reader); - size_t ProcessSubscribeNamespace(quic::QuicDataReader& reader); - size_t ProcessUnsubscribeNamespace(quic::QuicDataReader& reader); - size_t ProcessMaxRequestId(quic::QuicDataReader& reader); - size_t ProcessFetch(quic::QuicDataReader& reader); - size_t ProcessFetchCancel(quic::QuicDataReader& reader); - size_t ProcessFetchOk(quic::QuicDataReader& reader); - size_t ProcessRequestsBlocked(quic::QuicDataReader& reader); - size_t ProcessPublish(quic::QuicDataReader& reader); - size_t ProcessPublishOk(quic::QuicDataReader& reader); - size_t ProcessObjectAck(quic::QuicDataReader& reader); + absl::Status ProcessSubscribeOk(absl::string_view data); + absl::Status ProcessUnsubscribe(absl::string_view data); + absl::Status ProcessPublishDone(absl::string_view data); + absl::Status ProcessRequestUpdate(absl::string_view data); + absl::Status ProcessPublishNamespace(absl::string_view data); + absl::Status ProcessPublishNamespaceDone(absl::string_view data); + absl::Status ProcessNamespace(absl::string_view data); + absl::Status ProcessNamespaceDone(absl::string_view data); + absl::Status ProcessPublishNamespaceCancel(absl::string_view data); + absl::Status ProcessTrackStatus(absl::string_view data); + absl::Status ProcessGoAway(absl::string_view data); + absl::Status ProcessSubscribeNamespace(absl::string_view data); + absl::Status ProcessMaxRequestId(absl::string_view data); + absl::Status ProcessFetch(absl::string_view data); + absl::Status ProcessFetchCancel(absl::string_view data); + absl::Status ProcessFetchOk(absl::string_view data); + absl::Status ProcessRequestsBlocked(absl::string_view data); + absl::Status ProcessPublish(absl::string_view data); + absl::Status ProcessPublishOk(absl::string_view data); + absl::Status ProcessObjectAck(absl::string_view data); // If |error| is not provided, assumes kProtocolViolation. void ParseError(absl::string_view reason); + void ParseError(const absl::Status& status); void ParseError(MoqtError error, absl::string_view reason); // Reads a TrackNamespace from the reader. Returns false if the namespace is // too large. Sets a ParseError if the namespace is malformed. - bool ReadTrackNamespace(quic::QuicDataReader& reader, - TrackNamespace& track_namespace); + absl::Status ReadTrackNamespace(quic::QuicDataReader& reader, + TrackNamespace& track_namespace); // Reads a FullTrackName from the reader. Returns false if the name is too // large. Sets a ParseError if the name is malformed. - bool ReadFullTrackName(quic::QuicDataReader& reader, - FullTrackName& full_track_name); - bool FillAndValidateSetupParameters(const KeyValuePairList& in, - SetupParameters& out, - MoqtMessageType message_type); + absl::Status ReadFullTrackName(quic::QuicDataReader& reader, + FullTrackName& full_track_name); + absl::Status FillAndValidateSetupParameters(const KeyValuePairList& in, + SetupParameters& out, + MoqtMessageType message_type); // |reader| points to the beginning of a KeyValuePairList. Returns false if // there is any sort of error. (The function calls ParseError(), so the // caller has no need to do so.) - bool FillAndValidateMessageParameters(quic::QuicDataReader& reader, - MessageParameters& out); + absl::Status FillAndValidateMessageParameters(quic::QuicDataReader& reader, + MessageParameters& out); MoqtControlParserVisitor& visitor_; webtransport::Stream& stream_;
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 659006c..5a903ec 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -33,6 +33,8 @@ namespace { using ::testing::AnyOf; +using ::testing::HasSubstr; +using ::testing::Optional; constexpr std::array kMessageTypes{ MoqtMessageType::kRequestOk, @@ -323,8 +325,7 @@ ProcessData(message->PacketSample(), false); // The parser will actually report a message, because it's all there. EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, - "Message length does not match payload length"); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); } TEST_P(MoqtParserTest, PayloadLengthTooShort) { @@ -335,8 +336,7 @@ message->DecreasePayloadLengthByOne(); ProcessData(message->PacketSample(), false); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, - "Message length does not match payload length"); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); } // Tests for message-specific error cases, and behaviors for a single message @@ -508,7 +508,8 @@ stream.Receive(absl::string_view(setup, sizeof(setup)), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Duplicate Setup Parameter"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Duplicate Setup Parameter"))); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } @@ -567,7 +568,8 @@ stream.Receive(absl::string_view(setup, sizeof(setup)), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Duplicate Setup Parameter"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Duplicate Setup Parameter"))); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } @@ -625,7 +627,8 @@ stream.Receive(absl::string_view(setup, sizeof(setup)), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Duplicate Setup Parameter"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Duplicate Setup Parameter"))); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } @@ -689,25 +692,8 @@ stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Duplicate Message Parameter"); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); -} - -TEST_F(MoqtMessageSpecificTest, SubscribeMaxCacheDurationTwice) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); - char subscribe[] = { - 0x03, 0x00, 0x12, 0x01, 0x01, - 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" - 0x02, // two params - 0x04, 0x67, 0x10, // max_cache_duration = 10000 - 0x00, 0x67, 0x10 // max_cache_duration = 10000 - }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Duplicate Message Parameter"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Duplicate Message Parameter"))); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } @@ -948,7 +934,7 @@ stream.Receive(absl::string_view(message, writer.length()), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Unknown message type"); + EXPECT_EQ(visitor_.parsing_error_, "Unknown control message type 0xbeef"); } TEST_F(MoqtMessageSpecificTest, SubscribeNoParameters) { @@ -1122,7 +1108,10 @@ false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Duplicate Message Parameter"); + EXPECT_THAT( + visitor_.parsing_error_, + Optional(HasSubstr( + "AbsoluteRange filter specified with a start after the end"))); } TEST_F(MoqtMessageSpecificTest, ObjectAckNegativeDelta) { @@ -1290,7 +1279,8 @@ stream.Receive(subscribe_ok.PacketSample(), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Invalid SUBSCRIBE_OK track extensions"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Invalid SUBSCRIBE_OK track extensions"))); } TEST_F(MoqtMessageSpecificTest, SubscribeOkExpirationIsZero) { @@ -1334,8 +1324,9 @@ stream.Receive(fetch.PacketSample(), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, - "End object comes before start object in FETCH"); + EXPECT_THAT( + visitor_.parsing_error_, + Optional(HasSubstr("End object comes before start object in FETCH"))); } TEST_F(MoqtMessageSpecificTest, FetchInvalidRange2) { @@ -1346,8 +1337,9 @@ stream.Receive(fetch.PacketSample(), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, - "End object comes before start object in FETCH"); + EXPECT_THAT( + visitor_.parsing_error_, + Optional(HasSubstr("End object comes before start object in FETCH"))); } TEST_F(MoqtMessageSpecificTest, PaddingStream) { @@ -1386,7 +1378,8 @@ false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, "Invalid number of namespace elements"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Invalid number of namespace elements"))); } TEST_F(MoqtMessageSpecificTest, NamespaceTooLarge) { @@ -1409,7 +1402,8 @@ absl::string_view(publish_namespace, sizeof(publish_namespace)), false); parser.ReadAndDispatchMessages(); EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, "Invalid number of namespace elements"); + EXPECT_THAT(visitor_.parsing_error_, + Optional(HasSubstr("Invalid number of namespace elements"))); } TEST_F(MoqtMessageSpecificTest, RelativeJoiningFetch) {