Rewrite MoqtDataParser to only consume one thing at a time.
Also fix the way we handle errors when parsing datagrams.
PiperOrigin-RevId: 703205925
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index a5e0da0..81ab2d8 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -4,22 +4,28 @@
#include "quiche/quic/moqt/moqt_parser.h"
+#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <optional>
#include <string>
+#include <tuple>
+#include "absl/base/casts.h"
#include "absl/cleanup/cleanup.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "quiche/quic/core/quic_data_reader.h"
#include "quiche/quic/core/quic_time.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_priority.h"
#include "quiche/common/platform/api/quiche_bug_tracker.h"
#include "quiche/common/platform/api/quiche_logging.h"
+#include "quiche/common/quiche_data_reader.h"
+#include "quiche/common/quiche_stream.h"
namespace moqt {
@@ -61,81 +67,6 @@
return false;
}
-size_t ParseObjectHeader(quic::QuicDataReader& reader, MoqtObject& object,
- MoqtDataStreamType type) {
- if (!reader.ReadVarInt62(&object.track_alias)) {
- return 0;
- }
- if (type != MoqtDataStreamType::kStreamHeaderFetch &&
- !reader.ReadVarInt62(&object.group_id)) {
- return 0;
- }
- if (type == MoqtDataStreamType::kStreamHeaderSubgroup) {
- uint64_t subgroup_id;
- if (!reader.ReadVarInt62(&subgroup_id)) {
- return 0;
- }
- object.subgroup_id = subgroup_id;
- }
- if (type == MoqtDataStreamType::kObjectDatagram &&
- !reader.ReadVarInt62(&object.object_id)) {
- return 0;
- }
- if (type != MoqtDataStreamType::kStreamHeaderFetch &&
- !reader.ReadUInt8(&object.publisher_priority)) {
- return 0;
- }
- uint64_t status = static_cast<uint64_t>(MoqtObjectStatus::kNormal);
- if (type == MoqtDataStreamType::kObjectDatagram &&
- (!reader.ReadVarInt62(&object.payload_length) ||
- (object.payload_length == 0 && !reader.ReadVarInt62(&status)))) {
- return 0;
- }
- object.object_status = IntegerToObjectStatus(status);
- return reader.PreviouslyReadPayload().size();
-}
-
-size_t ParseObjectSubheader(quic::QuicDataReader& reader, MoqtObject& object,
- MoqtDataStreamType type) {
- switch (type) {
- case MoqtDataStreamType::kStreamHeaderFetch:
- if (!reader.ReadVarInt62(&object.group_id)) {
- return 0;
- }
- if (type == MoqtDataStreamType::kStreamHeaderFetch) {
- uint64_t value;
- if (!reader.ReadVarInt62(&value)) {
- return 0;
- }
- object.subgroup_id = value;
- }
- [[fallthrough]];
-
- case MoqtDataStreamType::kStreamHeaderSubgroup: {
- if (!reader.ReadVarInt62(&object.object_id)) {
- return 0;
- }
- if (type == MoqtDataStreamType::kStreamHeaderFetch &&
- !reader.ReadUInt8(&object.publisher_priority)) {
- return 0;
- }
- if (!reader.ReadVarInt62(&object.payload_length)) {
- return 0;
- }
- uint64_t status = static_cast<uint64_t>(MoqtObjectStatus::kNormal);
- if (object.payload_length == 0 && !reader.ReadVarInt62(&status)) {
- return 0;
- }
- object.object_status = IntegerToObjectStatus(status);
- return reader.PreviouslyReadPayload().size();
- }
-
- default:
- QUICHE_NOTREACHED();
- return 0;
- }
-}
-
} // namespace
// The buffering philosophy is complicated, to minimize copying. Here is an
@@ -1057,31 +988,44 @@
if (parsing_error_) {
return; // Don't send multiple parse errors.
}
+ next_input_ = kFailed;
no_more_data_ = true;
parsing_error_ = true;
visitor_.OnParsingError(MoqtError::kProtocolViolation, reason);
}
-absl::string_view ParseDatagram(absl::string_view data,
- MoqtObject& object_metadata) {
- uint64_t value;
+std::optional<absl::string_view> ParseDatagram(absl::string_view data,
+ MoqtObject& object_metadata) {
+ uint64_t type_raw, object_status_raw;
quic::QuicDataReader reader(data);
- if (!reader.ReadVarInt62(&value)) {
- return absl::string_view();
+ if (!reader.ReadVarInt62(&type_raw) ||
+ type_raw != static_cast<uint64_t>(MoqtDataStreamType::kObjectDatagram) ||
+ !reader.ReadVarInt62(&object_metadata.track_alias) ||
+ !reader.ReadVarInt62(&object_metadata.group_id) ||
+ !reader.ReadVarInt62(&object_metadata.object_id) ||
+ !reader.ReadUInt8(&object_metadata.publisher_priority) ||
+ !reader.ReadVarInt62(&object_metadata.payload_length)) {
+ return std::nullopt;
}
- if (static_cast<MoqtDataStreamType>(value) !=
- MoqtDataStreamType::kObjectDatagram) {
- return absl::string_view();
+ if (object_metadata.payload_length > 0) {
+ object_metadata.object_status = MoqtObjectStatus::kNormal;
+ } else {
+ if (!reader.ReadVarInt62(&object_status_raw)) {
+ return std::nullopt;
+ }
+ object_metadata.object_status = IntegerToObjectStatus(object_status_raw);
+ if (object_metadata.object_status ==
+ MoqtObjectStatus::kInvalidObjectStatus) {
+ return std::nullopt;
+ }
}
- size_t processed_data = ParseObjectHeader(
- reader, object_metadata, MoqtDataStreamType::kObjectDatagram);
- if (processed_data == 0) { // Incomplete header
- return absl::string_view();
+ if (reader.PeekRemainingPayload().size() != object_metadata.payload_length) {
+ return std::nullopt;
}
return reader.PeekRemainingPayload();
}
-void MoqtDataParser::ProcessData(absl::string_view data, bool fin) {
+void MoqtDataParser::ReadAllData() {
if (processing_) {
QUICHE_BUG(MoqtDataParser_reentry)
<< "Calling ProcessData() when ProcessData() is already in progress.";
@@ -1090,104 +1034,253 @@
processing_ = true;
auto on_return = absl::MakeCleanup([&] { processing_ = false; });
- if (no_more_data_) {
- ParseError("Data after end of stream");
- return;
- }
-
- // Sad path: there is already data buffered. Attempt to transfer a small
- // chunk from `data` into the buffer, in hope that it will make the contents
- // of the buffer parsable without any leftover data. This is a reasonable
- // expectation, since object headers are small, and are often followed by
- // large blobs of data.
- while (!buffered_message_.empty() && !data.empty()) {
- absl::string_view chunk = data.substr(0, chunk_size_);
- absl::StrAppend(&buffered_message_, chunk);
- absl::string_view unprocessed = ProcessDataInner(buffered_message_);
- if (unprocessed.size() >= chunk.size()) {
- // chunk didn't allow any processing at all.
- data.remove_prefix(chunk.size());
- } else {
- buffered_message_.clear();
- data.remove_prefix(chunk.size() - unprocessed.size());
+ State last_state = state();
+ for (;;) {
+ ParseNextItemFromStream();
+ if (state() == last_state || no_more_data_) {
+ break;
}
- }
-
- // Happy path: there is no buffered data.
- if (buffered_message_.empty() && !data.empty()) {
- buffered_message_.assign(ProcessDataInner(data));
- }
-
- if (fin) {
- if (!buffered_message_.empty() || !metadata_.has_value() ||
- payload_length_remaining_ > 0) {
- ParseError("FIN received at an unexpected point in the stream");
- return;
- }
- no_more_data_ = true;
+ last_state = state();
}
}
-absl::string_view MoqtDataParser::ProcessDataInner(absl::string_view data) {
- quic::QuicDataReader reader(data);
- while (!reader.IsDoneReading()) {
- absl::string_view remainder = reader.PeekRemainingPayload();
- switch (GetNextInput()) {
- case kStreamType: {
- uint64_t value;
- if (!reader.ReadVarInt62(&value)) {
- return remainder;
- }
- if (!IsAllowedStreamType(value)) {
- ParseError(absl::StrCat("Unknown stream type: ", value));
- return "";
- }
- type_ = static_cast<MoqtDataStreamType>(value);
- continue;
- }
+std::optional<uint64_t> MoqtDataParser::ReadVarInt62(bool& fin_read) {
+ fin_read = false;
- case kHeader: {
- MoqtObject header;
- size_t bytes_read = ParseObjectHeader(reader, header, *type_);
- if (bytes_read == 0) {
- return remainder;
- }
- metadata_ = header;
- continue;
- }
-
- case kSubheader: {
- size_t bytes_read = ParseObjectSubheader(reader, *metadata_, *type_);
- if (bytes_read == 0) {
- return remainder;
- }
- if (metadata_->object_status ==
- MoqtObjectStatus::kInvalidObjectStatus) {
- ParseError("Invalid object status provided");
- return "";
- }
- payload_length_remaining_ = metadata_->payload_length;
- if (payload_length_remaining_ == 0) {
- visitor_.OnObjectMessage(*metadata_, "", true);
- }
- continue;
- }
-
- case kData: {
- absl::string_view payload =
- reader.ReadAtMost(payload_length_remaining_);
- visitor_.OnObjectMessage(*metadata_, payload,
- payload.size() == payload_length_remaining_);
- payload_length_remaining_ -= payload.size();
-
- continue;
- }
-
- case kPadding:
- return "";
+ quiche::ReadStream::PeekResult peek_result = stream_.PeekNextReadableRegion();
+ if (!peek_result.has_data()) {
+ if (peek_result.fin_next) {
+ fin_read = stream_.SkipBytes(0);
+ QUICHE_DCHECK(fin_read);
}
+ return std::nullopt;
}
- return "";
+ char first_byte = peek_result.peeked_data[0];
+ size_t varint_size =
+ 1 << ((absl::bit_cast<uint8_t>(first_byte) & 0b11000000) >> 6);
+ if (stream_.ReadableBytes() < varint_size) {
+ return std::nullopt;
+ }
+
+ char buffer[8];
+ absl::Span<char> bytes_to_read =
+ absl::MakeSpan(buffer).subspan(0, varint_size);
+ quiche::ReadStream::ReadResult read_result = stream_.Read(bytes_to_read);
+ QUICHE_DCHECK_EQ(read_result.bytes_read, varint_size);
+ fin_read = read_result.fin;
+
+ quiche::QuicheDataReader reader(buffer, read_result.bytes_read);
+ uint64_t result;
+ bool success = reader.ReadVarInt62(&result);
+ QUICHE_DCHECK(success);
+ QUICHE_DCHECK(reader.IsDoneReading());
+ return result;
+}
+
+std::optional<uint64_t> MoqtDataParser::ReadVarInt62NoFin() {
+ bool fin_read = false;
+ std::optional<uint64_t> result = ReadVarInt62(fin_read);
+ if (fin_read) {
+ ParseError("Unexpected FIN received in the middle of a header");
+ return std::nullopt;
+ }
+ return result;
+}
+
+std::optional<uint8_t> MoqtDataParser::ReadUint8NoFin() {
+ char buffer[1];
+ quiche::ReadStream::ReadResult read_result =
+ stream_.Read(absl::MakeSpan(buffer));
+ if (read_result.fin) {
+ ParseError("Unexpected FIN received in the middle of a header");
+ return std::nullopt;
+ }
+ if (read_result.bytes_read == 0) {
+ return std::nullopt;
+ }
+ return absl::bit_cast<uint8_t>(buffer[0]);
+}
+
+void MoqtDataParser::AdvanceParserState() {
+ QUICHE_DCHECK(type_ == MoqtDataStreamType::kStreamHeaderSubgroup ||
+ type_ == MoqtDataStreamType::kStreamHeaderFetch);
+ const bool is_fetch = type_ == MoqtDataStreamType::kStreamHeaderFetch;
+ switch (next_input_) {
+ // The state table is factored into a separate function (rather than
+ // inlined) in order to separate the order of elements from the way they are
+ // parsed.
+ case kStreamType:
+ next_input_ = kTrackAlias;
+ break;
+ case kTrackAlias:
+ next_input_ = kGroupId;
+ break;
+ case kGroupId:
+ next_input_ = kSubgroupId;
+ break;
+ case kSubgroupId:
+ next_input_ = is_fetch ? kObjectId : kPublisherPriority;
+ break;
+ case kPublisherPriority:
+ next_input_ = is_fetch ? kObjectPayloadLength : kObjectId;
+ break;
+ case kObjectId:
+ next_input_ = is_fetch ? kPublisherPriority : kObjectPayloadLength;
+ break;
+ case kStatus:
+ case kData:
+ next_input_ = is_fetch ? kGroupId : kObjectId;
+ break;
+
+ case kObjectPayloadLength: // Either kStatus or kData depending on length.
+ case kPadding: // Handled separately.
+ case kFailed: // Should cause parsing to cease.
+ QUICHE_NOTREACHED();
+ break;
+ }
+}
+
+void MoqtDataParser::ParseNextItemFromStream() {
+ switch (next_input_) {
+ case kStreamType: {
+ std::optional<uint64_t> value_read = ReadVarInt62NoFin();
+ if (value_read.has_value()) {
+ if (!IsAllowedStreamType(*value_read)) {
+ ParseError("Invalid stream type supplied");
+ return;
+ }
+ type_ = static_cast<MoqtDataStreamType>(*value_read);
+ switch (*type_) {
+ case MoqtDataStreamType::kStreamHeaderSubgroup:
+ case MoqtDataStreamType::kStreamHeaderFetch:
+ AdvanceParserState();
+ break;
+ case MoqtDataStreamType::kPadding:
+ next_input_ = kPadding;
+ break;
+ case MoqtDataStreamType::kObjectDatagram:
+ QUICHE_BUG(ParseDataFromStream_kStreamType_unexpected);
+ return;
+ }
+ }
+ return;
+ }
+
+ case kTrackAlias: {
+ std::optional<uint64_t> value_read = ReadVarInt62NoFin();
+ if (value_read.has_value()) {
+ metadata_.track_alias = *value_read;
+ AdvanceParserState();
+ }
+ return;
+ }
+
+ case kGroupId: {
+ std::optional<uint64_t> value_read = ReadVarInt62NoFin();
+ if (value_read.has_value()) {
+ metadata_.group_id = *value_read;
+ AdvanceParserState();
+ }
+ return;
+ }
+
+ case kSubgroupId: {
+ std::optional<uint64_t> value_read = ReadVarInt62NoFin();
+ if (value_read.has_value()) {
+ metadata_.subgroup_id = *value_read;
+ AdvanceParserState();
+ }
+ return;
+ }
+
+ case kPublisherPriority: {
+ std::optional<uint8_t> value_read = ReadUint8NoFin();
+ if (value_read.has_value()) {
+ metadata_.publisher_priority = *value_read;
+ AdvanceParserState();
+ }
+ return;
+ }
+
+ case kObjectId: {
+ std::optional<uint64_t> value_read = ReadVarInt62NoFin();
+ if (value_read.has_value()) {
+ metadata_.object_id = *value_read;
+ AdvanceParserState();
+ }
+ return;
+ }
+
+ case kObjectPayloadLength: {
+ std::optional<uint64_t> value_read = ReadVarInt62NoFin();
+ if (value_read.has_value()) {
+ metadata_.payload_length = *value_read;
+ payload_length_remaining_ = *value_read;
+ if (metadata_.payload_length > 0) {
+ metadata_.object_status = MoqtObjectStatus::kNormal;
+ next_input_ = kData;
+ } else {
+ next_input_ = kStatus;
+ }
+ }
+ return;
+ }
+
+ case kStatus: {
+ bool fin_read = false;
+ std::optional<uint64_t> value_read = ReadVarInt62(fin_read);
+ if (value_read.has_value()) {
+ metadata_.object_status = IntegerToObjectStatus(*value_read);
+ if (metadata_.object_status == MoqtObjectStatus::kInvalidObjectStatus) {
+ ParseError("Invalid object status provided");
+ return;
+ }
+
+ visitor_.OnObjectMessage(metadata_, "", /*end_of_message=*/true);
+ AdvanceParserState();
+ }
+ if (fin_read) {
+ no_more_data_ = true;
+ return;
+ }
+ return;
+ }
+
+ case kData: {
+ while (payload_length_remaining_ > 0) {
+ quiche::ReadStream::PeekResult peek_result =
+ stream_.PeekNextReadableRegion();
+ if (peek_result.peeked_data.empty() && !peek_result.fin_next) {
+ return;
+ }
+ if (peek_result.fin_next &&
+ peek_result.peeked_data.size() < payload_length_remaining_) {
+ ParseError("FIN received at an unexpected point in the stream");
+ return;
+ }
+
+ size_t chunk_size =
+ std::min(payload_length_remaining_, peek_result.peeked_data.size());
+ payload_length_remaining_ -= chunk_size;
+ bool done = payload_length_remaining_ == 0;
+ visitor_.OnObjectMessage(
+ metadata_, peek_result.peeked_data.substr(0, chunk_size), done);
+ const bool fin = stream_.SkipBytes(chunk_size);
+ if (done) {
+ no_more_data_ |= fin;
+ AdvanceParserState();
+ }
+ }
+ return;
+ }
+
+ case kPadding:
+ no_more_data_ |= stream_.SkipBytes(stream_.ReadableBytes());
+ return;
+
+ case kFailed:
+ return;
+ }
}
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h
index 34953e4..a96a70d 100644
--- a/quiche/quic/moqt/moqt_parser.h
+++ b/quiche/quic/moqt/moqt_parser.h
@@ -17,6 +17,7 @@
#include "quiche/quic/core/quic_data_reader.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/common/platform/api/quiche_export.h"
+#include "quiche/common/quiche_stream.h"
namespace moqt {
@@ -174,90 +175,80 @@
// re-entrancy.
};
-// Parses an MoQT datagram. Returns the payload bytes, or empty string_view on
-// error. The caller provides the whole datagram in `data`. The function puts
-// the object metadata in `object_metadata`.
-absl::string_view ParseDatagram(absl::string_view data,
- MoqtObject& object_metadata);
+// Parses an MoQT datagram. Returns the payload bytes, or std::nullopt on error.
+// The caller provides the whole datagram in `data`. The function puts the
+// object metadata in `object_metadata`.
+std::optional<absl::string_view> ParseDatagram(absl::string_view data,
+ MoqtObject& object_metadata);
// Parser for MoQT unidirectional data stream.
class QUICHE_EXPORT MoqtDataParser {
public:
- explicit MoqtDataParser(MoqtDataParserVisitor* visitor)
- : visitor_(*visitor) {}
- ~MoqtDataParser() = default;
+ // `stream` must outlive the parser. The parser does not configure itself as
+ // a listener for the read events of the stream; it is responsibility of the
+ // caller to do so via one of the read methods below.
+ explicit MoqtDataParser(quiche::ReadStream* stream,
+ MoqtDataParserVisitor* visitor)
+ : stream_(*stream), visitor_(*visitor) {}
- // Take a buffer from the transport in |data|. Parse each complete message and
- // call the appropriate visitor function. If |fin| is true, there
- // is no more data arriving on the stream, so the parser will deliver any
- // message encoded as to run to the end of the stream.
- // All bytes can be freed. Calls OnParsingError() when there is a parsing
- // error.
- void ProcessData(absl::string_view data, bool fin);
+ // Reads all of the available objects on the stream.
+ void ReadAllData();
- // Alters `chunk_size_` value (see discussion below). Primarily intended to
- // be used for testing.
- void set_chunk_size(size_t size) { chunk_size_ = size; }
-
+ // Returns the type of the unidirectional stream, if already known.
std::optional<MoqtDataStreamType> stream_type() const { return type_; }
private:
friend class test::MoqtDataParserPeer;
- // If there is buffered data from the previous attempt at parsing it, new data
- // will be added in `chunk_size_`-sized chunks.
- constexpr static size_t kDefaultChunkSize = 64;
-
// Current state of the parser.
enum NextInput {
- // Nothing has been read yet; the next thing to be read is the stream type
- // varint.
kStreamType,
- // The next thing to be read is the stream header.
- kHeader,
- // The next thing to be read is the stream subheader for the given object.
- kSubheader,
- // The next thing to be read is the object payload.
+ kTrackAlias,
+ kGroupId,
+ kSubgroupId,
+ kPublisherPriority,
+ kObjectId,
+ kObjectPayloadLength,
+ kStatus,
kData,
- // The next thing to be read (and ignored) is padding.
kPadding,
+ kFailed,
};
+ struct State {
+ NextInput next_input;
+ uint64_t payload_remaining;
- // Infers the current state of the parser.
- NextInput GetNextInput() const {
- if (!type_.has_value()) {
- return kStreamType;
- }
- if (type_ == MoqtDataStreamType::kPadding) {
- return kPadding;
- }
- if (!metadata_.has_value()) {
- return kHeader;
- }
- if (payload_length_remaining_ > 0) {
- return kData;
- }
- return kSubheader;
- }
+ bool operator==(const State&) const = default;
+ };
+ State state() const { return State{next_input_, payload_length_remaining_}; }
- // Processes all that can be entirely processed, and returns the view for the
- // data that needs to be buffered.
- absl::string_view ProcessDataInner(absl::string_view data);
+ // Reads a single varint from the underlying stream.
+ std::optional<uint64_t> ReadVarInt62(bool& fin_read);
+ // Reads a single varint from the underlying stream. Triggers a parse error if
+ // a FIN has been encountered.
+ std::optional<uint64_t> ReadVarInt62NoFin();
+ // Reads a single uint8 from the underlying stream. Triggers a parse error if
+ // a FIN has been encountered.
+ std::optional<uint8_t> ReadUint8NoFin();
+
+ // Advances the state machine of the parser to the next expected state.
+ void AdvanceParserState();
+ // Reads the next available item from the stream.
+ void ParseNextItemFromStream();
void ParseError(absl::string_view reason);
+ quiche::ReadStream& stream_;
MoqtDataParserVisitor& visitor_;
- size_t chunk_size_ = kDefaultChunkSize;
bool no_more_data_ = false; // Fatal error or fin. No more parsing.
bool parsing_error_ = false;
std::string buffered_message_;
- // The three variables below implicitly drive the state machine; see
- // `GetNextInput()` for how the state is derived.
std::optional<MoqtDataStreamType> type_ = std::nullopt;
- std::optional<MoqtObject> metadata_ = std::nullopt;
+ NextInput next_input_ = kStreamType;
+ MoqtObject metadata_;
size_t payload_length_remaining_ = 0;
bool processing_ = false; // True if currently in ProcessData(), to prevent
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc
index 1c77875..a4b38bb 100644
--- a/quiche/quic/moqt/moqt_parser_test.cc
+++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -21,6 +21,7 @@
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/test_tools/moqt_test_message.h"
#include "quiche/quic/platform/api/quic_test.h"
+#include "quiche/web_transport/test_tools/in_memory_stream.h"
namespace moqt::test {
@@ -237,7 +238,8 @@
: message_type_(GetParam().message_type),
webtrans_(GetParam().uses_web_transport),
control_parser_(GetParam().uses_web_transport, visitor_),
- data_parser_(&visitor_) {}
+ data_stream_(/*stream_id=*/0),
+ data_parser_(&data_stream_, &visitor_) {}
bool IsDataStream() {
return absl::holds_alternative<MoqtDataStreamType>(message_type_);
@@ -254,7 +256,8 @@
void ProcessData(absl::string_view data, bool fin) {
if (IsDataStream()) {
- data_parser_.ProcessData(data, fin);
+ data_stream_.Receive(data, fin);
+ data_parser_.ReadAllData();
} else {
control_parser_.ProcessData(data, fin);
}
@@ -265,6 +268,7 @@
GeneralizedMessageType message_type_;
bool webtrans_;
MoqtControlParser control_parser_;
+ webtransport::test::InMemoryStream data_stream_;
MoqtDataParser data_parser_;
};
@@ -354,7 +358,6 @@
TEST_P(MoqtParserTest, TwoBytesAtATime) {
std::unique_ptr<TestMessageBase> message = MakeMessage();
- data_parser_.set_chunk_size(1);
for (size_t i = 0; i < message->total_message_size(); i += 3) {
EXPECT_EQ(visitor_.messages_received_, 0);
EXPECT_FALSE(visitor_.end_of_message_);
@@ -432,24 +435,28 @@
// Send the header + some payload, pure payload, then pure payload to end the
// message.
TEST_F(MoqtMessageSpecificTest, ThreePartObject) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
auto message = std::make_unique<StreamHeaderSubgroupMessage>();
EXPECT_TRUE(message->SetPayloadLength(14));
- parser.ProcessData(message->PacketSample(), false);
+ stream.Receive(message->PacketSample(), false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 0);
EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_));
EXPECT_FALSE(visitor_.end_of_message_);
EXPECT_EQ(visitor_.object_payload(), "foo");
// second part
- parser.ProcessData("bar", false);
+ stream.Receive("bar", false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 0);
EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_));
EXPECT_FALSE(visitor_.end_of_message_);
EXPECT_EQ(visitor_.object_payload(), "foobar");
// third part includes FIN
- parser.ProcessData("deadbeef", true);
+ stream.Receive("deadbeef", true);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 1);
EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_));
EXPECT_TRUE(visitor_.end_of_message_);
@@ -459,19 +466,22 @@
// Send the part of header, rest of header + payload, plus payload.
TEST_F(MoqtMessageSpecificTest, ThreePartObjectFirstIncomplete) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
auto message = std::make_unique<StreamHeaderSubgroupMessage>();
EXPECT_TRUE(message->SetPayloadLength(51));
// first part
- parser.ProcessData(message->PacketSample().substr(0, 4), false);
+ stream.Receive(message->PacketSample().substr(0, 4), false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 0);
// second part. Add padding to it.
message->set_wire_image_size(55);
- parser.ProcessData(
+ stream.Receive(
message->PacketSample().substr(4, message->total_message_size() - 4),
false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 0);
EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_));
EXPECT_FALSE(visitor_.end_of_message_);
@@ -480,7 +490,8 @@
EXPECT_EQ(visitor_.object_payload().length(), 48);
// third part includes FIN
- parser.ProcessData("bar", true);
+ stream.Receive("bar", true);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 1);
EXPECT_TRUE(message->EqualFieldValues(*visitor_.last_message_));
EXPECT_TRUE(visitor_.end_of_message_);
@@ -489,10 +500,12 @@
}
TEST_F(MoqtMessageSpecificTest, StreamHeaderSubgroupFollowOn) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
// first part
auto message1 = std::make_unique<StreamHeaderSubgroupMessage>();
- parser.ProcessData(message1->PacketSample(), false);
+ stream.Receive(message1->PacketSample(), false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 1);
EXPECT_TRUE(message1->EqualFieldValues(*visitor_.last_message_));
EXPECT_TRUE(visitor_.end_of_message_);
@@ -501,7 +514,8 @@
// second part
visitor_.object_payloads_.clear();
auto message2 = std::make_unique<StreamMiddlerSubgroupMessage>();
- parser.ProcessData(message2->PacketSample(), false);
+ stream.Receive(message2->PacketSample(), false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 2);
EXPECT_TRUE(message2->EqualFieldValues(*visitor_.last_message_));
EXPECT_TRUE(visitor_.end_of_message_);
@@ -857,11 +871,13 @@
}
TEST_F(MoqtMessageSpecificTest, FinMidPayload) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
auto message = std::make_unique<StreamHeaderSubgroupMessage>();
- parser.ProcessData(
+ stream.Receive(
message->PacketSample().substr(0, message->total_message_size() - 1),
true);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 0);
EXPECT_EQ(visitor_.parsing_error_,
"FIN received at an unexpected point in the stream");
@@ -869,12 +885,15 @@
}
TEST_F(MoqtMessageSpecificTest, PartialPayloadThenFin) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
auto message = std::make_unique<StreamHeaderSubgroupMessage>();
- parser.ProcessData(
+ stream.Receive(
message->PacketSample().substr(0, message->total_message_size() - 1),
false);
- parser.ProcessData(absl::string_view(), true);
+ parser.ReadAllData();
+ stream.Receive(absl::string_view(), true);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.messages_received_, 0);
EXPECT_EQ(visitor_.parsing_error_,
"FIN received at an unexpected point in the stream");
@@ -890,16 +909,18 @@
}
TEST_F(MoqtMessageSpecificTest, InvalidObjectStatus) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
char stream_header_subgroup[] = {
0x04, // type field
0x04, 0x05, 0x08, // varints
0x07, // publisher priority
0x06, 0x00, 0x0f, // object middler; status = 0x0f
};
- parser.ProcessData(
+ stream.Receive(
absl::string_view(stream_header_subgroup, sizeof(stream_header_subgroup)),
false);
+ parser.ReadAllData();
EXPECT_EQ(visitor_.parsing_error_, "Invalid object status provided");
EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation);
}
@@ -1243,7 +1264,9 @@
TEST_F(MoqtMessageSpecificTest, DatagramSuccessful) {
ObjectDatagramMessage message;
MoqtObject object;
- absl::string_view payload = ParseDatagram(message.PacketSample(), object);
+ std::optional<absl::string_view> payload =
+ ParseDatagram(message.PacketSample(), object);
+ ASSERT_TRUE(payload.has_value());
TestMessageBase::MessageStructuredData object_metadata =
TestMessageBase::MessageStructuredData(object);
EXPECT_TRUE(message.EqualFieldValues(object_metadata));
@@ -1253,24 +1276,26 @@
TEST_F(MoqtMessageSpecificTest, WrongMessageInDatagram) {
StreamHeaderSubgroupMessage message;
MoqtObject object;
- absl::string_view payload = ParseDatagram(message.PacketSample(), object);
- EXPECT_TRUE(payload.empty());
+ std::optional<absl::string_view> payload =
+ ParseDatagram(message.PacketSample(), object);
+ EXPECT_EQ(payload, std::nullopt);
}
TEST_F(MoqtMessageSpecificTest, TruncatedDatagram) {
ObjectDatagramMessage message;
message.set_wire_image_size(4);
MoqtObject object;
- absl::string_view payload = ParseDatagram(message.PacketSample(), object);
- EXPECT_TRUE(payload.empty());
+ std::optional<absl::string_view> payload =
+ ParseDatagram(message.PacketSample(), object);
+ EXPECT_EQ(payload, std::nullopt);
}
TEST_F(MoqtMessageSpecificTest, VeryTruncatedDatagram) {
char message = 0x40;
MoqtObject object;
- absl::string_view payload =
+ std::optional<absl::string_view> payload =
ParseDatagram(absl::string_view(&message, sizeof(message)), object);
- EXPECT_TRUE(payload.empty());
+ EXPECT_EQ(payload, std::nullopt);
}
TEST_F(MoqtMessageSpecificTest, SubscribeOkInvalidContentExists) {
@@ -1340,13 +1365,15 @@
}
TEST_F(MoqtMessageSpecificTest, PaddingStream) {
- MoqtDataParser parser(&visitor_);
+ webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+ MoqtDataParser parser(&stream, &visitor_);
std::string buffer(32, '\0');
quic::QuicDataWriter writer(buffer.size(), buffer.data());
ASSERT_TRUE(writer.WriteVarInt62(
static_cast<uint64_t>(MoqtDataStreamType::kPadding)));
for (int i = 0; i < 100; ++i) {
- parser.ProcessData(buffer, false);
+ stream.Receive(buffer, false);
+ parser.ReadAllData();
ASSERT_EQ(visitor_.messages_received_, 0);
ASSERT_EQ(visitor_.parsing_error_, std::nullopt);
}
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 772dbff..6a06e00 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -190,14 +190,18 @@
void MoqtSession::OnDatagramReceived(absl::string_view datagram) {
MoqtObject message;
- absl::string_view payload = ParseDatagram(datagram, message);
+ std::optional<absl::string_view> payload = ParseDatagram(datagram, message);
+ if (!payload.has_value()) {
+ Error(MoqtError::kProtocolViolation, "Malformed datagram received");
+ return;
+ }
QUICHE_DLOG(INFO) << ENDPOINT
<< "Received OBJECT message in datagram for subscribe_id "
<< " for track alias " << message.track_alias
<< " with sequence " << message.group_id << ":"
<< message.object_id << " priority "
<< message.publisher_priority << " length "
- << payload.size();
+ << payload->size();
SubscribeRemoteTrack* track = RemoteTrackByAlias(message.track_alias);
if (track == nullptr) {
return;
@@ -219,7 +223,7 @@
visitor->OnObjectFragment(
track->full_track_name(),
FullSequence{message.group_id, 0, message.object_id},
- message.publisher_priority, message.object_status, payload, true);
+ message.publisher_priority, message.object_status, *payload, true);
}
}
@@ -1086,9 +1090,9 @@
<< stream_->GetStreamId() << " for track alias "
<< message.track_alias << " with sequence "
<< message.group_id << ":" << message.object_id
- << " priority " << message.publisher_priority
- << " length " << payload.size() << " length "
- << message.payload_length << (end_of_message ? "F" : "");
+ << " priority " << message.publisher_priority << " length "
+ << payload.size() << " length " << message.payload_length
+ << (end_of_message ? "F" : "");
if (!session_->parameters_.deliver_partial_objects) {
if (!end_of_message) { // Buffer partial object.
if (partial_object_.empty()) {
@@ -1143,9 +1147,7 @@
partial_object_.clear();
}
-void MoqtSession::IncomingDataStream::OnCanRead() {
- ForwardStreamDataToParser(*stream_, parser_);
-}
+void MoqtSession::IncomingDataStream::OnCanRead() { parser_.ReadAllData(); }
void MoqtSession::IncomingDataStream::OnControlMessageReceived() {
session_->Error(MoqtError::kProtocolViolation,
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 33063ba..7908c9e 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -262,7 +262,7 @@
public MoqtDataParserVisitor {
public:
IncomingDataStream(MoqtSession* session, webtransport::Stream* stream)
- : session_(session), stream_(stream), parser_(this) {}
+ : session_(session), stream_(stream), parser_(stream, this) {}
// webtransport::StreamVisitor implementation.
void OnCanRead() override;