Split MoqtParser into MoqtControlParser and MoqtDataParser. MoqtControlParser is the same parser as before, but with object-specific logic removed. MoqtDataParser is new code, optimized for parsing data streams. There is quite a bit more cleanup that can be done here, but for now, this CL is already too large. PiperOrigin-RevId: 676603941
diff --git a/quiche/common/quiche_data_reader.cc b/quiche/common/quiche_data_reader.cc index 76f2de9..8f5bf8e 100644 --- a/quiche/common/quiche_data_reader.cc +++ b/quiche/common/quiche_data_reader.cc
@@ -4,6 +4,7 @@ #include "quiche/common/quiche_data_reader.h" +#include <algorithm> #include <cstring> #include <string> @@ -131,6 +132,13 @@ return true; } +absl::string_view QuicheDataReader::ReadAtMost(size_t size) { + size_t actual_size = std::min(size, BytesRemaining()); + absl::string_view result = absl::string_view(data_ + pos_, actual_size); + AdvancePos(actual_size); + return result; +} + bool QuicheDataReader::ReadTag(uint32_t* tag) { return ReadBytes(tag, sizeof(*tag)); }
diff --git a/quiche/common/quiche_data_reader.h b/quiche/common/quiche_data_reader.h index 4f8d46a..43e9c1d 100644 --- a/quiche/common/quiche_data_reader.h +++ b/quiche/common/quiche_data_reader.h
@@ -82,6 +82,9 @@ // Returns true on success, false otherwise. bool ReadStringPiece(absl::string_view* result, size_t size); + // Reads at most a given number of bytes into the provided view. + absl::string_view ReadAtMost(size_t size); + // Reads tag represented as 32-bit unsigned integer into given output // parameter. Tags are in big endian on the wire (e.g., CHLO is // 'C','H','L','O') and are read in byte order, so tags in memory are in big
diff --git a/quiche/common/quiche_data_reader_test.cc b/quiche/common/quiche_data_reader_test.cc index d65dd88..16378d8 100644 --- a/quiche/common/quiche_data_reader_test.cc +++ b/quiche/common/quiche_data_reader_test.cc
@@ -6,6 +6,7 @@ #include <cstdint> +#include "absl/strings/string_view.h" #include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_endian.h" @@ -184,4 +185,13 @@ EXPECT_STREQ("", dest); } +TEST(QuicheDataReaderTest, ReadAtMost) { + constexpr absl::string_view kData = "foobar"; + QuicheDataReader reader(kData); + EXPECT_EQ(reader.ReadAtMost(0), ""); + EXPECT_EQ(reader.ReadAtMost(3), "foo"); + EXPECT_EQ(reader.ReadAtMost(6), "bar"); + EXPECT_EQ(reader.ReadAtMost(1000), ""); +} + } // namespace quiche
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index dabdea4..32d2a5f 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -191,7 +191,7 @@ return quiche::QuicheBuffer(); } } - MoqtMessageType message_type = + MoqtDataStreamType message_type = GetMessageTypeForForwardingPreference(message.forwarding_preference); switch (message.forwarding_preference) { case MoqtForwardingPreference::kTrack: @@ -247,7 +247,7 @@ return quiche::QuicheBuffer(); } return Serialize( - WireVarInt62(MoqtMessageType::kObjectDatagram), + WireVarInt62(MoqtDataStreamType::kObjectDatagram), WireVarInt62(message.subscribe_id), WireVarInt62(message.track_alias), WireVarInt62(message.group_id), WireVarInt62(message.object_id), WireUint8(message.publisher_priority),
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index 0afeb2a..cb89551 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -35,24 +35,26 @@ std::vector<MoqtFramerTestParams> GetMoqtFramerTestParams() { std::vector<MoqtFramerTestParams> params; std::vector<MoqtMessageType> message_types = { - MoqtMessageType::kObjectStream, MoqtMessageType::kSubscribe, - MoqtMessageType::kSubscribeOk, MoqtMessageType::kSubscribeError, - MoqtMessageType::kUnsubscribe, MoqtMessageType::kSubscribeDone, - MoqtMessageType::kAnnounceCancel, MoqtMessageType::kTrackStatusRequest, - MoqtMessageType::kTrackStatus, MoqtMessageType::kAnnounce, - MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, - MoqtMessageType::kUnannounce, MoqtMessageType::kGoAway, - MoqtMessageType::kObjectAck, MoqtMessageType::kClientSetup, - MoqtMessageType::kServerSetup, MoqtMessageType::kStreamHeaderTrack, - MoqtMessageType::kStreamHeaderGroup, - }; - std::vector<bool> uses_web_transport_bool = { - false, - true, + MoqtMessageType::kSubscribe, + MoqtMessageType::kSubscribeOk, + MoqtMessageType::kSubscribeError, + MoqtMessageType::kUnsubscribe, + MoqtMessageType::kSubscribeDone, + MoqtMessageType::kAnnounceCancel, + MoqtMessageType::kTrackStatusRequest, + MoqtMessageType::kTrackStatus, + MoqtMessageType::kAnnounce, + MoqtMessageType::kAnnounceOk, + MoqtMessageType::kAnnounceError, + MoqtMessageType::kUnannounce, + MoqtMessageType::kGoAway, + MoqtMessageType::kObjectAck, + MoqtMessageType::kClientSetup, + MoqtMessageType::kServerSetup, }; for (const MoqtMessageType message_type : message_types) { if (message_type == MoqtMessageType::kClientSetup) { - for (const bool uses_web_transport : uses_web_transport_bool) { + for (const bool uses_web_transport : {false, true}) { params.push_back( MoqtFramerTestParams(message_type, uses_web_transport)); } @@ -103,12 +105,6 @@ quiche::QuicheBuffer SerializeMessage( TestMessageBase::MessageStructuredData& structured_data) { switch (message_type_) { - case MoqtMessageType::kObjectStream: - case MoqtMessageType::kStreamHeaderTrack: - case MoqtMessageType::kStreamHeaderGroup: { - MoqtObject data = std::get<MoqtObject>(structured_data); - return SerializeObject(framer_, data, "foo", true); - } case MoqtMessageType::kSubscribe: { auto data = std::get<MoqtSubscribe>(structured_data); return framer_.SerializeSubscribe(data);
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc index 7e3e616..a1a7729 100644 --- a/quiche/quic/moqt/moqt_messages.cc +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -6,6 +6,7 @@ #include <string> +#include "absl/strings/str_cat.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" namespace moqt { @@ -55,10 +56,6 @@ std::string MoqtMessageTypeToString(const MoqtMessageType message_type) { switch (message_type) { - case MoqtMessageType::kObjectStream: - return "OBJECT_STREAM"; - case MoqtMessageType::kObjectDatagram: - return "OBJECT_PREFER_DATAGRAM"; case MoqtMessageType::kClientSetup: return "CLIENT_SETUP"; case MoqtMessageType::kServerSetup: @@ -91,16 +88,26 @@ return "UNANNOUNCE"; case MoqtMessageType::kGoAway: return "GOAWAY"; - case MoqtMessageType::kStreamHeaderTrack: - return "STREAM_HEADER_TRACK"; - case MoqtMessageType::kStreamHeaderGroup: - return "STREAM_HEADER_GROUP"; case MoqtMessageType::kObjectAck: return "OBJECT_ACK"; } return "Unknown message " + std::to_string(static_cast<int>(message_type)); } +std::string MoqtDataStreamTypeToString(MoqtDataStreamType type) { + switch (type) { + case MoqtDataStreamType::kObjectStream: + return "OBJECT_STREAM"; + case MoqtDataStreamType::kObjectDatagram: + return "OBJECT_PREFER_DATAGRAM"; + case MoqtDataStreamType::kStreamHeaderTrack: + return "STREAM_HEADER_TRACK"; + case MoqtDataStreamType::kStreamHeaderGroup: + return "STREAM_HEADER_GROUP"; + } + return "Unknown stream type " + absl::StrCat(static_cast<int>(type)); +} + std::string MoqtForwardingPreferenceToString( MoqtForwardingPreference preference) { switch (preference) { @@ -118,15 +125,15 @@ return "Unknown preference " + std::to_string(static_cast<int>(preference)); } -MoqtForwardingPreference GetForwardingPreference(MoqtMessageType type) { +MoqtForwardingPreference GetForwardingPreference(MoqtDataStreamType type) { switch (type) { - case MoqtMessageType::kObjectStream: + case MoqtDataStreamType::kObjectStream: return MoqtForwardingPreference::kObject; - case MoqtMessageType::kObjectDatagram: + case MoqtDataStreamType::kObjectDatagram: return MoqtForwardingPreference::kDatagram; - case MoqtMessageType::kStreamHeaderTrack: + case MoqtDataStreamType::kStreamHeaderTrack: return MoqtForwardingPreference::kTrack; - case MoqtMessageType::kStreamHeaderGroup: + case MoqtDataStreamType::kStreamHeaderGroup: return MoqtForwardingPreference::kGroup; default: break; @@ -136,21 +143,21 @@ return MoqtForwardingPreference::kObject; }; -MoqtMessageType GetMessageTypeForForwardingPreference( +MoqtDataStreamType GetMessageTypeForForwardingPreference( MoqtForwardingPreference preference) { switch (preference) { case MoqtForwardingPreference::kObject: - return MoqtMessageType::kObjectStream; + return MoqtDataStreamType::kObjectStream; case MoqtForwardingPreference::kDatagram: - return MoqtMessageType::kObjectDatagram; + return MoqtDataStreamType::kObjectDatagram; case MoqtForwardingPreference::kTrack: - return MoqtMessageType::kStreamHeaderTrack; + return MoqtDataStreamType::kStreamHeaderTrack; case MoqtForwardingPreference::kGroup: - return MoqtMessageType::kStreamHeaderGroup; + return MoqtDataStreamType::kStreamHeaderGroup; } QUIC_BUG(quic_bug_bad_moqt_message_type_03) << "Forwarding preference does not indicate message type"; - return MoqtMessageType::kObjectStream; + return MoqtDataStreamType::kObjectStream; } } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index c25231b..87f47ac 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -59,9 +59,14 @@ // are not buffered by the parser). inline constexpr size_t kMaxMessageHeaderSize = 2048; -enum class QUICHE_EXPORT MoqtMessageType : uint64_t { +enum class QUICHE_EXPORT MoqtDataStreamType : uint64_t { kObjectStream = 0x00, kObjectDatagram = 0x01, + kStreamHeaderTrack = 0x50, + kStreamHeaderGroup = 0x51, +}; + +enum class QUICHE_EXPORT MoqtMessageType : uint64_t { kSubscribeUpdate = 0x02, kSubscribe = 0x03, kSubscribeOk = 0x04, @@ -78,8 +83,6 @@ kGoAway = 0x10, kClientSetup = 0x40, kServerSetup = 0x41, - kStreamHeaderTrack = 0x50, - kStreamHeaderGroup = 0x51, // QUICHE-specific extensions. @@ -429,13 +432,14 @@ }; std::string MoqtMessageTypeToString(MoqtMessageType message_type); +std::string MoqtDataStreamTypeToString(MoqtDataStreamType type); std::string MoqtForwardingPreferenceToString( MoqtForwardingPreference preference); -MoqtForwardingPreference GetForwardingPreference(MoqtMessageType type); +MoqtForwardingPreference GetForwardingPreference(MoqtDataStreamType type); -MoqtMessageType GetMessageTypeForForwardingPreference( +MoqtDataStreamType GetMessageTypeForForwardingPreference( MoqtForwardingPreference preference); } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 3e6b0e5..9e9b079 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/moqt/moqt_parser.h" +#include <array> #include <cstddef> #include <cstdint> #include <cstring> @@ -17,6 +18,7 @@ #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_priority.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" namespace moqt { @@ -47,6 +49,77 @@ return value >> 1; } +bool IsAllowedStreamType(uint64_t value) { + constexpr std::array kAllowedStreamTypes = { + MoqtDataStreamType::kObjectStream, MoqtDataStreamType::kStreamHeaderGroup, + MoqtDataStreamType::kStreamHeaderTrack}; + for (MoqtDataStreamType type : kAllowedStreamTypes) { + if (static_cast<uint64_t>(type) == value) { + return true; + } + } + return false; +} + +size_t ParseObjectHeader(quic::QuicDataReader& reader, MoqtObject& object, + MoqtDataStreamType type) { + if (!reader.ReadVarInt62(&object.subscribe_id) || + !reader.ReadVarInt62(&object.track_alias)) { + return 0; + } + if (type != MoqtDataStreamType::kStreamHeaderTrack && + !reader.ReadVarInt62(&object.group_id)) { + return 0; + } + if (type != MoqtDataStreamType::kStreamHeaderTrack && + type != MoqtDataStreamType::kStreamHeaderGroup && + !reader.ReadVarInt62(&object.object_id)) { + return 0; + } + if (!reader.ReadUInt8(&object.publisher_priority)) { + return 0; + } + uint64_t status = 0; + if ((type == MoqtDataStreamType::kObjectStream || + type == MoqtDataStreamType::kObjectDatagram) && + !reader.ReadVarInt62(&status)) { + return 0; + } + object.object_status = IntegerToObjectStatus(status); + object.forwarding_preference = GetForwardingPreference(type); + return reader.PreviouslyReadPayload().size(); +} + +size_t ParseObjectSubheader(quic::QuicDataReader& reader, MoqtObject& object, + MoqtDataStreamType type) { + switch (type) { + case MoqtDataStreamType::kStreamHeaderTrack: + if (!reader.ReadVarInt62(&object.group_id)) { + return 0; + } + [[fallthrough]]; + + case MoqtDataStreamType::kStreamHeaderGroup: { + uint64_t length; + if (!reader.ReadVarInt62(&object.object_id) || + !reader.ReadVarInt62(&length)) { + return 0; + } + object.payload_length = length; + uint64_t status = 0; + if (length == 0 && !reader.ReadVarInt62(&status)) { + return 0; + } + object.object_status = IntegerToObjectStatus(status); + return reader.PreviouslyReadPayload().size(); + } + + default: + QUICHE_NOTREACHED(); + return 0; + } +} + } // namespace // The buffering philosophy is complicated, to minimize copying. Here is an @@ -57,7 +130,7 @@ // Any OBJECT payload is always delivered to the application without copying. // If something has been buffered, when more data arrives copy just enough of it // to finish parsing that thing, then resume normal processing. -void MoqtParser::ProcessData(absl::string_view data, bool fin) { +void MoqtControlParser::ProcessData(absl::string_view data, bool fin) { if (no_more_data_) { ParseError("Data after end of stream"); } @@ -69,11 +142,6 @@ // Check for early fin if (fin) { no_more_data_ = true; - if (ObjectPayloadInProgress() && - payload_length_remaining_ > data.length()) { - ParseError("End of stream before complete OBJECT PAYLOAD"); - return; - } if (!buffered_message_.empty() && data.empty()) { ParseError("End of stream before complete message"); return; @@ -81,33 +149,7 @@ } std::optional<quic::QuicDataReader> reader = std::nullopt; size_t original_buffer_size = buffered_message_.size(); - // There are three cases: the parser has already delivered an OBJECT header - // and is now delivering payload; part of a message is in the buffer; or - // no message is in progress. - if (ObjectPayloadInProgress()) { - // This is additional payload for an OBJECT. - QUICHE_DCHECK(buffered_message_.empty()); - if (!object_metadata_->payload_length.has_value()) { - // Deliver the data and exit. - visitor_.OnObjectMessage(*object_metadata_, data, fin); - if (fin) { - object_metadata_.reset(); - } - return; - } - if (data.length() < payload_length_remaining_) { - // Does not finish the payload; deliver and exit. - visitor_.OnObjectMessage(*object_metadata_, data, false); - payload_length_remaining_ -= data.length(); - return; - } - // Finishes the payload. Deliver and continue. - reader.emplace(data); - visitor_.OnObjectMessage(*object_metadata_, - data.substr(0, payload_length_remaining_), true); - reader->Seek(payload_length_remaining_); - payload_length_remaining_ = 0; // Expect a new object. - } else if (!buffered_message_.empty()) { + if (!buffered_message_.empty()) { absl::StrAppend(&buffered_message_, data); reader.emplace(buffered_message_); } else { @@ -116,7 +158,7 @@ } size_t total_processed = 0; while (!reader->IsDoneReading()) { - size_t message_len = ProcessMessage(reader->PeekRemainingPayload(), fin); + size_t message_len = ProcessMessage(reader->PeekRemainingPayload()); if (message_len == 0) { if (reader->BytesRemaining() > kMaxMessageHeaderSize) { ParseError(MoqtError::kInternalError, @@ -142,47 +184,14 @@ } } -// static -absl::string_view MoqtParser::ProcessDatagram(absl::string_view data, - MoqtObject& object_metadata) { +size_t MoqtControlParser::ProcessMessage(absl::string_view data) { uint64_t value; quic::QuicDataReader reader(data); if (!reader.ReadVarInt62(&value)) { - return absl::string_view(); - } - if (static_cast<MoqtMessageType>(value) != MoqtMessageType::kObjectDatagram) { - return absl::string_view(); - } - size_t processed_data = ParseObjectHeader(reader, object_metadata, - MoqtMessageType::kObjectDatagram); - if (processed_data == 0) { // Incomplete header - return absl::string_view(); - } - return reader.PeekRemainingPayload(); -} - -size_t MoqtParser::ProcessMessage(absl::string_view data, bool fin) { - uint64_t value; - quic::QuicDataReader reader(data); - if (ObjectStreamInitialized() && !ObjectPayloadInProgress()) { - // This is a follow-on object in a stream. - return ProcessObject(reader, - GetMessageTypeForForwardingPreference( - object_metadata_->forwarding_preference), - fin); - } - if (!reader.ReadVarInt62(&value)) { return 0; } auto type = static_cast<MoqtMessageType>(value); switch (type) { - case MoqtMessageType::kObjectDatagram: - ParseError("Received OBJECT_DATAGRAM on stream"); - return 0; - case MoqtMessageType::kObjectStream: - case MoqtMessageType::kStreamHeaderTrack: - case MoqtMessageType::kStreamHeaderGroup: - return ProcessObject(reader, type, fin); case MoqtMessageType::kClientSetup: return ProcessClientSetup(reader); case MoqtMessageType::kServerSetup: @@ -217,98 +226,12 @@ return ProcessGoAway(reader); case moqt::MoqtMessageType::kObjectAck: return ProcessObjectAck(reader); - default: - ParseError("Unknown message type"); - return 0; } + ParseError("Unknown message type"); + return 0; } -size_t MoqtParser::ProcessObject(quic::QuicDataReader& reader, - MoqtMessageType type, bool fin) { - size_t processed_data = 0; - QUICHE_DCHECK(!ObjectPayloadInProgress()); - if (!ObjectStreamInitialized()) { - object_metadata_ = MoqtObject(); - processed_data = ParseObjectHeader(reader, object_metadata_.value(), type); - if (processed_data == 0) { - object_metadata_.reset(); - return 0; - } - } - // At this point, enough data has been processed to store in object_metadata_, - // even if there's nothing else in the buffer. - QUICHE_DCHECK(payload_length_remaining_ == 0); - switch (type) { - case MoqtMessageType::kStreamHeaderTrack: - if (!reader.ReadVarInt62(&object_metadata_->group_id)) { - return processed_data; - } - [[fallthrough]]; - case MoqtMessageType::kStreamHeaderGroup: { - uint64_t length; - if (!reader.ReadVarInt62(&object_metadata_->object_id) || - !reader.ReadVarInt62(&length)) { - return processed_data; - } - object_metadata_->payload_length = length; - uint64_t status = 0; // Defaults to kNormal. - if (length == 0 && !reader.ReadVarInt62(&status)) { - return processed_data; - } - object_metadata_->object_status = IntegerToObjectStatus(status); - break; - } - default: - break; - } - if (object_metadata_->object_status == - MoqtObjectStatus::kInvalidObjectStatus) { - ParseError("Invalid object status"); - return processed_data; - } - if (object_metadata_->object_status != MoqtObjectStatus::kNormal) { - // It is impossible to express an explicit length with this status. - if ((type == MoqtMessageType::kObjectStream || - type == MoqtMessageType::kObjectDatagram) && - reader.BytesRemaining() > 0) { - // There is additional data in the stream/datagram, which is an error. - ParseError("Object with non-normal status has payload"); - return processed_data; - } - visitor_.OnObjectMessage(*object_metadata_, "", true); - return reader.PreviouslyReadPayload().length(); - } - bool has_length = object_metadata_->payload_length.has_value(); - bool received_complete_message = false; - size_t payload_to_draw = reader.BytesRemaining(); - if (fin && has_length && - *object_metadata_->payload_length > reader.BytesRemaining()) { - ParseError("Received FIN mid-payload"); - return processed_data; - } - received_complete_message = - fin || (has_length && - *object_metadata_->payload_length <= reader.BytesRemaining()); - if (received_complete_message && has_length && - *object_metadata_->payload_length < reader.BytesRemaining()) { - payload_to_draw = *object_metadata_->payload_length; - } - // The error case where there's a fin before the explicit length is complete - // is handled in ProcessData() in two separate places. Even though the - // message is "done" if fin regardless of has_length, it's bad to report to - // the application that the object is done if it hasn't reached the promised - // length. - visitor_.OnObjectMessage( - *object_metadata_, - reader.PeekRemainingPayload().substr(0, payload_to_draw), - received_complete_message); - reader.Seek(payload_to_draw); - payload_length_remaining_ = - has_length ? *object_metadata_->payload_length - payload_to_draw : 0; - return reader.PreviouslyReadPayload().length(); -} - -size_t MoqtParser::ProcessClientSetup(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessClientSetup(quic::QuicDataReader& reader) { MoqtClientSetup setup; uint64_t number_of_supported_versions; if (!reader.ReadVarInt62(&number_of_supported_versions)) { @@ -386,7 +309,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessServerSetup(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessServerSetup(quic::QuicDataReader& reader) { MoqtServerSetup setup; uint64_t version; if (!reader.ReadVarInt62(&version)) { @@ -445,7 +368,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessSubscribe(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessSubscribe(quic::QuicDataReader& reader) { MoqtSubscribe subscribe_request; uint64_t filter, group, object; uint8_t group_order; @@ -545,7 +468,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessSubscribeOk(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessSubscribeOk(quic::QuicDataReader& reader) { MoqtSubscribeOk subscribe_ok; uint64_t milliseconds; uint8_t group_order; @@ -576,7 +499,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessSubscribeError(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessSubscribeError(quic::QuicDataReader& reader) { MoqtSubscribeError subscribe_error; uint64_t error_code; if (!reader.ReadVarInt62(&subscribe_error.subscribe_id) || @@ -590,7 +513,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessUnsubscribe(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessUnsubscribe(quic::QuicDataReader& reader) { MoqtUnsubscribe unsubscribe; if (!reader.ReadVarInt62(&unsubscribe.subscribe_id)) { return 0; @@ -599,7 +522,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessSubscribeDone(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessSubscribeDone(quic::QuicDataReader& reader) { MoqtSubscribeDone subscribe_done; uint8_t content_exists; uint64_t value; @@ -625,7 +548,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessSubscribeUpdate(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessSubscribeUpdate(quic::QuicDataReader& reader) { MoqtSubscribeUpdate subscribe_update; uint64_t end_group, end_object, num_params; if (!reader.ReadVarInt62(&subscribe_update.subscribe_id) || @@ -686,7 +609,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessAnnounce(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessAnnounce(quic::QuicDataReader& reader) { MoqtAnnounce announce; if (!reader.ReadStringVarInt62(announce.track_namespace)) { return 0; @@ -719,7 +642,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessAnnounceOk(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessAnnounceOk(quic::QuicDataReader& reader) { MoqtAnnounceOk announce_ok; if (!reader.ReadStringVarInt62(announce_ok.track_namespace)) { return 0; @@ -728,7 +651,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessAnnounceError(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessAnnounceError(quic::QuicDataReader& reader) { MoqtAnnounceError announce_error; if (!reader.ReadStringVarInt62(announce_error.track_namespace)) { return 0; @@ -745,7 +668,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessAnnounceCancel(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessAnnounceCancel(quic::QuicDataReader& reader) { MoqtAnnounceCancel announce_cancel; if (!reader.ReadStringVarInt62(announce_cancel.track_namespace)) { return 0; @@ -754,7 +677,8 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessTrackStatusRequest(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessTrackStatusRequest( + quic::QuicDataReader& reader) { MoqtTrackStatusRequest track_status_request; if (!reader.ReadStringVarInt62(track_status_request.track_namespace)) { return 0; @@ -766,7 +690,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessUnannounce(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessUnannounce(quic::QuicDataReader& reader) { MoqtUnannounce unannounce; if (!reader.ReadStringVarInt62(unannounce.track_namespace)) { return 0; @@ -775,7 +699,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessTrackStatus(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessTrackStatus(quic::QuicDataReader& reader) { MoqtTrackStatus track_status; uint64_t value; if (!reader.ReadStringVarInt62(track_status.track_namespace) || @@ -790,7 +714,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessGoAway(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessGoAway(quic::QuicDataReader& reader) { MoqtGoAway goaway; if (!reader.ReadStringVarInt62(goaway.new_session_uri)) { return 0; @@ -799,7 +723,7 @@ return reader.PreviouslyReadPayload().length(); } -size_t MoqtParser::ProcessObjectAck(quic::QuicDataReader& reader) { +size_t MoqtControlParser::ProcessObjectAck(quic::QuicDataReader& reader) { MoqtObjectAck object_ack; uint64_t raw_delta; if (!reader.ReadVarInt62(&object_ack.subscribe_id) || @@ -814,41 +738,12 @@ return reader.PreviouslyReadPayload().length(); } -// static -size_t MoqtParser::ParseObjectHeader(quic::QuicDataReader& reader, - MoqtObject& object, MoqtMessageType type) { - if (!reader.ReadVarInt62(&object.subscribe_id) || - !reader.ReadVarInt62(&object.track_alias)) { - return 0; - } - if (type != MoqtMessageType::kStreamHeaderTrack && - !reader.ReadVarInt62(&object.group_id)) { - return 0; - } - if (type != MoqtMessageType::kStreamHeaderTrack && - type != MoqtMessageType::kStreamHeaderGroup && - !reader.ReadVarInt62(&object.object_id)) { - return 0; - } - if (!reader.ReadUInt8(&object.publisher_priority)) { - return 0; - } - uint64_t status = 0; // Defaults to kNormal. - if ((type == MoqtMessageType::kObjectStream || - type == MoqtMessageType::kObjectDatagram) && - !reader.ReadVarInt62(&status)) { - return 0; - } - object.object_status = IntegerToObjectStatus(status); - object.forwarding_preference = GetForwardingPreference(type); - return reader.PreviouslyReadPayload().length(); -} - -void MoqtParser::ParseError(absl::string_view reason) { +void MoqtControlParser::ParseError(absl::string_view reason) { ParseError(MoqtError::kProtocolViolation, reason); } -void MoqtParser::ParseError(MoqtError error_code, absl::string_view reason) { +void MoqtControlParser::ParseError(MoqtError error_code, + absl::string_view reason) { if (parsing_error_) { return; // Don't send multiple parse errors. } @@ -857,8 +752,8 @@ visitor_.OnParsingError(error_code, reason); } -bool MoqtParser::ReadVarIntPieceVarInt62(quic::QuicDataReader& reader, - uint64_t& result) { +bool MoqtControlParser::ReadVarIntPieceVarInt62(quic::QuicDataReader& reader, + uint64_t& result) { uint64_t length; if (!reader.ReadVarInt62(&length)) { return false; @@ -874,15 +769,17 @@ return true; } -bool MoqtParser::ReadParameter(quic::QuicDataReader& reader, uint64_t& type, - absl::string_view& value) { +bool MoqtControlParser::ReadParameter(quic::QuicDataReader& reader, + uint64_t& type, + absl::string_view& value) { if (!reader.ReadVarInt62(&type)) { return false; } return reader.ReadStringPieceVarInt62(&value); } -bool MoqtParser::StringViewToVarInt(absl::string_view& sv, uint64_t& vi) { +bool MoqtControlParser::StringViewToVarInt(absl::string_view& sv, + uint64_t& vi) { quic::QuicDataReader reader(sv); if (static_cast<size_t>(reader.PeekVarInt62Length()) != sv.length()) { ParseError(MoqtError::kParameterLengthMismatch, @@ -893,4 +790,152 @@ return true; } +void MoqtDataParser::ParseError(absl::string_view reason) { + if (parsing_error_) { + return; // Don't send multiple parse errors. + } + no_more_data_ = true; + parsing_error_ = true; + visitor_.OnParsingError(MoqtError::kProtocolViolation, reason); +} + +absl::string_view ParseDatagram(absl::string_view data, + MoqtObject& object_metadata) { + uint64_t value; + quic::QuicDataReader reader(data); + if (!reader.ReadVarInt62(&value)) { + return absl::string_view(); + } + if (static_cast<MoqtDataStreamType>(value) != + MoqtDataStreamType::kObjectDatagram) { + return absl::string_view(); + } + size_t processed_data = ParseObjectHeader( + reader, object_metadata, MoqtDataStreamType::kObjectDatagram); + if (processed_data == 0) { // Incomplete header + return absl::string_view(); + } + return reader.PeekRemainingPayload(); +} + +void MoqtDataParser::ProcessData(absl::string_view data, bool fin) { + if (processing_) { + QUICHE_BUG(MoqtDataParser_reentry) + << "Calling ProcessData() when ProcessData() is already in progress."; + return; + } + processing_ = true; + auto on_return = absl::MakeCleanup([&] { processing_ = false; }); + + if (no_more_data_) { + ParseError("Data after end of stream"); + return; + } + + // Annoying path (going away soon): handle kObjectStream receiving a FIN. + if (data.empty() && fin && type_ == MoqtDataStreamType::kObjectStream) { + visitor_.OnObjectMessage(*metadata_, "", true); + } + + // Sad path: there is already data buffered. Attempt to transfer a small + // chunk from `data` into the buffer, in hope that it will make the contents + // of the buffer parsable without any leftover data. This is a reasonable + // expectation, since object headers are small, and are often followed by + // large blobs of data. + while (!buffered_message_.empty() && !data.empty()) { + absl::string_view chunk = data.substr(0, chunk_size_); + absl::StrAppend(&buffered_message_, chunk); + data.remove_prefix(chunk.size()); + + buffered_message_.assign( + ProcessDataInner(buffered_message_, fin && data.empty())); + } + + // Happy path: there is no buffered data. + if (buffered_message_.empty()) { + buffered_message_.assign(ProcessDataInner(data, fin)); + } + + if (fin) { + if (!buffered_message_.empty() || !metadata_.has_value() || + payload_length_remaining_ > 0) { + ParseError("FIN received at an unexpected point in the stream"); + return; + } + no_more_data_ = true; + } +} + +absl::string_view MoqtDataParser::ProcessDataInner(absl::string_view data, + bool fin) { + quic::QuicDataReader reader(data); + while (!reader.IsDoneReading()) { + absl::string_view remainder = reader.PeekRemainingPayload(); + switch (GetNextInput()) { + case kStreamType: { + uint64_t value; + if (!reader.ReadVarInt62(&value)) { + return remainder; + } + if (!IsAllowedStreamType(value)) { + ParseError(absl::StrCat("Unknown stream type: ", value)); + return ""; + } + type_ = static_cast<MoqtDataStreamType>(value); + continue; + } + + case kHeader: { + MoqtObject header; + size_t bytes_read = ParseObjectHeader(reader, header, *type_); + if (bytes_read == 0) { + return remainder; + } + if (type_ == MoqtDataStreamType::kObjectStream && + header.object_status == MoqtObjectStatus::kInvalidObjectStatus) { + ParseError("Invalid object status"); + return ""; + } + metadata_ = header; + continue; + } + + case kSubheader: { + size_t bytes_read = ParseObjectSubheader(reader, *metadata_, *type_); + if (bytes_read == 0) { + return remainder; + } + if (metadata_->object_status == + MoqtObjectStatus::kInvalidObjectStatus) { + ParseError("Invalid object status provided"); + return ""; + } + payload_length_remaining_ = *metadata_->payload_length; + continue; + } + + case kData: + if (payload_length_remaining_ == 0) { + // Special case: kObject, which does not have explicit length. + if (metadata_->object_status != MoqtObjectStatus::kNormal) { + ParseError("Object with non-normal status has payload"); + return ""; + } + visitor_.OnObjectMessage(*metadata_, reader.PeekRemainingPayload(), + fin); + return ""; + } + + absl::string_view payload = + reader.ReadAtMost(payload_length_remaining_); + visitor_.OnObjectMessage(*metadata_, payload, + payload.size() == payload_length_remaining_); + payload_length_remaining_ -= payload.size(); + + continue; + } + } + return ""; +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index b10c22f..2f25f02 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -// A parser for draft-ietf-moq-transport-01. +// A parser for draft-ietf-moq-transport. +// TODO(vasilvv): possibly split this header into two. #ifndef QUICHE_QUIC_MOQT_MOQT_PARSER_H_ #define QUICHE_QUIC_MOQT_MOQT_PARSER_H_ @@ -19,17 +20,10 @@ namespace moqt { -class QUICHE_EXPORT MoqtParserVisitor { +class QUICHE_EXPORT MoqtControlParserVisitor { public: - virtual ~MoqtParserVisitor() = default; + virtual ~MoqtControlParserVisitor() = default; - // If |end_of_message| is true, |payload| contains the last bytes of the - // OBJECT payload. If not, there will be subsequent calls with further payload - // data. The parser retains ownership of |message| and |payload|, so the - // visitor needs to copy anything it wants to retain. - virtual void OnObjectMessage(const MoqtObject& message, - absl::string_view payload, - bool end_of_message) = 0; // All of these are called only when the entire message has arrived. The // parser retains ownership of the memory. virtual void OnClientSetupMessage(const MoqtClientSetup& message) = 0; @@ -54,11 +48,26 @@ virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; }; -class QUICHE_EXPORT MoqtParser { +class MoqtDataParserVisitor { public: - MoqtParser(bool uses_web_transport, MoqtParserVisitor& visitor) + virtual ~MoqtDataParserVisitor() = default; + + // If |end_of_message| is true, |payload| contains the last bytes of the + // OBJECT payload. If not, there will be subsequent calls with further payload + // data. The parser retains ownership of |message| and |payload|, so the + // visitor needs to copy anything it wants to retain. + virtual void OnObjectMessage(const MoqtObject& message, + absl::string_view payload, + bool end_of_message) = 0; + + virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; +}; + +class QUICHE_EXPORT MoqtControlParser { + public: + MoqtControlParser(bool uses_web_transport, MoqtControlParserVisitor& visitor) : visitor_(visitor), uses_web_transport_(uses_web_transport) {} - ~MoqtParser() = default; + ~MoqtControlParser() = default; // Take a buffer from the transport in |data|. Parse each complete message and // call the appropriate visitor function. If |fin| is true, there @@ -71,24 +80,17 @@ // datagram rather than a stream. void ProcessData(absl::string_view data, bool fin); - // Provide a separate path for datagrams. Returns the payload bytes, or empty - // string_view on error. The caller provides the whole datagram in |data|. - // The function puts the object metadata in |object_metadata|. - static absl::string_view ProcessDatagram(absl::string_view data, - MoqtObject& object_metadata); - private: // The central switch statement to dispatch a message to the correct // Process* function. Returns 0 if it could not parse the full messsage // (except for object payload). Otherwise, returns the number of bytes // processed. - size_t ProcessMessage(absl::string_view data, bool fin); + size_t ProcessMessage(absl::string_view data); + // The Process* functions parse the serialized data into the appropriate // structs, and call the relevant visitor function for further action. Returns // the number of bytes consumed if the message is complete; returns 0 // otherwise. - size_t ProcessObject(quic::QuicDataReader& reader, MoqtMessageType type, - bool fin); size_t ProcessClientSetup(quic::QuicDataReader& reader); size_t ProcessServerSetup(quic::QuicDataReader& reader); size_t ProcessSubscribe(quic::QuicDataReader& reader); @@ -107,9 +109,6 @@ size_t ProcessGoAway(quic::QuicDataReader& reader); size_t ProcessObjectAck(quic::QuicDataReader& reader); - static size_t ParseObjectHeader(quic::QuicDataReader& reader, - MoqtObject& object, MoqtMessageType type); - // If |error| is not provided, assumes kProtocolViolation. void ParseError(absl::string_view reason); void ParseError(MoqtError error, absl::string_view reason); @@ -125,40 +124,94 @@ // string_view is not exactly the right length. bool StringViewToVarInt(absl::string_view& sv, uint64_t& vi); - // Simplify understanding of state. - // Returns true if the stream has delivered all object metadata common to all - // objects on that stream. - bool ObjectStreamInitialized() const { return object_metadata_.has_value(); } - // Returns true if the stream has delivered all metadata but not all payload - // for the most recent object. - bool ObjectPayloadInProgress() const { - return (object_metadata_.has_value() && - object_metadata_->object_status == MoqtObjectStatus::kNormal && - (object_metadata_->forwarding_preference == - MoqtForwardingPreference::kObject || - object_metadata_->forwarding_preference == - MoqtForwardingPreference::kDatagram || - payload_length_remaining_ > 0)); - } - - MoqtParserVisitor& visitor_; + MoqtControlParserVisitor& visitor_; bool uses_web_transport_; bool no_more_data_ = false; // Fatal error or fin. No more parsing. bool parsing_error_ = false; std::string buffered_message_; - // Metadata for an object which is delivered in parts. - // If object_metadata_ is nullopt, nothing has been processed on the stream. - // If object_metadata_ exists but payload_length is nullopt or - // payload_length_remaining_ is nonzero, the object payload is in mid- - // delivery. - // If object_metadata_ exists and payload_length_remaining_ is zero, an object - // has been completely delivered and the next object header on the stream has - // not been delivered. - // Use ObjectStreamInitialized() and ObjectPayloadInProgress() to keep the - // state straight. - std::optional<MoqtObject> object_metadata_ = std::nullopt; + bool processing_ = false; // True if currently in ProcessData(), to prevent + // re-entrancy. +}; + +// Parses an MoQT datagram. Returns the payload bytes, or empty string_view on +// error. The caller provides the whole datagram in `data`. The function puts +// the object metadata in `object_metadata`. +absl::string_view ParseDatagram(absl::string_view data, + MoqtObject& object_metadata); + +// Parser for MoQT unidirectional data stream. +class QUICHE_EXPORT MoqtDataParser { + public: + explicit MoqtDataParser(MoqtDataParserVisitor* visitor) + : visitor_(*visitor) {} + ~MoqtDataParser() = default; + + // Take a buffer from the transport in |data|. Parse each complete message and + // call the appropriate visitor function. If |fin| is true, there + // is no more data arriving on the stream, so the parser will deliver any + // message encoded as to run to the end of the stream. + // All bytes can be freed. Calls OnParsingError() when there is a parsing + // error. + void ProcessData(absl::string_view data, bool fin); + + // Alters `chunk_size_` value (see discussion below). Primarily intended to + // be used for testing. + void set_chunk_size(size_t size) { chunk_size_ = size; } + + private: + // If there is buffered data from the previous attempt at parsing it, new data + // will be added in `chunk_size_`-sized chunks. + constexpr static size_t kDefaultChunkSize = 64; + + // Current state of the parser. + enum NextInput { + // Nothing has been read yet; the next thing to be read is the stream type + // varint. + kStreamType, + // The next thing to be read is the stream header. + kHeader, + // The next thing to be read is the stream subheader for the given object. + kSubheader, + // The next thing to be read is the object payload. + kData, + }; + + // Infers the current state of the parser. + NextInput GetNextInput() const { + if (!type_.has_value()) { + return kStreamType; + } + if (!metadata_.has_value()) { + return kHeader; + } + if (payload_length_remaining_ > 0 || + *type_ == MoqtDataStreamType::kObjectStream) { + return kData; + } + return kSubheader; + } + + // Processes all that can be entirely processed, and returns the view for the + // data that needs to be buffered. + // TODO: remove the `fin` argument once kObjectStream is gone. + absl::string_view ProcessDataInner(absl::string_view data, bool fin); + + void ParseError(absl::string_view reason); + + MoqtDataParserVisitor& visitor_; + size_t chunk_size_ = kDefaultChunkSize; + + bool no_more_data_ = false; // Fatal error or fin. No more parsing. + bool parsing_error_ = false; + + std::string buffered_message_; + + // The three variables below implicitly drive the state machine; see + // `GetNextInput()` for how the state is derived. + std::optional<MoqtDataStreamType> type_ = std::nullopt; + std::optional<MoqtObject> metadata_ = std::nullopt; size_t payload_length_remaining_ = 0; bool processing_ = false; // True if currently in ProcessData(), to prevent
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 11f6f29..a074f16 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -4,16 +4,18 @@ #include "quiche/quic/moqt/moqt_parser.h" +#include <array> #include <cstddef> #include <cstdint> #include <cstring> #include <memory> #include <optional> #include <string> -#include <utility> #include <vector> +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/variant.h" #include "quiche/quic/core/quic_data_writer.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_messages.h" @@ -24,64 +26,44 @@ namespace { +using ::testing::AnyOf; using ::testing::HasSubstr; using ::testing::Optional; -inline bool IsObjectMessage(MoqtMessageType type) { - return (type == MoqtMessageType::kObjectStream || - type == MoqtMessageType::kObjectDatagram || - type == MoqtMessageType::kStreamHeaderTrack || - type == MoqtMessageType::kStreamHeaderGroup); -} - -inline bool IsObjectWithoutPayloadLength(MoqtMessageType type) { - return (type == MoqtMessageType::kObjectStream || - type == MoqtMessageType::kObjectDatagram); -} - -std::vector<MoqtMessageType> message_types = { - MoqtMessageType::kObjectStream, - // kObjectDatagram is a unique set of tests. - MoqtMessageType::kSubscribe, - MoqtMessageType::kSubscribeOk, - MoqtMessageType::kSubscribeError, - MoqtMessageType::kSubscribeUpdate, - MoqtMessageType::kUnsubscribe, - MoqtMessageType::kSubscribeDone, - MoqtMessageType::kAnnounceCancel, - MoqtMessageType::kTrackStatusRequest, - MoqtMessageType::kTrackStatus, - MoqtMessageType::kAnnounce, - MoqtMessageType::kAnnounceOk, - MoqtMessageType::kAnnounceError, - MoqtMessageType::kUnannounce, - MoqtMessageType::kClientSetup, - MoqtMessageType::kServerSetup, - MoqtMessageType::kStreamHeaderTrack, - MoqtMessageType::kStreamHeaderGroup, - MoqtMessageType::kGoAway, +constexpr std::array kMessageTypes{ + MoqtMessageType::kSubscribe, MoqtMessageType::kSubscribeOk, + MoqtMessageType::kSubscribeError, MoqtMessageType::kSubscribeUpdate, + MoqtMessageType::kUnsubscribe, MoqtMessageType::kSubscribeDone, + MoqtMessageType::kAnnounceCancel, MoqtMessageType::kTrackStatusRequest, + MoqtMessageType::kTrackStatus, MoqtMessageType::kAnnounce, + MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, + MoqtMessageType::kUnannounce, MoqtMessageType::kClientSetup, + MoqtMessageType::kServerSetup, MoqtMessageType::kGoAway, MoqtMessageType::kObjectAck, }; +constexpr std::array kDataStreamTypes{MoqtDataStreamType::kObjectStream, + MoqtDataStreamType::kStreamHeaderTrack, + MoqtDataStreamType::kStreamHeaderGroup}; +using GeneralizedMessageType = + absl::variant<MoqtMessageType, MoqtDataStreamType>; } // namespace struct MoqtParserTestParams { MoqtParserTestParams(MoqtMessageType message_type, bool uses_web_transport) : message_type(message_type), uses_web_transport(uses_web_transport) {} - MoqtMessageType message_type; + explicit MoqtParserTestParams(MoqtDataStreamType message_type) + : message_type(message_type), uses_web_transport(true) {} + GeneralizedMessageType message_type; bool uses_web_transport; }; std::vector<MoqtParserTestParams> GetMoqtParserTestParams() { std::vector<MoqtParserTestParams> params; - std::vector<bool> uses_web_transport_bool = { - false, - true, - }; - for (const MoqtMessageType message_type : message_types) { + for (MoqtMessageType message_type : kMessageTypes) { if (message_type == MoqtMessageType::kClientSetup) { - for (const bool uses_web_transport : uses_web_transport_bool) { + for (const bool uses_web_transport : {false, true}) { params.push_back( MoqtParserTestParams(message_type, uses_web_transport)); } @@ -91,28 +73,40 @@ params.push_back(MoqtParserTestParams(message_type, true)); } } + for (MoqtDataStreamType type : kDataStreamTypes) { + params.push_back(MoqtParserTestParams(type)); + } return params; } +std::string TypeFormatter(MoqtMessageType type) { + return MoqtMessageTypeToString(type); +} +std::string TypeFormatter(MoqtDataStreamType type) { + return MoqtDataStreamTypeToString(type); +} std::string ParamNameFormatter( const testing::TestParamInfo<MoqtParserTestParams>& info) { - return MoqtMessageTypeToString(info.param.message_type) + "_" + - (info.param.uses_web_transport ? "WebTransport" : "QUIC"); + return absl::visit([](auto x) { return TypeFormatter(x); }, + info.param.message_type) + + "_" + (info.param.uses_web_transport ? "WebTransport" : "QUIC"); } -class MoqtParserTestVisitor : public MoqtParserVisitor { +class MoqtParserTestVisitor : public MoqtControlParserVisitor, + public MoqtDataParserVisitor { public: ~MoqtParserTestVisitor() = default; void OnObjectMessage(const MoqtObject& message, absl::string_view payload, bool end_of_message) override { MoqtObject object = message; - object_payload_ = payload; + object_payloads_.push_back(std::string(payload)); end_of_message_ = end_of_message; - messages_received_++; + if (end_of_message) { + ++messages_received_; + } last_message_ = TestMessageBase::MessageStructuredData(object); } - template <typename Message> void OnControlMessage(const Message& message) { end_of_message_ = true; @@ -177,7 +171,9 @@ parsing_error_code_ = code; } - std::optional<absl::string_view> object_payload_; + std::string object_payload() { return absl::StrJoin(object_payloads_, ""); } + + std::vector<std::string> object_payloads_; bool end_of_message_ = false; std::optional<absl::string_view> parsing_error_; MoqtError parsing_error_code_; @@ -191,16 +187,36 @@ MoqtParserTest() : message_type_(GetParam().message_type), webtrans_(GetParam().uses_web_transport), - parser_(GetParam().uses_web_transport, visitor_) {} + control_parser_(GetParam().uses_web_transport, visitor_), + data_parser_(&visitor_) {} - std::unique_ptr<TestMessageBase> MakeMessage(MoqtMessageType message_type) { - return CreateTestMessage(message_type, webtrans_); + bool IsDataStream() { + return absl::holds_alternative<MoqtDataStreamType>(message_type_); } + std::unique_ptr<TestMessageBase> MakeMessage() { + if (IsDataStream()) { + return CreateTestDataStream(absl::get<MoqtDataStreamType>(message_type_)); + } else { + return CreateTestMessage(absl::get<MoqtMessageType>(message_type_), + webtrans_); + } + } + + void ProcessData(absl::string_view data, bool fin) { + if (IsDataStream()) { + data_parser_.ProcessData(data, fin); + } else { + control_parser_.ProcessData(data, fin); + } + } + + protected: MoqtParserTestVisitor visitor_; - MoqtMessageType message_type_; + GeneralizedMessageType message_type_; bool webtrans_; - MoqtParser parser_; + MoqtControlParser control_parser_; + MoqtDataParser data_parser_; }; INSTANTIATE_TEST_SUITE_P(MoqtParserTests, MoqtParserTest, @@ -208,129 +224,130 @@ ParamNameFormatter); TEST_P(MoqtParserTest, OneMessage) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - parser_.ProcessData(message->PacketSample(), true); + std::unique_ptr<TestMessageBase> message = MakeMessage(); + ProcessData(message->PacketSample(), true); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - if (IsObjectMessage(message_type_)) { - // Check payload message. - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); + if (IsDataStream()) { + EXPECT_EQ(visitor_.object_payload(), "foo"); } } TEST_P(MoqtParserTest, OneMessageWithLongVarints) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + std::unique_ptr<TestMessageBase> message = MakeMessage(); message->ExpandVarints(); - parser_.ProcessData(message->PacketSample(), true); + ProcessData(message->PacketSample(), true); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - if (IsObjectMessage(message_type_)) { - // Check payload message. - EXPECT_EQ(visitor_.object_payload_, "foo"); - } EXPECT_FALSE(visitor_.parsing_error_.has_value()); + if (IsDataStream()) { + EXPECT_EQ(visitor_.object_payload(), "foo"); + } } TEST_P(MoqtParserTest, TwoPartMessage) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + std::unique_ptr<TestMessageBase> message = MakeMessage(); // The test Object message has payload for less then half the message length, // so splitting the message in half will prevent the first half from being // processed. size_t first_data_size = message->total_message_size() / 2; - if (message_type_ == MoqtMessageType::kStreamHeaderTrack) { - // The boundary happens to fall right after the stream header, so move it. - ++first_data_size; - } - parser_.ProcessData(message->PacketSample().substr(0, first_data_size), - false); + ProcessData(message->PacketSample().substr(0, first_data_size), false); EXPECT_EQ(visitor_.messages_received_, 0); - parser_.ProcessData( + ProcessData( message->PacketSample().substr( first_data_size, message->total_message_size() - first_data_size), true); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - if (IsObjectMessage(message_type_)) { - EXPECT_EQ(visitor_.object_payload_, "foo"); - } EXPECT_TRUE(visitor_.end_of_message_); EXPECT_FALSE(visitor_.parsing_error_.has_value()); + if (IsDataStream()) { + EXPECT_EQ(visitor_.object_payload(), "foo"); + } } TEST_P(MoqtParserTest, OneByteAtATime) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - size_t kObjectPayloadSize = 3; + std::unique_ptr<TestMessageBase> message = MakeMessage(); for (size_t i = 0; i < message->total_message_size(); ++i) { - if (!IsObjectMessage(message_type_)) { - EXPECT_EQ(visitor_.messages_received_, 0); - } + EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_FALSE(visitor_.end_of_message_); - parser_.ProcessData(message->PacketSample().substr(i, 1), false); + bool last = i == (message->total_message_size() - 1); + ProcessData(message->PacketSample().substr(i, 1), last); } - EXPECT_EQ(visitor_.messages_received_, - (IsObjectMessage(message_type_) ? (kObjectPayloadSize + 1) : 1)); - if (IsObjectWithoutPayloadLength(message_type_)) { - EXPECT_FALSE(visitor_.end_of_message_); - parser_.ProcessData(absl::string_view(), true); // Needs the FIN - EXPECT_EQ(visitor_.messages_received_, kObjectPayloadSize + 2); - } + EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); EXPECT_FALSE(visitor_.parsing_error_.has_value()); + if (IsDataStream()) { + EXPECT_EQ(visitor_.object_payload(), "foo"); + } } TEST_P(MoqtParserTest, OneByteAtATimeLongerVarints) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); + std::unique_ptr<TestMessageBase> message = MakeMessage(); message->ExpandVarints(); - size_t kObjectPayloadSize = 3; for (size_t i = 0; i < message->total_message_size(); ++i) { - if (!IsObjectMessage(message_type_)) { - EXPECT_EQ(visitor_.messages_received_, 0); - } + EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_FALSE(visitor_.end_of_message_); - parser_.ProcessData(message->PacketSample().substr(i, 1), false); + bool last = i == (message->total_message_size() - 1); + ProcessData(message->PacketSample().substr(i, 1), last); } - EXPECT_EQ(visitor_.messages_received_, - (IsObjectMessage(message_type_) ? (kObjectPayloadSize + 1) : 1)); - if (IsObjectWithoutPayloadLength(message_type_)) { - EXPECT_FALSE(visitor_.end_of_message_); - parser_.ProcessData(absl::string_view(), true); // Needs the FIN - EXPECT_EQ(visitor_.messages_received_, kObjectPayloadSize + 2); - } + EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); EXPECT_FALSE(visitor_.parsing_error_.has_value()); + if (IsDataStream()) { + EXPECT_EQ(visitor_.object_payload(), "foo"); + } +} + +TEST_P(MoqtParserTest, TwoBytesAtATime) { + std::unique_ptr<TestMessageBase> message = MakeMessage(); + data_parser_.set_chunk_size(1); + for (size_t i = 0; i < message->total_message_size(); i += 3) { + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_FALSE(visitor_.end_of_message_); + bool last = (i + 2) >= message->total_message_size(); + ProcessData(message->PacketSample().substr(i, 3), last); + } + EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); + EXPECT_TRUE(visitor_.end_of_message_); + EXPECT_FALSE(visitor_.parsing_error_.has_value()); + if (IsDataStream()) { + EXPECT_EQ(visitor_.object_payload(), "foo"); + } } TEST_P(MoqtParserTest, EarlyFin) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - size_t first_data_size = message->total_message_size() / 2; - if (message_type_ == MoqtMessageType::kStreamHeaderTrack) { - // The boundary happens to fall right after the stream header, so move it. - ++first_data_size; + if (message_type_ == + GeneralizedMessageType(MoqtDataStreamType::kObjectStream)) { + return; } - parser_.ProcessData(message->PacketSample().substr(0, first_data_size), true); + std::unique_ptr<TestMessageBase> message = MakeMessage(); + size_t first_data_size = message->total_message_size() - 1; + ProcessData(message->PacketSample().substr(0, first_data_size), true); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "FIN after incomplete message"); + EXPECT_THAT(visitor_.parsing_error_, + AnyOf("FIN after incomplete message", + "FIN received at an unexpected point in the stream")); } TEST_P(MoqtParserTest, SeparateEarlyFin) { - std::unique_ptr<TestMessageBase> message = MakeMessage(message_type_); - size_t first_data_size = message->total_message_size() / 2; - if (message_type_ == MoqtMessageType::kStreamHeaderTrack) { - // The boundary happens to fall right after the stream header, so move it. - ++first_data_size; + if (message_type_ == + GeneralizedMessageType(MoqtDataStreamType::kObjectStream)) { + return; } - parser_.ProcessData(message->PacketSample().substr(0, first_data_size), - false); - parser_.ProcessData(absl::string_view(), true); + std::unique_ptr<TestMessageBase> message = MakeMessage(); + size_t first_data_size = message->total_message_size() - 1; + ProcessData(message->PacketSample().substr(0, first_data_size), false); + ProcessData(absl::string_view(), true); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "End of stream before complete message"); + EXPECT_THAT(visitor_.parsing_error_, + AnyOf("End of stream before complete message", + "FIN received at an unexpected point in the stream")); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } @@ -349,20 +366,17 @@ TEST_F(MoqtMessageSpecificTest, ObjectStreamSeparateFin) { // OBJECT can return on an unknown-length message even without receiving a // FIN. - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); auto message = std::make_unique<ObjectStreamMessage>(); parser.ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_EQ(visitor_.object_payload(), "foo"); EXPECT_FALSE(visitor_.end_of_message_); parser.ProcessData(absl::string_view(), true); // send the FIN - EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), ""); EXPECT_TRUE(visitor_.end_of_message_); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } @@ -370,36 +384,33 @@ // Send the header + some payload, pure payload, then pure payload to end the // message. TEST_F(MoqtMessageSpecificTest, ThreePartObject) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); auto message = std::make_unique<ObjectStreamMessage>(); parser.ProcessData(message->PacketSample(), false); - EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_EQ(visitor_.object_payload(), "foo"); // second part parser.ProcessData("bar", false); - EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "bar"); + EXPECT_EQ(visitor_.object_payload(), "foobar"); // third part includes FIN parser.ProcessData("deadbeef", true); - EXPECT_EQ(visitor_.messages_received_, 3); + EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "deadbeef"); + EXPECT_EQ(visitor_.object_payload(), "foobardeadbeef"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } // Send the part of header, rest of header + payload, plus payload. TEST_F(MoqtMessageSpecificTest, ThreePartObjectFirstIncomplete) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); auto message = std::make_unique<ObjectStreamMessage>(); // first part @@ -411,70 +422,66 @@ parser.ProcessData( message->PacketSample().substr(4, message->total_message_size() - 4), false); - EXPECT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.messages_received_, 0); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_FALSE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); // The value "93" is the overall wire image size of 100 minus the non-payload // part of the message. - EXPECT_EQ(visitor_.object_payload_->length(), 93); + EXPECT_EQ(visitor_.object_payload().length(), 93); // third part includes FIN parser.ProcessData("bar", true); - EXPECT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "bar"); + EXPECT_EQ(*visitor_.object_payloads_.crbegin(), "bar"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } TEST_F(MoqtMessageSpecificTest, StreamHeaderGroupFollowOn) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); // first part auto message1 = std::make_unique<StreamHeaderGroupMessage>(); parser.ProcessData(message1->PacketSample(), false); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message1->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_EQ(visitor_.object_payload(), "foo"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); // second part + visitor_.object_payloads_.clear(); auto message2 = std::make_unique<StreamMiddlerGroupMessage>(); parser.ProcessData(message2->PacketSample(), false); EXPECT_EQ(visitor_.messages_received_, 2); EXPECT_TRUE(message2->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "bar"); + EXPECT_EQ(visitor_.object_payload(), "bar"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } TEST_F(MoqtMessageSpecificTest, StreamHeaderTrackFollowOn) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); // first part auto message1 = std::make_unique<StreamHeaderTrackMessage>(); parser.ProcessData(message1->PacketSample(), false); EXPECT_EQ(visitor_.messages_received_, 1); EXPECT_TRUE(message1->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "foo"); + EXPECT_EQ(visitor_.object_payload(), "foo"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); // second part + visitor_.object_payloads_.clear(); auto message2 = std::make_unique<StreamMiddlerTrackMessage>(); parser.ProcessData(message2->PacketSample(), false); EXPECT_EQ(visitor_.messages_received_, 2); EXPECT_TRUE(message2->EqualFieldValues(*visitor_.last_message_)); EXPECT_TRUE(visitor_.end_of_message_); - EXPECT_TRUE(visitor_.object_payload_.has_value()); - EXPECT_EQ(*(visitor_.object_payload_), "bar"); + EXPECT_EQ(visitor_.object_payload(), "bar"); EXPECT_FALSE(visitor_.parsing_error_.has_value()); } TEST_F(MoqtMessageSpecificTest, ClientSetupRoleIsInvalid) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x40, 0x02, 0x01, 0x02, // versions 0x03, // 3 params @@ -489,7 +496,7 @@ } TEST_F(MoqtMessageSpecificTest, ServerSetupRoleIsInvalid) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x41, 0x01, 0x01, // 1 param @@ -504,7 +511,7 @@ } TEST_F(MoqtMessageSpecificTest, SetupRoleAppearsTwice) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x40, 0x02, 0x01, 0x02, // versions 0x03, // 3 params @@ -520,7 +527,7 @@ } TEST_F(MoqtMessageSpecificTest, ClientSetupRoleIsMissing) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 0x01, // 1 param @@ -535,7 +542,7 @@ } TEST_F(MoqtMessageSpecificTest, ServerSetupRoleIsMissing) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x41, 0x01, 0x00, // 1 param }; @@ -548,7 +555,7 @@ } TEST_F(MoqtMessageSpecificTest, SetupRoleVarintLengthIsWrong) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x40, // type 0x02, 0x01, 0x02, // versions @@ -566,7 +573,7 @@ } TEST_F(MoqtMessageSpecificTest, SetupPathFromServer) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x41, 0x01, // version = 1 @@ -581,7 +588,7 @@ } TEST_F(MoqtMessageSpecificTest, SetupPathAppearsTwice) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 0x03, // 3 params @@ -598,7 +605,7 @@ } TEST_F(MoqtMessageSpecificTest, SetupPathOverWebtrans) { - MoqtParser parser(kWebTrans, visitor_); + MoqtControlParser parser(kWebTrans, visitor_); char setup[] = { 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 0x02, // 2 params @@ -614,7 +621,7 @@ } TEST_F(MoqtMessageSpecificTest, SetupPathMissing) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 0x01, // 1 param @@ -629,7 +636,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationInfoTwice) { - MoqtParser parser(kWebTrans, visitor_); + MoqtControlParser parser(kWebTrans, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x04, 0x61, 0x62, 0x63, 0x64, // track_name = "abcd" @@ -647,7 +654,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeUpdateAuthorizationInfoTwice) { - MoqtParser parser(kWebTrans, visitor_); + MoqtControlParser parser(kWebTrans, visitor_); char subscribe_update[] = { 0x02, 0x02, 0x03, 0x01, 0x05, 0x06, // start and end sequences 0xaa, // priority = 0xaa @@ -664,7 +671,7 @@ } TEST_F(MoqtMessageSpecificTest, AnnounceAuthorizationInfoTwice) { - MoqtParser parser(kWebTrans, visitor_); + MoqtControlParser parser(kWebTrans, visitor_); char announce[] = { 0x06, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" 0x02, // 2 params @@ -680,69 +687,65 @@ } TEST_F(MoqtMessageSpecificTest, FinMidPayload) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); auto message = std::make_unique<StreamHeaderGroupMessage>(); parser.ProcessData( message->PacketSample().substr(0, message->total_message_size() - 1), true); EXPECT_EQ(visitor_.messages_received_, 0); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "Received FIN mid-payload"); + EXPECT_EQ(visitor_.parsing_error_, + "FIN received at an unexpected point in the stream"); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, PartialPayloadThenFin) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); auto message = std::make_unique<StreamHeaderTrackMessage>(); parser.ProcessData( message->PacketSample().substr(0, message->total_message_size() - 1), false); parser.ProcessData(absl::string_view(), true); - EXPECT_EQ(visitor_.messages_received_, 1); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, - "End of stream before complete OBJECT PAYLOAD"); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_EQ(visitor_.parsing_error_, + "FIN received at an unexpected point in the stream"); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, DataAfterFin) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); parser.ProcessData(absl::string_view(), true); // Find FIN parser.ProcessData("foo", false); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "Data after end of stream"); + EXPECT_EQ(visitor_.parsing_error_, "Data after end of stream"); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, NonNormalObjectHasPayload) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); char object_stream[] = { 0x00, 0x03, 0x04, 0x05, 0x06, 0x07, 0x02, // varints 0x66, 0x6f, 0x6f, // payload = "foo" }; parser.ProcessData(absl::string_view(object_stream, sizeof(object_stream)), false); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, + EXPECT_EQ(visitor_.parsing_error_, "Object with non-normal status has payload"); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, InvalidObjectStatus) { - MoqtParser parser(kRawQuic, visitor_); + MoqtDataParser parser(&visitor_); char object_stream[] = { 0x00, 0x03, 0x04, 0x05, 0x06, 0x07, 0x06, // varints 0x66, 0x6f, 0x6f, // payload = "foo" }; parser.ProcessData(absl::string_view(object_stream, sizeof(object_stream)), false); - EXPECT_TRUE(visitor_.parsing_error_.has_value()); - EXPECT_EQ(*visitor_.parsing_error_, "Invalid object status"); + EXPECT_EQ(visitor_.parsing_error_, "Invalid object status"); EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } TEST_F(MoqtMessageSpecificTest, Setup2KB) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char big_message[2 * kMaxMessageHeaderSize]; quic::QuicDataWriter writer(sizeof(big_message), big_message); writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kServerSetup)); @@ -761,7 +764,7 @@ } TEST_F(MoqtMessageSpecificTest, UnknownMessageType) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char message[4]; quic::QuicDataWriter writer(sizeof(message), message); writer.WriteVarInt62(0xbeef); // unknown message type @@ -772,7 +775,7 @@ } TEST_F(MoqtMessageSpecificTest, LatestGroup) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -794,7 +797,7 @@ } TEST_F(MoqtMessageSpecificTest, LatestObject) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -816,7 +819,7 @@ } TEST_F(MoqtMessageSpecificTest, InvalidDeliveryOrder) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -832,7 +835,7 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteStart) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -856,7 +859,7 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeExplicitEndObject) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -882,7 +885,7 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeWholeEndGroup) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -908,7 +911,7 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeEndGroupTooLow) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -929,7 +932,7 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeExactlyOneObject) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -947,7 +950,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeUpdateExactlyOneObject) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe_update[] = { 0x02, 0x02, 0x03, 0x01, 0x04, 0x07, // start and end sequences 0x20, // priority @@ -959,7 +962,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeUpdateEndGroupTooLow) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe_update[] = { 0x02, 0x02, 0x03, 0x01, 0x03, 0x06, // start and end sequences 0x20, // priority @@ -974,7 +977,7 @@ } TEST_F(MoqtMessageSpecificTest, AbsoluteRangeEndObjectTooLow) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe[] = { 0x03, 0x01, 0x02, // id and alias 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" @@ -995,7 +998,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeUpdateEndObjectTooLow) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe_update[] = { 0x02, 0x02, 0x03, 0x02, 0x04, 0x01, // start and end sequences 0x01, // 1 parameter @@ -1009,7 +1012,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeUpdateNoEndGroup) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char subscribe_update[] = { 0x02, 0x02, 0x03, 0x02, 0x00, 0x01, // start and end sequences 0x20, // priority @@ -1025,7 +1028,7 @@ } TEST_F(MoqtMessageSpecificTest, ObjectAckNegativeDelta) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); char object_ack[] = { 0x71, 0x84, // type 0x01, 0x10, 0x20, // subscribe ID, group, object @@ -1045,17 +1048,14 @@ TEST_F(MoqtMessageSpecificTest, AllMessagesTogether) { char buffer[5000]; - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); size_t write = 0; size_t read = 0; int fully_received = 0; std::unique_ptr<TestMessageBase> prev_message = nullptr; - for (MoqtMessageType type : message_types) { + for (MoqtMessageType type : kMessageTypes) { // Each iteration, process from the halfway point of one message to the // halfway point of the next. - if (IsObjectMessage(type)) { - continue; // Objects cannot share a stream with other messages. - } std::unique_ptr<TestMessageBase> message = CreateTestMessage(type, kRawQuic); memcpy(buffer + write, message->PacketSample().data(), @@ -1082,8 +1082,7 @@ TEST_F(MoqtMessageSpecificTest, DatagramSuccessful) { ObjectDatagramMessage message; MoqtObject object; - absl::string_view payload = - MoqtParser::ProcessDatagram(message.PacketSample(), object); + absl::string_view payload = ParseDatagram(message.PacketSample(), object); TestMessageBase::MessageStructuredData object_metadata = TestMessageBase::MessageStructuredData(object); EXPECT_TRUE(message.EqualFieldValues(object_metadata)); @@ -1091,35 +1090,30 @@ } TEST_F(MoqtMessageSpecificTest, WrongMessageInDatagram) { - MoqtParser parser(kRawQuic, visitor_); ObjectStreamMessage message; MoqtObject object; - absl::string_view payload = - MoqtParser::ProcessDatagram(message.PacketSample(), object); + absl::string_view payload = ParseDatagram(message.PacketSample(), object); EXPECT_TRUE(payload.empty()); } TEST_F(MoqtMessageSpecificTest, TruncatedDatagram) { - MoqtParser parser(kRawQuic, visitor_); ObjectDatagramMessage message; message.set_wire_image_size(4); MoqtObject object; - absl::string_view payload = - MoqtParser::ProcessDatagram(message.PacketSample(), object); + absl::string_view payload = ParseDatagram(message.PacketSample(), object); EXPECT_TRUE(payload.empty()); } TEST_F(MoqtMessageSpecificTest, VeryTruncatedDatagram) { - MoqtParser parser(kRawQuic, visitor_); char message = 0x40; MoqtObject object; - absl::string_view payload = MoqtParser::ProcessDatagram( - absl::string_view(&message, sizeof(message)), object); + absl::string_view payload = + ParseDatagram(absl::string_view(&message, sizeof(message)), object); EXPECT_TRUE(payload.empty()); } TEST_F(MoqtMessageSpecificTest, SubscribeOkInvalidContentExists) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); SubscribeOkMessage subscribe_ok; subscribe_ok.SetInvalidContentExists(); parser.ProcessData(subscribe_ok.PacketSample(), false); @@ -1130,7 +1124,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeOkInvalidDeliveryOrder) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); SubscribeOkMessage subscribe_ok; subscribe_ok.SetInvalidDeliveryOrder(); parser.ProcessData(subscribe_ok.PacketSample(), false); @@ -1141,7 +1135,7 @@ } TEST_F(MoqtMessageSpecificTest, SubscribeDoneInvalidContentExists) { - MoqtParser parser(kRawQuic, visitor_); + MoqtControlParser parser(kRawQuic, visitor_); SubscribeDoneMessage subscribe_done; subscribe_done.SetInvalidContentExists(); parser.ProcessData(subscribe_done.PacketSample(), false);
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index c74bc39..afcee57 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -188,7 +188,7 @@ void MoqtSession::OnDatagramReceived(absl::string_view datagram) { MoqtObject message; - absl::string_view payload = MoqtParser::ProcessDatagram(datagram, message); + absl::string_view payload = ParseDatagram(datagram, message); if (payload.empty()) { Error(MoqtError::kProtocolViolation, "Malformed datagram"); return; @@ -533,8 +533,9 @@ track.visitor()}); } +template <class Parser> static void ForwardStreamDataToParser(webtransport::Stream& stream, - MoqtParser& parser) { + Parser& parser) { bool fin = quiche::ProcessAllReadableRegions(stream, [&](absl::string_view chunk) { parser.ProcessData(chunk, /*end_of_stream=*/false); @@ -573,13 +574,6 @@ absl::StrCat("Control stream reset with error code ", error)); } -void MoqtSession::ControlStream::OnObjectMessage(const MoqtObject& message, - absl::string_view payload, - bool end_of_message) { - session_->Error(MoqtError::kProtocolViolation, - "Received OBJECT message on control stream"); -} - void MoqtSession::ControlStream::OnClientSetupMessage( const MoqtClientSetup& message) { session_->control_stream_ = stream_->GetStreamId(); @@ -880,6 +874,11 @@ << (end_of_message ? "F" : ""); if (!session_->parameters_.deliver_partial_objects) { if (!end_of_message) { // Buffer partial object. + if (partial_object_.empty() && message.payload_length.has_value()) { + // Avoid redundant allocations by reserving the appropriate amount of + // memory if known. + partial_object_.reserve(*message.payload_length); + } absl::StrAppend(&partial_object_, payload); return; }
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index c8d29c1..d437dc3 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -176,7 +176,7 @@ friend class test::MoqtSessionPeer; class QUICHE_EXPORT ControlStream : public webtransport::StreamVisitor, - public MoqtParserVisitor { + public MoqtControlParserVisitor { public: ControlStream(MoqtSession* session, webtransport::Stream* stream); @@ -187,10 +187,7 @@ void OnStopSendingReceived(webtransport::StreamErrorCode error) override; void OnWriteSideInDataRecvdState() override {} - // MoqtParserVisitor implementation. - // TODO: Handle a stream FIN. - void OnObjectMessage(const MoqtObject& message, absl::string_view payload, - bool end_of_message) override; + // MoqtControlParserVisitor implementation. void OnClientSetupMessage(const MoqtClientSetup& message) override; void OnServerSetupMessage(const MoqtServerSetup& message) override; void OnSubscribeMessage(const MoqtSubscribe& message) override; @@ -240,15 +237,13 @@ MoqtSession* session_; webtransport::Stream* stream_; - MoqtParser parser_; + MoqtControlParser parser_; }; class QUICHE_EXPORT IncomingDataStream : public webtransport::StreamVisitor, - public MoqtParserVisitor { + public MoqtDataParserVisitor { public: IncomingDataStream(MoqtSession* session, webtransport::Stream* stream) - : session_(session), - stream_(stream), - parser_(session->parameters_.using_webtrans, *this) {} + : session_(session), stream_(stream), parser_(this) {} // webtransport::StreamVisitor implementation. void OnCanRead() override; @@ -261,58 +256,6 @@ // TODO: Handle a stream FIN. void OnObjectMessage(const MoqtObject& message, absl::string_view payload, bool end_of_message) override; - void OnClientSetupMessage(const MoqtClientSetup&) override { - OnControlMessageReceived(); - } - void OnServerSetupMessage(const MoqtServerSetup&) override { - OnControlMessageReceived(); - } - void OnSubscribeMessage(const MoqtSubscribe&) override { - OnControlMessageReceived(); - } - void OnSubscribeOkMessage(const MoqtSubscribeOk&) override { - OnControlMessageReceived(); - } - void OnSubscribeErrorMessage(const MoqtSubscribeError&) override { - OnControlMessageReceived(); - } - void OnUnsubscribeMessage(const MoqtUnsubscribe&) override { - OnControlMessageReceived(); - } - void OnSubscribeDoneMessage(const MoqtSubscribeDone&) override { - OnControlMessageReceived(); - } - void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate&) override { - OnControlMessageReceived(); - } - void OnAnnounceMessage(const MoqtAnnounce&) override { - OnControlMessageReceived(); - } - void OnAnnounceOkMessage(const MoqtAnnounceOk&) override { - OnControlMessageReceived(); - } - void OnAnnounceErrorMessage(const MoqtAnnounceError&) override { - OnControlMessageReceived(); - } - void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override { - OnControlMessageReceived(); - } - void OnTrackStatusRequestMessage( - const MoqtTrackStatusRequest& message) override { - OnControlMessageReceived(); - } - void OnUnannounceMessage(const MoqtUnannounce&) override { - OnControlMessageReceived(); - } - void OnTrackStatusMessage(const MoqtTrackStatus&) override { - OnControlMessageReceived(); - } - void OnGoAwayMessage(const MoqtGoAway&) override { - OnControlMessageReceived(); - } - void OnObjectAckMessage(const MoqtObjectAck&) override { - OnControlMessageReceived(); - } void OnParsingError(MoqtError error_code, absl::string_view reason) override; @@ -328,7 +271,7 @@ MoqtSession* session_; webtransport::Stream* stream_; - MoqtParser parser_; + MoqtDataParser parser_; std::string partial_object_; }; // Represents a record for a single subscription to a local track that is
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index a129a9d..cd7d49c 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -75,7 +75,7 @@ class MoqtSessionPeer { public: - static std::unique_ptr<MoqtParserVisitor> CreateControlStream( + static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream( MoqtSession* session, webtransport::test::MockStream* stream) { auto new_stream = std::make_unique<MoqtSession::ControlStream>(session, stream); @@ -86,7 +86,7 @@ return new_stream; } - static std::unique_ptr<MoqtParserVisitor> CreateIncomingDataStream( + static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream( MoqtSession* session, webtransport::Stream* stream) { auto new_stream = std::make_unique<MoqtSession::IncomingDataStream>(session, stream); @@ -100,9 +100,10 @@ // can inject packets into that stream. // This function is useful for any test that wants to inject packets on a // stream created by the MoqtSession. - static MoqtParserVisitor* FetchParserVisitorFromWebtransportStreamVisitor( + static MoqtControlParserVisitor* + FetchParserVisitorFromWebtransportStreamVisitor( MoqtSession* session, webtransport::StreamVisitor* visitor) { - return (MoqtSession::ControlStream*)visitor; + return static_cast<MoqtSession::ControlStream*>(visitor); } static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name, @@ -207,7 +208,7 @@ EXPECT_TRUE(correct_message); // Receive SERVER_SETUP - MoqtParserVisitor* stream_input = + MoqtControlParserVisitor* stream_input = MoqtSessionPeer::FetchParserVisitorFromWebtransportStreamVisitor( &session_, visitor.get()); // Handle the server setup @@ -224,7 +225,7 @@ &mock_session_, MoqtSessionParameters(quic::Perspective::IS_SERVER), session_callbacks_.AsSessionCallbacks()); webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); MoqtClientSetup setup = { /*supported_versions=*/{MoqtVersion::kDraft05}, @@ -313,7 +314,7 @@ /*parameters=*/MoqtSubscribeParameters(), }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); // Request for track returns SUBSCRIBE_ERROR. bool correct_message = false; @@ -352,7 +353,7 @@ std::optional<MoqtAnnounceErrorReason> error_message)> announce_resolved_callback; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); bool correct_message = true; @@ -387,7 +388,7 @@ std::optional<MoqtAnnounceErrorReason> error_message)> announce_resolved_callback; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); bool correct_message = true; @@ -449,7 +450,7 @@ /*parameters=*/MoqtSubscribeParameters(), }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); bool correct_message = true; EXPECT_CALL(mock_stream, Writev(_, _)) @@ -465,7 +466,7 @@ TEST_F(MoqtSessionTest, SubscribeWithOk) { webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MockRemoteTrackVisitor remote_track_visitor; EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); @@ -497,7 +498,7 @@ TEST_F(MoqtSessionTest, SubscribeWithError) { webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MockRemoteTrackVisitor remote_track_visitor; EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); @@ -531,7 +532,7 @@ TEST_F(MoqtSessionTest, ReplyToAnnounce) { webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MoqtAnnounce announce = { /*track_namespace=*/"foo", @@ -566,7 +567,7 @@ /*payload_length=*/8, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _, _)).Times(1); @@ -591,7 +592,7 @@ /*payload_length=*/16, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _, _)).Times(1); @@ -621,7 +622,7 @@ /*payload_length=*/16, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session, &mock_stream); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _, _)).Times(2); @@ -659,7 +660,7 @@ /*payload_length=*/8, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _, _)) @@ -684,7 +685,7 @@ /*largest_id=*/std::nullopt, }; webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtParserVisitor> control_stream = + std::unique_ptr<MoqtControlParserVisitor> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); EXPECT_CALL(visitor_, OnReply(_, _)).Times(1); control_stream->OnSubscribeOkMessage(ok); @@ -718,7 +719,7 @@ /*payload_length=*/8, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _, _)) @@ -743,7 +744,7 @@ /*track_alias =*/3, }; webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtParserVisitor> control_stream = + std::unique_ptr<MoqtControlParserVisitor> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), @@ -780,7 +781,7 @@ /*payload_length=*/8, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _, _)) @@ -833,7 +834,7 @@ /*payload_length=*/8, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _, _)) @@ -863,7 +864,7 @@ /*largest_id=*/std::nullopt, }; webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtParserVisitor> control_stream = + std::unique_ptr<MoqtControlParserVisitor> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), @@ -1056,7 +1057,7 @@ &mock_session_, MoqtSessionParameters(quic::Perspective::IS_SERVER), session_callbacks_.AsSessionCallbacks()); webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); MoqtClientSetup setup = { /*supported_versions*/ {MoqtVersion::kDraft05}, @@ -1097,7 +1098,7 @@ SetupPublisher(ftn, MoqtForwardingPreference::kTrack, FullSequence(4, 2)); MoqtSessionPeer::AddSubscription(&session_, track, 0, 1, 3, 4); webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MoqtUnsubscribe unsubscribe = { /*subscribe_id=*/0, @@ -1188,7 +1189,7 @@ /*payload_length=*/8, }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> object_stream = + std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream(&session_, &mock_stream); EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _, _)).Times(1); @@ -1230,7 +1231,7 @@ /*parameters=*/MoqtSubscribeParameters(), }; webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); // Request for track returns Protocol Violation. EXPECT_CALL(mock_session_, @@ -1244,7 +1245,7 @@ TEST_F(MoqtSessionTest, AnnounceFromSubscriber) { MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber); webtransport::test::MockStream mock_stream; - std::unique_ptr<MoqtParserVisitor> stream_input = + std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MoqtAnnounce announce = { /*track_namespace=*/"foo",
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 4661197..f653b26 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -1083,10 +1083,6 @@ static inline std::unique_ptr<TestMessageBase> CreateTestMessage( MoqtMessageType message_type, bool is_webtrans) { switch (message_type) { - case MoqtMessageType::kObjectStream: - return std::make_unique<ObjectStreamMessage>(); - case MoqtMessageType::kObjectDatagram: - return std::make_unique<ObjectDatagramMessage>(); case MoqtMessageType::kSubscribe: return std::make_unique<SubscribeMessage>(); case MoqtMessageType::kSubscribeOk: @@ -1121,15 +1117,26 @@ return std::make_unique<ClientSetupMessage>(is_webtrans); case MoqtMessageType::kServerSetup: return std::make_unique<ServerSetupMessage>(); - case MoqtMessageType::kStreamHeaderTrack: - return std::make_unique<StreamHeaderTrackMessage>(); - case MoqtMessageType::kStreamHeaderGroup: - return std::make_unique<StreamHeaderGroupMessage>(); default: return nullptr; } } +static inline std::unique_ptr<TestMessageBase> CreateTestDataStream( + MoqtDataStreamType type) { + switch (type) { + case MoqtDataStreamType::kObjectStream: + return std::make_unique<ObjectStreamMessage>(); + case MoqtDataStreamType::kObjectDatagram: + return std::make_unique<ObjectDatagramMessage>(); + case MoqtDataStreamType::kStreamHeaderTrack: + return std::make_unique<StreamHeaderTrackMessage>(); + case MoqtDataStreamType::kStreamHeaderGroup: + return std::make_unique<StreamHeaderGroupMessage>(); + } + return nullptr; +} + } // namespace moqt::test #endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_