blob: be5451d853337cc32204a378d595a46a2fe9c521 [file]
// Copyright (c) 2026 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H
#define QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <queue>
#include <utility>
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "quiche/quic/core/quic_time.h"
#include "quiche/quic/moqt/moqt_error.h"
#include "quiche/quic/moqt/moqt_framer.h"
#include "quiche/quic/moqt/moqt_key_value_pair.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/common/platform/api/quiche_bug_tracker.h"
#include "quiche/common/quiche_buffer_allocator.h"
#include "quiche/common/quiche_callbacks.h"
#include "quiche/common/quiche_mem_slice.h"
#include "quiche/common/quiche_stream.h"
#include "quiche/web_transport/web_transport.h"
namespace moqt {
enum class MoqtBidiStreamType : uint8_t {
kUnknown,
kControl,
kSubscribeNamespace, // TODO(martinduke): Support this case.
};
using SessionErrorCallback =
quiche::SingleUseCallback<void(MoqtError, absl::string_view)>;
// The provider of this callback owns nothing in MoqtBidiStreamBase. This merely
// deletes the record.
using BidiStreamDeletedCallback = quiche::SingleUseCallback<void()>;
// A generic parser visitor that assumes all messages are invalid. Serves a base
// class for visitors that accept a subset of messages and maintains state based
// on those messages.
class MoqtBidiStreamBase : public MoqtControlParserVisitor,
public webtransport::StreamVisitor {
public:
MoqtBidiStreamBase(MoqtFramer* absl_nonnull framer,
BidiStreamDeletedCallback stream_deleted_callback,
SessionErrorCallback session_error_callback)
: framer_(framer),
stream_deleted_callback_(std::move(stream_deleted_callback)),
session_error_callback_(std::move(session_error_callback)) {}
~MoqtBidiStreamBase() override { std::move(stream_deleted_callback_)(); }
virtual void set_stream(webtransport::Stream* absl_nonnull stream) {
set_stream(stream, std::nullopt);
}
// MoqtControlParserVisitor implementation. All control messages are protocol
// violations by default.
virtual void OnClientSetupMessage(const MoqtClientSetup& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnServerSetupMessage(const MoqtServerSetup& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnRequestOkMessage(const MoqtRequestOk& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnRequestErrorMessage(const MoqtRequestError& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnSubscribeMessage(const MoqtSubscribe& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnPublishDoneMessage(const MoqtPublishDone& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnSubscribeUpdateMessage(
const MoqtSubscribeUpdate& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnPublishNamespaceMessage(
const MoqtPublishNamespace& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnPublishNamespaceDoneMessage(
const MoqtPublishNamespaceDone& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnNamespaceMessage(const MoqtNamespace& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnNamespaceDoneMessage(
const MoqtNamespaceDone& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnPublishNamespaceCancelMessage(
const MoqtPublishNamespaceCancel& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnGoAwayMessage(const MoqtGoAway& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnSubscribeNamespaceMessage(
const MoqtSubscribeNamespace& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnFetchMessage(const MoqtFetch& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnFetchCancelMessage(const MoqtFetchCancel& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnFetchOkMessage(const MoqtFetchOk& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnRequestsBlockedMessage(
const MoqtRequestsBlocked& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnPublishMessage(const MoqtPublish& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnPublishOkMessage(const MoqtPublishOk& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnObjectAckMessage(const MoqtObjectAck& message) override {
OnParsingError(wrong_message_error_, wrong_message_reason_);
}
virtual void OnParsingError(MoqtError code,
absl::string_view reason) override {
std::move(session_error_callback_)(code, reason);
}
// webtransport::StreamVisitor implementation.
void OnResetStreamReceived(webtransport::StreamErrorCode error) override {}
void OnStopSendingReceived(webtransport::StreamErrorCode error) override {}
void OnWriteSideInDataRecvdState() override {}
void OnCanRead() override {
if (parser_ == nullptr) {
QUICHE_BUG(quiche_bug_moqt_parser_is_null) << "Parser is null";
return;
}
parser_->ReadAndDispatchMessages();
}
void OnCanWrite() override {
if (pending_messages_.empty() && fin_queued_) {
if (!stream_->SendFin()) {
std::move(session_error_callback_)(MoqtError::kInternalError,
"Failed to send FIN");
}
return;
}
while (!pending_messages_.empty() && stream_->CanWrite()) {
SendMessage(std::move(pending_messages_.front()),
fin_queued_ && pending_messages_.size() == 1);
pending_messages_.pop();
}
}
bool QueueIsFull() const {
return pending_messages_.size() == kMaxPendingMessages;
}
void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false) {
if (fin_queued_) {
return;
}
if (stream_ == nullptr || !stream_->CanWrite()) {
AddToQueue(std::move(message));
return;
}
SendMessage(std::move(message), fin);
}
void SendRequestOk(uint64_t request_id, const MessageParameters& parameters,
bool fin = false) {
SendOrBufferMessage(
framer_->SerializeRequestOk(MoqtRequestOk{request_id, parameters}),
fin);
}
void SendRequestError(uint64_t request_id, RequestErrorCode error_code,
std::optional<quic::QuicTimeDelta> retry_interval,
absl::string_view reason_phrase, bool fin = false) {
MoqtRequestError request_error;
request_error.request_id = request_id;
request_error.error_code = error_code;
request_error.retry_interval = retry_interval;
request_error.reason_phrase = reason_phrase;
SendOrBufferMessage(framer_->SerializeRequestError(request_error), fin);
}
void SendRequestError(uint64_t request_id, MoqtRequestErrorInfo info,
bool fin = false) {
SendRequestError(request_id, info.error_code, info.retry_interval,
info.reason_phrase, fin);
}
void Fin() {
fin_queued_ = true;
if (pending_messages_.empty()) {
if (stream_ != nullptr && !SendFinOnStream(*stream_).ok()) {
std::move(session_error_callback_)(MoqtError::kInternalError,
"Failed to send FIN");
}
return;
}
}
void Reset(webtransport::StreamErrorCode error) {
if (stream_ != nullptr) {
stream_->ResetWithUserCode(error);
}
}
protected:
// The caller is responsible for calling stream->SetVisitor(). Derived
// classes will wrap this with a call to stream->SetPriority().
void set_stream(webtransport::Stream* absl_nonnull stream,
std::optional<MoqtMessageType> first_message_type) {
stream_ = stream;
parser_ = std::make_unique<MoqtControlParser>(framer_->using_webtrans(),
stream_, *this);
if (first_message_type.has_value()) {
parser_->set_message_type(static_cast<uint64_t>(*first_message_type));
}
}
const size_t kMaxPendingMessages = 100;
void AddToQueue(quiche::QuicheBuffer message) {
if (pending_messages_.size() == kMaxPendingMessages) {
std::move(session_error_callback_)(
MoqtError::kInternalError,
"Not enough flow credit on the control stream");
return;
}
pending_messages_.push(std::move(message));
}
MoqtFramer* absl_nonnull framer_;
MoqtControlParser* parser() { return parser_.get(); }
void OnBidiStreamDeleted() {
if (stream_deleted_callback_ != nullptr) {
std::move(stream_deleted_callback_)();
}
}
webtransport::Stream* stream() { return stream_; }
private:
void SendMessage(quiche::QuicheBuffer message, bool fin) {
quiche::StreamWriteOptions options;
options.set_send_fin(fin);
// TODO: while we buffer unconditionally, we should still at some point tear
// down the connection if we've buffered too many control messages;
// otherwise, there is potential for memory exhaustion attacks.
options.set_buffer_unconditionally(true);
std::array write_vector = {quiche::QuicheMemSlice(std::move(message))};
absl::Status success =
stream_->Writev(absl::MakeSpan(write_vector), options);
if (!success.ok()) {
std::move(session_error_callback_)(MoqtError::kInternalError,
"Failed to write a control message");
}
}
webtransport::Stream* stream_;
std::unique_ptr<MoqtControlParser> parser_;
std::queue<quiche::QuicheBuffer> pending_messages_;
bool fin_queued_ = false;
BidiStreamDeletedCallback stream_deleted_callback_;
SessionErrorCallback session_error_callback_;
const MoqtError wrong_message_error_ = MoqtError::kProtocolViolation;
const absl::string_view wrong_message_reason_ =
"Message not allowed for this stream type";
};
} // namespace moqt
#endif // QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H