// Copyright 2024 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h"

#include <string>
#include <variant>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "quiche/quic/moqt/moqt_framer.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/common/quiche_buffer_allocator.h"
#include "quiche/common/quiche_stream.h"
#include "quiche/common/simple_buffer_allocator.h"
#include "quiche/web_transport/test_tools/in_memory_stream.h"

namespace moqt::test {

namespace {

struct TypeVisitor {
  MoqtMessageType operator()(const MoqtClientSetup&) {
    return MoqtMessageType::kClientSetup;
  }
  MoqtMessageType operator()(const MoqtServerSetup&) {
    return MoqtMessageType::kServerSetup;
  }
  MoqtMessageType operator()(const MoqtSubscribe&) {
    return MoqtMessageType::kSubscribe;
  }
  MoqtMessageType operator()(const MoqtSubscribeOk&) {
    return MoqtMessageType::kSubscribeOk;
  }
  MoqtMessageType operator()(const MoqtSubscribeError&) {
    return MoqtMessageType::kSubscribeError;
  }
  MoqtMessageType operator()(const MoqtUnsubscribe&) {
    return MoqtMessageType::kUnsubscribe;
  }
  MoqtMessageType operator()(const MoqtSubscribeDone&) {
    return MoqtMessageType::kSubscribeDone;
  }
  MoqtMessageType operator()(const MoqtSubscribeUpdate&) {
    return MoqtMessageType::kSubscribeUpdate;
  }
  MoqtMessageType operator()(const MoqtAnnounce&) {
    return MoqtMessageType::kAnnounce;
  }
  MoqtMessageType operator()(const MoqtAnnounceOk&) {
    return MoqtMessageType::kAnnounceOk;
  }
  MoqtMessageType operator()(const MoqtAnnounceError&) {
    return MoqtMessageType::kAnnounceError;
  }
  MoqtMessageType operator()(const MoqtAnnounceCancel&) {
    return MoqtMessageType::kAnnounceCancel;
  }
  MoqtMessageType operator()(const MoqtTrackStatusRequest&) {
    return MoqtMessageType::kTrackStatusRequest;
  }
  MoqtMessageType operator()(const MoqtUnannounce&) {
    return MoqtMessageType::kUnannounce;
  }
  MoqtMessageType operator()(const MoqtTrackStatus&) {
    return MoqtMessageType::kTrackStatus;
  }
  MoqtMessageType operator()(const MoqtGoAway&) {
    return MoqtMessageType::kGoAway;
  }
  MoqtMessageType operator()(const MoqtSubscribeAnnounces&) {
    return MoqtMessageType::kSubscribeAnnounces;
  }
  MoqtMessageType operator()(const MoqtSubscribeAnnouncesOk&) {
    return MoqtMessageType::kSubscribeAnnouncesOk;
  }
  MoqtMessageType operator()(const MoqtSubscribeAnnouncesError&) {
    return MoqtMessageType::kSubscribeAnnouncesError;
  }
  MoqtMessageType operator()(const MoqtUnsubscribeAnnounces&) {
    return MoqtMessageType::kUnsubscribeAnnounces;
  }
  MoqtMessageType operator()(const MoqtMaxRequestId&) {
    return MoqtMessageType::kMaxRequestId;
  }
  MoqtMessageType operator()(const MoqtFetch&) {
    return MoqtMessageType::kFetch;
  }
  MoqtMessageType operator()(const MoqtFetchCancel&) {
    return MoqtMessageType::kFetchCancel;
  }
  MoqtMessageType operator()(const MoqtFetchOk&) {
    return MoqtMessageType::kFetchOk;
  }
  MoqtMessageType operator()(const MoqtFetchError&) {
    return MoqtMessageType::kFetchError;
  }
  MoqtMessageType operator()(const MoqtSubscribesBlocked&) {
    return MoqtMessageType::kSubscribesBlocked;
  }
  MoqtMessageType operator()(const MoqtObjectAck&) {
    return MoqtMessageType::kObjectAck;
  }
};

struct FramingVisitor {
  quiche::QuicheBuffer operator()(const MoqtClientSetup& message) {
    return framer.SerializeClientSetup(message);
  }
  quiche::QuicheBuffer operator()(const MoqtServerSetup& message) {
    return framer.SerializeServerSetup(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribe& message) {
    return framer.SerializeSubscribe(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeOk& message) {
    return framer.SerializeSubscribeOk(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeError& message) {
    return framer.SerializeSubscribeError(message);
  }
  quiche::QuicheBuffer operator()(const MoqtUnsubscribe& message) {
    return framer.SerializeUnsubscribe(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeDone& message) {
    return framer.SerializeSubscribeDone(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeUpdate& message) {
    return framer.SerializeSubscribeUpdate(message);
  }
  quiche::QuicheBuffer operator()(const MoqtAnnounce& message) {
    return framer.SerializeAnnounce(message);
  }
  quiche::QuicheBuffer operator()(const MoqtAnnounceOk& message) {
    return framer.SerializeAnnounceOk(message);
  }
  quiche::QuicheBuffer operator()(const MoqtAnnounceError& message) {
    return framer.SerializeAnnounceError(message);
  }
  quiche::QuicheBuffer operator()(const MoqtAnnounceCancel& message) {
    return framer.SerializeAnnounceCancel(message);
  }
  quiche::QuicheBuffer operator()(const MoqtTrackStatusRequest& message) {
    return framer.SerializeTrackStatusRequest(message);
  }
  quiche::QuicheBuffer operator()(const MoqtUnannounce& message) {
    return framer.SerializeUnannounce(message);
  }
  quiche::QuicheBuffer operator()(const MoqtTrackStatus& message) {
    return framer.SerializeTrackStatus(message);
  }
  quiche::QuicheBuffer operator()(const MoqtGoAway& message) {
    return framer.SerializeGoAway(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeAnnounces& message) {
    return framer.SerializeSubscribeAnnounces(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeAnnouncesOk& message) {
    return framer.SerializeSubscribeAnnouncesOk(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribeAnnouncesError& message) {
    return framer.SerializeSubscribeAnnouncesError(message);
  }
  quiche::QuicheBuffer operator()(const MoqtUnsubscribeAnnounces& message) {
    return framer.SerializeUnsubscribeAnnounces(message);
  }
  quiche::QuicheBuffer operator()(const MoqtMaxRequestId& message) {
    return framer.SerializeMaxRequestId(message);
  }
  quiche::QuicheBuffer operator()(const MoqtFetch& message) {
    return framer.SerializeFetch(message);
  }
  quiche::QuicheBuffer operator()(const MoqtFetchCancel& message) {
    return framer.SerializeFetchCancel(message);
  }
  quiche::QuicheBuffer operator()(const MoqtFetchOk& message) {
    return framer.SerializeFetchOk(message);
  }
  quiche::QuicheBuffer operator()(const MoqtFetchError& message) {
    return framer.SerializeFetchError(message);
  }
  quiche::QuicheBuffer operator()(const MoqtSubscribesBlocked& message) {
    return framer.SerializeSubscribesBlocked(message);
  }
  quiche::QuicheBuffer operator()(const MoqtObjectAck& message) {
    return framer.SerializeObjectAck(message);
  }

  MoqtFramer& framer;
};

class GenericMessageParseVisitor : public MoqtControlParserVisitor {
 public:
  explicit GenericMessageParseVisitor(std::vector<MoqtGenericFrame>* frames)
      : frames_(*frames) {}

  void OnClientSetupMessage(const MoqtClientSetup& message) {
    frames_.push_back(message);
  }
  void OnServerSetupMessage(const MoqtServerSetup& message) {
    frames_.push_back(message);
  }
  void OnSubscribeMessage(const MoqtSubscribe& message) {
    frames_.push_back(message);
  }
  void OnSubscribeOkMessage(const MoqtSubscribeOk& message) {
    frames_.push_back(message);
  }
  void OnSubscribeErrorMessage(const MoqtSubscribeError& message) {
    frames_.push_back(message);
  }
  void OnUnsubscribeMessage(const MoqtUnsubscribe& message) {
    frames_.push_back(message);
  }
  void OnSubscribeDoneMessage(const MoqtSubscribeDone& message) {
    frames_.push_back(message);
  }
  void OnSubscribeUpdateMessage(const MoqtSubscribeUpdate& message) {
    frames_.push_back(message);
  }
  void OnAnnounceMessage(const MoqtAnnounce& message) {
    frames_.push_back(message);
  }
  void OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
    frames_.push_back(message);
  }
  void OnAnnounceErrorMessage(const MoqtAnnounceError& message) {
    frames_.push_back(message);
  }
  void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) {
    frames_.push_back(message);
  }
  void OnTrackStatusRequestMessage(const MoqtTrackStatusRequest& message) {
    frames_.push_back(message);
  }
  void OnUnannounceMessage(const MoqtUnannounce& message) {
    frames_.push_back(message);
  }
  void OnTrackStatusMessage(const MoqtTrackStatus& message) {
    frames_.push_back(message);
  }
  void OnGoAwayMessage(const MoqtGoAway& message) {
    frames_.push_back(message);
  }
  void OnSubscribeAnnouncesMessage(const MoqtSubscribeAnnounces& message) {
    frames_.push_back(message);
  }
  void OnSubscribeAnnouncesOkMessage(const MoqtSubscribeAnnouncesOk& message) {
    frames_.push_back(message);
  }
  void OnSubscribeAnnouncesErrorMessage(
      const MoqtSubscribeAnnouncesError& message) {
    frames_.push_back(message);
  }
  void OnUnsubscribeAnnouncesMessage(const MoqtUnsubscribeAnnounces& message) {
    frames_.push_back(message);
  }
  void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) {
    frames_.push_back(message);
  }
  void OnFetchMessage(const MoqtFetch& message) { frames_.push_back(message); }
  void OnFetchCancelMessage(const MoqtFetchCancel& message) {
    frames_.push_back(message);
  }
  void OnFetchOkMessage(const MoqtFetchOk& message) {
    frames_.push_back(message);
  }
  void OnFetchErrorMessage(const MoqtFetchError& message) {
    frames_.push_back(message);
  }
  void OnSubscribesBlockedMessage(const MoqtSubscribesBlocked& message) {
    frames_.push_back(message);
  }
  void OnObjectAckMessage(const MoqtObjectAck& message) {
    frames_.push_back(message);
  }

  void OnParsingError(MoqtError code, absl::string_view reason) {
    ADD_FAILURE() << "Parsing failed: " << reason;
  }

 private:
  std::vector<MoqtGenericFrame>& frames_;
};

}  // namespace

std::string SerializeGenericMessage(const MoqtGenericFrame& frame,
                                    bool use_webtrans) {
  MoqtFramer framer(quiche::SimpleBufferAllocator::Get(), use_webtrans);
  return std::string(std::visit(FramingVisitor{framer}, frame).AsStringView());
}

MoqtMessageType MessageTypeForGenericMessage(const MoqtGenericFrame& frame) {
  return std::visit(TypeVisitor(), frame);
}

std::vector<MoqtGenericFrame> ParseGenericMessage(absl::string_view body) {
  std::vector<MoqtGenericFrame> result;
  GenericMessageParseVisitor visitor(&result);
  webtransport::test::InMemoryStream stream(/*id=*/0);
  MoqtControlParser parser(/*uses_web_transport=*/true, &stream, visitor);
  stream.Receive(body, /*fin=*/false);
  parser.ReadAndDispatchMessages();
  return result;
}

absl::Status StoreSubscribe::operator()(
    absl::Span<const absl::string_view> data,
    const quiche::StreamWriteOptions& options) const {
  std::string merged_message = absl::StrJoin(data, "");
  std::vector<MoqtGenericFrame> frames = ParseGenericMessage(merged_message);
  if (frames.size() != 1 || !std::holds_alternative<MoqtSubscribe>(frames[0])) {
    ADD_FAILURE() << "Expected one SUBSCRIBE frame in a write";
    return absl::InternalError("Expected one SUBSCRIBE frame in a write");
  }
  *subscribe_ = std::get<MoqtSubscribe>(frames[0]);
  return absl::OkStatus();
}

}  // namespace moqt::test
