Add APIs to read until a specific point in the stream is reached. Also fix the handling of FIN without data that came up when writing tests. PiperOrigin-RevId: 703568334
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 81ab2d8..800f9eb 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -1025,7 +1025,7 @@ return reader.PeekRemainingPayload(); } -void MoqtDataParser::ReadAllData() { +void MoqtDataParser::ReadDataUntil(StopCondition stop_condition) { if (processing_) { QUICHE_BUG(MoqtDataParser_reentry) << "Calling ProcessData() when ProcessData() is already in progress."; @@ -1037,7 +1037,7 @@ State last_state = state(); for (;;) { ParseNextItemFromStream(); - if (state() == last_state || no_more_data_) { + if (state() == last_state || no_more_data_ || stop_condition()) { break; } last_state = state(); @@ -1048,7 +1048,7 @@ fin_read = false; quiche::ReadStream::PeekResult peek_result = stream_.PeekNextReadableRegion(); - if (!peek_result.has_data()) { + if (peek_result.peeked_data.empty()) { if (peek_result.fin_next) { fin_read = stream_.SkipBytes(0); QUICHE_DCHECK(fin_read); @@ -1141,6 +1141,9 @@ } void MoqtDataParser::ParseNextItemFromStream() { + if (CheckForFinWithoutData()) { + return; + } switch (next_input_) { case kStreamType: { std::optional<uint64_t> value_read = ReadVarInt62NoFin(); @@ -1236,6 +1239,7 @@ return; } + ++num_objects_read_; visitor_.OnObjectMessage(metadata_, "", /*end_of_message=*/true); AdvanceParserState(); } @@ -1250,11 +1254,10 @@ while (payload_length_remaining_ > 0) { quiche::ReadStream::PeekResult peek_result = stream_.PeekNextReadableRegion(); - if (peek_result.peeked_data.empty() && !peek_result.fin_next) { + if (!peek_result.has_data()) { return; } - if (peek_result.fin_next && - peek_result.peeked_data.size() < payload_length_remaining_) { + if (peek_result.fin_next && payload_length_remaining_ > 0) { ParseError("FIN received at an unexpected point in the stream"); return; } @@ -1267,6 +1270,7 @@ metadata_, peek_result.peeked_data.substr(0, chunk_size), done); const bool fin = stream_.SkipBytes(chunk_size); if (done) { + ++num_objects_read_; no_more_data_ |= fin; AdvanceParserState(); } @@ -1283,4 +1287,39 @@ } } +void MoqtDataParser::ReadAllData() { + ReadDataUntil(+[]() { return false; }); +} + +void MoqtDataParser::ReadStreamType() { + return ReadDataUntil([this]() { return type_.has_value(); }); +} + +void MoqtDataParser::ReadTrackAlias() { + return ReadDataUntil( + [this]() { return type_.has_value() && next_input_ != kTrackAlias; }); +} + +void MoqtDataParser::ReadAtMostOneObject() { + const size_t num_objects_read_initial = num_objects_read_; + return ReadDataUntil( + [&]() { return num_objects_read_ != num_objects_read_initial; }); +} + +bool MoqtDataParser::CheckForFinWithoutData() { + if (!stream_.PeekNextReadableRegion().fin_next) { + return false; + } + const bool valid_state = + (type_ == MoqtDataStreamType::kStreamHeaderSubgroup && + next_input_ == kObjectId) || + (type_ == MoqtDataStreamType::kStreamHeaderFetch && + next_input_ == kGroupId); + if (!valid_state || num_objects_read_ == 0) { + ParseError("FIN received at an unexpected point in the stream"); + return true; + } + return stream_.SkipBytes(0); +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index a96a70d..1b4de80 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -17,6 +17,7 @@ #include "quiche/quic/core/quic_data_reader.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_callbacks.h" #include "quiche/common/quiche_stream.h" namespace moqt { @@ -194,6 +195,10 @@ // Reads all of the available objects on the stream. void ReadAllData(); + void ReadStreamType(); + void ReadTrackAlias(); + void ReadAtMostOneObject(); + // Returns the type of the unidirectional stream, if already known. std::optional<MoqtDataStreamType> stream_type() const { return type_; } @@ -214,6 +219,10 @@ kPadding, kFailed, }; + + // If a StopCondition callback returns true, parsing will terminate. + using StopCondition = quiche::UnretainedCallback<bool()>; + struct State { NextInput next_input; uint64_t payload_remaining; @@ -222,6 +231,8 @@ }; State state() const { return State{next_input_, payload_length_remaining_}; } + void ReadDataUntil(StopCondition stop_condition); + // Reads a single varint from the underlying stream. std::optional<uint64_t> ReadVarInt62(bool& fin_read); // Reads a single varint from the underlying stream. Triggers a parse error if @@ -235,6 +246,9 @@ void AdvanceParserState(); // Reads the next available item from the stream. void ParseNextItemFromStream(); + // Checks if we have encountered a FIN without data. If so, processes it and + // returns true. + bool CheckForFinWithoutData(); void ParseError(absl::string_view reason); @@ -250,6 +264,7 @@ NextInput next_input_ = kStreamType; MoqtObject metadata_; size_t payload_length_remaining_ = 0; + size_t num_objects_read_ = 0; bool processing_ = false; // True if currently in ProcessData(), to prevent // re-entrancy.
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index a4b38bb..d35e9ab 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -1379,4 +1379,53 @@ } } +class MoqtDataParserStateMachineTest : public quic::test::QuicTest { + protected: + MoqtDataParserStateMachineTest() + : stream_(/*stream_id=*/0), parser_(&stream_, &visitor_) {} + + webtransport::test::InMemoryStream stream_; + MoqtParserTestVisitor visitor_; + MoqtDataParser parser_; +}; + +TEST_F(MoqtDataParserStateMachineTest, ReadAll) { + stream_.Receive(StreamHeaderSubgroupMessage().PacketSample()); + stream_.Receive(StreamMiddlerSubgroupMessage().PacketSample()); + parser_.ReadAllData(); + ASSERT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.object_payloads_[0], "foo"); + EXPECT_EQ(visitor_.object_payloads_[1], "bar"); + stream_.Receive("", /*fin=*/true); + parser_.ReadAllData(); + EXPECT_EQ(visitor_.parsing_error_, std::nullopt); +} + +TEST_F(MoqtDataParserStateMachineTest, ReadObjects) { + stream_.Receive(StreamHeaderSubgroupMessage().PacketSample()); + stream_.Receive(StreamMiddlerSubgroupMessage().PacketSample(), /*fin=*/true); + parser_.ReadAtMostOneObject(); + ASSERT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.object_payloads_[0], "foo"); + parser_.ReadAtMostOneObject(); + ASSERT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.object_payloads_[1], "bar"); + EXPECT_EQ(visitor_.parsing_error_, std::nullopt); +} + +TEST_F(MoqtDataParserStateMachineTest, ReadTypeThenObjects) { + stream_.Receive(StreamHeaderSubgroupMessage().PacketSample()); + stream_.Receive(StreamMiddlerSubgroupMessage().PacketSample(), /*fin=*/true); + parser_.ReadStreamType(); + ASSERT_EQ(visitor_.messages_received_, 0); + EXPECT_EQ(parser_.stream_type(), MoqtDataStreamType::kStreamHeaderSubgroup); + parser_.ReadAtMostOneObject(); + ASSERT_EQ(visitor_.messages_received_, 1); + EXPECT_EQ(visitor_.object_payloads_[0], "foo"); + parser_.ReadAtMostOneObject(); + ASSERT_EQ(visitor_.messages_received_, 2); + EXPECT_EQ(visitor_.object_payloads_[1], "bar"); + EXPECT_EQ(visitor_.parsing_error_, std::nullopt); +} + } // namespace moqt::test
diff --git a/quiche/web_transport/test_tools/in_memory_stream.cc b/quiche/web_transport/test_tools/in_memory_stream.cc index 99d5354..724f4f8 100644 --- a/quiche/web_transport/test_tools/in_memory_stream.cc +++ b/quiche/web_transport/test_tools/in_memory_stream.cc
@@ -43,9 +43,7 @@ return PeekResult{"", fin_received_, fin_received_}; } absl::string_view next_chunk = *buffer_.Chunks().begin(); - return PeekResult{next_chunk, - fin_received_ && next_chunk.size() == buffer_.size(), - fin_received_}; + return PeekResult{next_chunk, false, fin_received_}; } bool InMemoryStream::SkipBytes(size_t bytes) {