Add test utils to parse control messages and to extract subscribe details.
Those are useful for application tests since the subscribe ID and track alias are needed to have inputs to further stages in the test.
PiperOrigin-RevId: 702836322
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc
index 5d0116d..8005d4e 100644
--- a/quiche/quic/moqt/test_tools/moqt_framer_utils.cc
+++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.cc
@@ -5,15 +5,25 @@
#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h"
#include <string>
+#include <vector>
+#include "absl/status/status.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "quiche/quic/moqt/moqt_framer.h"
#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_parser.h"
+#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/common/quiche_buffer_allocator.h"
+#include "quiche/common/quiche_stream.h"
#include "quiche/common/simple_buffer_allocator.h"
namespace moqt::test {
+namespace {
+
struct TypeVisitor {
MoqtMessageType operator()(const MoqtClientSetup&) {
return MoqtMessageType::kClientSetup;
@@ -95,10 +105,6 @@
}
};
-MoqtMessageType MessageTypeForGenericMessage(const MoqtGenericFrame& frame) {
- return absl::visit(TypeVisitor(), frame);
-}
-
struct FramingVisitor {
quiche::QuicheBuffer operator()(const MoqtClientSetup& message) {
return framer.SerializeClientSetup(message);
@@ -182,10 +188,129 @@
MoqtFramer& framer;
};
+class GenericMessageParseVisitor : public MoqtControlParserVisitor {
+ public:
+ explicit GenericMessageParseVisitor(std::vector<MoqtGenericFrame>* frames)
+ : frames_(*frames) {}
+
+ void OnClientSetupMessage(const MoqtClientSetup& message) {
+ frames_.push_back(message);
+ }
+ void OnServerSetupMessage(const MoqtServerSetup& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeMessage(const MoqtSubscribe& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeOkMessage(const MoqtSubscribeOk& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeErrorMessage(const MoqtSubscribeError& message) {
+ frames_.push_back(message);
+ }
+ void OnUnsubscribeMessage(const MoqtUnsubscribe& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeDoneMessage(const MoqtSubscribeDone& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) {
+ frames_.push_back(message);
+ }
+ void OnAnnounceMessage(const MoqtAnnounce& message) {
+ frames_.push_back(message);
+ }
+ void OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
+ frames_.push_back(message);
+ }
+ void OnAnnounceErrorMessage(const MoqtAnnounceError& message) {
+ frames_.push_back(message);
+ }
+ void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) {
+ frames_.push_back(message);
+ }
+ void OnTrackStatusRequestMessage(const MoqtTrackStatusRequest& message) {
+ frames_.push_back(message);
+ }
+ void OnUnannounceMessage(const MoqtUnannounce& message) {
+ frames_.push_back(message);
+ }
+ void OnTrackStatusMessage(const MoqtTrackStatus& message) {
+ frames_.push_back(message);
+ }
+ void OnGoAwayMessage(const MoqtGoAway& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeAnnouncesMessage(const MoqtSubscribeAnnounces& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeAnnouncesOkMessage(const MoqtSubscribeAnnouncesOk& message) {
+ frames_.push_back(message);
+ }
+ void OnSubscribeAnnouncesErrorMessage(
+ const MoqtSubscribeAnnouncesError& message) {
+ frames_.push_back(message);
+ }
+ void OnUnsubscribeAnnouncesMessage(const MoqtUnsubscribeAnnounces& message) {
+ frames_.push_back(message);
+ }
+ void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) {
+ frames_.push_back(message);
+ }
+ void OnFetchMessage(const MoqtFetch& message) { frames_.push_back(message); }
+ void OnFetchCancelMessage(const MoqtFetchCancel& message) {
+ frames_.push_back(message);
+ }
+ void OnFetchOkMessage(const MoqtFetchOk& message) {
+ frames_.push_back(message);
+ }
+ void OnFetchErrorMessage(const MoqtFetchError& message) {
+ frames_.push_back(message);
+ }
+ void OnObjectAckMessage(const MoqtObjectAck& message) {
+ frames_.push_back(message);
+ }
+
+ void OnParsingError(MoqtError code, absl::string_view reason) {
+ ADD_FAILURE() << "Parsing failed: " << reason;
+ }
+
+ private:
+ std::vector<MoqtGenericFrame>& frames_;
+};
+
+} // namespace
+
std::string SerializeGenericMessage(const MoqtGenericFrame& frame,
bool use_webtrans) {
MoqtFramer framer(quiche::SimpleBufferAllocator::Get(), use_webtrans);
return std::string(absl::visit(FramingVisitor{framer}, frame).AsStringView());
}
+MoqtMessageType MessageTypeForGenericMessage(const MoqtGenericFrame& frame) {
+ return absl::visit(TypeVisitor(), frame);
+}
+
+std::vector<MoqtGenericFrame> ParseGenericMessage(absl::string_view body) {
+ std::vector<MoqtGenericFrame> result;
+ GenericMessageParseVisitor visitor(&result);
+ MoqtControlParser parser(/*uses_web_transport=*/true, visitor);
+ parser.ProcessData(body, /*fin=*/true);
+ return result;
+}
+
+absl::Status StoreSubscribe::operator()(
+ absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) const {
+ std::string merged_message = absl::StrJoin(data, "");
+ std::vector<MoqtGenericFrame> frames = ParseGenericMessage(merged_message);
+ if (frames.size() != 1 ||
+ !absl::holds_alternative<MoqtSubscribe>(frames[0])) {
+ ADD_FAILURE() << "Expected one SUBSCRIBE frame in a write";
+ return absl::InternalError("Expected one SUBSCRIBE frame in a write");
+ }
+ *subscribe_ = absl::get<MoqtSubscribe>(frames[0]);
+ return absl::OkStatus();
+}
+
} // namespace moqt::test
diff --git a/quiche/quic/moqt/test_tools/moqt_framer_utils.h b/quiche/quic/moqt/test_tools/moqt_framer_utils.h
index a1ad760..efbe5ca 100644
--- a/quiche/quic/moqt/test_tools/moqt_framer_utils.h
+++ b/quiche/quic/moqt/test_tools/moqt_framer_utils.h
@@ -6,13 +6,19 @@
#define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_
#include <cstdint>
+#include <optional>
#include <string>
+#include <vector>
+#include "absl/status/status.h"
#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "quiche/quic/moqt/moqt_messages.h"
-#include "quiche/quic/platform/api/quic_test.h"
+#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/common/quiche_data_reader.h"
+#include "quiche/common/quiche_stream.h"
namespace moqt::test {
@@ -32,6 +38,9 @@
std::string SerializeGenericMessage(const MoqtGenericFrame& frame,
bool use_webtrans = false);
+// Parses a concatenation of one or more MoQT control messages.
+std::vector<MoqtGenericFrame> ParseGenericMessage(absl::string_view body);
+
MATCHER_P(SerializedControlMessage, message,
"Matches against a specific expected MoQT message") {
std::string merged_message = absl::StrJoin(arg, "");
@@ -57,6 +66,20 @@
return true;
}
+// gmock action for extracting an SUBSCRIBE message written onto a stream.
+class StoreSubscribe {
+ public:
+ explicit StoreSubscribe(std::optional<MoqtSubscribe>* subscribe)
+ : subscribe_(subscribe) {}
+
+ // quiche::WriteStream::Writev() implementation.
+ absl::Status operator()(absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) const;
+
+ private:
+ std::optional<MoqtSubscribe>* subscribe_;
+};
+
} // namespace moqt::test
#endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_FRAMER_UTILS_H_