Rewrite MOQT control message parser. Instead of the old parser there are now two classes: * MoqtControlStreamParser unframes control messages. It will be transferred from Unknown stream to a specific bidi stream subtype. * MoqtControlMessageParser is ~stateless parser that actually parses the payload of unframed messages. There's a MoqtControlParser polyfill that will be removed in the follow-up CL, as fixing the call sites would make this CL 2x larger. This also sharpens error handling around various FIN-related edge cases (for instance, an empty bidi stream or a bidi stream with an incomplete type is now a fatal error). PiperOrigin-RevId: 910348826
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 38e1743..19d864e 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -35,6 +35,7 @@ #include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_endian.h" #include "quiche/common/quiche_status_utils.h" #include "quiche/web_transport/web_transport.h" @@ -481,189 +482,134 @@ return status; } -bool MoqtMessageTypeParser::ReadUntilMessageTypeKnown() { - if (message_type_.has_value()) { - return true; +absl::StatusOr<MoqtRawControlMessage> +MoqtControlStreamParser::ReadNextMessage() { + if (error_encountered_ || fin_read_) { + return absl::FailedPreconditionError( + "Trying to read from a control stream after an error or an EOF " + "occurred."); } - bool fin_read = false; - message_type_ = ReadVarInt62FromStream(stream_, fin_read); - if (fin_read) { - return false; + absl::StatusOr<MoqtRawControlMessage> result = ReadNextMessageInner(); + if (!result.ok() && !absl::IsUnavailable(result.status())) { + error_encountered_ = true; + } else { + if (fin_read_ && !allow_fin_) { + result = absl::InvalidArgumentError( + "Unexpected FIN on a control stream (no FINs are allowed on this " + "stream)"); + error_encountered_ = true; + } } - return true; + return result; } -void MoqtControlParser::ReadAndDispatchMessages() { - if (no_more_data_) { - ParseError("Data after end of stream"); - return; +absl::StatusOr<MoqtMessageType> +MoqtControlStreamParser::ReadFirstMessageType() { + if (first_message_type_.has_value()) { + return static_cast<MoqtMessageType>(*first_message_type_); } - if (processing_) { - return; + if (error_encountered_ || fin_read_) { + return absl::FailedPreconditionError( + "Trying to read from a control stream after an error or an EOF " + "occurred."); } - processing_ = true; - auto on_return = absl::MakeCleanup([&] { processing_ = false; }); - while (!no_more_data_) { - bool fin_read = false; - // Read the message type. - if (!message_type_.has_value()) { - message_type_ = ReadVarInt62FromStream(stream_, fin_read); - if (fin_read) { - ParseError("FIN on control stream"); - return; - } - if (!message_type_.has_value()) { - return; - } - } - QUICHE_DCHECK(message_type_.has_value()); - - // Read the message length. - if (!message_size_.has_value()) { - if (stream_.ReadableBytes() < 2) { - return; - } - std::array<char, 2> size_bytes; - webtransport::Stream::ReadResult result = - stream_.Read(absl::MakeSpan(size_bytes)); - if (result.bytes_read != 2) { - ParseError(MoqtError::kInternalError, - "Stream returned incorrect ReadableBytes"); - return; - } - if (result.fin) { - ParseError("FIN on control stream"); - return; - } - message_size_ = static_cast<uint16_t>(size_bytes[0]) << 8 | - static_cast<uint16_t>(size_bytes[1]); - if (*message_size_ > kMaxMessageHeaderSize) { - ParseError(MoqtError::kInternalError, - absl::StrCat("Cannot parse control messages more than ", - kMaxMessageHeaderSize, " bytes")); - return; - } - } - QUICHE_DCHECK(message_size_.has_value()); - - // Read the message if it's fully received. - // - // CAUTION: if the flow control windows are too low, and - // kMaxMessageHeaderSize is too high, this will cause a deadlock. - if (stream_.ReadableBytes() < *message_size_) { - return; - } - absl::FixedArray<char> message(*message_size_); - webtransport::Stream::ReadResult result = - stream_.Read(absl::MakeSpan(message)); - if (result.bytes_read != *message_size_) { - ParseError("Stream returned incorrect ReadableBytes"); - return; - } - if (result.fin) { - ParseError("FIN on control stream"); - return; - } - ProcessMessage(absl::string_view(message.data(), message.size()), - static_cast<MoqtMessageType>(*message_type_)); - message_type_.reset(); - message_size_.reset(); + absl::Status read_status = ReadMessageType(); + if (absl::IsUnavailable(read_status) && fin_read_) { + return absl::InvalidArgumentError("FIN received before any type"); } + QUICHE_RETURN_IF_ERROR(read_status); + return static_cast<MoqtMessageType>(*first_message_type_); } -void MoqtControlParser::ProcessMessage(absl::string_view data, - MoqtMessageType message_type) { - absl::Status status; - switch (message_type) { - case MoqtMessageType::kClientSetup: - status = ProcessClientSetup(data); - break; - case MoqtMessageType::kServerSetup: - status = ProcessServerSetup(data); - break; - case MoqtMessageType::kRequestOk: - status = ProcessRequestOk(data); - break; - case MoqtMessageType::kRequestError: - status = ProcessRequestError(data); - break; - case MoqtMessageType::kSubscribe: - status = ProcessSubscribe(data); - break; - case MoqtMessageType::kSubscribeOk: - status = ProcessSubscribeOk(data); - break; - case MoqtMessageType::kUnsubscribe: - status = ProcessUnsubscribe(data); - break; - case MoqtMessageType::kPublishDone: - status = ProcessPublishDone(data); - break; - case MoqtMessageType::kRequestUpdate: - status = ProcessRequestUpdate(data); - break; - case MoqtMessageType::kPublishNamespace: - status = ProcessPublishNamespace(data); - break; - case MoqtMessageType::kPublishNamespaceDone: - status = ProcessPublishNamespaceDone(data); - break; - case MoqtMessageType::kNamespace: - status = ProcessNamespace(data); - break; - case MoqtMessageType::kNamespaceDone: - status = ProcessNamespaceDone(data); - break; - case MoqtMessageType::kPublishNamespaceCancel: - status = ProcessPublishNamespaceCancel(data); - break; - case MoqtMessageType::kTrackStatus: - status = ProcessTrackStatus(data); - break; - case MoqtMessageType::kGoAway: - status = ProcessGoAway(data); - break; - case MoqtMessageType::kSubscribeNamespace: - status = ProcessSubscribeNamespace(data); - break; - case MoqtMessageType::kMaxRequestId: - status = ProcessMaxRequestId(data); - break; - case MoqtMessageType::kFetch: - status = ProcessFetch(data); - break; - case MoqtMessageType::kFetchCancel: - status = ProcessFetchCancel(data); - break; - case MoqtMessageType::kFetchOk: - status = ProcessFetchOk(data); - break; - case MoqtMessageType::kRequestsBlocked: - status = ProcessRequestsBlocked(data); - break; - case MoqtMessageType::kPublish: - status = ProcessPublish(data); - break; - case MoqtMessageType::kPublishOk: - status = ProcessPublishOk(data); - break; - case moqt::MoqtMessageType::kObjectAck: - status = ProcessObjectAck(data); - break; - default: - ParseError(absl::InvalidArgumentError( - absl::StrCat("Unknown control message type 0x", - absl::Hex(static_cast<uint64_t>(message_type))))); - return; +absl::Status MoqtControlStreamParser::ReadMessageType() { + if (current_message_type_.has_value()) { + QUICHE_BUG(MoqtControlStreamParser_ReadMessageType_bad_state) + << "ReadMessageType() called in an invalid state"; + return absl::InternalError("ReadMessageType() called in an invalid state"); } - if (!status.ok()) { - ParseError( - quiche::AppendToStatus(status, " while parsing a message of type 0x", - absl::Hex(static_cast<uint64_t>(message_type)))); + current_message_type_ = ReadVarInt62FromStream(stream_, fin_read_); + if (!current_message_type_.has_value()) { + webtransport::Stream::PeekResult peek_result = + stream_.PeekNextReadableRegion(); + if (peek_result.all_data_received && !peek_result.peeked_data.empty()) { + return absl::InvalidArgumentError( + "Unexpected FIN on a control stream (FIN received in the middle of " + "type)"); + } + return absl::UnavailableError("No complete message available"); } + if (fin_read_) { + return absl::InvalidArgumentError( + "Unexpected FIN on a control stream (FIN received immediately after " + "type)"); + } + if (!first_message_type_.has_value()) { + first_message_type_ = *current_message_type_; + } + return absl::OkStatus(); } -absl::Status MoqtControlParser::ProcessClientSetup(absl::string_view data) { +absl::StatusOr<MoqtRawControlMessage> +MoqtControlStreamParser::ReadNextMessageInner() { + if (!current_message_type_.has_value()) { + QUICHE_RETURN_IF_ERROR(ReadMessageType()); + } + + if (!current_message_remaining_.has_value()) { + uint16_t message_size = 0; + std::array<char, sizeof(message_size)> buffer; + if (stream_.ReadableBytes() < buffer.size()) { + if (stream_.PeekNextReadableRegion().all_data_received) { + return absl::InvalidArgumentError( + "Unexpected FIN on a control stream (FIN received in the middle of " + "the message size)"); + } + return absl::UnavailableError("No complete message available"); + } + webtransport::Stream::ReadResult read_result = + stream_.Read(absl::MakeSpan(buffer)); + fin_read_ |= read_result.fin; + QUICHE_DCHECK_EQ(read_result.bytes_read, buffer.size()); + + memcpy(&message_size, buffer.data(), buffer.size()); + message_size = quiche::QuicheEndian::NetToHost16(message_size); + if (message_size > kMaxMessageHeaderSize) { + return absl::InvalidArgumentError( + absl::StrCat("A control message exceeds the maximum allowed size of ", + kMaxMessageHeaderSize, " bytes")); + } + current_message_.resize(message_size); + current_message_remaining_ = absl::MakeSpan(current_message_); + } + + QUICHE_DCHECK(current_message_remaining_.has_value()); + if (!current_message_remaining_->empty()) { + webtransport::Stream::ReadResult read_result = + stream_.Read(*current_message_remaining_); + current_message_remaining_->remove_prefix(read_result.bytes_read); + fin_read_ |= read_result.fin; + } + if (!current_message_remaining_->empty()) { + if (fin_read_) { + return absl::InvalidArgumentError(absl::StrCat( + "FIN encountered when there are ", current_message_remaining_->size(), + " bytes left in the current message")); + } + return absl::UnavailableError("No complete message available"); + } + MoqtRawControlMessage message{ + .type = static_cast<MoqtMessageType>(*current_message_type_), + .payload = std::move(current_message_)}; + current_message_type_.reset(); + current_message_remaining_.reset(); + // Technically, std::move() leaves `current_message_` in a + // "valid but undefined state"; clear it out explicitly. + current_message_.clear(); + return message; +} + +absl::StatusOr<MoqtClientSetup> MoqtControlMessageParser::ProcessClientSetup( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtClientSetup setup; KeyValuePairList parameters; @@ -671,23 +617,24 @@ QUICHE_RETURN_IF_ERROR(FillAndValidateSetupParameters( parameters, setup.parameters, MoqtMessageType::kClientSetup)); // TODO(martinduke): Validate construction of the PATH (Sec 8.3.2.1) - visitor_.OnClientSetupMessage(setup); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return setup; } -absl::Status MoqtControlParser::ProcessServerSetup(absl::string_view data) { +absl::StatusOr<MoqtServerSetup> MoqtControlMessageParser::ProcessServerSetup( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtServerSetup setup; KeyValuePairList parameters; QUICHE_RETURN_IF_ERROR(ParseKeyValuePairList(reader, parameters)); QUICHE_RETURN_IF_ERROR(FillAndValidateSetupParameters( parameters, setup.parameters, MoqtMessageType::kServerSetup)); - visitor_.OnServerSetupMessage(setup); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return setup; } -absl::Status MoqtControlParser::ProcessSubscribe(absl::string_view data, - MoqtMessageType message_type) { +absl::StatusOr<MoqtSubscribe> MoqtControlMessageParser::ProcessSubscribe( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtSubscribe subscribe; if (!reader.ReadVarInt62(&subscribe.request_id)) { @@ -696,15 +643,12 @@ QUICHE_RETURN_IF_ERROR(ReadFullTrackName(reader, subscribe.full_track_name)); QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, subscribe.parameters)); - if (message_type == MoqtMessageType::kTrackStatus) { - visitor_.OnTrackStatusMessage(subscribe); - } else { - visitor_.OnSubscribeMessage(subscribe); - } - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return subscribe; } -absl::Status MoqtControlParser::ProcessSubscribeOk(absl::string_view data) { +absl::StatusOr<MoqtSubscribeOk> MoqtControlMessageParser::ProcessSubscribeOk( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtSubscribeOk subscribe_ok; if (!reader.ReadVarInt62(&subscribe_ok.request_id)) { @@ -721,11 +665,12 @@ if (!subscribe_ok.extensions.Validate()) { return absl::InvalidArgumentError("Invalid SUBSCRIBE_OK track extensions"); } - visitor_.OnSubscribeOkMessage(subscribe_ok); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return subscribe_ok; } -absl::Status MoqtControlParser::ProcessRequestError(absl::string_view data) { +absl::StatusOr<MoqtRequestError> MoqtControlMessageParser::ProcessRequestError( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtRequestError request_error; uint64_t error_code; @@ -742,21 +687,23 @@ ? std::nullopt : std::make_optional( quic::QuicTimeDelta::FromMilliseconds(raw_interval - 1)); - visitor_.OnRequestErrorMessage(request_error); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return request_error; } -absl::Status MoqtControlParser::ProcessUnsubscribe(absl::string_view data) { +absl::StatusOr<MoqtUnsubscribe> MoqtControlMessageParser::ProcessUnsubscribe( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtUnsubscribe unsubscribe; if (!reader.ReadVarInt62(&unsubscribe.request_id)) { return absl::InvalidArgumentError("Message missing fields"); } - visitor_.OnUnsubscribeMessage(unsubscribe); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return unsubscribe; } -absl::Status MoqtControlParser::ProcessPublishDone(absl::string_view data) { +absl::StatusOr<MoqtPublishDone> MoqtControlMessageParser::ProcessPublishDone( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtPublishDone publish_done; uint64_t value; @@ -767,11 +714,12 @@ return absl::InvalidArgumentError("Message missing fields"); } publish_done.status_code = static_cast<PublishDoneCode>(value); - visitor_.OnPublishDoneMessage(publish_done); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return publish_done; } -absl::Status MoqtControlParser::ProcessRequestUpdate(absl::string_view data) { +absl::StatusOr<MoqtRequestUpdate> +MoqtControlMessageParser::ProcessRequestUpdate(absl::string_view data) const { quic::QuicDataReader reader(data); MoqtRequestUpdate request_update; if (!reader.ReadVarInt62(&request_update.request_id) || @@ -780,12 +728,13 @@ } QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, request_update.parameters)); - visitor_.OnRequestUpdateMessage(request_update); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return request_update; } -absl::Status MoqtControlParser::ProcessPublishNamespace( - absl::string_view data) { +absl::StatusOr<MoqtPublishNamespace> +MoqtControlMessageParser::ProcessPublishNamespace( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtPublishNamespace publish_namespace; if (!reader.ReadVarInt62(&publish_namespace.request_id)) { @@ -795,29 +744,32 @@ ReadTrackNamespace(reader, publish_namespace.track_namespace)); QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, publish_namespace.parameters)); - visitor_.OnPublishNamespaceMessage(publish_namespace); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return publish_namespace; } -absl::Status MoqtControlParser::ProcessNamespace(absl::string_view data) { +absl::StatusOr<MoqtNamespace> MoqtControlMessageParser::ProcessNamespace( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtNamespace _namespace; QUICHE_RETURN_IF_ERROR( ReadTrackNamespace(reader, _namespace.track_namespace_suffix)); - visitor_.OnNamespaceMessage(_namespace); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return _namespace; } -absl::Status MoqtControlParser::ProcessNamespaceDone(absl::string_view data) { +absl::StatusOr<MoqtNamespaceDone> +MoqtControlMessageParser::ProcessNamespaceDone(absl::string_view data) const { quic::QuicDataReader reader(data); MoqtNamespaceDone namespace_done; QUICHE_RETURN_IF_ERROR( ReadTrackNamespace(reader, namespace_done.track_namespace_suffix)); - visitor_.OnNamespaceDoneMessage(namespace_done); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return namespace_done; } -absl::Status MoqtControlParser::ProcessRequestOk(absl::string_view data) { +absl::StatusOr<MoqtRequestOk> MoqtControlMessageParser::ProcessRequestOk( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtRequestOk request_ok; if (!reader.ReadVarInt62(&request_ok.request_id)) { @@ -825,23 +777,25 @@ } QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, request_ok.parameters)); - visitor_.OnRequestOkMessage(request_ok); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return request_ok; } -absl::Status MoqtControlParser::ProcessPublishNamespaceDone( - absl::string_view data) { +absl::StatusOr<MoqtPublishNamespaceDone> +MoqtControlMessageParser::ProcessPublishNamespaceDone( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtPublishNamespaceDone pn_done; if (!reader.ReadVarInt62(&pn_done.request_id)) { return absl::InvalidArgumentError("Request ID missing"); } - visitor_.OnPublishNamespaceDoneMessage(pn_done); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return pn_done; } -absl::Status MoqtControlParser::ProcessPublishNamespaceCancel( - absl::string_view data) { +absl::StatusOr<MoqtPublishNamespaceCancel> +MoqtControlMessageParser::ProcessPublishNamespaceCancel( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtPublishNamespaceCancel publish_namespace_cancel; uint64_t error_code; @@ -852,26 +806,29 @@ } publish_namespace_cancel.error_code = static_cast<RequestErrorCode>(error_code); - visitor_.OnPublishNamespaceCancelMessage(publish_namespace_cancel); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return publish_namespace_cancel; } -absl::Status MoqtControlParser::ProcessTrackStatus(absl::string_view data) { - return ProcessSubscribe(data, MoqtMessageType::kTrackStatus); +absl::StatusOr<MoqtTrackStatus> MoqtControlMessageParser::ProcessTrackStatus( + absl::string_view data) const { + return ProcessSubscribe(data); } -absl::Status MoqtControlParser::ProcessGoAway(absl::string_view data) { +absl::StatusOr<MoqtGoAway> MoqtControlMessageParser::ProcessGoAway( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtGoAway goaway; if (!reader.ReadStringVarInt62(goaway.new_session_uri)) { return absl::InvalidArgumentError("Missing new session URI"); } - visitor_.OnGoAwayMessage(goaway); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return goaway; } -absl::Status MoqtControlParser::ProcessSubscribeNamespace( - absl::string_view data) { +absl::StatusOr<MoqtSubscribeNamespace> +MoqtControlMessageParser::ProcessSubscribeNamespace( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtSubscribeNamespace subscribe_namespace; uint64_t raw_option; @@ -890,21 +847,23 @@ static_cast<SubscribeNamespaceOption>(raw_option); QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, subscribe_namespace.parameters)); - visitor_.OnSubscribeNamespaceMessage(subscribe_namespace); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return subscribe_namespace; } -absl::Status MoqtControlParser::ProcessMaxRequestId(absl::string_view data) { +absl::StatusOr<MoqtMaxRequestId> MoqtControlMessageParser::ProcessMaxRequestId( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtMaxRequestId max_request_id; if (!reader.ReadVarInt62(&max_request_id.max_request_id)) { return absl::InvalidArgumentError("Max request ID missing"); } - visitor_.OnMaxRequestIdMessage(max_request_id); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return max_request_id; } -absl::Status MoqtControlParser::ProcessFetch(absl::string_view data) { +absl::StatusOr<MoqtFetch> MoqtControlMessageParser::ProcessFetch( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtFetch fetch; uint64_t type; @@ -963,11 +922,12 @@ } QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, fetch.parameters)); - visitor_.OnFetchMessage(fetch); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return fetch; } -absl::Status MoqtControlParser::ProcessFetchOk(absl::string_view data) { +absl::StatusOr<MoqtFetchOk> MoqtControlMessageParser::ProcessFetchOk( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtFetchOk fetch_ok; uint8_t end_of_track; @@ -993,31 +953,34 @@ if (!fetch_ok.extensions.Validate()) { return absl::InvalidArgumentError("Invalid FETCH_OK track extensions"); } - visitor_.OnFetchOkMessage(fetch_ok); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return fetch_ok; } -absl::Status MoqtControlParser::ProcessFetchCancel(absl::string_view data) { +absl::StatusOr<MoqtFetchCancel> MoqtControlMessageParser::ProcessFetchCancel( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtFetchCancel fetch_cancel; if (!reader.ReadVarInt62(&fetch_cancel.request_id)) { return absl::InvalidArgumentError("Request ID missing"); } - visitor_.OnFetchCancelMessage(fetch_cancel); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return fetch_cancel; } -absl::Status MoqtControlParser::ProcessRequestsBlocked(absl::string_view data) { +absl::StatusOr<MoqtRequestsBlocked> +MoqtControlMessageParser::ProcessRequestsBlocked(absl::string_view data) const { quic::QuicDataReader reader(data); MoqtRequestsBlocked requests_blocked; if (!reader.ReadVarInt62(&requests_blocked.max_request_id)) { return absl::InvalidArgumentError("Max request ID missing"); } - visitor_.OnRequestsBlockedMessage(requests_blocked); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return requests_blocked; } -absl::Status MoqtControlParser::ProcessPublish(absl::string_view data) { +absl::StatusOr<MoqtPublish> MoqtControlMessageParser::ProcessPublish( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtPublish publish; QUICHE_DCHECK(reader.PreviouslyReadPayload().empty()); @@ -1035,11 +998,12 @@ if (!publish.extensions.Validate()) { return absl::InvalidArgumentError("Invalid PUBLISH track extensions"); } - visitor_.OnPublishMessage(publish); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return publish; } -absl::Status MoqtControlParser::ProcessPublishOk(absl::string_view data) { +absl::StatusOr<MoqtPublishOk> MoqtControlMessageParser::ProcessPublishOk( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtPublishOk publish_ok; if (!reader.ReadVarInt62(&publish_ok.request_id)) { @@ -1047,11 +1011,12 @@ } QUICHE_RETURN_IF_ERROR( FillAndValidateMessageParameters(reader, publish_ok.parameters)); - visitor_.OnPublishOkMessage(publish_ok); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return publish_ok; } -absl::Status MoqtControlParser::ProcessObjectAck(absl::string_view data) { +absl::StatusOr<MoqtObjectAck> MoqtControlMessageParser::ProcessObjectAck( + absl::string_view data) const { quic::QuicDataReader reader(data); MoqtObjectAck object_ack; uint64_t raw_delta; @@ -1063,32 +1028,12 @@ } object_ack.delta_from_deadline = quic::QuicTimeDelta::FromMicroseconds( SignedVarintUnserializedForm(raw_delta)); - visitor_.OnObjectAckMessage(object_ack); - return CheckForTrailingData(reader); + QUICHE_RETURN_IF_ERROR(CheckForTrailingData(reader)); + return object_ack; } -void MoqtControlParser::ParseError(absl::string_view reason) { - ParseError(MoqtError::kProtocolViolation, reason); -} - -void MoqtControlParser::ParseError(const absl::Status& status) { - ParseError( - GetMoqtErrorForStatus(status).value_or(MoqtError::kProtocolViolation), - status.message()); -} - -void MoqtControlParser::ParseError(MoqtError error_code, - absl::string_view reason) { - if (parsing_error_) { - return; // Don't send multiple parse errors. - } - no_more_data_ = true; - parsing_error_ = true; - visitor_.OnParsingError(error_code, reason); -} - -absl::Status MoqtControlParser::ReadTrackNamespace( - quic::QuicDataReader& reader, TrackNamespace& track_namespace) { +absl::Status MoqtControlMessageParser::ReadTrackNamespace( + quic::QuicDataReader& reader, TrackNamespace& track_namespace) const { QUICHE_DCHECK(track_namespace.empty()); uint64_t num_elements; if (!reader.ReadVarInt62(&num_elements)) { @@ -1111,8 +1056,8 @@ return absl::OkStatus(); } -absl::Status MoqtControlParser::ReadFullTrackName( - quic::QuicDataReader& reader, FullTrackName& full_track_name) { +absl::Status MoqtControlMessageParser::ReadFullTrackName( + quic::QuicDataReader& reader, FullTrackName& full_track_name) const { QUICHE_DCHECK(!full_track_name.IsValid()); TrackNamespace track_namespace; QUICHE_RETURN_IF_ERROR(ReadTrackNamespace(reader, track_namespace)); @@ -1127,9 +1072,9 @@ return absl::OkStatus(); } -absl::Status MoqtControlParser::FillAndValidateSetupParameters( +absl::Status MoqtControlMessageParser::FillAndValidateSetupParameters( const KeyValuePairList& in, SetupParameters& out, - MoqtMessageType message_type) { + MoqtMessageType message_type) const { QUICHE_RETURN_IF_ERROR(out.FromKeyValuePairList(in)); MoqtError error = SetupParametersAllowedByMessage(out, message_type, uses_web_transport_); @@ -1139,8 +1084,8 @@ return absl::OkStatus(); } -absl::Status MoqtControlParser::FillAndValidateMessageParameters( - quic::QuicDataReader& reader, MessageParameters& out) { +absl::Status MoqtControlMessageParser::FillAndValidateMessageParameters( + quic::QuicDataReader& reader, MessageParameters& out) const { KeyValuePairList pairs; QUICHE_RETURN_IF_ERROR(ParseKeyValuePairList(reader, pairs)); // All parameter types are allowed in all messages.
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index e04dbcb..9a62066 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -13,16 +13,24 @@ #include <optional> #include <string> +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/functional/overload.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "quiche/quic/core/quic_data_reader.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_callbacks.h" +#include "quiche/common/quiche_status_utils.h" #include "quiche/web_transport/web_transport.h" namespace moqt { @@ -31,6 +39,7 @@ class MoqtDataParserPeer; } +// TODO(vasilvv): remove once all uses are switched to a new parser. class QUICHE_EXPORT MoqtControlParserVisitor { public: virtual ~MoqtControlParserVisitor() = default; @@ -70,6 +79,13 @@ virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; }; +// MoqtRawControlMessage represents an MOQT control message that has been +// unframed from the control stream, but not parsed yet. +struct MoqtRawControlMessage { + MoqtMessageType type; + std::string payload; +}; + class MoqtDataParserVisitor { public: virtual ~MoqtDataParserVisitor() = default; @@ -88,18 +104,201 @@ virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; }; -class QUICHE_EXPORT MoqtMessageTypeParser { +// MoqtControlStreamParser unframes MoQT control messages from the control +// stream without parsing the payload. +class QUICHE_EXPORT MoqtControlStreamParser { public: - MoqtMessageTypeParser(webtransport::Stream* stream) : stream_(*stream) {} - ~MoqtMessageTypeParser() = default; + explicit MoqtControlStreamParser(webtransport::Stream* absl_nonnull stream) + : stream_(*stream) {} - // Returns false if there was a FIN. - bool ReadUntilMessageTypeKnown(); - std::optional<uint64_t> message_type() const { return message_type_; } + // MoqtControlStreamParser is not movable, since reading from the same stream + // through two different parsers would corrupt the state. + MoqtControlStreamParser(const MoqtControlStreamParser&) = delete; + MoqtControlStreamParser(MoqtControlStreamParser&& other) = delete; + MoqtControlStreamParser& operator=(const MoqtControlStreamParser&) = delete; + MoqtControlStreamParser& operator=(MoqtControlStreamParser&&) = delete; + + // TODO(vasilvv): remove once nothing calls this. + void set_message_type(uint64_t message_type) { + current_message_type_ = message_type; + } + + // Reads the next available message on the stream. Returns kUnavailable + // status if no complete message can be read; if FIN is read, `fin_read` will + // be set to true. + absl::StatusOr<MoqtRawControlMessage> ReadNextMessage(); + // Reads the type of the first message on the stream. + absl::StatusOr<MoqtMessageType> ReadFirstMessageType(); + + bool fin_read() const { return fin_read_; } + webtransport::Stream* stream() const { return &stream_; } + + // Initially, MoqtControlStreamParser does not allow a control stream to have + // a FIN. Once the type of the stream is established, that restriction can be + // lifted. + bool allow_fin() const { return allow_fin_; } + void set_allow_fin(bool allow_fin) { allow_fin_ = allow_fin; } private: + absl::StatusOr<MoqtRawControlMessage> ReadNextMessageInner(); + // Reads the message type from the stream. + absl::Status ReadMessageType(); + webtransport::Stream& stream_; - std::optional<uint64_t> message_type_; + std::optional<uint64_t> first_message_type_; + std::optional<uint64_t> current_message_type_; + std::optional<absl::Span<char>> current_message_remaining_; + std::string current_message_; + bool allow_fin_ = false; + bool error_encountered_ = false; + bool fin_read_ = false; +}; + +// MoqtControlMessageParser parses MOQT control messages. The parsing is +// stateless; the object itself only carries the context (protocol version and +// parameters) required to parse messages. +class MoqtControlMessageParser { + public: + // `moqt_version` is not currently used, as we only support one version. + MoqtControlMessageParser(absl::string_view /*moqt_version*/, + bool uses_web_transport) + : uses_web_transport_(uses_web_transport) {} + + // Parsers for individual messages. + absl::StatusOr<MoqtClientSetup> ProcessClientSetup( + absl::string_view data) const; + absl::StatusOr<MoqtServerSetup> ProcessServerSetup( + absl::string_view data) const; + absl::StatusOr<MoqtRequestOk> ProcessRequestOk(absl::string_view data) const; + absl::StatusOr<MoqtRequestError> ProcessRequestError( + absl::string_view data) const; + absl::StatusOr<MoqtSubscribe> ProcessSubscribe(absl::string_view data) const; + absl::StatusOr<MoqtSubscribeOk> ProcessSubscribeOk( + absl::string_view data) const; + absl::StatusOr<MoqtUnsubscribe> ProcessUnsubscribe( + absl::string_view data) const; + absl::StatusOr<MoqtPublishDone> ProcessPublishDone( + absl::string_view data) const; + absl::StatusOr<MoqtRequestUpdate> ProcessRequestUpdate( + absl::string_view data) const; + absl::StatusOr<MoqtPublishNamespace> ProcessPublishNamespace( + absl::string_view data) const; + absl::StatusOr<MoqtPublishNamespaceDone> ProcessPublishNamespaceDone( + absl::string_view data) const; + absl::StatusOr<MoqtNamespace> ProcessNamespace(absl::string_view data) const; + absl::StatusOr<MoqtNamespaceDone> ProcessNamespaceDone( + absl::string_view data) const; + absl::StatusOr<MoqtPublishNamespaceCancel> ProcessPublishNamespaceCancel( + absl::string_view data) const; + absl::StatusOr<MoqtTrackStatus> ProcessTrackStatus( + absl::string_view data) const; + absl::StatusOr<MoqtGoAway> ProcessGoAway(absl::string_view data) const; + absl::StatusOr<MoqtSubscribeNamespace> ProcessSubscribeNamespace( + absl::string_view data) const; + absl::StatusOr<MoqtMaxRequestId> ProcessMaxRequestId( + absl::string_view data) const; + absl::StatusOr<MoqtFetch> ProcessFetch(absl::string_view data) const; + absl::StatusOr<MoqtFetchCancel> ProcessFetchCancel( + absl::string_view data) const; + absl::StatusOr<MoqtFetchOk> ProcessFetchOk(absl::string_view data) const; + absl::StatusOr<MoqtRequestsBlocked> ProcessRequestsBlocked( + absl::string_view data) const; + absl::StatusOr<MoqtPublish> ProcessPublish(absl::string_view data) const; + absl::StatusOr<MoqtPublishOk> ProcessPublishOk(absl::string_view data) const; + absl::StatusOr<MoqtObjectAck> ProcessObjectAck(absl::string_view data) const; + + // Parse a raw message and call a callback on it if successful. + // Example usage: + // + // parser_.ParseMessage(message, [] (const auto& message) { + // QUICHE_LOG(INFO) << "Received message: " << message; + // return absl::OkStatus(); + // }); + template <typename F> + absl::Status ParseMessage(const MoqtRawControlMessage& message, + const F& callback) const { + const auto parse = [&](auto parse_method) -> absl::Status { + auto parsed_message = (this->*parse_method)(message.payload); + QUICHE_RETURN_IF_ERROR(parsed_message.status()); + return callback(*std::move(parsed_message)); + }; + switch (message.type) { + case MoqtMessageType::kClientSetup: + return parse(&MoqtControlMessageParser::ProcessClientSetup); + case MoqtMessageType::kServerSetup: + return parse(&MoqtControlMessageParser::ProcessServerSetup); + case MoqtMessageType::kRequestOk: + return parse(&MoqtControlMessageParser::ProcessRequestOk); + case MoqtMessageType::kRequestError: + return parse(&MoqtControlMessageParser::ProcessRequestError); + case MoqtMessageType::kSubscribe: + return parse(&MoqtControlMessageParser::ProcessSubscribe); + case MoqtMessageType::kSubscribeOk: + return parse(&MoqtControlMessageParser::ProcessSubscribeOk); + case MoqtMessageType::kUnsubscribe: + return parse(&MoqtControlMessageParser::ProcessUnsubscribe); + case MoqtMessageType::kPublishDone: + return parse(&MoqtControlMessageParser::ProcessPublishDone); + case MoqtMessageType::kRequestUpdate: + return parse(&MoqtControlMessageParser::ProcessRequestUpdate); + case MoqtMessageType::kPublishNamespace: + return parse(&MoqtControlMessageParser::ProcessPublishNamespace); + case MoqtMessageType::kPublishNamespaceDone: + return parse(&MoqtControlMessageParser::ProcessPublishNamespaceDone); + case MoqtMessageType::kNamespace: + return parse(&MoqtControlMessageParser::ProcessNamespace); + case MoqtMessageType::kNamespaceDone: + return parse(&MoqtControlMessageParser::ProcessNamespaceDone); + case MoqtMessageType::kPublishNamespaceCancel: + return parse(&MoqtControlMessageParser::ProcessPublishNamespaceCancel); + case MoqtMessageType::kTrackStatus: + return parse(&MoqtControlMessageParser::ProcessTrackStatus); + case MoqtMessageType::kGoAway: + return parse(&MoqtControlMessageParser::ProcessGoAway); + case MoqtMessageType::kSubscribeNamespace: + return parse(&MoqtControlMessageParser::ProcessSubscribeNamespace); + case MoqtMessageType::kMaxRequestId: + return parse(&MoqtControlMessageParser::ProcessMaxRequestId); + case MoqtMessageType::kFetch: + return parse(&MoqtControlMessageParser::ProcessFetch); + case MoqtMessageType::kFetchCancel: + return parse(&MoqtControlMessageParser::ProcessFetchCancel); + case MoqtMessageType::kFetchOk: + return parse(&MoqtControlMessageParser::ProcessFetchOk); + case MoqtMessageType::kRequestsBlocked: + return parse(&MoqtControlMessageParser::ProcessRequestsBlocked); + case MoqtMessageType::kPublish: + return parse(&MoqtControlMessageParser::ProcessPublish); + case MoqtMessageType::kPublishOk: + return parse(&MoqtControlMessageParser::ProcessPublishOk); + case MoqtMessageType::kObjectAck: + return parse(&MoqtControlMessageParser::ProcessObjectAck); + default: + return absl::InvalidArgumentError( + absl::StrCat("Unknown control message type 0x", + absl::Hex(static_cast<uint64_t>(message.type)))); + } + } + + private: + // Reads a TrackNamespace from the reader. Returns false if the namespace is + // too large. Sets a ParseError if the namespace is malformed. + absl::Status ReadTrackNamespace(quic::QuicDataReader& reader, + TrackNamespace& track_namespace) const; + // Reads a FullTrackName from the reader. Returns false if the name is too + // large. Sets a ParseError if the name is malformed. + absl::Status ReadFullTrackName(quic::QuicDataReader& reader, + FullTrackName& full_track_name) const; + absl::Status FillAndValidateSetupParameters( + const KeyValuePairList& in, SetupParameters& out, + MoqtMessageType message_type) const; + // |reader| points to the beginning of a KeyValuePairList. Returns false if + // there is any sort of error. (The function calls ParseError(), so the + // caller has no need to do so.) + absl::Status FillAndValidateMessageParameters(quic::QuicDataReader& reader, + MessageParameters& out) const; + + bool uses_web_transport_; }; class QUICHE_EXPORT MoqtControlParser { @@ -107,87 +306,147 @@ MoqtControlParser(bool uses_web_transport, webtransport::Stream* stream, MoqtControlParserVisitor& visitor) : visitor_(visitor), - stream_(*stream), - uses_web_transport_(uses_web_transport) {} + stream_parser_(stream), + message_parser_(kDefaultMoqtVersion, uses_web_transport) {} ~MoqtControlParser() = default; + void set_message_type(uint64_t message_type) { + stream_parser_.set_message_type(message_type); + } - void set_message_type(uint64_t message_type) { message_type_ = message_type; } - void ReadAndDispatchMessages(); + void ReadAndDispatchMessages() { + if (processing_) { + return; + } + processing_ = true; + auto cleanup = absl::MakeCleanup([this] { processing_ = false; }); + while (true) { + absl::StatusOr<MoqtRawControlMessage> raw_message = + stream_parser_.ReadNextMessage(); + if (absl::IsUnavailable(raw_message.status())) { + return; + } + if (!raw_message.ok()) { + visitor_.OnParsingError(GetMoqtErrorForStatus(raw_message.status()) + .value_or(MoqtError::kProtocolViolation), + raw_message.status().message()); + return; + } + absl::Status status = message_parser_.ParseMessage( + *raw_message, + absl::Overload{[&](const MoqtClientSetup& message) { + visitor_.OnClientSetupMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtServerSetup& message) { + visitor_.OnServerSetupMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtRequestOk& message) { + visitor_.OnRequestOkMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtRequestError& message) { + visitor_.OnRequestErrorMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtSubscribe& message) { + visitor_.OnSubscribeMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtSubscribeOk& message) { + visitor_.OnSubscribeOkMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtUnsubscribe& message) { + visitor_.OnUnsubscribeMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtPublishDone& message) { + visitor_.OnPublishDoneMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtRequestUpdate& message) { + visitor_.OnRequestUpdateMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtPublishNamespace& message) { + visitor_.OnPublishNamespaceMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtPublishNamespaceDone& message) { + visitor_.OnPublishNamespaceDoneMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtNamespace& message) { + visitor_.OnNamespaceMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtNamespaceDone& message) { + visitor_.OnNamespaceDoneMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtPublishNamespaceCancel& message) { + visitor_.OnPublishNamespaceCancelMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtTrackStatus& message) { + visitor_.OnTrackStatusMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtGoAway& message) { + visitor_.OnGoAwayMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtSubscribeNamespace& message) { + visitor_.OnSubscribeNamespaceMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtMaxRequestId& message) { + visitor_.OnMaxRequestIdMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtFetch& message) { + visitor_.OnFetchMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtFetchCancel& message) { + visitor_.OnFetchCancelMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtFetchOk& message) { + visitor_.OnFetchOkMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtRequestsBlocked& message) { + visitor_.OnRequestsBlockedMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtPublish& message) { + visitor_.OnPublishMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtPublishOk& message) { + visitor_.OnPublishOkMessage(message); + return absl::OkStatus(); + }, + [&](const MoqtObjectAck& message) { + visitor_.OnObjectAckMessage(message); + return absl::OkStatus(); + }}); + if (!status.ok()) { + visitor_.OnParsingError(GetMoqtErrorForStatus(status).value_or( + MoqtError::kProtocolViolation), + status.message()); + return; + } + } + } private: - // The central switch statement to dispatch a message to the correct - // Process* function. Invokles an error callback if parsing fails. - void ProcessMessage(absl::string_view data, MoqtMessageType message_type); - - // The Process* functions parse the serialized data into the appropriate - // structs, and call the relevant visitor function for further action. Returns - // the number of bytes consumed if the message is complete; returns 0 - // otherwise. - absl::Status ProcessClientSetup(absl::string_view data); - absl::Status ProcessServerSetup(absl::string_view data); - absl::Status ProcessRequestOk(absl::string_view data); - absl::Status ProcessRequestError(absl::string_view data); - // Subscribe formats are used for TrackStatus as well, so take the message - // type as an argument, defaulting to the subscribe version. - absl::Status ProcessSubscribe( - absl::string_view data, - MoqtMessageType message_type = MoqtMessageType::kSubscribe); - absl::Status ProcessSubscribeOk(absl::string_view data); - absl::Status ProcessUnsubscribe(absl::string_view data); - absl::Status ProcessPublishDone(absl::string_view data); - absl::Status ProcessRequestUpdate(absl::string_view data); - absl::Status ProcessPublishNamespace(absl::string_view data); - absl::Status ProcessPublishNamespaceDone(absl::string_view data); - absl::Status ProcessNamespace(absl::string_view data); - absl::Status ProcessNamespaceDone(absl::string_view data); - absl::Status ProcessPublishNamespaceCancel(absl::string_view data); - absl::Status ProcessTrackStatus(absl::string_view data); - absl::Status ProcessGoAway(absl::string_view data); - absl::Status ProcessSubscribeNamespace(absl::string_view data); - absl::Status ProcessMaxRequestId(absl::string_view data); - absl::Status ProcessFetch(absl::string_view data); - absl::Status ProcessFetchCancel(absl::string_view data); - absl::Status ProcessFetchOk(absl::string_view data); - absl::Status ProcessRequestsBlocked(absl::string_view data); - absl::Status ProcessPublish(absl::string_view data); - absl::Status ProcessPublishOk(absl::string_view data); - absl::Status ProcessObjectAck(absl::string_view data); - - // If |error| is not provided, assumes kProtocolViolation. - void ParseError(absl::string_view reason); - void ParseError(const absl::Status& status); - void ParseError(MoqtError error, absl::string_view reason); - - // Reads a TrackNamespace from the reader. Returns false if the namespace is - // too large. Sets a ParseError if the namespace is malformed. - absl::Status ReadTrackNamespace(quic::QuicDataReader& reader, - TrackNamespace& track_namespace); - // Reads a FullTrackName from the reader. Returns false if the name is too - // large. Sets a ParseError if the name is malformed. - absl::Status ReadFullTrackName(quic::QuicDataReader& reader, - FullTrackName& full_track_name); - absl::Status FillAndValidateSetupParameters(const KeyValuePairList& in, - SetupParameters& out, - MoqtMessageType message_type); - // |reader| points to the beginning of a KeyValuePairList. Returns false if - // there is any sort of error. (The function calls ParseError(), so the - // caller has no need to do so.) - absl::Status FillAndValidateMessageParameters(quic::QuicDataReader& reader, - MessageParameters& out); - MoqtControlParserVisitor& visitor_; - webtransport::Stream& stream_; - bool uses_web_transport_; - bool no_more_data_ = false; // Fatal error or fin. No more parsing. - bool parsing_error_ = false; - - std::optional<uint64_t> message_type_; - std::optional<uint16_t> message_size_; - - uint64_t max_auth_token_cache_size_ = 0; - uint64_t auth_token_cache_size_ = 0; - bool processing_ = false; // True if currently in ProcessData(), to prevent - // re-entrancy. + MoqtControlStreamParser stream_parser_; + MoqtControlMessageParser message_parser_; + bool processing_ = false; }; // Parses an MoQT datagram. Returns the payload bytes, or std::nullopt on error.
diff --git a/quiche/quic/moqt/moqt_parser_fuzz_test.cc b/quiche/quic/moqt/moqt_parser_fuzz_test.cc index f9b7958..ee73052 100644 --- a/quiche/quic/moqt/moqt_parser_fuzz_test.cc +++ b/quiche/quic/moqt/moqt_parser_fuzz_test.cc
@@ -5,8 +5,11 @@ #include <array> #include <string> +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_parser.h" +#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h" #include "quiche/common/platform/api/quiche_fuzztest.h" #include "quiche/common/platform/api/quiche_test.h" @@ -20,7 +23,9 @@ webtransport::test::InMemoryStream stream(/*stream_id=*/0); MoqtParserTestVisitor visitor(/*enable_logging=*/false); - MoqtControlParser control_parser(uses_web_transport, &stream, visitor); + MoqtControlStreamParser control_stream_parser(&stream); + MoqtControlMessageParser control_message_parser(kDefaultMoqtVersion, + uses_web_transport); MoqtDataParser data_parser(&stream, &visitor); if (is_data_stream) { @@ -28,7 +33,15 @@ data_parser.ReadAllData(); } else { stream.Receive(stream_data, /*fin=*/false); - control_parser.ReadAndDispatchMessages(); + while (true) { + absl::StatusOr<MoqtRawControlMessage> message = + control_stream_parser.ReadNextMessage(); + if (!message.ok()) { + break; + } + (void)control_message_parser.ParseMessage( + *message, [](auto) { return absl::OkStatus(); }); + } } }
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 5a903ec..e395a82 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -15,26 +15,34 @@ #include <variant> #include <vector> +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_data_writer.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/moqt_types.h" +#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" #include "quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h" #include "quiche/quic/moqt/test_tools/moqt_test_message.h" #include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_status_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" #include "quiche/web_transport/test_tools/in_memory_stream.h" namespace moqt::test { namespace { +using ::quiche::test::IsOkAndHolds; +using ::quiche::test::StatusIs; using ::testing::AnyOf; using ::testing::HasSubstr; -using ::testing::Optional; constexpr std::array kMessageTypes{ MoqtMessageType::kRequestOk, @@ -111,6 +119,13 @@ "_" + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); } +std::optional<MoqtError> ExtractMoqtErrorForStatus(const absl::Status& status) { + if (!absl::IsInvalidArgument(status)) { + return std::nullopt; + } + return GetMoqtErrorForStatus(status).value_or(MoqtError::kProtocolViolation); +} + class MoqtParserTest : public quic::test::QuicTestWithParam<MoqtParserTestParams> { public: @@ -118,16 +133,16 @@ : message_type_(GetParam().message_type), webtrans_(GetParam().uses_web_transport), control_stream_(/*stream_id=*/0), - control_parser_(GetParam().uses_web_transport, &control_stream_, - visitor_), + control_parser_(&control_stream_), + message_parser_(kDefaultMoqtVersion, webtrans_), data_stream_(/*stream_id=*/0), - data_parser_(&data_stream_, &visitor_) { + data_parser_(&data_stream_, &data_visitor_) { // The default object has priority 0x07, so setting this will let the // parser set the correct value when absent. data_parser_.set_default_publisher_priority(0x07); } - bool IsDataStream() { + bool IsDataStream() const { return std::holds_alternative<MoqtDataStreamType>(message_type_); } @@ -146,17 +161,67 @@ return; } control_stream_.Receive(data, /*fin=*/false); - control_parser_.ReadAndDispatchMessages(); + for (;;) { + absl::StatusOr<MoqtRawControlMessage> message = + control_parser_.ReadNextMessage(); + if (!message.ok()) { + if (!absl::IsUnavailable(message.status())) { + control_parsing_error_ = message.status().message(); + } + break; + } + absl::Status status = + message_parser_.ParseMessage(*message, [&](auto message) { + control_messages_.push_back(std::move(message)); + return absl::OkStatus(); + }); + if (!status.ok()) { + control_parsing_error_ = status.message(); + break; + } + } } protected: - MoqtParserTestVisitor visitor_; + size_t messages_received() const { + return IsDataStream() ? data_visitor_.messages_received() + : control_messages_.size(); + } + + std::optional<TestMessageBase::MessageStructuredData> last_message() const { + if (IsDataStream()) { + return data_visitor_.last_message(); + } + if (control_messages_.empty()) { + return std::nullopt; + } + return control_messages_.back(); + } + bool end_of_message() const { + return IsDataStream() ? data_visitor_.end_of_message() + : !control_messages_.empty(); + } + std::optional<std::string> parsing_error() const { + return IsDataStream() ? data_visitor_.parsing_error() + : control_parsing_error_; + } + std::string object_payload() const { + QUICHE_DCHECK(IsDataStream()); + return data_visitor_.object_payload(); + } + GeneralizedMessageType message_type_; bool webtrans_; webtransport::test::InMemoryStream control_stream_; - MoqtControlParser control_parser_; + MoqtControlStreamParser control_parser_; + MoqtControlMessageParser message_parser_; webtransport::test::InMemoryStream data_stream_; MoqtDataParser data_parser_; + + private: + std::vector<TestMessageBase::MessageStructuredData> control_messages_; + std::optional<std::string> control_parsing_error_; + MoqtParserTestVisitor data_visitor_; }; INSTANTIATE_TEST_SUITE_P(MoqtParserTests, MoqtParserTest, @@ -167,11 +232,11 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->MakeObjectEndOfStream(); ProcessData(message->PacketSample(), true); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); + ASSERT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -179,12 +244,12 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->ExpandVarints(); ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); + EXPECT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); + EXPECT_EQ(parsing_error(), std::nullopt); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -196,17 +261,17 @@ // processed. size_t first_data_size = message->total_message_size() / 2; ProcessData(message->PacketSample().substr(0, first_data_size), false); - EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_EQ(messages_received(), 0); ProcessData( message->PacketSample().substr( first_data_size, message->total_message_size() - first_data_size), true); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); + EXPECT_FALSE(parsing_error().has_value()); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -214,17 +279,17 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->MakeObjectEndOfStream(); for (size_t i = 0; i < message->total_message_size(); ++i) { - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_EQ(messages_received(), 0); + EXPECT_FALSE(end_of_message()); bool last = i == (message->total_message_size() - 1); ProcessData(message->PacketSample().substr(i, 1), last); } - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); + EXPECT_FALSE(parsing_error().has_value()); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -237,11 +302,11 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->MakeObjectEndOfStream(); ProcessData(message->PacketSample(), true); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); + ASSERT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -250,17 +315,17 @@ message->ExpandVarints(); message->MakeObjectEndOfStream(); for (size_t i = 0; i < message->total_message_size(); ++i) { - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_EQ(messages_received(), 0); + EXPECT_FALSE(end_of_message()); bool last = i == (message->total_message_size() - 1); ProcessData(message->PacketSample().substr(i, 1), last); } - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); + EXPECT_FALSE(parsing_error().has_value()); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -268,17 +333,17 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->MakeObjectEndOfStream(); for (size_t i = 0; i < message->total_message_size(); i += 3) { - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_FALSE(visitor_.end_of_message_); + EXPECT_EQ(messages_received(), 0); + EXPECT_FALSE(end_of_message()); bool last = (i + 3) >= message->total_message_size(); ProcessData(message->PacketSample().substr(i, 3), last); } - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*last_message())); + EXPECT_TRUE(end_of_message()); + EXPECT_FALSE(parsing_error().has_value()); if (IsDataStream()) { - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(object_payload(), "foo"); } } @@ -289,8 +354,8 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); size_t first_data_size = message->total_message_size() - 1; ProcessData(message->PacketSample().substr(0, first_data_size), true); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, + EXPECT_EQ(messages_received(), 0); + EXPECT_THAT(parsing_error(), AnyOf("FIN after incomplete message", "FIN received at an unexpected point in the stream")); } @@ -303,8 +368,8 @@ size_t first_data_size = message->total_message_size() - 1; ProcessData(message->PacketSample().substr(0, first_data_size), false); ProcessData(absl::string_view(), true); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, + EXPECT_EQ(messages_received(), 0); + EXPECT_THAT(parsing_error(), AnyOf("FIN after incomplete message", "FIN received at an unexpected point in the stream")); } @@ -323,9 +388,8 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->IncreasePayloadLengthByOne(); ProcessData(message->PacketSample(), false); - // The parser will actually report a message, because it's all there. - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(messages_received(), 0); + EXPECT_TRUE(parsing_error().has_value()); } TEST_P(MoqtParserTest, PayloadLengthTooShort) { @@ -335,8 +399,8 @@ std::unique_ptr<TestMessageBase> message = MakeMessage(); message->DecreasePayloadLengthByOne(); ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(messages_received(), 0); + EXPECT_TRUE(parsing_error().has_value()); } // Tests for message-specific error cases, and behaviors for a single message @@ -345,7 +409,32 @@ public: MoqtMessageSpecificTest() {} - MoqtParserTestVisitor visitor_; + absl::StatusOr<std::vector<AnyMoqtControlMessage>> ParseAllMessages( + absl::string_view data, + absl::string_view moqt_version = kDefaultMoqtVersion, + bool uses_web_transport = true) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + stream.Receive(data, /*fin=*/true); + MoqtControlStreamParser stream_parser(&stream); + stream_parser.set_allow_fin(true); + MoqtControlMessageParser message_parser(moqt_version, uses_web_transport); + std::vector<AnyMoqtControlMessage> result; + while (!stream_parser.fin_read()) { + absl::StatusOr<MoqtRawControlMessage> raw_message = + stream_parser.ReadNextMessage(); + // ParseAllMessages expects a sequence of complete messages. + if (absl::IsUnavailable(raw_message.status())) { + return absl::InvalidArgumentError("Incomplete control message"); + } + QUICHE_RETURN_IF_ERROR(raw_message.status()); + QUICHE_RETURN_IF_ERROR( + message_parser.ParseMessage(*raw_message, [&](auto message) { + result.push_back(std::move(message)); + return absl::OkStatus(); + })); + } + return std::move(result); + } static constexpr bool kWebTrans = true; static constexpr bool kRawQuic = false; @@ -355,42 +444,44 @@ // message. TEST_F(MoqtMessageSpecificTest, ThreePartObject) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); MoqtDataStreamType type = MoqtDataStreamType::Subgroup(1, 1, true, false); auto message = std::make_unique<StreamHeaderSubgroupMessage>(type); EXPECT_TRUE(message->SetPayloadLength(14)); message->set_wire_image_size(message->total_message_size() - 11); stream.Receive(message->PacketSample(), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload(), "foo"); + EXPECT_EQ(data_visitor.messages_received(), 0); + EXPECT_TRUE(message->EqualFieldValues(*data_visitor.last_message())); + EXPECT_FALSE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload(), "foo"); // second part stream.Receive("bar", false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload(), "foobar"); + EXPECT_EQ(data_visitor.messages_received(), 0); + EXPECT_TRUE(message->EqualFieldValues(*data_visitor.last_message())); + EXPECT_FALSE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload(), "foobar"); // third part includes FIN stream.Receive("deadbeef", true); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.fin_received_); - EXPECT_EQ(visitor_.object_payload(), "foobardeadbeef"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(data_visitor.messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); + EXPECT_TRUE(data_visitor.fin_received()); + EXPECT_EQ(data_visitor.object_payload(), "foobardeadbeef"); + EXPECT_FALSE(data_visitor.parsing_error().has_value()); } // Send the part of header, rest of header + payload, plus payload. TEST_F(MoqtMessageSpecificTest, ThreePartObjectFirstIncomplete) { uint8_t payload_length = 51; webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); MoqtDataStreamType type = MoqtDataStreamType::Subgroup(2, 1, false, false); auto message = std::make_unique<StreamHeaderSubgroupMessage>(type); EXPECT_TRUE(message->SetPayloadLength(payload_length)); @@ -398,106 +489,107 @@ // first part stream.Receive(message->PacketSample().substr(0, 4), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_EQ(data_visitor.messages_received(), 0); // second part. Add padding to it. stream.Receive( message->PacketSample().substr(4, message->total_message_size() - 7), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload().length(), payload_length - 3); + EXPECT_EQ(data_visitor.messages_received(), 0); + EXPECT_TRUE(message->EqualFieldValues(*data_visitor.last_message())); + EXPECT_FALSE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload().length(), payload_length - 3); // third part includes FIN stream.Receive("bar", true); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.fin_received_); - EXPECT_EQ(*visitor_.object_payloads_.crbegin(), "bar"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(data_visitor.messages_received(), 1); + EXPECT_TRUE(message->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); + EXPECT_TRUE(data_visitor.fin_received()); + EXPECT_EQ(*data_visitor.object_payloads().crbegin(), "bar"); + EXPECT_FALSE(data_visitor.parsing_error().has_value()); } TEST_F(MoqtMessageSpecificTest, ObjectSplitInExtension) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); MoqtDataStreamType type = MoqtDataStreamType::Subgroup(2, 1, false, false); auto message = std::make_unique<StreamHeaderSubgroupMessage>(type); // first part stream.Receive(message->PacketSample().substr(0, 10), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_EQ(data_visitor.messages_received(), 0); // second part stream.Receive( message->PacketSample().substr(10, sizeof(message->total_message_size())), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(visitor_.last_message_.has_value() && - message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_EQ(data_visitor.messages_received(), 1); + EXPECT_TRUE(data_visitor.last_message().has_value() && + message->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); } TEST_F(MoqtMessageSpecificTest, StreamHeaderSubgroupFollowOn) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); // first part MoqtDataStreamType type = MoqtDataStreamType::Subgroup(0, 1, false, false); auto message1 = std::make_unique<StreamHeaderSubgroupMessage>(type); stream.Receive(message1->PacketSample(), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message1->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload(), "foo"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(data_visitor.messages_received(), 1); + EXPECT_TRUE(message1->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload(), "foo"); + EXPECT_FALSE(data_visitor.parsing_error().has_value()); // second part - visitor_.object_payloads_.clear(); + data_visitor.object_payloads().clear(); auto message2 = std::make_unique<StreamMiddlerSubgroupMessage>(type); stream.Receive(message2->PacketSample(), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 2); - EXPECT_TRUE(message2->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload(), "bar"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(data_visitor.messages_received(), 2); + EXPECT_TRUE(message2->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload(), "bar"); + EXPECT_FALSE(data_visitor.parsing_error().has_value()); } TEST_F(MoqtMessageSpecificTest, StreamHeaderSubgroupFollowOnExpandedVarInts) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); // first part MoqtDataStreamType type = MoqtDataStreamType::Subgroup(0, 1, false, false); auto message1 = std::make_unique<StreamHeaderSubgroupMessage>(type); message1->ExpandVarints(); stream.Receive(message1->PacketSample(), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(message1->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload(), "foo"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(data_visitor.messages_received(), 1); + EXPECT_TRUE(message1->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload(), "foo"); + EXPECT_FALSE(data_visitor.parsing_error().has_value()); // second part - visitor_.object_payloads_.clear(); + data_visitor.object_payloads().clear(); auto message2 = std::make_unique<StreamMiddlerSubgroupMessage>(type); message2->ExpandVarints(); stream.Receive(message2->PacketSample(), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 2); - EXPECT_TRUE(message2->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_EQ(visitor_.object_payload(), "bar"); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(data_visitor.messages_received(), 2); + EXPECT_TRUE(message2->EqualFieldValues(*data_visitor.last_message())); + EXPECT_TRUE(data_visitor.end_of_message()); + EXPECT_EQ(data_visitor.object_payload(), "bar"); + EXPECT_FALSE(data_visitor.parsing_error().has_value()); } TEST_F(MoqtMessageSpecificTest, ClientSetupMaxRequestIdAppearsTwice) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x0a, 0x03, // 3 params @@ -505,182 +597,155 @@ 0x01, 0x32, // max_request_id = 50 0x00, 0x32, // max_request_id = 50 }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Duplicate Setup Parameter"))); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(setup, sizeof(setup))); + EXPECT_THAT(parsed, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Duplicate Setup Parameter"))); } TEST_F(MoqtMessageSpecificTest, ServerSetupAuthorizationTokenTagRegister) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kWebTrans, &stream, visitor_); char setup[] = { 0x21, 0x00, 0x0b, 0x02, // 2 params 0x02, 0x32, // max_request_id = 50 0x01, 0x06, 0x01, 0x10, 0x00, 0x62, 0x61, 0x72, // REGISTER 0x01 }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(setup, sizeof(setup))); // No error even though the registration exceeds the max cache size of 0. - EXPECT_EQ(visitor_.messages_received_, 1); + QUICHE_EXPECT_OK(parsed.status()); } TEST_F(MoqtMessageSpecificTest, SetupPathFromServer) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x21, 0x00, 0x06, 0x01, // 1 param 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kInvalidPath); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(setup, sizeof(setup))); + ASSERT_THAT(parsed.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Setup parameter parsing error"))); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kInvalidPath); } TEST_F(MoqtMessageSpecificTest, SetupAuthorityFromServer) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x21, 0x00, 0x06, 0x01, // 1 param 0x05, 0x03, 0x66, 0x6f, 0x6f, // authority = "foo" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kInvalidAuthority); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(setup, sizeof(setup))); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kInvalidAuthority); } TEST_F(MoqtMessageSpecificTest, SetupPathAppearsTwice) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x0b, 0x02, // 2 params 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" 0x00, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Duplicate Setup Parameter"))); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, SetupPathOverWebtrans) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kWebTrans, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x06, 0x01, // 1 param 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kInvalidPath); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kWebTrans); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kInvalidPath); } TEST_F(MoqtMessageSpecificTest, SetupAuthorityOverWebtrans) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kWebTrans, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x06, 0x01, // 1 param 0x05, 0x03, 0x66, 0x6f, 0x6f, // authority = "foo" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kInvalidAuthority); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kWebTrans); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kInvalidAuthority); } TEST_F(MoqtMessageSpecificTest, SetupPathMissing) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x01, 0x00, // no param }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kInvalidPath); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kInvalidPath); } TEST_F(MoqtMessageSpecificTest, ServerSetupMaxRequestIdAppearsTwice) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x21, 0x00, 0x05, 0x02, // 2 params 0x02, 0x32, // max_request_id = 50 0x00, 0x32, // max_request_id = 50 }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Duplicate Setup Parameter"))); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, ClientSetupMalformedPath) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x06, 0x01, // 1 param 0x01, 0x03, 0x66, 0x5c, 0x6f, // path = "f\o" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kMalformedPath); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kMalformedPath); } TEST_F(MoqtMessageSpecificTest, ClientSetupMalformedAuthority) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x20, 0x00, 0x0b, 0x02, // 2 params 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" 0x04, 0x03, 0x66, 0x5c, 0x6f, // authority = "f\o" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kMalformedAuthority); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kMalformedAuthority); } TEST_F(MoqtMessageSpecificTest, ServerSetupUnknownParameterIsOk) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char setup[] = { 0x21, 0x00, 0x0b, 0x02, // 2 params 0x1f, 0x03, 0x62, 0x61, 0x72, // 0x1f = "bar" 0x00, 0x03, 0x62, 0x61, 0x72, // 0x1f = "bar" }; - stream.Receive(absl::string_view(setup, sizeof(setup)), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, 1); - MoqtServerSetup message = - std::get<MoqtServerSetup>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(setup, sizeof(setup)), kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtServerSetup message = std::get<MoqtServerSetup>((*parsed)[0]); EXPECT_EQ(message.parameters, SetupParameters()); } TEST_F(MoqtMessageSpecificTest, SubscribeDeliveryTimeoutTwice) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x12, 0x01, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -689,17 +754,14 @@ 0x02, 0x67, 0x10, // delivery_timeout = 10000 0x00, 0x67, 0x10, // delivery_timeout = 10000 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Duplicate Message Parameter"))); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationTokenTagDelete) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x10, 0x01, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -707,19 +769,18 @@ 0x01, // one param 0x03, 0x02, 0x00, 0x00 // authorization_token = DELETE 0; }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - MoqtSubscribe message = - std::get<MoqtSubscribe>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtSubscribe message = std::get<MoqtSubscribe>((*parsed)[0]); ASSERT_FALSE(message.parameters.authorization_tokens.empty()); EXPECT_EQ(message.parameters.authorization_tokens[0].alias_type, AuthTokenAliasType::kDelete); } TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationTokenTagRegister) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x14, 0x01, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -727,11 +788,12 @@ 0x01, // one param 0x03, 0x06, 0x01, 0x10, 0x00, 0x62, 0x61, 0x72, // REGISTER 0x01 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - MoqtSubscribe message = - std::get<MoqtSubscribe>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtSubscribe message = std::get<MoqtSubscribe>((*parsed)[0]); ASSERT_FALSE(message.parameters.authorization_tokens.empty()); EXPECT_EQ(message.parameters.authorization_tokens[0].alias_type, AuthTokenAliasType::kRegister); @@ -739,8 +801,6 @@ TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationTokenTagUnknownAliasType) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x10, 0x01, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -748,16 +808,15 @@ 0x01, // one param 0x03, 0x02, 0x04, 0x07, // authorization_token type 4 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kKeyValueFormattingError); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kKeyValueFormattingError); } TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationTokenTagUnknownTokenType) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x12, 0x01, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -765,15 +824,14 @@ 0x01, // one param 0x03, 0x04, 0x03, 0x01, 0x00, 0x00 // authorization_token type 1 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kKeyValueFormattingError); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kKeyValueFormattingError); } TEST_F(MoqtMessageSpecificTest, SubscribeInvalidForward) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0e, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -781,15 +839,14 @@ 0x01, // 2 parameters 0x10, 0x02 // forward = 2 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, SubscribeInvalidFilter) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0f, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -797,15 +854,14 @@ 0x01, // 1 parameter 0x21, 0x01, 0x10 // filter_type = 0x10 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, PublishNamespaceAuthorizationTokenTwice) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kWebTrans, &stream, visitor_); char publish_namespace[] = { 0x06, 0x00, 0x15, 0x02, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -813,44 +869,106 @@ 0x03, 0x05, 0x03, 0x00, 0x62, 0x61, 0x72, // authorization = "bar" 0x00, 0x05, 0x03, 0x00, 0x62, 0x61, 0x72, // authorization = "bar" }; - stream.Receive( - absl::string_view(publish_namespace, sizeof(publish_namespace)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(publish_namespace, sizeof(publish_namespace)), + kDefaultMoqtVersion, kWebTrans); + EXPECT_TRUE(parsed.ok()); + EXPECT_EQ(parsed->size(), 1); } -TEST_F(MoqtMessageSpecificTest, FinMidPayload) { +TEST_F(MoqtMessageSpecificTest, CannotAccessAfterError1) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + stream.Receive("\xff", /*fin=*/true); + MoqtControlStreamParser parser(&stream); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(parser.ReadFirstMessageType().status(), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(MoqtMessageSpecificTest, CannotAccessAfterError2) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + stream.Receive("\x03\xff\xff"); + MoqtControlStreamParser parser(&stream); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_THAT(parser.ReadFirstMessageType(), + IsOkAndHolds(MoqtMessageType::kSubscribe)); +} + +TEST_F(MoqtMessageSpecificTest, FinMidType) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + stream.Receive("\xff", /*fin=*/true); + MoqtControlStreamParser parser(&stream); + parser.set_allow_fin(true); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(MoqtMessageSpecificTest, FinMidLength) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + stream.Receive(absl::string_view("\0\0", 2), /*fin=*/true); + MoqtControlStreamParser parser(&stream); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(MoqtMessageSpecificTest, FinMidControlPayload) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + stream.Receive(absl::string_view("\x00\x00\xff ", 4), /*fin=*/false); + MoqtControlStreamParser parser(&stream); + ASSERT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kUnavailable)); + + stream.Receive("test", /*fin=*/true); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("250 bytes left in the current message"))); +} + +TEST_F(MoqtMessageSpecificTest, FinMidDataPayload) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); MoqtDataStreamType type = MoqtDataStreamType::Subgroup(0, 1, true, false); auto message = std::make_unique<StreamHeaderSubgroupMessage>(type); stream.Receive( message->PacketSample().substr(0, message->total_message_size() - 1), true); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - AnyOf("FIN after incomplete message", - "FIN received at an unexpected point in the stream")); + EXPECT_EQ(data_visitor.messages_received(), 0); + ASSERT_TRUE(data_visitor.parsing_error().has_value()); + EXPECT_THAT( + data_visitor.parsing_error().value(), + AnyOf(HasSubstr("FIN after incomplete message"), + HasSubstr("FIN received at an unexpected point in the stream"))); } TEST_F(MoqtMessageSpecificTest, FinMidExtension) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); MoqtDataStreamType type = MoqtDataStreamType::Subgroup(0, 1, false, false); auto message = std::make_unique<StreamHeaderSubgroupMessage>(type); // Read up to the extension body and then FIN. stream.Receive(message->PacketSample().substr(0, 7), true); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - AnyOf("FIN after incomplete message", - "FIN received at an unexpected point in the stream")); + EXPECT_EQ(data_visitor.messages_received(), 0); + ASSERT_TRUE(data_visitor.parsing_error().has_value()); + EXPECT_THAT( + data_visitor.parsing_error().value(), + AnyOf(HasSubstr("FIN after incomplete message"), + HasSubstr("FIN received at an unexpected point in the stream"))); } TEST_F(MoqtMessageSpecificTest, PartialPayloadThenFin) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); MoqtDataStreamType type = MoqtDataStreamType::Subgroup(1, 1, false, false); auto message = std::make_unique<StreamHeaderSubgroupMessage>(type); stream.Receive( @@ -859,35 +977,99 @@ parser.ReadAllData(); stream.Receive(absl::string_view(), true); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - AnyOf("FIN after incomplete message", - "FIN received at an unexpected point in the stream")); + EXPECT_EQ(data_visitor.messages_received(), 0); + ASSERT_TRUE(data_visitor.parsing_error().has_value()); + EXPECT_THAT( + data_visitor.parsing_error().value(), + AnyOf(HasSubstr("FIN after incomplete message"), + HasSubstr("FIN received at an unexpected point in the stream"))); } TEST_F(MoqtMessageSpecificTest, FinMidVarint) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); stream.Receive("\x40", true); parser.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - AnyOf("FIN after incomplete message", - "FIN received at an unexpected point in the stream")); + EXPECT_EQ(data_visitor.messages_received(), 0); + ASSERT_TRUE(data_visitor.parsing_error().has_value()); + EXPECT_THAT( + data_visitor.parsing_error().value(), + AnyOf(HasSubstr("FIN after incomplete message"), + HasSubstr("FIN received at an unexpected point in the stream"))); } -TEST_F(MoqtMessageSpecificTest, ControlStreamFin) { +TEST_F(MoqtMessageSpecificTest, ControlStreamFinWhenAllowed) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); - stream.Receive(absl::string_view(), true); // Find FIN - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.parsing_error_, "FIN on control stream"); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + MoqtControlStreamParser parser(&stream); + parser.set_allow_fin(true); + stream.Receive(absl::string_view("\0\0\0", 3), true); + EXPECT_FALSE(parser.fin_read()); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kOk)); + EXPECT_TRUE(parser.fin_read()); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kFailedPrecondition)); + EXPECT_TRUE(parser.fin_read()); +} + +TEST_F(MoqtMessageSpecificTest, ControlStreamFinWhenAllowedSeparateFin) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtControlStreamParser parser(&stream); + parser.set_allow_fin(true); + stream.Receive(absl::string_view("\0\0\0", 3), false); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kOk)); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_FALSE(parser.fin_read()); + + stream.Receive(absl::string_view(), true); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_TRUE(parser.fin_read()); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(MoqtMessageSpecificTest, ControlStreamFinWhenDisallowed) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtControlStreamParser parser(&stream); + stream.Receive(absl::string_view(), true); + EXPECT_FALSE(parser.fin_read()); + EXPECT_THAT(parser.ReadNextMessage().status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("FIN on a control stream"))); +} + +TEST_F(MoqtMessageSpecificTest, ControlStreamReadType) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtControlStreamParser parser(&stream); + stream.Receive("\x03", false); + absl::StatusOr<MoqtMessageType> type = parser.ReadFirstMessageType(); + EXPECT_THAT(type, IsOkAndHolds(MoqtMessageType::kSubscribe)); +} + +TEST_F(MoqtMessageSpecificTest, ControlStreamFinBeforeType) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtControlStreamParser parser(&stream); + stream.Receive("", true); + absl::StatusOr<MoqtMessageType> type = parser.ReadFirstMessageType(); + EXPECT_EQ(type.status().code(), absl::StatusCode::kInvalidArgument); +} + +TEST_F(MoqtMessageSpecificTest, ControlStreamFinInTheMiddleOfType) { + webtransport::test::InMemoryStream stream(/*stream_id=*/0); + MoqtControlStreamParser parser(&stream); + stream.Receive("\xff", true); + absl::StatusOr<MoqtMessageType> type = parser.ReadFirstMessageType(); + EXPECT_EQ(type.status().code(), absl::StatusCode::kInvalidArgument); } TEST_F(MoqtMessageSpecificTest, InvalidObjectStatus) { webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtParserTestVisitor data_visitor; + MoqtDataParser parser(&stream, &data_visitor); char stream_header_subgroup[] = { 0x15, // type field 0x04, 0x05, 0x08, // varints @@ -898,13 +1080,12 @@ absl::string_view(stream_header_subgroup, sizeof(stream_header_subgroup)), false); parser.ReadAllData(); - EXPECT_EQ(visitor_.parsing_error_, "Invalid object status provided"); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + ASSERT_TRUE(data_visitor.parsing_error().has_value()); + EXPECT_THAT(data_visitor.parsing_error().value(), + HasSubstr("Invalid object status provided")); } TEST_F(MoqtMessageSpecificTest, Setup2KB) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char big_message[2 * kMaxMessageHeaderSize]; quic::QuicDataWriter writer(sizeof(big_message), big_message); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kServerSetup)); @@ -915,43 +1096,40 @@ writer.WriteVarInt62(kMaxMessageHeaderSize); // very long parameter writer.WriteRepeatedByte(0x04, kMaxMessageHeaderSize); // Send incomplete message - stream.Receive(absl::string_view(big_message, writer.length() - 1), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, - "Cannot parse control messages more than 2048 bytes"); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kInternalError); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(big_message, writer.length())); + EXPECT_THAT( + parsed.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("control message exceeds the maximum allowed size"))); } TEST_F(MoqtMessageSpecificTest, UnknownMessageType) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char message[7]; quic::QuicDataWriter writer(sizeof(message), message); writer.WriteVarInt62(0xbeef); // unknown message type writer.WriteUInt16(0x1); // length writer.WriteVarInt62(0x1); // payload - stream.Receive(absl::string_view(message, writer.length()), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_, "Unknown control message type 0xbeef"); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(message, writer.length())); + EXPECT_THAT(parsed.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unknown control message type 0xbeef"))); } TEST_F(MoqtMessageSpecificTest, SubscribeNoParameters) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0c, 0x01, // request_id = 1 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" 0x00, // 0 parameters }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); - MoqtSubscribe message = - std::get<MoqtSubscribe>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtSubscribe message = std::get<MoqtSubscribe>((*parsed)[0]); EXPECT_FALSE(message.parameters.delivery_timeout.has_value()); EXPECT_FALSE(message.parameters.forward_has_value()); EXPECT_FALSE(message.parameters.subscription_filter.has_value()); @@ -965,8 +1143,6 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeUnknownParameter) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0f, 0x01, // request_id = 1 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -974,15 +1150,14 @@ 0x01, // 0 parameters 0x40, 0x60, 0x01, // unknown parameter = 0x60 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, LargestObject) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0f, 0x01, // request_id = 1 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -990,20 +1165,18 @@ 0x01, // 1 parameter 0x21, 0x01, 0x02, // filter_type = kLargestObject }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); - MoqtSubscribe message = - std::get<MoqtSubscribe>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtSubscribe message = std::get<MoqtSubscribe>((*parsed)[0]); ASSERT_TRUE(message.parameters.subscription_filter.has_value()); SubscriptionFilter& filter = *message.parameters.subscription_filter; EXPECT_TRUE(filter.type() == MoqtFilterType::kLargestObject); } TEST_F(MoqtMessageSpecificTest, InvalidDeliveryOrder) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0e, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -1011,15 +1184,14 @@ 0x01, // 1 parameter 0x22, 0x03, // invalid group order = 3 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, NextGroupStart) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x0f, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -1027,20 +1199,18 @@ 0x01, // 1 parameter 0x21, 0x01, 0x01, // filter_type = kNextGroupStart }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); - MoqtSubscribe message = - std::get<MoqtSubscribe>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtSubscribe message = std::get<MoqtSubscribe>((*parsed)[0]); ASSERT_TRUE(message.parameters.subscription_filter.has_value()); SubscriptionFilter& filter = *message.parameters.subscription_filter; EXPECT_TRUE(filter.type() == MoqtFilterType::kNextGroupStart); } TEST_F(MoqtMessageSpecificTest, AbsoluteRange) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x12, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -1050,12 +1220,12 @@ 0x07 // filter_type = kAbsoluteRange // (4,1) to 7 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); - MoqtSubscribe message = - std::get<MoqtSubscribe>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtSubscribe message = std::get<MoqtSubscribe>((*parsed)[0]); ASSERT_TRUE(message.parameters.subscription_filter.has_value()); SubscriptionFilter& filter = *message.parameters.subscription_filter; EXPECT_TRUE(filter.type() == MoqtFilterType::kAbsoluteRange && @@ -1063,8 +1233,6 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeEndGroupTooLow) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x12, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -1074,15 +1242,14 @@ 0x03 // filter_type = kAbsoluteRange // (4,1) to 3 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeExactlyOneGroup) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe[] = { 0x03, 0x00, 0x12, 0x01, // id 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -1092,42 +1259,37 @@ 0x04 // filter_type = kAbsoluteRange // (4,1) to 4 }; - stream.Receive(absl::string_view(subscribe, sizeof(subscribe)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe, sizeof(subscribe)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); } TEST_F(MoqtMessageSpecificTest, RequestUpdateEndGroupTooLow) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char request_update[] = { 0x02, 0x00, 0x09, 0x02, 0x00, // request IDs 0x01, 0x21, 0x04, 0x04, 0x04, 0x01, 0x03, // filter }; - stream.Receive(absl::string_view(request_update, sizeof(request_update)), - false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT( - visitor_.parsing_error_, - Optional(HasSubstr( - "AbsoluteRange filter specified with a start after the end"))); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(request_update, sizeof(request_update)), + kDefaultMoqtVersion, kRawQuic); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, ObjectAckNegativeDelta) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char object_ack[] = { 0x71, 0x84, 0x00, 0x05, // type 0x01, 0x10, 0x20, // subscribe ID, group, object 0x40, 0x81, // -0x40 time delta }; - stream.Receive(absl::string_view(object_ack, sizeof(object_ack)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - ASSERT_EQ(visitor_.messages_received_, 1); - MoqtObjectAck message = - std::get<MoqtObjectAck>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(object_ack, sizeof(object_ack)), + kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1); + MoqtObjectAck message = std::get<MoqtObjectAck>((*parsed)[0]); EXPECT_EQ(message.subscribe_id, 0x01); EXPECT_EQ(message.group_id, 0x10); EXPECT_EQ(message.object_id, 0x20); @@ -1136,64 +1298,15 @@ } TEST_F(MoqtMessageSpecificTest, AllMessagesTogether) { - char buffer[5000]; - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); - size_t write = 0; - size_t read = 0; - int fully_received = 0; - std::unique_ptr<TestMessageBase> prev_message = nullptr; + std::string buffer; for (MoqtMessageType type : kMessageTypes) { - // Each iteration, process from the halfway point of one message to the - // halfway point of the next. std::unique_ptr<TestMessageBase> message = CreateTestMessage(type, kRawQuic); - memcpy(buffer + write, message->PacketSample().data(), - message->total_message_size()); - size_t new_read = write + message->total_message_size() / 2; - stream.Receive(absl::string_view(buffer + read, new_read - read), false); - parser.ReadAndDispatchMessages(); - ASSERT_EQ(visitor_.messages_received_, fully_received); - if (prev_message != nullptr) { - EXPECT_TRUE(prev_message->EqualFieldValues(*visitor_.last_message_)); - } - fully_received++; - read = new_read; - write += message->total_message_size(); - prev_message = std::move(message); + buffer += message->PacketSample(); } - // Deliver the rest - stream.Receive(absl::string_view(buffer + read, write - read), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, fully_received); - EXPECT_TRUE(prev_message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_FALSE(visitor_.parsing_error_.has_value()); -} - -TEST_F(MoqtMessageSpecificTest, ReadOnlyMessageType) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtMessageTypeParser parser(&stream); - char buffer[] = {0x40, 0x03}; - stream.Receive(absl::string_view(buffer, sizeof(buffer)), false); - EXPECT_TRUE(parser.ReadUntilMessageTypeKnown()); - EXPECT_EQ(parser.message_type(), 0x03); -} - -TEST_F(MoqtMessageSpecificTest, ReadOnlyMessageTypeIncomplete) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtMessageTypeParser parser(&stream); - char buffer[] = {0x40}; - stream.Receive(absl::string_view(buffer, sizeof(buffer)), false); - EXPECT_TRUE(parser.ReadUntilMessageTypeKnown()); - EXPECT_FALSE(parser.message_type().has_value()); -} - -TEST_F(MoqtMessageSpecificTest, ReadOnlyMessageTypeEarlyFin) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtMessageTypeParser parser(&stream); - char buffer[] = {0x03}; - stream.Receive(absl::string_view(buffer, sizeof(buffer)), true); - EXPECT_FALSE(parser.ReadUntilMessageTypeKnown()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(buffer, kDefaultMoqtVersion, kRawQuic); + ASSERT_TRUE(parsed.ok()); } TEST_F(MoqtMessageSpecificTest, DatagramSuccessful) { @@ -1272,167 +1385,153 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeOkInvalidDeliveryOrder) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); SubscribeOkMessage subscribe_ok; subscribe_ok.SetInvalidDeliveryOrder(); - stream.Receive(subscribe_ok.PacketSample(), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Invalid SUBSCRIBE_OK track extensions"))); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(subscribe_ok.PacketSample(), kDefaultMoqtVersion, + /*uses_web_transport=*/false); + EXPECT_FALSE(parsed.ok()); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); + EXPECT_THAT(parsed.status().message(), + HasSubstr("Invalid SUBSCRIBE_OK track extensions")); } TEST_F(MoqtMessageSpecificTest, SubscribeOkExpirationIsZero) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe_ok[] = { 0x04, 0x00, 0x05, 0x02, 0x01, // request_id = 2, track_alias = 1 0x01, 0x08, 0x00 // expires = 0 }; - stream.Receive(absl::string_view(subscribe_ok, sizeof(subscribe_ok)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - ASSERT_EQ(visitor_.messages_received_, 1); - MoqtSubscribeOk message = - std::get<MoqtSubscribeOk>(visitor_.last_message_.value()); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(absl::string_view(subscribe_ok, sizeof(subscribe_ok)), + kDefaultMoqtVersion, /*uses_web_transport=*/false); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1u); + MoqtSubscribeOk message = std::get<MoqtSubscribeOk>((*parsed)[0]); EXPECT_EQ(message.parameters.expires, quic::QuicTimeDelta::Infinite()); } TEST_F(MoqtMessageSpecificTest, FetchWholeGroup) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); FetchMessage fetch; fetch.SetEndObject(5, std::nullopt); - stream.Receive(fetch.PacketSample(), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(visitor_.last_message_.has_value()); - if (!visitor_.last_message_.has_value()) { - return; - } - MoqtFetch parse_result = std::get<MoqtFetch>(*visitor_.last_message_); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(fetch.PacketSample(), kDefaultMoqtVersion, + /*uses_web_transport=*/false); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1u); + MoqtFetch parse_result = std::get<MoqtFetch>((*parsed)[0]); auto standalone = std::get<StandaloneFetch>(parse_result.fetch); EXPECT_EQ(standalone.end_location, Location(5, kMaxObjectId)); } TEST_F(MoqtMessageSpecificTest, FetchInvalidRange) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); FetchMessage fetch; fetch.SetEndObject(1, 1); - stream.Receive(fetch.PacketSample(), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT( - visitor_.parsing_error_, - Optional(HasSubstr("End object comes before start object in FETCH"))); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(fetch.PacketSample(), kDefaultMoqtVersion, + /*uses_web_transport=*/false); + EXPECT_FALSE(parsed.ok()); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); + EXPECT_THAT(parsed.status().message(), + HasSubstr("End object comes before start object in FETCH")); } TEST_F(MoqtMessageSpecificTest, FetchInvalidRange2) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); FetchMessage fetch; fetch.SetEndObject(0, std::nullopt); - stream.Receive(fetch.PacketSample(), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_THAT( - visitor_.parsing_error_, - Optional(HasSubstr("End object comes before start object in FETCH"))); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(fetch.PacketSample(), kDefaultMoqtVersion, + /*uses_web_transport=*/false); + EXPECT_FALSE(parsed.ok()); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); + EXPECT_THAT(parsed.status().message(), + HasSubstr("End object comes before start object in FETCH")); } TEST_F(MoqtMessageSpecificTest, PaddingStream) { + MoqtParserTestVisitor visitor; webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtDataParser parser(&stream, &visitor_); + MoqtDataParser parser(&stream, &visitor); std::string buffer(32, '\0'); quic::QuicDataWriter writer(buffer.size(), buffer.data()); ASSERT_TRUE(writer.WriteVarInt62(MoqtDataStreamType::Padding().value())); for (int i = 0; i < 100; ++i) { stream.Receive(buffer, false); parser.ReadAllData(); - ASSERT_EQ(visitor_.messages_received_, 0); - ASSERT_EQ(visitor_.parsing_error_, std::nullopt); + ASSERT_EQ(visitor.messages_received(), 0); + ASSERT_EQ(visitor.parsing_error(), std::nullopt); } } // All messages with TrackNamespace use ReadTrackNamespace too check this. Use // PUBLISH_NAMESPACE. TEST_F(MoqtMessageSpecificTest, NamespaceTooSmall) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char publish_namespace[7] = { 0x06, 0x00, 0x04, 0x02, // request_id = 2 0x01, 0x00, // one empty namespace element 0x00, // no parameters }; - stream.Receive( - absl::string_view(publish_namespace, sizeof(publish_namespace)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( + absl::string_view(publish_namespace, sizeof(publish_namespace)), + kDefaultMoqtVersion, /*uses_web_transport=*/false); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1u); + --publish_namespace[2]; // Remove one element. --publish_namespace[4]; - stream.Receive( + parsed = ParseAllMessages( absl::string_view(publish_namespace, sizeof(publish_namespace) - 1), - false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Invalid number of namespace elements"))); + kDefaultMoqtVersion, /*uses_web_transport=*/false); + EXPECT_FALSE(parsed.ok()); + EXPECT_THAT(parsed.status().message(), + HasSubstr("Invalid number of namespace elements")); } TEST_F(MoqtMessageSpecificTest, NamespaceTooLarge) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char publish_namespace[39] = { 0x06, 0x00, 0x23, 0x02, // type, length = 35, request_id = 2 0x20, // 32 namespace elements. This is the maximum. }; // 32 empty namespace elements + no parameters. - stream.Receive( + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( absl::string_view(publish_namespace, sizeof(publish_namespace) - 1), - false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); + kDefaultMoqtVersion, /*uses_web_transport=*/false); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1u); + ++publish_namespace[2]; // Add one element. ++publish_namespace[4]; - stream.Receive( - absl::string_view(publish_namespace, sizeof(publish_namespace)), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_THAT(visitor_.parsing_error_, - Optional(HasSubstr("Invalid number of namespace elements"))); + parsed = ParseAllMessages( + absl::string_view(publish_namespace, sizeof(publish_namespace)), + kDefaultMoqtVersion, /*uses_web_transport=*/false); + EXPECT_FALSE(parsed.ok()); + EXPECT_THAT(parsed.status().message(), + HasSubstr("Invalid number of namespace elements")); } TEST_F(MoqtMessageSpecificTest, RelativeJoiningFetch) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); RelativeJoiningFetchMessage message; - stream.Receive(message.PacketSample(), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - EXPECT_TRUE(visitor_.last_message_.has_value() && - message.EqualFieldValues(*visitor_.last_message_)); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(message.PacketSample(), kDefaultMoqtVersion, + /*uses_web_transport=*/false); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1u); + EXPECT_TRUE(std::holds_alternative<MoqtFetch>((*parsed)[0])); } TEST_F(MoqtMessageSpecificTest, AbsoluteJoiningFetch) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); AbsoluteJoiningFetchMessage message; - stream.Receive(message.PacketSample(), false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - EXPECT_TRUE(visitor_.last_message_.has_value() && - message.EqualFieldValues(*visitor_.last_message_)); + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = + ParseAllMessages(message.PacketSample(), kDefaultMoqtVersion, + /*uses_web_transport=*/false); + ASSERT_TRUE(parsed.ok()); + ASSERT_EQ(parsed->size(), 1u); + EXPECT_TRUE(std::holds_alternative<MoqtFetch>((*parsed)[0])); } TEST_F(MoqtMessageSpecificTest, InvalidSubscribeNamespaceOption) { - webtransport::test::InMemoryStream stream(/*stream_id=*/0); - MoqtControlParser parser(kRawQuic, &stream, visitor_); char subscribe_namespace[] = { 0x11, 0x00, 0x11, 0x01, // request_id = 1 0x01, 0x03, 0x66, 0x6f, 0x6f, // namespace = "foo" @@ -1441,12 +1540,12 @@ 0x03, 0x05, 0x03, 0x00, 0x62, 0x61, 0x72, // authorization_tag = "bar" 0x0d, 0x01, // forward = true }; - stream.Receive( + absl::StatusOr<std::vector<AnyMoqtControlMessage>> parsed = ParseAllMessages( absl::string_view(subscribe_namespace, sizeof(subscribe_namespace)), - false); - parser.ReadAndDispatchMessages(); - EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + kDefaultMoqtVersion, /*uses_web_transport=*/false); + EXPECT_FALSE(parsed.ok()); + EXPECT_EQ(ExtractMoqtErrorForStatus(parsed.status()), + MoqtError::kProtocolViolation); } class MoqtDataParserStateMachineTest : public quic::test::QuicTest { @@ -1464,13 +1563,13 @@ stream_.Receive(StreamHeaderSubgroupMessage(type).PacketSample()); stream_.Receive(StreamMiddlerSubgroupMessage(type).PacketSample()); parser_.ReadAllData(); - ASSERT_EQ(visitor_.messages_received_, 2); - EXPECT_EQ(visitor_.object_payloads_[0], "foo"); - EXPECT_EQ(visitor_.object_payloads_[1], "bar"); + ASSERT_EQ(visitor_.messages_received(), 2); + EXPECT_EQ(visitor_.object_payloads()[0], "foo"); + EXPECT_EQ(visitor_.object_payloads()[1], "bar"); stream_.Receive("", /*fin=*/true); parser_.ReadAllData(); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - EXPECT_TRUE(visitor_.fin_received_); + EXPECT_EQ(visitor_.parsing_error(), std::nullopt); + EXPECT_TRUE(visitor_.fin_received()); } TEST_F(MoqtDataParserStateMachineTest, ReadObjects) { @@ -1479,13 +1578,13 @@ stream_.Receive(StreamMiddlerSubgroupMessage(type).PacketSample(), /*fin=*/true); parser_.ReadAtMostOneObject(); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.object_payloads_[0], "foo"); + ASSERT_EQ(visitor_.messages_received(), 1); + EXPECT_EQ(visitor_.object_payloads()[0], "foo"); parser_.ReadAtMostOneObject(); - ASSERT_EQ(visitor_.messages_received_, 2); - EXPECT_EQ(visitor_.object_payloads_[1], "bar"); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - EXPECT_TRUE(visitor_.fin_received_); + ASSERT_EQ(visitor_.messages_received(), 2); + EXPECT_EQ(visitor_.object_payloads()[1], "bar"); + EXPECT_EQ(visitor_.parsing_error(), std::nullopt); + EXPECT_TRUE(visitor_.fin_received()); } TEST_F(MoqtDataParserStateMachineTest, ReadTypeThenObjects) { @@ -1494,17 +1593,17 @@ stream_.Receive(StreamMiddlerSubgroupMessage(type).PacketSample(), /*fin=*/true); parser_.ReadStreamType(); - ASSERT_EQ(visitor_.messages_received_, 0); + ASSERT_EQ(visitor_.messages_received(), 0); EXPECT_TRUE(parser_.stream_type().has_value() && parser_.stream_type()->IsSubgroup()); parser_.ReadAtMostOneObject(); - ASSERT_EQ(visitor_.messages_received_, 1); - EXPECT_EQ(visitor_.object_payloads_[0], "foo"); + ASSERT_EQ(visitor_.messages_received(), 1); + EXPECT_EQ(visitor_.object_payloads()[0], "foo"); parser_.ReadAtMostOneObject(); - ASSERT_EQ(visitor_.messages_received_, 2); - EXPECT_EQ(visitor_.object_payloads_[1], "bar"); - EXPECT_EQ(visitor_.parsing_error_, std::nullopt); - EXPECT_TRUE(visitor_.fin_received_); + ASSERT_EQ(visitor_.messages_received(), 2); + EXPECT_EQ(visitor_.object_payloads()[1], "bar"); + EXPECT_EQ(visitor_.parsing_error(), std::nullopt); + EXPECT_TRUE(visitor_.fin_received()); } TEST_F(MoqtDataParserStateMachineTest, ReadTypeThenObjectsFetch) { @@ -1518,17 +1617,17 @@ stream.Receive(header.PacketSample()); stream.Receive(middler.PacketSample(), /*fin=*/true); parser.ReadStreamType(); - ASSERT_EQ(visitor.messages_received_, 0); + ASSERT_EQ(visitor.messages_received(), 0); parser.ReadAtMostOneObject(); - ASSERT_EQ(visitor.messages_received_, 1); - EXPECT_TRUE(header.EqualFieldValues(visitor.last_message_.value())); - EXPECT_EQ(visitor.object_payloads_[0], "foo"); + ASSERT_EQ(visitor.messages_received(), 1); + EXPECT_TRUE(header.EqualFieldValues(visitor.last_message().value())); + EXPECT_EQ(visitor.object_payloads()[0], "foo"); parser.ReadAtMostOneObject(); - ASSERT_EQ(visitor.messages_received_, 2); - EXPECT_TRUE(middler.EqualFieldValues(visitor.last_message_.value())); - EXPECT_EQ(visitor.object_payloads_[1], "bar"); - EXPECT_EQ(visitor.parsing_error_, std::nullopt); - EXPECT_TRUE(visitor.fin_received_); + ASSERT_EQ(visitor.messages_received(), 2); + EXPECT_TRUE(middler.EqualFieldValues(visitor.last_message().value())); + EXPECT_EQ(visitor.object_payloads()[1], "bar"); + EXPECT_EQ(visitor.parsing_error(), std::nullopt); + EXPECT_TRUE(visitor.fin_received()); } } @@ -1543,9 +1642,8 @@ stream.Receive(absl::string_view(data, sizeof(data))); parser.ReadStreamType(); parser.ReadAtMostOneObject(); - EXPECT_EQ(visitor.parsing_error_, + EXPECT_EQ(visitor.parsing_error(), "Invalid serialization flags for first object"); - EXPECT_EQ(visitor.parsing_error_code_, MoqtError::kProtocolViolation); } } @@ -1563,9 +1661,8 @@ parser.ReadStreamType(); parser.ReadAtMostOneObject(); parser.ReadAtMostOneObject(); - EXPECT_EQ(visitor.parsing_error_, + EXPECT_EQ(visitor.parsing_error(), "reference to subgroup ID of prior datagram"); - EXPECT_EQ(visitor.parsing_error_code_, MoqtError::kProtocolViolation); } } @@ -1574,8 +1671,7 @@ stream_.Receive(absl::string_view(data, sizeof(data))); parser_.ReadStreamType(); parser_.ReadAtMostOneObject(); - EXPECT_EQ(visitor_.parsing_error_, "Invalid serialization flags"); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + EXPECT_EQ(visitor_.parsing_error(), "Invalid serialization flags"); } TEST_F(MoqtDataParserStateMachineTest, InvalidNonexistentRangeUnknownRange) { @@ -1583,8 +1679,7 @@ stream_.Receive(absl::string_view(data, sizeof(data))); parser_.ReadStreamType(); parser_.ReadAtMostOneObject(); - EXPECT_EQ(visitor_.parsing_error_, "Invalid serialization flags"); - EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); + EXPECT_EQ(visitor_.parsing_error(), "Invalid serialization flags"); } TEST_F(MoqtDataParserStateMachineTest, IgnoresEndRangeIndicators) { @@ -1599,7 +1694,7 @@ StreamMiddlerFetchMessage middler(*serialization); stream_.Receive(middler.PacketSample(), /*fin=*/true); parser_.ReadAllData(); - EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.messages_received(), 2); // TODO(martinduke): Once Issue #1506 is resolved, check that the values // are reported correctly. }
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index c5c2ab8..499666d 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -23,6 +23,7 @@ #include "absl/functional/bind_front.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -910,17 +911,24 @@ } void MoqtSession::UnknownBidiStream::OnCanRead() { - if (!parser_.ReadUntilMessageTypeKnown()) { - // Got an early FIN. - stream_->ResetWithUserCode(kResetCodeCancelled); + absl::StatusOr<MoqtMessageType> message_type = + parser_->ReadFirstMessageType(); + if (absl::IsUnavailable(message_type.status())) { return; } - if (!parser_.message_type().has_value()) { + if (absl::IsInvalidArgument(message_type.status())) { + // Received a FIN before any type has been available, which is malformed. + session_->Error(MoqtError::kProtocolViolation, + message_type.status().message()); return; } - MoqtMessageType message_type = - static_cast<MoqtMessageType>(*parser_.message_type()); - switch (message_type) { + if (!message_type.ok()) { + // The result is neither of "OK", "no type available", or "parse error". + // This is unexpected; treat it as an internal error, and reset the stream. + stream_->ResetWithUserCode(kResetCodeInternalError); + return; + } + switch (*message_type) { case MoqtMessageType::kClientSetup: { if (session_->control_stream_.GetIfAvailable() != nullptr) { session_->Error(MoqtError::kProtocolViolation,
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index fbb194b..59e2084 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -220,7 +220,9 @@ // responsible for calling stream->SetVisitor(). UnknownBidiStream(MoqtSession* session, webtransport::Stream* absl_nonnull stream) - : session_(session), stream_(stream), parser_(stream) {} + : session_(session), + stream_(stream), + parser_(std::make_unique<MoqtControlStreamParser>(stream)) {} ~UnknownBidiStream() {} // webtransport::StreamVisitor overrides. @@ -233,7 +235,7 @@ private: MoqtSession* session_; webtransport::Stream* stream_; - MoqtMessageTypeParser parser_; + std::unique_ptr<MoqtControlStreamParser> parser_; }; class QUICHE_EXPORT ControlStream : public MoqtBidiStreamBase {
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 04af379..ecf0e53 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -331,7 +331,7 @@ EXPECT_CALL(mock_stream_, visitor).WillOnce(Return(visitor.get())); }); EXPECT_CALL(mock_stream_, PeekNextReadableRegion()) - .WillOnce(Return( + .WillRepeatedly(Return( webtransport::Stream::PeekResult(absl::string_view(), false, false))); server_session.OnIncomingBidirectionalStreamAvailable(); }
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc index 970ee46..39ea217 100644 --- a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc +++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc
@@ -6,19 +6,10 @@ #include <string> #include <variant> -#include <vector> -#include "absl/status/status.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_messages.h" -#include "quiche/quic/moqt/moqt_parser.h" -#include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_buffer_allocator.h" -#include "quiche/web_transport/test_tools/in_memory_stream.h" -#include "quiche/web_transport/web_transport.h" namespace moqt::test { @@ -105,123 +96,12 @@ bool is_track_status; }; -class GenericMessageParseVisitor : public MoqtControlParserVisitor { - public: - explicit GenericMessageParseVisitor(std::vector<MoqtGenericFrame>* frames) - : frames_(*frames) {} - - void OnClientSetupMessage(const MoqtClientSetup& message) { - frames_.push_back(message); - } - void OnServerSetupMessage(const MoqtServerSetup& message) { - frames_.push_back(message); - } - void OnRequestOkMessage(const MoqtRequestOk& message) { - frames_.push_back(message); - } - void OnRequestErrorMessage(const MoqtRequestError& message) { - frames_.push_back(message); - } - void OnSubscribeMessage(const MoqtSubscribe& message) { - frames_.push_back(message); - } - void OnSubscribeOkMessage(const MoqtSubscribeOk& message) { - frames_.push_back(message); - } - void OnUnsubscribeMessage(const MoqtUnsubscribe& message) { - frames_.push_back(message); - } - void OnPublishDoneMessage(const MoqtPublishDone& message) { - frames_.push_back(message); - } - void OnRequestUpdateMessage(const MoqtRequestUpdate& message) { - frames_.push_back(message); - } - void OnPublishNamespaceMessage(const MoqtPublishNamespace& message) { - frames_.push_back(message); - } - void OnPublishNamespaceDoneMessage(const MoqtPublishNamespaceDone& message) { - frames_.push_back(message); - } - void OnNamespaceMessage(const MoqtNamespace& message) { - frames_.push_back(message); - } - void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) { - frames_.push_back(message); - } - void OnPublishNamespaceCancelMessage( - const MoqtPublishNamespaceCancel& message) { - frames_.push_back(message); - } - void OnTrackStatusMessage(const MoqtTrackStatus& message) { - frames_.push_back(message); - } - void OnGoAwayMessage(const MoqtGoAway& message) { - frames_.push_back(message); - } - void OnSubscribeNamespaceMessage(const MoqtSubscribeNamespace& message) { - frames_.push_back(message); - } - void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) { - frames_.push_back(message); - } - void OnFetchMessage(const MoqtFetch& message) { frames_.push_back(message); } - void OnFetchCancelMessage(const MoqtFetchCancel& message) { - frames_.push_back(message); - } - void OnFetchOkMessage(const MoqtFetchOk& message) { - frames_.push_back(message); - } - void OnRequestsBlockedMessage(const MoqtRequestsBlocked& message) { - frames_.push_back(message); - } - void OnPublishMessage(const MoqtPublish& message) { - frames_.push_back(message); - } - void OnPublishOkMessage(const MoqtPublishOk& message) { - frames_.push_back(message); - } - void OnObjectAckMessage(const MoqtObjectAck& message) { - frames_.push_back(message); - } - - void OnParsingError(MoqtError code, absl::string_view reason) { - ADD_FAILURE() << "Parsing failed: " << reason; - } - - private: - std::vector<MoqtGenericFrame>& frames_; -}; - } // namespace -std::string SerializeGenericMessage(const MoqtGenericFrame& frame, +std::string SerializeGenericMessage(const AnyMoqtControlMessage& frame, bool use_webtrans) { MoqtFramer framer(use_webtrans); return std::string(std::visit(FramingVisitor{framer}, frame).AsStringView()); } -std::vector<MoqtGenericFrame> ParseGenericMessage(absl::string_view body) { - std::vector<MoqtGenericFrame> result; - GenericMessageParseVisitor visitor(&result); - webtransport::test::InMemoryStream stream(/*id=*/0); - MoqtControlParser parser(/*uses_web_transport=*/true, &stream, visitor); - stream.Receive(body, /*fin=*/false); - parser.ReadAndDispatchMessages(); - return result; -} - -absl::Status StoreSubscribe::operator()( - absl::Span<const absl::string_view> data, - const webtransport::StreamWriteOptions& options) const { - std::string merged_message = absl::StrJoin(data, ""); - std::vector<MoqtGenericFrame> frames = ParseGenericMessage(merged_message); - if (frames.size() != 1 || !std::holds_alternative<MoqtSubscribe>(frames[0])) { - ADD_FAILURE() << "Expected one SUBSCRIBE frame in a write"; - return absl::InternalError("Expected one SUBSCRIBE frame in a write"); - } - subscribe_->emplace(std::get<MoqtSubscribe>(frames[0])); - return absl::OkStatus(); -} - } // namespace moqt::test
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.h b/quiche/quic/moqt/test_tools/moqt_framer_utils.h index 6aa81aa..0e4326b 100644 --- a/quiche/quic/moqt/test_tools/moqt_framer_utils.h +++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.h
@@ -6,41 +6,31 @@ #define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_ #include <cstdint> -#include <optional> #include <string> #include <variant> #include <vector> -#include "absl/status/status.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_data_reader.h" #include "quiche/common/quiche_mem_slice.h" -#include "quiche/web_transport/web_transport.h" namespace moqt::test { -// TODO: remove MoqtObject from TestMessageBase::MessageStructuredData and merge -// those two types. -using MoqtGenericFrame = - std::variant<MoqtClientSetup, MoqtServerSetup, MoqtRequestOk, - MoqtRequestError, MoqtSubscribe, MoqtSubscribeOk, - MoqtUnsubscribe, MoqtPublishDone, MoqtRequestUpdate, - MoqtPublishNamespace, MoqtPublishNamespaceDone, MoqtNamespace, - MoqtNamespaceDone, MoqtPublishNamespaceCancel, MoqtTrackStatus, - MoqtGoAway, MoqtSubscribeNamespace, MoqtMaxRequestId, - MoqtFetch, MoqtFetchCancel, MoqtFetchOk, MoqtRequestsBlocked, - MoqtPublish, MoqtPublishOk, MoqtObjectAck>; +using AnyMoqtControlMessage = std::variant< + MoqtClientSetup, MoqtServerSetup, MoqtRequestOk, MoqtRequestError, + MoqtSubscribe, MoqtSubscribeOk, MoqtUnsubscribe, MoqtPublishDone, + MoqtRequestUpdate, MoqtPublishNamespace, MoqtPublishNamespaceDone, + MoqtPublishNamespaceCancel, MoqtTrackStatus, MoqtGoAway, + MoqtSubscribeNamespace, MoqtMaxRequestId, MoqtFetch, MoqtFetchCancel, + MoqtFetchOk, MoqtRequestsBlocked, MoqtPublish, MoqtPublishOk, MoqtNamespace, + MoqtNamespaceDone, MoqtObjectAck>; -std::string SerializeGenericMessage(const MoqtGenericFrame& frame, +std::string SerializeGenericMessage(const AnyMoqtControlMessage& frame, bool use_webtrans = false); -// Parses a concatenation of one or more MoQT control messages. -std::vector<MoqtGenericFrame> ParseGenericMessage(absl::string_view body); - MATCHER_P(SerializedControlMessage, message, "Matches against a specific expected MoQT message") { std::vector<absl::string_view> data_written; @@ -76,21 +66,6 @@ return true; } -// gmock action for extracting an SUBSCRIBE message written onto a stream. -class StoreSubscribe { - public: - explicit StoreSubscribe(std::optional<MoqtSubscribe>* subscribe) - : subscribe_(subscribe) {} - - // quiche::WriteStream::Writev() implementation. - absl::Status operator()( - absl::Span<const absl::string_view> data, - const webtransport::StreamWriteOptions& options) const; - - private: - std::optional<MoqtSubscribe>* subscribe_; -}; - } // namespace moqt::test #endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_
diff --git a/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h b/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h index 10451df..51cb39e 100644 --- a/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h +++ b/quiche/quic/moqt/test_tools/moqt_parser_test_visitor.h
@@ -12,19 +12,17 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_parser.h" -#include "quiche/quic/moqt/test_tools/moqt_test_message.h" #include "quiche/common/platform/api/quiche_logging.h" namespace moqt::test { -class MoqtParserTestVisitor : public MoqtControlParserVisitor, - public MoqtDataParserVisitor { +class MoqtParserTestVisitor : public MoqtDataParserVisitor { public: explicit MoqtParserTestVisitor(bool enable_logging = true) : enable_logging_(enable_logging) {} - ~MoqtParserTestVisitor() = default; void OnObjectMessage(const MoqtObject& message, absl::string_view payload, bool end_of_message) override { @@ -34,101 +32,26 @@ if (end_of_message) { ++messages_received_; } - last_message_.emplace(TestMessageBase::MessageStructuredData(object)); + last_message_.emplace(object); } void OnFin() override { fin_received_ = true; } - template <typename Message> - void OnControlMessage(const Message& message) { - end_of_message_ = true; - ++messages_received_; - last_message_.emplace(TestMessageBase::MessageStructuredData(message)); - } - void OnClientSetupMessage(const MoqtClientSetup& message) override { - OnControlMessage(message); - } - void OnServerSetupMessage(const MoqtServerSetup& message) override { - OnControlMessage(message); - } - void OnRequestOkMessage(const MoqtRequestOk& message) override { - OnControlMessage(message); - } - void OnRequestErrorMessage(const MoqtRequestError& message) override { - OnControlMessage(message); - } - void OnSubscribeMessage(const MoqtSubscribe& message) override { - OnControlMessage(message); - } - void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override { - OnControlMessage(message); - } - void OnRequestUpdateMessage(const MoqtRequestUpdate& message) override { - OnControlMessage(message); - } - void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override { - OnControlMessage(message); - } - void OnPublishDoneMessage(const MoqtPublishDone& message) override { - OnControlMessage(message); - } - void OnPublishNamespaceMessage(const MoqtPublishNamespace& message) override { - OnControlMessage(message); - } - void OnPublishNamespaceDoneMessage( - const MoqtPublishNamespaceDone& message) override { - OnControlMessage(message); - } - void OnNamespaceMessage(const MoqtNamespace& message) override { - OnControlMessage(message); - } - void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) override { - OnControlMessage(message); - } - void OnPublishNamespaceCancelMessage( - const MoqtPublishNamespaceCancel& message) override { - OnControlMessage(message); - } - void OnTrackStatusMessage(const MoqtTrackStatus& message) override { - OnControlMessage(message); - } - void OnGoAwayMessage(const MoqtGoAway& message) override { - OnControlMessage(message); - } - void OnSubscribeNamespaceMessage( - const MoqtSubscribeNamespace& message) override { - OnControlMessage(message); - } - void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override { - OnControlMessage(message); - } - void OnFetchMessage(const MoqtFetch& message) override { - OnControlMessage(message); - } - void OnFetchCancelMessage(const MoqtFetchCancel& message) override { - OnControlMessage(message); - } - void OnFetchOkMessage(const MoqtFetchOk& message) override { - OnControlMessage(message); - } - void OnRequestsBlockedMessage(const MoqtRequestsBlocked& message) override { - OnControlMessage(message); - } - void OnPublishMessage(const MoqtPublish& message) override { - OnControlMessage(message); - } - void OnPublishOkMessage(const MoqtPublishOk& message) override { - OnControlMessage(message); - } - void OnObjectAckMessage(const MoqtObjectAck& message) override { - OnControlMessage(message); - } void OnParsingError(MoqtError code, absl::string_view reason) override { QUICHE_LOG_IF(INFO, enable_logging_) << "Parsing error: " << reason; parsing_error_ = reason; parsing_error_code_ = code; } - std::string object_payload() { return absl::StrJoin(object_payloads_, ""); } + std::string object_payload() const { + return absl::StrJoin(object_payloads_, ""); + } + std::vector<std::string>& object_payloads() { return object_payloads_; } + uint64_t messages_received() const { return messages_received_; } + bool end_of_message() const { return end_of_message_; } + bool fin_received() const { return fin_received_; } + std::optional<MoqtObject> last_message() const { return last_message_; } + std::optional<std::string> parsing_error() const { return parsing_error_; } + private: bool enable_logging_ = true; std::vector<std::string> object_payloads_; bool end_of_message_ = false; @@ -136,7 +59,7 @@ std::optional<std::string> parsing_error_; MoqtError parsing_error_code_; uint64_t messages_received_ = 0; - std::optional<TestMessageBase::MessageStructuredData> last_message_; + std::optional<MoqtObject> last_message_; }; } // namespace moqt::test
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 813e969..f1da5ad 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -131,7 +131,7 @@ // Compares |values| to the derived class's structured data to make sure // they are equal. - virtual bool EqualFieldValues(MessageStructuredData& values) const = 0; + virtual bool EqualFieldValues(const MessageStructuredData& values) const = 0; // Expand all varints in the message. This is pure virtual because each // message has a different layout of varints. @@ -230,7 +230,7 @@ // Base class for the two subtypes of Object Message. class QUICHE_NO_EXPORT ObjectMessage : public TestMessageBase { public: - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::move(std::get<MoqtObject>(values)); if (cast.track_alias != object_.track_alias) { QUIC_LOG(INFO) << "OBJECT Track ID mismatch"; @@ -620,7 +620,7 @@ } } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtClientSetup>(values); if (cast.parameters != client_setup_.parameters) { QUIC_LOG(INFO) << "CLIENT_SETUP parameter mismatch"; @@ -669,7 +669,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtServerSetup>(values); if (cast.parameters != server_setup_.parameters) { QUIC_LOG(INFO) << "SERVER_SETUP parameter mismatch"; @@ -704,7 +704,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtSubscribe>(values); if (cast.request_id != subscribe_.request_id) { QUIC_LOG(INFO) << "SUBSCRIBE subscribe ID mismatch"; @@ -758,7 +758,7 @@ subscribe_ok_.parameters.largest_object = Location(12, 20); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtSubscribeOk>(values); if (cast.request_id != subscribe_ok_.request_id) { QUIC_LOG(INFO) << "SUBSCRIBE OK subscribe ID mismatch"; @@ -823,7 +823,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtRequestError>(values); if (cast.request_id != request_error_.request_id) { QUIC_LOG(INFO) << "REQUEST_ERROR request_id mismatch"; @@ -874,7 +874,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtUnsubscribe>(values); if (cast.request_id != unsubscribe_.request_id) { QUIC_LOG(INFO) << "UNSUBSCRIBE request ID mismatch"; @@ -905,7 +905,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtPublishDone>(values); if (cast.request_id != publish_done_.request_id) { QUIC_LOG(INFO) << "PUBLISH_DONE request ID mismatch"; @@ -959,7 +959,7 @@ request_update_.parameters.subscription_filter.emplace(Location(3, 1), 5); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtRequestUpdate>(values); if (cast.request_id != request_update_.request_id) { QUIC_LOG(INFO) << "REQUEST_UPDATE request ID mismatch"; @@ -1007,7 +1007,7 @@ AuthToken(AuthTokenType::kOutOfBand, "bar")); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtPublishNamespace>(values); if (cast.request_id != publish_namespace_.request_id) { QUIC_LOG(INFO) << "PUBLISH_NAMESPACE request ID mismatch"; @@ -1051,7 +1051,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtNamespace>(values); if (cast.track_namespace_suffix != namespace_.track_namespace_suffix) { QUIC_LOG(INFO) << "NAMESPACE suffix mismatch"; @@ -1083,7 +1083,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtNamespaceDone>(values); if (cast.track_namespace_suffix != namespace_done_.track_namespace_suffix) { QUIC_LOG(INFO) << "NAMESPACE_DONE suffix mismatch"; @@ -1116,7 +1116,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtRequestOk>(values); if (cast.request_id != request_ok_.request_id) { QUIC_LOG(INFO) << "REQUEST_OK request ID mismatch"; @@ -1154,7 +1154,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtPublishNamespaceDone>(values); if (cast.request_id != publish_namespace_done_.request_id) { QUIC_LOG(INFO) << "PUBLISH_NAMESPACE_DONE request ID mismatch"; @@ -1188,7 +1188,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtPublishNamespaceCancel>(values); if (cast.request_id != publish_namespace_cancel_.request_id) { QUIC_LOG(INFO) << "PUBLISH_NAMESPACE CANCEL request ID mismatch"; @@ -1231,7 +1231,7 @@ SetByte(0, static_cast<uint8_t>(MoqtMessageType::kTrackStatus)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto value = std::get<MoqtTrackStatus>(values); auto* subscribe = reinterpret_cast<MoqtSubscribe*>(&value); MessageStructuredData structured_data = @@ -1250,7 +1250,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtGoAway>(values); if (cast.new_session_uri != goaway_.new_session_uri) { QUIC_LOG(INFO) << "GOAWAY full track name mismatch"; @@ -1284,7 +1284,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtSubscribeNamespace>(values); if (cast.request_id != subscribe_namespace_.request_id) { QUIC_LOG(INFO) << "SUBSCRIBE_NAMESPACE request_id mismatch"; @@ -1336,7 +1336,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtMaxRequestId>(values); if (cast.max_request_id != max_request_id_.max_request_id) { QUIC_LOG(INFO) << "MAX_REQUEST_ID mismatch"; @@ -1373,7 +1373,7 @@ fetch_.parameters.group_order = MoqtDeliveryOrder::kAscending; fetch_.parameters.subscriber_priority = 2; } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtFetch>(values); if (cast.request_id != fetch_.request_id) { QUIC_LOG(INFO) << "FETCH request_id mismatch"; @@ -1452,7 +1452,7 @@ fetch_.parameters.group_order = MoqtDeliveryOrder::kAscending; fetch_.parameters.subscriber_priority = 2; } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtFetch>(values); if (cast.request_id != fetch_.request_id) { QUIC_LOG(INFO) << "FETCH request_id mismatch"; @@ -1510,7 +1510,7 @@ fetch_.parameters.group_order = MoqtDeliveryOrder::kAscending; fetch_.parameters.subscriber_priority = 2; } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtFetch>(values); if (cast.request_id != fetch_.request_id) { QUIC_LOG(INFO) << "FETCH request_id mismatch"; @@ -1562,7 +1562,7 @@ FetchOkMessage() : TestMessageBase() { SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtFetchOk>(values); if (cast.request_id != fetch_ok_.request_id) { QUIC_LOG(INFO) << "FETCH_OK request_id mismatch"; @@ -1621,7 +1621,7 @@ FetchCancelMessage() : TestMessageBase() { SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtFetchCancel>(values); if (cast.request_id != fetch_cancel_.request_id) { QUIC_LOG(INFO) << "FETCH_CANCEL subscribe_id mismatch"; @@ -1652,7 +1652,7 @@ RequestsBlockedMessage() : TestMessageBase() { SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtRequestsBlocked>(values); if (cast.max_request_id != requests_blocked_.max_request_id) { QUIC_LOG(INFO) << "SUBSCRIBES_BLOCKED max_subscribe_id mismatch"; @@ -1687,7 +1687,7 @@ publish_.parameters.largest_object = Location(10, 1); publish_.parameters.set_forward(true); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtPublish>(values); if (cast.request_id != publish_.request_id) { QUIC_LOG(INFO) << "PUBLISH request_id mismatch"; @@ -1757,7 +1757,7 @@ publish_ok_.parameters.subscription_filter = SubscriptionFilter(Location(5, 4), 6); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtPublishOk>(values); if (cast.request_id != publish_ok_.request_id) { QUIC_LOG(INFO) << "PUBLISH_OK request_id mismatch"; @@ -1799,7 +1799,7 @@ SetWireImage(raw_packet_, sizeof(raw_packet_)); } - bool EqualFieldValues(MessageStructuredData& values) const override { + bool EqualFieldValues(const MessageStructuredData& values) const override { auto cast = std::get<MoqtObjectAck>(values); if (cast.subscribe_id != object_ack_.subscribe_id) { QUIC_LOG(INFO) << "OBJECT_ACK subscribe ID mismatch";