Add APIs to read until a specific point in the stream is reached.
Also fix the handling of FIN without data that came up when writing tests.
PiperOrigin-RevId: 703568334
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index 81ab2d8..800f9eb 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -1025,7 +1025,7 @@
return reader.PeekRemainingPayload();
}
-void MoqtDataParser::ReadAllData() {
+void MoqtDataParser::ReadDataUntil(StopCondition stop_condition) {
if (processing_) {
QUICHE_BUG(MoqtDataParser_reentry)
<< "Calling ProcessData() when ProcessData() is already in progress.";
@@ -1037,7 +1037,7 @@
State last_state = state();
for (;;) {
ParseNextItemFromStream();
- if (state() == last_state || no_more_data_) {
+ if (state() == last_state || no_more_data_ || stop_condition()) {
break;
}
last_state = state();
@@ -1048,7 +1048,7 @@
fin_read = false;
quiche::ReadStream::PeekResult peek_result = stream_.PeekNextReadableRegion();
- if (!peek_result.has_data()) {
+ if (peek_result.peeked_data.empty()) {
if (peek_result.fin_next) {
fin_read = stream_.SkipBytes(0);
QUICHE_DCHECK(fin_read);
@@ -1141,6 +1141,9 @@
}
void MoqtDataParser::ParseNextItemFromStream() {
+ if (CheckForFinWithoutData()) {
+ return;
+ }
switch (next_input_) {
case kStreamType: {
std::optional<uint64_t> value_read = ReadVarInt62NoFin();
@@ -1236,6 +1239,7 @@
return;
}
+ ++num_objects_read_;
visitor_.OnObjectMessage(metadata_, "", /*end_of_message=*/true);
AdvanceParserState();
}
@@ -1250,11 +1254,10 @@
while (payload_length_remaining_ > 0) {
quiche::ReadStream::PeekResult peek_result =
stream_.PeekNextReadableRegion();
- if (peek_result.peeked_data.empty() && !peek_result.fin_next) {
+ if (!peek_result.has_data()) {
return;
}
- if (peek_result.fin_next &&
- peek_result.peeked_data.size() < payload_length_remaining_) {
+ if (peek_result.fin_next && payload_length_remaining_ > 0) {
ParseError("FIN received at an unexpected point in the stream");
return;
}
@@ -1267,6 +1270,7 @@
metadata_, peek_result.peeked_data.substr(0, chunk_size), done);
const bool fin = stream_.SkipBytes(chunk_size);
if (done) {
+ ++num_objects_read_;
no_more_data_ |= fin;
AdvanceParserState();
}
@@ -1283,4 +1287,39 @@
}
}
+void MoqtDataParser::ReadAllData() {
+ ReadDataUntil(+[]() { return false; });
+}
+
+void MoqtDataParser::ReadStreamType() {
+ return ReadDataUntil([this]() { return type_.has_value(); });
+}
+
+void MoqtDataParser::ReadTrackAlias() {
+ return ReadDataUntil(
+ [this]() { return type_.has_value() && next_input_ != kTrackAlias; });
+}
+
+void MoqtDataParser::ReadAtMostOneObject() {
+ const size_t num_objects_read_initial = num_objects_read_;
+ return ReadDataUntil(
+ [&]() { return num_objects_read_ != num_objects_read_initial; });
+}
+
+bool MoqtDataParser::CheckForFinWithoutData() {
+ if (!stream_.PeekNextReadableRegion().fin_next) {
+ return false;
+ }
+ const bool valid_state =
+ (type_ == MoqtDataStreamType::kStreamHeaderSubgroup &&
+ next_input_ == kObjectId) ||
+ (type_ == MoqtDataStreamType::kStreamHeaderFetch &&
+ next_input_ == kGroupId);
+ if (!valid_state || num_objects_read_ == 0) {
+ ParseError("FIN received at an unexpected point in the stream");
+ return true;
+ }
+ return stream_.SkipBytes(0);
+}
+
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h
index a96a70d..1b4de80 100644
--- a/quiche/quic/moqt/moqt_parser.h
+++ b/quiche/quic/moqt/moqt_parser.h
@@ -17,6 +17,7 @@
#include "quiche/quic/core/quic_data_reader.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/common/platform/api/quiche_export.h"
+#include "quiche/common/quiche_callbacks.h"
#include "quiche/common/quiche_stream.h"
namespace moqt {
@@ -194,6 +195,10 @@
// Reads all of the available objects on the stream.
void ReadAllData();
+ void ReadStreamType();
+ void ReadTrackAlias();
+ void ReadAtMostOneObject();
+
// Returns the type of the unidirectional stream, if already known.
std::optional<MoqtDataStreamType> stream_type() const { return type_; }
@@ -214,6 +219,10 @@
kPadding,
kFailed,
};
+
+ // If a StopCondition callback returns true, parsing will terminate.
+ using StopCondition = quiche::UnretainedCallback<bool()>;
+
struct State {
NextInput next_input;
uint64_t payload_remaining;
@@ -222,6 +231,8 @@
};
State state() const { return State{next_input_, payload_length_remaining_}; }
+ void ReadDataUntil(StopCondition stop_condition);
+
// Reads a single varint from the underlying stream.
std::optional<uint64_t> ReadVarInt62(bool& fin_read);
// Reads a single varint from the underlying stream. Triggers a parse error if
@@ -235,6 +246,9 @@
void AdvanceParserState();
// Reads the next available item from the stream.
void ParseNextItemFromStream();
+ // Checks if we have encountered a FIN without data. If so, processes it and
+ // returns true.
+ bool CheckForFinWithoutData();
void ParseError(absl::string_view reason);
@@ -250,6 +264,7 @@
NextInput next_input_ = kStreamType;
MoqtObject metadata_;
size_t payload_length_remaining_ = 0;
+ size_t num_objects_read_ = 0;
bool processing_ = false; // True if currently in ProcessData(), to prevent
// re-entrancy.
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc
index a4b38bb..d35e9ab 100644
--- a/quiche/quic/moqt/moqt_parser_test.cc
+++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -1379,4 +1379,53 @@
}
}
+class MoqtDataParserStateMachineTest : public quic::test::QuicTest {
+ protected:
+ MoqtDataParserStateMachineTest()
+ : stream_(/*stream_id=*/0), parser_(&stream_, &visitor_) {}
+
+ webtransport::test::InMemoryStream stream_;
+ MoqtParserTestVisitor visitor_;
+ MoqtDataParser parser_;
+};
+
+TEST_F(MoqtDataParserStateMachineTest, ReadAll) {
+ stream_.Receive(StreamHeaderSubgroupMessage().PacketSample());
+ stream_.Receive(StreamMiddlerSubgroupMessage().PacketSample());
+ parser_.ReadAllData();
+ ASSERT_EQ(visitor_.messages_received_, 2);
+ EXPECT_EQ(visitor_.object_payloads_[0], "foo");
+ EXPECT_EQ(visitor_.object_payloads_[1], "bar");
+ stream_.Receive("", /*fin=*/true);
+ parser_.ReadAllData();
+ EXPECT_EQ(visitor_.parsing_error_, std::nullopt);
+}
+
+TEST_F(MoqtDataParserStateMachineTest, ReadObjects) {
+ stream_.Receive(StreamHeaderSubgroupMessage().PacketSample());
+ stream_.Receive(StreamMiddlerSubgroupMessage().PacketSample(), /*fin=*/true);
+ parser_.ReadAtMostOneObject();
+ ASSERT_EQ(visitor_.messages_received_, 1);
+ EXPECT_EQ(visitor_.object_payloads_[0], "foo");
+ parser_.ReadAtMostOneObject();
+ ASSERT_EQ(visitor_.messages_received_, 2);
+ EXPECT_EQ(visitor_.object_payloads_[1], "bar");
+ EXPECT_EQ(visitor_.parsing_error_, std::nullopt);
+}
+
+TEST_F(MoqtDataParserStateMachineTest, ReadTypeThenObjects) {
+ stream_.Receive(StreamHeaderSubgroupMessage().PacketSample());
+ stream_.Receive(StreamMiddlerSubgroupMessage().PacketSample(), /*fin=*/true);
+ parser_.ReadStreamType();
+ ASSERT_EQ(visitor_.messages_received_, 0);
+ EXPECT_EQ(parser_.stream_type(), MoqtDataStreamType::kStreamHeaderSubgroup);
+ parser_.ReadAtMostOneObject();
+ ASSERT_EQ(visitor_.messages_received_, 1);
+ EXPECT_EQ(visitor_.object_payloads_[0], "foo");
+ parser_.ReadAtMostOneObject();
+ ASSERT_EQ(visitor_.messages_received_, 2);
+ EXPECT_EQ(visitor_.object_payloads_[1], "bar");
+ EXPECT_EQ(visitor_.parsing_error_, std::nullopt);
+}
+
} // namespace moqt::test
diff --git a/quiche/web_transport/test_tools/in_memory_stream.cc b/quiche/web_transport/test_tools/in_memory_stream.cc
index 99d5354..724f4f8 100644
--- a/quiche/web_transport/test_tools/in_memory_stream.cc
+++ b/quiche/web_transport/test_tools/in_memory_stream.cc
@@ -43,9 +43,7 @@
return PeekResult{"", fin_received_, fin_received_};
}
absl::string_view next_chunk = *buffer_.Chunks().begin();
- return PeekResult{next_chunk,
- fin_received_ && next_chunk.size() == buffer_.size(),
- fin_received_};
+ return PeekResult{next_chunk, false, fin_received_};
}
bool InMemoryStream::SkipBytes(size_t bytes) {