| // Copyright 2024 The Chromium Authors. All rights reserved. | 
 | // Use of this source code is governed by a BSD-style license that can be | 
 | // found in the LICENSE file. | 
 |  | 
 | #include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" | 
 |  | 
 | #include <string> | 
 | #include <vector> | 
 |  | 
 | #include "absl/status/status.h" | 
 | #include "absl/strings/str_join.h" | 
 | #include "absl/strings/string_view.h" | 
 | #include "absl/types/span.h" | 
 | #include "absl/types/variant.h" | 
 | #include "quiche/quic/moqt/moqt_framer.h" | 
 | #include "quiche/quic/moqt/moqt_messages.h" | 
 | #include "quiche/quic/moqt/moqt_parser.h" | 
 | #include "quiche/common/platform/api/quiche_test.h" | 
 | #include "quiche/common/quiche_buffer_allocator.h" | 
 | #include "quiche/common/quiche_stream.h" | 
 | #include "quiche/common/simple_buffer_allocator.h" | 
 |  | 
 | namespace moqt::test { | 
 |  | 
 | namespace { | 
 |  | 
 | struct TypeVisitor { | 
 |   MoqtMessageType operator()(const MoqtClientSetup&) { | 
 |     return MoqtMessageType::kClientSetup; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtServerSetup&) { | 
 |     return MoqtMessageType::kServerSetup; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribe&) { | 
 |     return MoqtMessageType::kSubscribe; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeOk&) { | 
 |     return MoqtMessageType::kSubscribeOk; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeError&) { | 
 |     return MoqtMessageType::kSubscribeError; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtUnsubscribe&) { | 
 |     return MoqtMessageType::kUnsubscribe; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeDone&) { | 
 |     return MoqtMessageType::kSubscribeDone; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeUpdate&) { | 
 |     return MoqtMessageType::kSubscribeUpdate; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtAnnounce&) { | 
 |     return MoqtMessageType::kAnnounce; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtAnnounceOk&) { | 
 |     return MoqtMessageType::kAnnounceOk; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtAnnounceError&) { | 
 |     return MoqtMessageType::kAnnounceError; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtAnnounceCancel&) { | 
 |     return MoqtMessageType::kAnnounceCancel; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtTrackStatusRequest&) { | 
 |     return MoqtMessageType::kTrackStatusRequest; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtUnannounce&) { | 
 |     return MoqtMessageType::kUnannounce; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtTrackStatus&) { | 
 |     return MoqtMessageType::kTrackStatus; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtGoAway&) { | 
 |     return MoqtMessageType::kGoAway; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeAnnounces&) { | 
 |     return MoqtMessageType::kSubscribeAnnounces; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeAnnouncesOk&) { | 
 |     return MoqtMessageType::kSubscribeAnnouncesOk; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtSubscribeAnnouncesError&) { | 
 |     return MoqtMessageType::kSubscribeAnnouncesError; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtUnsubscribeAnnounces&) { | 
 |     return MoqtMessageType::kUnsubscribeAnnounces; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtMaxSubscribeId&) { | 
 |     return MoqtMessageType::kMaxSubscribeId; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtFetch&) { | 
 |     return MoqtMessageType::kFetch; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtFetchCancel&) { | 
 |     return MoqtMessageType::kFetchCancel; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtFetchOk&) { | 
 |     return MoqtMessageType::kFetchOk; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtFetchError&) { | 
 |     return MoqtMessageType::kFetchError; | 
 |   } | 
 |   MoqtMessageType operator()(const MoqtObjectAck&) { | 
 |     return MoqtMessageType::kObjectAck; | 
 |   } | 
 | }; | 
 |  | 
 | struct FramingVisitor { | 
 |   quiche::QuicheBuffer operator()(const MoqtClientSetup& message) { | 
 |     return framer.SerializeClientSetup(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtServerSetup& message) { | 
 |     return framer.SerializeServerSetup(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribe& message) { | 
 |     return framer.SerializeSubscribe(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeOk& message) { | 
 |     return framer.SerializeSubscribeOk(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeError& message) { | 
 |     return framer.SerializeSubscribeError(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtUnsubscribe& message) { | 
 |     return framer.SerializeUnsubscribe(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeDone& message) { | 
 |     return framer.SerializeSubscribeDone(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeUpdate& message) { | 
 |     return framer.SerializeSubscribeUpdate(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtAnnounce& message) { | 
 |     return framer.SerializeAnnounce(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtAnnounceOk& message) { | 
 |     return framer.SerializeAnnounceOk(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtAnnounceError& message) { | 
 |     return framer.SerializeAnnounceError(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtAnnounceCancel& message) { | 
 |     return framer.SerializeAnnounceCancel(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtTrackStatusRequest& message) { | 
 |     return framer.SerializeTrackStatusRequest(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtUnannounce& message) { | 
 |     return framer.SerializeUnannounce(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtTrackStatus& message) { | 
 |     return framer.SerializeTrackStatus(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtGoAway& message) { | 
 |     return framer.SerializeGoAway(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeAnnounces& message) { | 
 |     return framer.SerializeSubscribeAnnounces(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeAnnouncesOk& message) { | 
 |     return framer.SerializeSubscribeAnnouncesOk(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtSubscribeAnnouncesError& message) { | 
 |     return framer.SerializeSubscribeAnnouncesError(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtUnsubscribeAnnounces& message) { | 
 |     return framer.SerializeUnsubscribeAnnounces(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtMaxSubscribeId& message) { | 
 |     return framer.SerializeMaxSubscribeId(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtFetch& message) { | 
 |     return framer.SerializeFetch(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtFetchCancel& message) { | 
 |     return framer.SerializeFetchCancel(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtFetchOk& message) { | 
 |     return framer.SerializeFetchOk(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtFetchError& message) { | 
 |     return framer.SerializeFetchError(message); | 
 |   } | 
 |   quiche::QuicheBuffer operator()(const MoqtObjectAck& message) { | 
 |     return framer.SerializeObjectAck(message); | 
 |   } | 
 |  | 
 |   MoqtFramer& framer; | 
 | }; | 
 |  | 
 | class GenericMessageParseVisitor : public MoqtControlParserVisitor { | 
 |  public: | 
 |   explicit GenericMessageParseVisitor(std::vector<MoqtGenericFrame>* frames) | 
 |       : frames_(*frames) {} | 
 |  | 
 |   void OnClientSetupMessage(const MoqtClientSetup& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnServerSetupMessage(const MoqtServerSetup& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeMessage(const MoqtSubscribe& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeOkMessage(const MoqtSubscribeOk& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeErrorMessage(const MoqtSubscribeError& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnUnsubscribeMessage(const MoqtUnsubscribe& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeDoneMessage(const MoqtSubscribeDone& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnAnnounceMessage(const MoqtAnnounce& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnAnnounceOkMessage(const MoqtAnnounceOk& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnAnnounceErrorMessage(const MoqtAnnounceError& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnTrackStatusRequestMessage(const MoqtTrackStatusRequest& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnUnannounceMessage(const MoqtUnannounce& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnTrackStatusMessage(const MoqtTrackStatus& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnGoAwayMessage(const MoqtGoAway& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeAnnouncesMessage(const MoqtSubscribeAnnounces& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeAnnouncesOkMessage(const MoqtSubscribeAnnouncesOk& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnSubscribeAnnouncesErrorMessage( | 
 |       const MoqtSubscribeAnnouncesError& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnUnsubscribeAnnouncesMessage(const MoqtUnsubscribeAnnounces& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnFetchMessage(const MoqtFetch& message) { frames_.push_back(message); } | 
 |   void OnFetchCancelMessage(const MoqtFetchCancel& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnFetchOkMessage(const MoqtFetchOk& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnFetchErrorMessage(const MoqtFetchError& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |   void OnObjectAckMessage(const MoqtObjectAck& message) { | 
 |     frames_.push_back(message); | 
 |   } | 
 |  | 
 |   void OnParsingError(MoqtError code, absl::string_view reason) { | 
 |     ADD_FAILURE() << "Parsing failed: " << reason; | 
 |   } | 
 |  | 
 |  private: | 
 |   std::vector<MoqtGenericFrame>& frames_; | 
 | }; | 
 |  | 
 | }  // namespace | 
 |  | 
 | std::string SerializeGenericMessage(const MoqtGenericFrame& frame, | 
 |                                     bool use_webtrans) { | 
 |   MoqtFramer framer(quiche::SimpleBufferAllocator::Get(), use_webtrans); | 
 |   return std::string(absl::visit(FramingVisitor{framer}, frame).AsStringView()); | 
 | } | 
 |  | 
 | MoqtMessageType MessageTypeForGenericMessage(const MoqtGenericFrame& frame) { | 
 |   return absl::visit(TypeVisitor(), frame); | 
 | } | 
 |  | 
 | std::vector<MoqtGenericFrame> ParseGenericMessage(absl::string_view body) { | 
 |   std::vector<MoqtGenericFrame> result; | 
 |   GenericMessageParseVisitor visitor(&result); | 
 |   MoqtControlParser parser(/*uses_web_transport=*/true, visitor); | 
 |   parser.ProcessData(body, /*fin=*/true); | 
 |   return result; | 
 | } | 
 |  | 
 | absl::Status StoreSubscribe::operator()( | 
 |     absl::Span<const absl::string_view> data, | 
 |     const quiche::StreamWriteOptions& options) const { | 
 |   std::string merged_message = absl::StrJoin(data, ""); | 
 |   std::vector<MoqtGenericFrame> frames = ParseGenericMessage(merged_message); | 
 |   if (frames.size() != 1 || | 
 |       !absl::holds_alternative<MoqtSubscribe>(frames[0])) { | 
 |     ADD_FAILURE() << "Expected one SUBSCRIBE frame in a write"; | 
 |     return absl::InternalError("Expected one SUBSCRIBE frame in a write"); | 
 |   } | 
 |   *subscribe_ = absl::get<MoqtSubscribe>(frames[0]); | 
 |   return absl::OkStatus(); | 
 | } | 
 |  | 
 | }  // namespace moqt::test |