Use new MOQT control message parser API directly. In particular, message types supported by a stream are now inferred automatically; there is no need to manually create a stub for every message. Error handling is modified to return errors from methods instead of passing them to a callback when possible. In previous iterations, we had a lot of complications where an error caused the connection to be destroyed, deleting some of the state that was still accessed up the call stack. The new error handling pattern aims to minimize the call stack at the point where we delete the connection. PiperOrigin-RevId: 914700657
diff --git a/build/source_list.bzl b/build/source_list.bzl index d741557..b7ca694 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -1617,6 +1617,7 @@ "quic/moqt/tools/moqt_server.h", ] moqt_srcs = [ + "quic/moqt/moqt_bidi_stream.cc", "quic/moqt/moqt_bitrate_adjuster.cc", "quic/moqt/moqt_error.cc", "quic/moqt/moqt_framer.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index f07ff9b..2de8cbf 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -1621,6 +1621,7 @@ "src/quiche/quic/moqt/tools/moqt_server.h", ] moqt_srcs = [ + "src/quiche/quic/moqt/moqt_bidi_stream.cc", "src/quiche/quic/moqt/moqt_bitrate_adjuster.cc", "src/quiche/quic/moqt/moqt_error.cc", "src/quiche/quic/moqt/moqt_framer.cc",
diff --git a/build/source_list.json b/build/source_list.json index dc40e1a..e159716 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -1620,6 +1620,7 @@ "quiche/quic/moqt/tools/moqt_server.h" ], "moqt_srcs": [ + "quiche/quic/moqt/moqt_bidi_stream.cc", "quiche/quic/moqt/moqt_bitrate_adjuster.cc", "quiche/quic/moqt/moqt_error.cc", "quiche/quic/moqt/moqt_framer.cc",
diff --git a/quiche/quic/moqt/moqt_bidi_stream.cc b/quiche/quic/moqt/moqt_bidi_stream.cc new file mode 100644 index 0000000..3d6b9c6 --- /dev/null +++ b/quiche/quic/moqt/moqt_bidi_stream.cc
@@ -0,0 +1,149 @@ +// Copyright 2026 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/moqt/moqt_bidi_stream.h" + +#include <array> +#include <cstdint> +#include <optional> +#include <utility> + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_error.h" +#include "quiche/quic/moqt/moqt_key_value_pair.h" +#include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_parser.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_mem_slice.h" +#include "quiche/web_transport/stream_helpers.h" +#include "quiche/web_transport/web_transport.h" + +namespace moqt { + +void MoqtBidiStreamBase::OnCanRead() { + if (stream_parser_ == nullptr) { + QUICHE_BUG(MoqtBidiStreamBase_OnCanRead_no_stream) + << "OnCanRead() called when no stream is bound"; + return; + } + while (!stream_parser_->fin_read()) { + absl::StatusOr<MoqtRawControlMessage> message = + stream_parser_->ReadNextMessage(); + if (absl::IsUnavailable(message.status())) { + return; + } + if (!message.ok()) { + OnFatalError(message.status()); + return; + } + absl::Status status = OnRawControlMessage(*message); + if (!status.ok()) { + OnFatalError(status); + return; + } + } +} + +void MoqtBidiStreamBase::OnCanWrite() { + if (stream_parser_ == nullptr) { + QUICHE_BUG(MoqtBidiStreamBase_OnCanWrite_no_stream) + << "OnCanWrite() called when no stream is bound"; + return; + } + webtransport::Stream* stream = stream_parser_->stream(); + if (pending_messages_.empty() && fin_queued_) { + absl::Status status = webtransport::SendFinOnStream(*stream); + if (!status.ok()) { + OnFatalError(status); + } + return; + } + while (!pending_messages_.empty() && stream->CanWrite()) { + absl::Status status = + SendMessage(std::move(pending_messages_.front()), + fin_queued_ && pending_messages_.size() == 1); + pending_messages_.pop_front(); + if (!status.ok()) { + OnFatalError(status); + return; + } + } +} + +absl::Status MoqtBidiStreamBase::SendOrBufferMessage( + quiche::QuicheBuffer message, bool fin) { + if (fin_queued_) { + return absl::InternalError( + "Trying to send data when a FIN has been already queued"); + } + if (stream() == nullptr || !stream()->CanWrite()) { + fin_queued_ = fin; + return AddToQueue(std::move(message)); + } + return SendMessage(std::move(message), fin); +} + +absl::Status MoqtBidiStreamBase::SendRequestOk( + uint64_t request_id, const MessageParameters& parameters, bool fin) { + return SendOrBufferMessage( + framer_->SerializeRequestOk(MoqtRequestOk{request_id, parameters}), fin); +} + +absl::Status MoqtBidiStreamBase::SendRequestError( + uint64_t request_id, RequestErrorCode error_code, + std::optional<quic::QuicTimeDelta> retry_interval, + absl::string_view reason_phrase, bool fin) { + MoqtRequestError request_error; + request_error.request_id = request_id; + request_error.error_code = error_code; + request_error.retry_interval = retry_interval; + request_error.reason_phrase = reason_phrase; + return SendOrBufferMessage(framer_->SerializeRequestError(request_error), + fin); +} + +absl::Status MoqtBidiStreamBase::SendRequestError(uint64_t request_id, + MoqtRequestErrorInfo info, + bool fin) { + return SendRequestError(request_id, info.error_code, info.retry_interval, + info.reason_phrase, fin); +} + +void MoqtBidiStreamBase::OnFatalError(absl::Status status) { + QUICHE_DCHECK(!status.ok()); + if (session_error_callback_ == nullptr) { + return; + } + std::optional<MoqtError> error_code = GetMoqtErrorForStatus(status); + if (!error_code.has_value()) { + error_code = absl::IsInvalidArgument(status) ? MoqtError::kProtocolViolation + : MoqtError::kInternalError; + } + std::move(session_error_callback_)(*error_code, status.message()); +} + +absl::Status MoqtBidiStreamBase::AddToQueue(quiche::QuicheBuffer message) { + if (pending_messages_.size() == kMaxPendingMessages) { + return absl::ResourceExhaustedError( + "Not enough flow credit on the control stream"); + } + pending_messages_.push_back(std::move(message)); + return absl::OkStatus(); +} + +absl::Status MoqtBidiStreamBase::SendMessage(quiche::QuicheBuffer message, + bool fin) { + webtransport::StreamWriteOptions options; + options.set_send_fin(fin); + std::array write_vector = {quiche::QuicheMemSlice(std::move(message))}; + return stream()->Writev(absl::MakeSpan(write_vector), options); +} + +} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_bidi_stream.h b/quiche/quic/moqt/moqt_bidi_stream.h index e2bc925..29f98f4 100644 --- a/quiche/quic/moqt/moqt_bidi_stream.h +++ b/quiche/quic/moqt/moqt_bidi_stream.h
@@ -5,38 +5,35 @@ #ifndef QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H #define QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H -#include <array> #include <cstddef> #include <cstdint> #include <memory> #include <optional> -#include <queue> +#include <type_traits> #include <utility> +#include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "quiche/quic/core/quic_time.h" #include "quiche/quic/moqt/moqt_error.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_parser.h" -#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_callbacks.h" -#include "quiche/common/quiche_mem_slice.h" -#include "quiche/web_transport/stream_helpers.h" +#include "quiche/common/quiche_circular_deque.h" #include "quiche/web_transport/web_transport.h" namespace moqt { -enum class MoqtBidiStreamType : uint8_t { - kUnknown, - kControl, - kSubscribeNamespace, // TODO(martinduke): Support this case. -}; +namespace test { +class MoqtBidiStreamTestWrapper; +} using SessionErrorCallback = quiche::SingleUseCallback<void(MoqtError, absl::string_view)>; @@ -44,243 +41,150 @@ // deletes the record. using BidiStreamDeletedCallback = quiche::SingleUseCallback<void()>; -// A generic parser visitor that assumes all messages are invalid. Serves a base -// class for visitors that accept a subset of messages and maintains state based -// on those messages. -class MoqtBidiStreamBase : public MoqtControlParserVisitor, - public webtransport::StreamVisitor { +// MoqtBidiStreamBase is the base class for bidirectional streams in MoQT. It +// contains basic methods for handling and dispatching messages. An instance of +// MoqtBidiStreamBase can be created before the underlying stream is available, +// as it might not yet exist due to flow control limits. +class MoqtBidiStreamBase : public webtransport::StreamVisitor { public: + // Maximum amount of messages buffered on top of the QUIC send buffer. + static constexpr size_t kMaxPendingMessages = 100; + MoqtBidiStreamBase(MoqtFramer* absl_nonnull framer, + const MoqtControlMessageParser& message_parser, BidiStreamDeletedCallback stream_deleted_callback, SessionErrorCallback session_error_callback) : framer_(framer), + message_parser_(message_parser), stream_deleted_callback_(std::move(stream_deleted_callback)), session_error_callback_(std::move(session_error_callback)) {} ~MoqtBidiStreamBase() override { std::move(stream_deleted_callback_)(); } - virtual void set_stream(webtransport::Stream* absl_nonnull stream) { - set_stream(stream, std::nullopt); - } - // MoqtControlParserVisitor implementation. All control messages are protocol - // violations by default. - virtual void OnClientSetupMessage(const MoqtClientSetup& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); + // Binds a WebTransport stream associated with `parser` to this object. + void BindStream( + std::unique_ptr<MoqtControlStreamParser> absl_nonnull parser) { + QUICHE_DCHECK(stream_parser_ == nullptr); + stream_parser_ = std::move(parser); + OnStreamBound(); } - virtual void OnServerSetupMessage(const MoqtServerSetup& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnRequestOkMessage(const MoqtRequestOk& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnRequestErrorMessage(const MoqtRequestError& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnSubscribeMessage(const MoqtSubscribe& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnPublishDoneMessage(const MoqtPublishDone& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnRequestUpdateMessage( - const MoqtRequestUpdate& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnPublishNamespaceMessage( - const MoqtPublishNamespace& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnPublishNamespaceDoneMessage( - const MoqtPublishNamespaceDone& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnNamespaceMessage(const MoqtNamespace& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnNamespaceDoneMessage( - const MoqtNamespaceDone& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnPublishNamespaceCancelMessage( - const MoqtPublishNamespaceCancel& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnGoAwayMessage(const MoqtGoAway& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnSubscribeNamespaceMessage( - const MoqtSubscribeNamespace& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnFetchMessage(const MoqtFetch& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnFetchCancelMessage(const MoqtFetchCancel& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnFetchOkMessage(const MoqtFetchOk& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnRequestsBlockedMessage( - const MoqtRequestsBlocked& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnPublishMessage(const MoqtPublish& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnObjectAckMessage(const MoqtObjectAck& message) override { - OnParsingError(wrong_message_error_, wrong_message_reason_); - } - virtual void OnParsingError(MoqtError code, - absl::string_view reason) override { - std::move(session_error_callback_)(code, reason); + // Binds a WebTransport stream `stream` to this object. + void BindStream(webtransport::Stream* absl_nonnull stream) { + QUICHE_DCHECK(stream_parser_ == nullptr); + stream_parser_ = std::make_unique<MoqtControlStreamParser>(stream); + OnStreamBound(); } // webtransport::StreamVisitor implementation. void OnResetStreamReceived(webtransport::StreamErrorCode error) override {} void OnStopSendingReceived(webtransport::StreamErrorCode error) override {} void OnWriteSideInDataRecvdState() override {} - void OnCanRead() override { - if (parser_ == nullptr) { - QUICHE_BUG(quiche_bug_moqt_parser_is_null) << "Parser is null"; - return; - } - parser_->ReadAndDispatchMessages(); - } - void OnCanWrite() override { - if (pending_messages_.empty() && fin_queued_) { - if (!stream_->SendFin()) { - std::move(session_error_callback_)(MoqtError::kInternalError, - "Failed to send FIN"); - } - return; - } - while (!pending_messages_.empty() && stream_->CanWrite()) { - SendMessage(std::move(pending_messages_.front()), - fin_queued_ && pending_messages_.size() == 1); - pending_messages_.pop(); - } - } + void OnCanRead() override; + void OnCanWrite() override; bool QueueIsFull() const { return pending_messages_.size() == kMaxPendingMessages; } - void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false) { - if (fin_queued_) { - return; - } - if (stream_ == nullptr || !stream_->CanWrite()) { - AddToQueue(std::move(message)); - return; - } - SendMessage(std::move(message), fin); + absl::Status SendOrBufferMessage(quiche::QuicheBuffer message, + bool fin = false); + void SendOrBufferMessageOrFatal(quiche::QuicheBuffer message, + bool fin = false) { + CheckStatus(SendOrBufferMessage(std::move(message), fin)); } - void SendRequestOk(uint64_t request_id, const MessageParameters& parameters, - bool fin = false) { - SendOrBufferMessage( - framer_->SerializeRequestOk(MoqtRequestOk{request_id, parameters}), - fin); - } - void SendRequestError(uint64_t request_id, RequestErrorCode error_code, - std::optional<quic::QuicTimeDelta> retry_interval, - absl::string_view reason_phrase, bool fin = false) { - MoqtRequestError request_error; - request_error.request_id = request_id; - request_error.error_code = error_code; - request_error.retry_interval = retry_interval; - request_error.reason_phrase = reason_phrase; - SendOrBufferMessage(framer_->SerializeRequestError(request_error), fin); - } - void SendRequestError(uint64_t request_id, MoqtRequestErrorInfo info, - bool fin = false) { - SendRequestError(request_id, info.error_code, info.retry_interval, - info.reason_phrase, fin); - } + + absl::Status SendRequestOk(uint64_t request_id, + const MessageParameters& parameters, + bool fin = false); + absl::Status SendRequestError( + uint64_t request_id, RequestErrorCode error_code, + std::optional<quic::QuicTimeDelta> retry_interval, + absl::string_view reason_phrase, bool fin = false); + absl::Status SendRequestError(uint64_t request_id, MoqtRequestErrorInfo info, + bool fin = false); + void Fin() { fin_queued_ = true; - if (pending_messages_.empty()) { - if (stream_ != nullptr && !webtransport::SendFinOnStream(*stream_).ok()) { - std::move(session_error_callback_)(MoqtError::kInternalError, - "Failed to send FIN"); - } - return; - } + OnCanWrite(); } void Reset(webtransport::StreamErrorCode error) { - if (stream_ != nullptr) { - stream_->ResetWithUserCode(error); + webtransport::Stream* stream = stream_parser_->stream(); + if (stream != nullptr) { + stream->ResetWithUserCode(error); + } + } + + // If `status` is not OK, terminates the connection with a fatal error. + void CheckStatus(absl::Status status) { + if (!status.ok()) { + OnFatalError(status); } } protected: - // The caller is responsible for calling stream->SetVisitor(). Derived - // classes will wrap this with a call to stream->SetPriority(). - void set_stream(webtransport::Stream* absl_nonnull stream, - std::optional<MoqtMessageType> first_message_type) { - stream_ = stream; - parser_ = std::make_unique<MoqtControlParser>(framer_->using_webtrans(), - stream_, *this); - if (first_message_type.has_value()) { - parser_->set_message_type(static_cast<uint64_t>(*first_message_type)); - } + // Called when a WebTransport stream has been associated with the object. + // Should be used to set the priority for the stream. + virtual void OnStreamBound() = 0; + + // Called when a control message has been received. The subclass should use + // DispatchControlMessage to process it. + virtual absl::Status OnRawControlMessage( + const MoqtRawControlMessage& message) = 0; + + // Terminates the MoQT session due to a fatal error encountered. + void OnFatalError(absl::Status status); + + MoqtControlStreamParser* stream_parser() { return stream_parser_.get(); } + MoqtFramer* framer() const { return framer_; } + webtransport::Stream* stream() const { + return stream_parser_ != nullptr ? stream_parser_->stream() : nullptr; } - const size_t kMaxPendingMessages = 100; - void AddToQueue(quiche::QuicheBuffer message) { - if (pending_messages_.size() == kMaxPendingMessages) { - std::move(session_error_callback_)( - MoqtError::kInternalError, - "Not enough flow credit on the control stream"); - return; - } - pending_messages_.push(std::move(message)); + + // Parses the supplied control message. If the message is well-formed, and the + // class defines an `OnControlMessage` method that accepts it, it is passed to + // that method. Otherwise, an appropriate error message is returned; + // `stream_type` is used to format that message. + template <typename Subclass> + absl::Status DispatchControlMessage(const MoqtRawControlMessage& message, + absl::string_view stream_type) { + static_assert(!std::is_same_v<Subclass, MoqtBidiStreamBase>); + return message_parser_.ParseMessage(message, [&](const auto& + parsed_message) { + if constexpr (CanDispatch<Subclass, decltype(parsed_message)>::value) { + return absl::down_cast<Subclass*>(this)->OnControlMessage( + parsed_message); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Received an unexpected message of type ", + MoqtMessageTypeToString(message.type), " on a ", + stream_type, " stream")); + } + }); } - MoqtFramer* absl_nonnull framer_; - MoqtControlParser* parser() { return parser_.get(); } - void OnBidiStreamDeleted() { - if (stream_deleted_callback_ != nullptr) { - std::move(stream_deleted_callback_)(); - } - } - webtransport::Stream* stream() { return stream_; } private: - void SendMessage(quiche::QuicheBuffer message, bool fin) { - webtransport::StreamWriteOptions options; - options.set_send_fin(fin); - // TODO: while we buffer unconditionally, we should still at some point tear - // down the connection if we've buffered too many control messages; - // otherwise, there is potential for memory exhaustion attacks. - options.set_buffer_unconditionally(true); - std::array write_vector = {quiche::QuicheMemSlice(std::move(message))}; - absl::Status success = - stream_->Writev(absl::MakeSpan(write_vector), options); - if (!success.ok()) { - std::move(session_error_callback_)(MoqtError::kInternalError, - "Failed to write a control message"); - } - } + friend class test::MoqtBidiStreamTestWrapper; - webtransport::Stream* stream_; - std::unique_ptr<MoqtControlParser> parser_; - std::queue<quiche::QuicheBuffer> pending_messages_; + absl::Status AddToQueue(quiche::QuicheBuffer message); + absl::Status SendMessage(quiche::QuicheBuffer message, bool fin); + + // CanDispatch<S, M> indicates whether `S` has a method with signature + // absl::Status OnControlMessage(const M&); + template <typename Subclass, typename Message, typename = void> + struct CanDispatch : std::false_type {}; + template <typename Subclass, typename Message> + struct CanDispatch<Subclass, Message, + std::enable_if_t<std::is_same_v< + decltype(std::declval<Subclass>().OnControlMessage( + std::declval<Message>())), + absl::Status>>> : std::true_type {}; + + MoqtFramer* absl_nonnull framer_; + std::unique_ptr<MoqtControlStreamParser> absl_nullable stream_parser_; + MoqtControlMessageParser message_parser_; + quiche::QuicheCircularDeque<quiche::QuicheBuffer> pending_messages_; bool fin_queued_ = false; BidiStreamDeletedCallback stream_deleted_callback_; SessionErrorCallback session_error_callback_; - const MoqtError wrong_message_error_ = MoqtError::kProtocolViolation; - const absl::string_view wrong_message_reason_ = - "Message not allowed for this stream type"; }; } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_bidi_stream_test.cc b/quiche/quic/moqt/moqt_bidi_stream_test.cc index bc77125..a2cee5b 100644 --- a/quiche/quic/moqt/moqt_bidi_stream_test.cc +++ b/quiche/quic/moqt/moqt_bidi_stream_test.cc
@@ -5,6 +5,7 @@ #include "quiche/quic/moqt/moqt_bidi_stream.h" #include <memory> +#include <optional> #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -13,9 +14,13 @@ #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_parser.h" +#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" #include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_mem_slice.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "quiche/web_transport/test_tools/in_memory_stream.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -24,185 +29,54 @@ namespace moqt::test { +class TestMoqtBidiStream : public MoqtBidiStreamBase { + public: + using MoqtBidiStreamBase::MoqtBidiStreamBase; + + absl::Status OnControlMessage(const MoqtRequestOk& message) { + ++ok_received_; + return absl::OkStatus(); + } + + void OnStreamBound() override {} + absl::Status OnRawControlMessage( + const MoqtRawControlMessage& message) override { + return DispatchControlMessage<TestMoqtBidiStream>(message, "test"); + } + int ok_received() const { return ok_received_; } + + private: + int ok_received_ = 0; +}; + class MoqtBidiStreamTest : public quiche::test::QuicheTest { public: MoqtBidiStreamTest() : framer_(true), - stream_(std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), + stream_(std::make_unique<TestMoqtBidiStream>( + &framer_, + MoqtControlMessageParser(kDefaultMoqtVersion, + /*webtransport=*/true), + deleted_callback_.AsStdFunction(), error_callback_.AsStdFunction())) {} MoqtFramer framer_; testing::MockFunction<void()> deleted_callback_; - testing::MockFunction<void(MoqtError, absl::string_view)> error_callback_; - std::unique_ptr<MoqtBidiStreamBase> stream_; + testing::StrictMock<testing::MockFunction<void(MoqtError, absl::string_view)>> + error_callback_; + std::unique_ptr<TestMoqtBidiStream> stream_; webtransport::test::MockStream mock_stream_; }; -TEST_F(MoqtBidiStreamTest, AllMessagesRejected) { - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnClientSetupMessage(MoqtClientSetup{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnServerSetupMessage(MoqtServerSetup{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnRequestOkMessage(MoqtRequestOk{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnRequestErrorMessage(MoqtRequestError{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnSubscribeMessage(MoqtSubscribe{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnSubscribeOkMessage(MoqtSubscribeOk{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnUnsubscribeMessage(MoqtUnsubscribe{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnPublishDoneMessage(MoqtPublishDone{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnRequestUpdateMessage(MoqtRequestUpdate{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnPublishNamespaceMessage(MoqtPublishNamespace{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnPublishNamespaceDoneMessage(MoqtPublishNamespaceDone{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnPublishNamespaceCancelMessage(MoqtPublishNamespaceCancel{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnTrackStatusMessage(MoqtTrackStatus{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnGoAwayMessage(MoqtGoAway{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnSubscribeNamespaceMessage(MoqtSubscribeNamespace{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnMaxRequestIdMessage(MoqtMaxRequestId{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnFetchMessage(MoqtFetch{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnFetchCancelMessage(MoqtFetchCancel{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnFetchOkMessage(MoqtFetchOk{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnRequestsBlockedMessage(MoqtRequestsBlocked{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnPublishMessage(MoqtPublish{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); - EXPECT_CALL(error_callback_, - Call(MoqtError::kProtocolViolation, - "Message not allowed for this stream type")); - stream_->OnObjectAckMessage(MoqtObjectAck{}); - stream_ = std::make_unique<MoqtBidiStreamBase>( - &framer_, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction()); -} - TEST_F(MoqtBidiStreamTest, MessageBufferedThenSent) { - stream_->set_stream(&mock_stream_); + stream_->BindStream(&mock_stream_); EXPECT_CALL(mock_stream_, CanWrite).WillRepeatedly(Return(false)); EXPECT_CALL(mock_stream_, Writev).Times(0); - stream_->SendRequestOk(0, MessageParameters()); - stream_->SendRequestError(2, RequestErrorCode::kUnauthorized, std::nullopt, - "bad request"); + QUICHE_EXPECT_OK(stream_->SendRequestOk(0, MessageParameters())); + QUICHE_EXPECT_OK(stream_->SendRequestError(2, RequestErrorCode::kUnauthorized, + std::nullopt, + + "bad request")); stream_->Fin(); { testing::InSequence seq; @@ -222,7 +96,7 @@ } TEST_F(MoqtBidiStreamTest, FinSentWhenDrained) { - stream_->set_stream(&mock_stream_); + stream_->BindStream(&mock_stream_); EXPECT_CALL(mock_stream_, Writev) .WillOnce([](absl::Span<quiche::QuicheMemSlice>, const webtransport::StreamWriteOptions& options) { @@ -233,7 +107,7 @@ } TEST_F(MoqtBidiStreamTest, Reset) { - stream_->set_stream(&mock_stream_); + stream_->BindStream(&mock_stream_); EXPECT_CALL(mock_stream_, ResetWithUserCode(1234)); stream_->Reset(1234); } @@ -244,18 +118,38 @@ } TEST_F(MoqtBidiStreamTest, PendingQueueFull) { - stream_->set_stream(&mock_stream_); + stream_->BindStream(&mock_stream_); EXPECT_CALL(mock_stream_, CanWrite).WillRepeatedly(Return(false)); - for (int i = 0; i < 100; ++i) { // kMaxPendingMessages = 100. + for (int i = 0; i < MoqtBidiStreamBase::kMaxPendingMessages; ++i) { EXPECT_FALSE(stream_->QueueIsFull()); - stream_->SendOrBufferMessage( - framer_.SerializeRequestUpdate(MoqtRequestUpdate{})); + QUICHE_EXPECT_OK(stream_->SendOrBufferMessage( + framer_.SerializeRequestUpdate(MoqtRequestUpdate{}))); } EXPECT_TRUE(stream_->QueueIsFull()); - EXPECT_CALL(error_callback_, Call(MoqtError::kInternalError, _)); - stream_->SendOrBufferMessage( - framer_.SerializeRequestUpdate(MoqtRequestUpdate{})); - EXPECT_TRUE(stream_->QueueIsFull()); + EXPECT_EQ(stream_ + ->SendOrBufferMessage( + framer_.SerializeRequestUpdate(MoqtRequestUpdate{})) + .code(), + absl::StatusCode::kResourceExhausted); +} + +TEST_F(MoqtBidiStreamTest, DispatchControlMessage) { + webtransport::test::InMemoryStream stream(0); + stream_->BindStream(&stream); + MoqtFramer framer(/*using_webtrans=*/true); + stream.Receive(framer.SerializeRequestOk(MoqtRequestOk()).AsStringView()); + stream_->OnCanRead(); + EXPECT_EQ(stream_->ok_received(), 1u); + + stream.Receive(framer.SerializeGoAway(MoqtGoAway()).AsStringView()); + EXPECT_CALL(error_callback_, Call) + .WillOnce([](MoqtError error, absl::string_view message) { + EXPECT_EQ(error, MoqtError::kProtocolViolation); + EXPECT_EQ( + message, + "Received an unexpected message of type GOAWAY on a test stream"); + }); + stream_->OnCanRead(); } } // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_namespace_stream.cc b/quiche/quic/moqt/moqt_namespace_stream.cc index f43967b..69b8732 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.cc +++ b/quiche/quic/moqt/moqt_namespace_stream.cc
@@ -10,6 +10,7 @@ #include <utility> #include "absl/base/nullability.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/moqt/moqt_bidi_stream.h" #include "quiche/quic/moqt/moqt_error.h" @@ -18,9 +19,11 @@ #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/session_namespace_tree.h" #include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/web_transport/stream_helpers.h" #include "quiche/web_transport/web_transport.h" namespace moqt { @@ -31,126 +34,120 @@ task->DeclareEof(); } } - -void MoqtNamespaceSubscriberStream::set_stream( - webtransport::Stream* absl_nonnull stream) { - // TODO(martinduke): Set the priority for this stream. - MoqtBidiStreamBase::set_stream(stream); +absl::Status MoqtNamespaceSubscriberStream::OnRawControlMessage( + const MoqtRawControlMessage& message) { + return DispatchControlMessage<MoqtNamespaceSubscriberStream>( + message, "namespace subscriber"); } -void MoqtNamespaceSubscriberStream::OnRequestOkMessage( +void MoqtNamespaceSubscriberStream::OnStreamBound() { + // TODO(martinduke): Set the priority for this stream. +} + +absl::Status MoqtNamespaceSubscriberStream::OnControlMessage( const MoqtRequestOk& message) { if (message.request_id == request_id_) { // Response to the initial SUBSCRIBE_NAMESPACE. if (response_callback_ == nullptr) { - OnParsingError(MoqtError::kProtocolViolation, "Two responses"); - return; + return absl::InvalidArgumentError("Two responses"); } std::move(response_callback_)(std::nullopt); response_callback_ = nullptr; - return; + return absl::OkStatus(); } NamespaceTask* task = task_.GetIfAvailable(); if (task == nullptr) { // The application has already unsubscribed, and the stream has been reset. // This is irrelevant. - return; + return absl::OkStatus(); } MoqtResponseCallback callback = task->GetResponseCallback(message.request_id); if (callback == nullptr) { - OnParsingError(MoqtError::kProtocolViolation, - "Unexpected request ID in response"); - return; + return absl::InvalidArgumentError("Unexpected request ID in response"); } std::move(callback)(std::nullopt); + return absl::OkStatus(); } -void MoqtNamespaceSubscriberStream::OnRequestErrorMessage( +absl::Status MoqtNamespaceSubscriberStream::OnControlMessage( const MoqtRequestError& message) { if (message.request_id == request_id_) { if (response_callback_ == nullptr) { - OnParsingError(MoqtError::kProtocolViolation, "Two responses"); - return; + return absl::InvalidArgumentError("Two responses"); } std::move(response_callback_)(MoqtRequestErrorInfo{ message.error_code, message.retry_interval, message.reason_phrase}); response_callback_ = nullptr; - return; + return absl::OkStatus(); } NamespaceTask* task = task_.GetIfAvailable(); if (task == nullptr) { // The application has already unsubscribed, and the stream has been reset. // This is irrelevant. - return; + return absl::OkStatus(); } MoqtResponseCallback callback = task->GetResponseCallback(message.request_id); if (callback == nullptr) { - OnParsingError(MoqtError::kProtocolViolation, - "Unexpected request ID in response"); - return; + return absl::InvalidArgumentError("Unexpected request ID in response"); } std::move(callback)(MoqtRequestErrorInfo{ message.error_code, message.retry_interval, message.reason_phrase}); + return absl::OkStatus(); } -void MoqtNamespaceSubscriberStream::OnNamespaceMessage( +absl::Status MoqtNamespaceSubscriberStream::OnControlMessage( const MoqtNamespace& message) { if (response_callback_ != nullptr) { - OnParsingError(MoqtError::kProtocolViolation, - "First message must be REQUEST_OK or REQUEST_ERROR"); - return; + return absl::InvalidArgumentError( + "First message must be REQUEST_OK or REQUEST_ERROR"); } NamespaceTask* task = task_.GetIfAvailable(); if (task == nullptr) { // The application has already unsubscribed, and the stream has been reset. // This is irrelevant. - return; + return absl::OkStatus(); } if (task->prefix().number_of_elements() + message.track_namespace_suffix.number_of_elements() > kMaxNamespaceElements) { - OnParsingError(MoqtError::kProtocolViolation, - "Too many namespace elements"); - return; + return absl::InvalidArgumentError("Too many namespace elements"); } if (task->prefix().total_length() + message.track_namespace_suffix.total_length() > kMaxFullTrackNameSize) { - OnParsingError(MoqtError::kProtocolViolation, "Namespace too large"); - return; + return absl::InvalidArgumentError("Namespace too large"); } auto [it, inserted] = published_suffixes_.insert(message.track_namespace_suffix); if (!inserted) { - OnParsingError(MoqtError::kProtocolViolation, - "Two NAMESPACE messages for the same track namespace"); - return; + return absl::InvalidArgumentError( + "Two NAMESPACE messages for the same track namespace"); } QUIC_DLOG(INFO) << "Received NAMESPACE message for " << message.track_namespace_suffix; task->AddPendingSuffix(message.track_namespace_suffix, TransactionType::kAdd); + return absl::OkStatus(); } -void MoqtNamespaceSubscriberStream::OnNamespaceDoneMessage( +absl::Status MoqtNamespaceSubscriberStream::OnControlMessage( const MoqtNamespaceDone& message) { if (response_callback_ != nullptr) { - OnParsingError(MoqtError::kProtocolViolation, - "First message must be REQUEST_OK or REQUEST_ERROR"); - return; + return absl::InvalidArgumentError( + "First message must be REQUEST_OK or REQUEST_ERROR"); } NamespaceTask* task = task_.GetIfAvailable(); if (task == nullptr) { - return; + return absl::OkStatus(); } if (published_suffixes_.erase(message.track_namespace_suffix) == 0) { - OnParsingError(MoqtError::kProtocolViolation, - "NAMESPACE_DONE with no active namespace"); - return; + return absl::InvalidArgumentError( + "NAMESPACE_DONE with no active namespace"); } QUIC_DLOG(INFO) << "Received NAMESPACE_DONE message for " << message.track_namespace_suffix; task->AddPendingSuffix(message.track_namespace_suffix, TransactionType::kDelete); + return absl::OkStatus(); } std::unique_ptr<MoqtNamespaceTask> MoqtNamespaceSubscriberStream::CreateTask( @@ -187,7 +184,8 @@ } MoqtRequestUpdate message{next_request_id_, state_->request_id_, parameters}; pending_updates_[message.request_id] = std::move(response_callback); - state_->SendOrBufferMessage(state_->framer_->SerializeRequestUpdate(message)); + state_->SendOrBufferMessageOrFatal( + state_->framer()->SerializeRequestUpdate(message)); next_request_id_ += 2; } @@ -247,18 +245,15 @@ } MoqtNamespacePublisherStream::MoqtNamespacePublisherStream( - MoqtFramer* framer, webtransport::Stream* stream, + MoqtFramer* framer, const MoqtControlMessageParser& message_parser, SessionErrorCallback session_error_callback, SessionNamespaceTree* absl_nonnull tree, MoqtIncomingSubscribeNamespaceCallback& application) // No stream_deleted_callback because there's no state yet. : MoqtBidiStreamBase( - framer, []() {}, std::move(session_error_callback)), + framer, message_parser, []() {}, std::move(session_error_callback)), tree_(tree->GetWeakPtr()), - application_(application) { - // TODO(martinduke): Set the priority for this stream. - MoqtBidiStreamBase::set_stream(stream, MoqtMessageType::kSubscribeNamespace); -} + application_(application) {} MoqtNamespacePublisherStream::~MoqtNamespacePublisherStream() { if (task_ == nullptr) { @@ -271,51 +266,42 @@ } } -void MoqtNamespacePublisherStream::OnSubscribeNamespaceMessage( +absl::Status MoqtNamespacePublisherStream::OnRawControlMessage( + const MoqtRawControlMessage& message) { + return DispatchControlMessage<MoqtNamespacePublisherStream>( + message, "namespace publisher"); +} + +absl::Status MoqtNamespacePublisherStream::OnControlMessage( const MoqtSubscribeNamespace& message) { request_id_ = message.request_id; SessionNamespaceTree* tree = tree_.GetIfAvailable(); if (tree == nullptr) { - SendRequestError(request_id_, RequestErrorCode::kInternalError, - std::nullopt, "Session is gone", /*fin=*/true); - return; + return SendRequestError(request_id_, RequestErrorCode::kInternalError, + std::nullopt, "Session is gone", /*fin=*/true); } if (!tree->SubscribeNamespace(message.track_namespace_prefix)) { - SendRequestError(request_id_, RequestErrorCode::kPrefixOverlap, - std::nullopt, "", /*fin=*/true); - return; + return SendRequestError(request_id_, RequestErrorCode::kPrefixOverlap, + std::nullopt, "", /*fin=*/true); } QUICHE_DCHECK(task_ == nullptr); - task_ = application_(message.track_namespace_prefix, - message.subscribe_options, message.parameters, - // Response callback - [this](std::optional<MoqtRequestErrorInfo> error) { - if (error.has_value()) { - SendRequestError(request_id_, *error, /*fin=*/true); - } else { - SendRequestOk(request_id_, MessageParameters()); - } - }); + task_ = + application_(message.track_namespace_prefix, message.subscribe_options, + message.parameters, ResponseCallback(request_id_)); if (task_ != nullptr) { task_->SetObjectsAvailableCallback([this]() { ProcessNamespaces(); }); } + return absl::OkStatus(); } -void MoqtNamespacePublisherStream::OnRequestUpdateMessage( +absl::Status MoqtNamespacePublisherStream::OnControlMessage( const MoqtRequestUpdate& message) { if (task_ == nullptr) { // This stream is dying. - return; + return absl::OkStatus(); } - task_->Update(message.parameters, - [this, request_id = message.request_id]( - std::optional<MoqtRequestErrorInfo> error) { - if (error.has_value()) { - SendRequestError(request_id, *error, /*fin=*/false); - } else { - SendRequestOk(request_id, MessageParameters()); - } - }); + task_->Update(message.parameters, ResponseCallback(request_id_)); + return absl::OkStatus(); } void MoqtNamespacePublisherStream::ProcessNamespaces() { @@ -330,14 +316,16 @@ case kPending: return; case kEof: - if (!SendFinOnStream(*stream()).ok()) { - OnParsingError(MoqtError::kInternalError, "Failed to send FIN"); + if (absl::Status status = webtransport::SendFinOnStream(*stream()); + !status.ok()) { + OnFatalError(status); }; return; case kError: Reset(kResetCodeCancelled); return; case kSuccess: { + absl::Status write_status; switch (type) { case TransactionType::kAdd: { auto [it, inserted] = published_suffixes_.insert(suffix); @@ -346,8 +334,8 @@ // cause a protocol violation. return; } - SendOrBufferMessage( - framer_->SerializeNamespace(MoqtNamespace{suffix})); + write_status = SendOrBufferMessage( + framer()->SerializeNamespace(MoqtNamespace{suffix})); break; } case TransactionType::kDelete: { @@ -356,14 +344,37 @@ // cause a protocol violation. return; } - SendOrBufferMessage( - framer_->SerializeNamespaceDone(MoqtNamespaceDone{suffix})); + write_status = SendOrBufferMessage( + framer()->SerializeNamespaceDone(MoqtNamespaceDone{suffix})); break; } } + if (!write_status.ok()) { + if (absl::IsResourceExhausted(write_status)) { + // The peer is not reading data fast enough, and the sender has + // reached its buffer limit; reset the stream. + Reset(kResetCodeTooFarBehind); + return; + } + // All other write errors are fatal. + OnFatalError(write_status); + return; + } + break; } } } } +MoqtResponseCallback MoqtNamespacePublisherStream::ResponseCallback( + uint64_t request_id) { + return [this, request_id](std::optional<MoqtRequestErrorInfo> error) { + if (error.has_value()) { + CheckStatus(SendRequestError(request_id, *error, /*fin=*/true)); + } else { + CheckStatus(SendRequestOk(request_id, MessageParameters())); + } + }; +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_namespace_stream.h b/quiche/quic/moqt/moqt_namespace_stream.h index ddc63e2..7374260 100644 --- a/quiche/quic/moqt/moqt_namespace_stream.h +++ b/quiche/quic/moqt/moqt_namespace_stream.h
@@ -14,12 +14,14 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "quiche/quic/moqt/moqt_bidi_stream.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" #include "quiche/quic/moqt/session_namespace_tree.h" #include "quiche/common/quiche_circular_deque.h" @@ -33,22 +35,25 @@ public: // Assumes the caller will send or queue the SUBSCRIBE_NAMESPACE. MoqtNamespaceSubscriberStream( - MoqtFramer* framer, uint64_t request_id, - BidiStreamDeletedCallback stream_deleted_callback, + MoqtFramer* framer, const MoqtControlMessageParser& message_parser, + uint64_t request_id, BidiStreamDeletedCallback stream_deleted_callback, SessionErrorCallback session_error_callback, MoqtResponseCallback response_callback) - : MoqtBidiStreamBase(framer, std::move(stream_deleted_callback), + : MoqtBidiStreamBase(framer, message_parser, + std::move(stream_deleted_callback), std::move(session_error_callback)), request_id_(request_id), response_callback_(std::move(response_callback)) {} ~MoqtNamespaceSubscriberStream() override; // MoqtBidiStreamBase overrides. - void set_stream(webtransport::Stream* absl_nonnull stream) override; - void OnRequestOkMessage(const MoqtRequestOk& message) override; - void OnRequestErrorMessage(const MoqtRequestError& message) override; - void OnNamespaceMessage(const MoqtNamespace& message) override; - void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) override; + void OnStreamBound() override; + absl::Status OnRawControlMessage( + const MoqtRawControlMessage& message) override; + absl::Status OnControlMessage(const MoqtRequestOk& message); + absl::Status OnControlMessage(const MoqtRequestError& message); + absl::Status OnControlMessage(const MoqtNamespace& message); + absl::Status OnControlMessage(const MoqtNamespaceDone& message); // Send the prefix now so it is only stored in one place (the task). std::unique_ptr<MoqtNamespaceTask> CreateTask(const TrackNamespace& prefix); @@ -122,19 +127,23 @@ public: // Constructor for the publisher side. MoqtNamespacePublisherStream( - MoqtFramer* framer, webtransport::Stream* stream, + MoqtFramer* framer, const MoqtControlMessageParser& message_parser, SessionErrorCallback session_error_callback, SessionNamespaceTree* absl_nonnull tree, MoqtIncomingSubscribeNamespaceCallback& application); ~MoqtNamespacePublisherStream() override; - // MoqtBidiStreamBase overrides. - void OnSubscribeNamespaceMessage( - const MoqtSubscribeNamespace& message) override; - void OnRequestUpdateMessage(const MoqtRequestUpdate&) override; + void OnStreamBound() override { + // TODO(martinduke): Set the priority for this stream. + } + absl::Status OnRawControlMessage( + const MoqtRawControlMessage& message) override; + absl::Status OnControlMessage(const MoqtSubscribeNamespace& message); + absl::Status OnControlMessage(const MoqtRequestUpdate& message); private: void ProcessNamespaces(); + MoqtResponseCallback ResponseCallback(uint64_t request_id); uint64_t request_id_; quiche::QuicheWeakPtr<SessionNamespaceTree> tree_;
diff --git a/quiche/quic/moqt/moqt_namespace_stream_test.cc b/quiche/quic/moqt/moqt_namespace_stream_test.cc index 9b4b2be..de1f764 100644 --- a/quiche/quic/moqt/moqt_namespace_stream_test.cc +++ b/quiche/quic/moqt/moqt_namespace_stream_test.cc
@@ -19,12 +19,16 @@ #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" +#include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_session_callbacks.h" +#include "quiche/quic/moqt/moqt_session_interface.h" +#include "quiche/quic/moqt/moqt_types.h" #include "quiche/quic/moqt/session_namespace_tree.h" #include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" #include "quiche/quic/moqt/test_tools/moqt_mock_visitor.h" #include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_mem_slice.h" +#include "quiche/common/test_tools/quiche_test_utils.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -38,16 +42,21 @@ constexpr uint64_t kRequestId = 3; const TrackNamespace kPrefix({"foo"}); +MoqtControlMessageParser ControlMessageParser() { + return MoqtControlMessageParser(kDefaultMoqtVersion, true); +} + class MoqtNamespaceSubscriberStreamTest : public quiche::test::QuicheTest { public: MoqtNamespaceSubscriberStreamTest() : framer_(true), - stream_(&framer_, kRequestId, deleted_callback_.AsStdFunction(), + stream_(&framer_, ControlMessageParser(), kRequestId, + deleted_callback_.AsStdFunction(), error_callback_.AsStdFunction(), response_callback_.AsStdFunction()), task_(stream_.CreateTask(kPrefix)) { task_->SetObjectsAvailableCallback([this]() { ++objects_available_; }); - stream_.set_stream(&mock_stream_); + stream_.BindStream(&mock_stream_); ON_CALL(mock_stream_, CanWrite()).WillByDefault(Return(true)); } @@ -55,6 +64,11 @@ EXPECT_EQ(objects_available_, expected_count); } + template <typename M> + void ReceiveControlMessage(const M& message) { + stream_.CheckStatus(stream_.OnControlMessage(message)); + } + MoqtFramer framer_; testing::MockFunction<void()> deleted_callback_; testing::MockFunction<void(MoqtError, absl::string_view)> error_callback_; @@ -68,48 +82,48 @@ TEST_F(MoqtNamespaceSubscriberStreamTest, RequestOk) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); } TEST_F(MoqtNamespaceSubscriberStreamTest, RequestOkWrongId) { EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "Unexpected request ID in response")); - stream_.OnRequestOkMessage({kRequestId + 1}); + ReceiveControlMessage(MoqtRequestOk{kRequestId + 1}); } TEST_F(MoqtNamespaceSubscriberStreamTest, RequestError) { EXPECT_CALL(response_callback_, Call); - stream_.OnRequestErrorMessage({kRequestId, RequestErrorCode::kInternalError, - quic::QuicTimeDelta::FromMilliseconds(100), - "bar"}); + ReceiveControlMessage( + MoqtRequestError{kRequestId, RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); } TEST_F(MoqtNamespaceSubscriberStreamTest, RequestErrorWrongId) { EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "Unexpected request ID in response")); - stream_.OnRequestErrorMessage( - {kRequestId + 1, RequestErrorCode::kInternalError, - quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); + ReceiveControlMessage( + MoqtRequestError{kRequestId + 1, RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); } TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceBeforeResponse) { EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "First message must be REQUEST_OK or REQUEST_ERROR")); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); } TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneBeforeResponse) { EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "First message must be REQUEST_OK or REQUEST_ERROR")); - stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespaceDone{TrackNamespace({"bar"})}); } TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceAfterResponse) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(1); TrackNamespace received_namespace; TransactionType type; @@ -121,10 +135,10 @@ TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneAfterResponse) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(1); - stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespaceDone{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(2); TrackNamespace received_namespace; TransactionType type; @@ -139,43 +153,43 @@ TEST_F(MoqtNamespaceSubscriberStreamTest, DuplicateNamespace) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(1); EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "Two NAMESPACE messages for the same track namespace")); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); } TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneWithoutNamespace) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); EXPECT_CALL(error_callback_, Call(MoqtError::kProtocolViolation, "NAMESPACE_DONE with no active namespace")); - stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespaceDone{TrackNamespace({"bar"})}); } TEST_F(MoqtNamespaceSubscriberStreamTest, NamespaceDoneThenNamespace) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); EXPECT_CALL(error_callback_, Call).Times(0); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(1); - stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespaceDone{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(2); - stream_.OnNamespaceMessage({TrackNamespace({"buzz"})}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"buzz"})}); CheckNumberOfObjectsAvailable(3); } TEST_F(MoqtNamespaceSubscriberStreamTest, TaskGetNextSuffix) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); - stream_.OnNamespaceMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(1); - stream_.OnNamespaceMessage({TrackNamespace({"buzz"})}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"buzz"})}); CheckNumberOfObjectsAvailable(2); - stream_.OnNamespaceDoneMessage({TrackNamespace({"bar"})}); + ReceiveControlMessage(MoqtNamespaceDone{TrackNamespace({"bar"})}); CheckNumberOfObjectsAvailable(3); TrackNamespace received_namespace; TransactionType type; @@ -189,7 +203,7 @@ EXPECT_EQ(received_namespace, TrackNamespace({"bar"})); EXPECT_EQ(type, TransactionType::kDelete); EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kPending); - stream_.OnNamespaceMessage({TrackNamespace({"another"})}); + ReceiveControlMessage(MoqtNamespace{TrackNamespace({"another"})}); CheckNumberOfObjectsAvailable(4); EXPECT_EQ(task_->GetNextSuffix(received_namespace, type), kSuccess); EXPECT_EQ(received_namespace, TrackNamespace({"another"})); @@ -199,14 +213,16 @@ TEST_F(MoqtNamespaceSubscriberStreamTest, DeclareEof) { auto stream = std::make_unique<MoqtNamespaceSubscriberStream>( - &framer_, kRequestId, deleted_callback_.AsStdFunction(), - error_callback_.AsStdFunction(), response_callback_.AsStdFunction()); + &framer_, ControlMessageParser(), kRequestId, + deleted_callback_.AsStdFunction(), error_callback_.AsStdFunction(), + response_callback_.AsStdFunction()); std::unique_ptr<MoqtNamespaceTask> task = stream->CreateTask(kPrefix); ASSERT_TRUE(task != nullptr); task->SetObjectsAvailableCallback([this]() { ++objects_available_; }); EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream->OnRequestOkMessage({kRequestId}); - stream->OnNamespaceMessage({TrackNamespace({"bar"})}); + QUICHE_EXPECT_OK(stream->OnControlMessage(MoqtRequestOk{kRequestId})); + QUICHE_EXPECT_OK( + stream->OnControlMessage(MoqtNamespace{TrackNamespace({"bar"})})); CheckNumberOfObjectsAvailable(1); stream.reset(); CheckNumberOfObjectsAvailable(2); @@ -220,7 +236,7 @@ TEST_F(MoqtNamespaceSubscriberStreamTest, UpdateAndRequestOk) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _)); MessageParameters update_params; @@ -229,12 +245,12 @@ update_response_callback; task_->Update(update_params, update_response_callback.AsStdFunction()); EXPECT_CALL(update_response_callback, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId + 2}); + ReceiveControlMessage(MoqtRequestOk{kRequestId + 2}); } TEST_F(MoqtNamespaceSubscriberStreamTest, UpdateAndRequestError) { EXPECT_CALL(response_callback_, Call(Eq(std::nullopt))); - stream_.OnRequestOkMessage({kRequestId}); + ReceiveControlMessage(MoqtRequestOk{kRequestId}); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _)); MessageParameters update_params; @@ -243,9 +259,9 @@ update_response_callback; task_->Update(update_params, update_response_callback.AsStdFunction()); EXPECT_CALL(update_response_callback, Call(_)); - stream_.OnRequestErrorMessage( - {kRequestId + 2, RequestErrorCode::kInternalError, - quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); + ReceiveControlMessage( + MoqtRequestError{kRequestId + 2, RequestErrorCode::kInternalError, + quic::QuicTimeDelta::FromMilliseconds(100), "bar"}); } class MoqtNamespacePublisherStreamTest : public quiche::test::QuicheTest { @@ -254,11 +270,18 @@ : framer_(false), tree_(), application_callback_(mock_application_.AsStdFunction()), - stream_(&framer_, &mock_stream_, error_callback_.AsStdFunction(), - &tree_, application_callback_) { + stream_(&framer_, ControlMessageParser(), + error_callback_.AsStdFunction(), &tree_, + application_callback_) { + stream_.BindStream(&mock_stream_); EXPECT_CALL(mock_stream_, CanWrite()).WillRepeatedly(Return(true)); } + template <typename M> + void ReceiveControlMessage(const M& message) { + stream_.CheckStatus(stream_.OnControlMessage(message)); + } + MoqtFramer framer_; testing::MockFunction<void(MoqtError, absl::string_view)> error_callback_; webtransport::test::MockStream mock_stream_; @@ -292,7 +315,7 @@ }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_.OnSubscribeNamespaceMessage(message); + ReceiveControlMessage(message); ASSERT_TRUE(task_ptr != nullptr); EXPECT_EQ(task_ptr->prefix(), message.track_namespace_prefix); @@ -358,7 +381,7 @@ }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_.OnSubscribeNamespaceMessage(message); + ReceiveControlMessage(message); } TEST_F(MoqtNamespacePublisherStreamTest, RequestUpdateOk) { @@ -381,7 +404,7 @@ }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_.OnSubscribeNamespaceMessage(message); + ReceiveControlMessage(message); ASSERT_TRUE(task_ptr != nullptr); // Now send RequestUpdate @@ -398,7 +421,7 @@ }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_.OnRequestUpdateMessage(update_message); + ReceiveControlMessage(update_message); } TEST_F(MoqtNamespacePublisherStreamTest, RequestUpdateError) { @@ -421,7 +444,7 @@ }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_.OnSubscribeNamespaceMessage(message); + ReceiveControlMessage(message); ASSERT_TRUE(task_ptr != nullptr); // Now send RequestUpdate @@ -440,7 +463,7 @@ }); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_.OnRequestUpdateMessage(update_message); + ReceiveControlMessage(update_message); } TEST_F(MoqtNamespacePublisherStreamTest, SubscribePrefixOverlap) { @@ -454,13 +477,13 @@ tree_.SubscribeNamespace(TrackNamespace({"foo", "bar"})); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_.OnSubscribeNamespaceMessage(message); + ReceiveControlMessage(message); // Try to subscribe to the parent. Also not allowed. message.track_namespace_prefix.PopElement(); message.track_namespace_prefix.PopElement(); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_.OnSubscribeNamespaceMessage(message); + ReceiveControlMessage(message); } } // namespace
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index 35e3c2f..893f9d6 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -15,7 +15,6 @@ #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" -#include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -27,7 +26,6 @@ #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_priority.h" -#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_callbacks.h" #include "quiche/common/quiche_status_utils.h" @@ -39,45 +37,6 @@ class MoqtDataParserPeer; } -// TODO(vasilvv): remove once all uses are switched to a new parser. -class QUICHE_EXPORT MoqtControlParserVisitor { - public: - virtual ~MoqtControlParserVisitor() = default; - - // All of these are called only when the entire message has arrived. The - // parser retains ownership of the memory. - virtual void OnClientSetupMessage(const MoqtClientSetup& message) = 0; - virtual void OnServerSetupMessage(const MoqtServerSetup& message) = 0; - virtual void OnRequestOkMessage(const MoqtRequestOk& message) = 0; - virtual void OnRequestErrorMessage(const MoqtRequestError& message) = 0; - virtual void OnSubscribeMessage(const MoqtSubscribe& message) = 0; - virtual void OnSubscribeOkMessage(const MoqtSubscribeOk& message) = 0; - virtual void OnUnsubscribeMessage(const MoqtUnsubscribe& message) = 0; - virtual void OnPublishDoneMessage(const MoqtPublishDone& message) = 0; - virtual void OnRequestUpdateMessage(const MoqtRequestUpdate& message) = 0; - virtual void OnPublishNamespaceMessage( - const MoqtPublishNamespace& message) = 0; - virtual void OnPublishNamespaceDoneMessage( - const MoqtPublishNamespaceDone& message) = 0; - virtual void OnNamespaceMessage(const MoqtNamespace& message) = 0; - virtual void OnNamespaceDoneMessage(const MoqtNamespaceDone& message) = 0; - virtual void OnPublishNamespaceCancelMessage( - const MoqtPublishNamespaceCancel& message) = 0; - virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) = 0; - virtual void OnGoAwayMessage(const MoqtGoAway& message) = 0; - virtual void OnSubscribeNamespaceMessage( - const MoqtSubscribeNamespace& message) = 0; - virtual void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) = 0; - virtual void OnFetchMessage(const MoqtFetch& message) = 0; - virtual void OnFetchCancelMessage(const MoqtFetchCancel& message) = 0; - virtual void OnFetchOkMessage(const MoqtFetchOk& message) = 0; - virtual void OnRequestsBlockedMessage(const MoqtRequestsBlocked& message) = 0; - virtual void OnPublishMessage(const MoqtPublish& message) = 0; - virtual void OnObjectAckMessage(const MoqtObjectAck& message) = 0; - - virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; -}; - // MoqtRawControlMessage represents an MOQT control message that has been // unframed from the control stream, but not parsed yet. struct MoqtRawControlMessage { @@ -117,11 +76,6 @@ MoqtControlStreamParser& operator=(const MoqtControlStreamParser&) = delete; MoqtControlStreamParser& operator=(MoqtControlStreamParser&&) = delete; - // TODO(vasilvv): remove once nothing calls this. - void set_message_type(uint64_t message_type) { - current_message_type_ = message_type; - } - // Reads the next available message on the stream. Returns kUnavailable // status if no complete message can be read; if FIN is read, `fin_read` will // be set to true. @@ -297,150 +251,6 @@ bool uses_web_transport_; }; -class QUICHE_EXPORT MoqtControlParser { - public: - MoqtControlParser(bool uses_web_transport, webtransport::Stream* stream, - MoqtControlParserVisitor& visitor) - : visitor_(visitor), - stream_parser_(stream), - message_parser_(kDefaultMoqtVersion, uses_web_transport) {} - ~MoqtControlParser() = default; - void set_message_type(uint64_t message_type) { - stream_parser_.set_message_type(message_type); - } - - void ReadAndDispatchMessages() { - if (processing_) { - return; - } - processing_ = true; - auto cleanup = absl::MakeCleanup([this] { processing_ = false; }); - while (true) { - absl::StatusOr<MoqtRawControlMessage> raw_message = - stream_parser_.ReadNextMessage(); - if (absl::IsUnavailable(raw_message.status())) { - return; - } - if (!raw_message.ok()) { - visitor_.OnParsingError(GetMoqtErrorForStatus(raw_message.status()) - .value_or(MoqtError::kProtocolViolation), - raw_message.status().message()); - return; - } - absl::Status status = message_parser_.ParseMessage( - *raw_message, - absl::Overload{[&](const MoqtClientSetup& message) { - visitor_.OnClientSetupMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtServerSetup& message) { - visitor_.OnServerSetupMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtRequestOk& message) { - visitor_.OnRequestOkMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtRequestError& message) { - visitor_.OnRequestErrorMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtSubscribe& message) { - visitor_.OnSubscribeMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtSubscribeOk& message) { - visitor_.OnSubscribeOkMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtUnsubscribe& message) { - visitor_.OnUnsubscribeMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtPublishDone& message) { - visitor_.OnPublishDoneMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtRequestUpdate& message) { - visitor_.OnRequestUpdateMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtPublishNamespace& message) { - visitor_.OnPublishNamespaceMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtPublishNamespaceDone& message) { - visitor_.OnPublishNamespaceDoneMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtNamespace& message) { - visitor_.OnNamespaceMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtNamespaceDone& message) { - visitor_.OnNamespaceDoneMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtPublishNamespaceCancel& message) { - visitor_.OnPublishNamespaceCancelMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtTrackStatus& message) { - visitor_.OnTrackStatusMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtGoAway& message) { - visitor_.OnGoAwayMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtSubscribeNamespace& message) { - visitor_.OnSubscribeNamespaceMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtMaxRequestId& message) { - visitor_.OnMaxRequestIdMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtFetch& message) { - visitor_.OnFetchMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtFetchCancel& message) { - visitor_.OnFetchCancelMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtFetchOk& message) { - visitor_.OnFetchOkMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtRequestsBlocked& message) { - visitor_.OnRequestsBlockedMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtPublish& message) { - visitor_.OnPublishMessage(message); - return absl::OkStatus(); - }, - [&](const MoqtObjectAck& message) { - visitor_.OnObjectAckMessage(message); - return absl::OkStatus(); - }}); - if (!status.ok()) { - visitor_.OnParsingError(GetMoqtErrorForStatus(status).value_or( - MoqtError::kProtocolViolation), - status.message()); - return; - } - } - } - - private: - MoqtControlParserVisitor& visitor_; - MoqtControlStreamParser stream_parser_; - MoqtControlMessageParser message_parser_; - bool processing_ = false; -}; - // Parses an MoQT datagram. Returns the payload bytes, or std::nullopt on error. // The caller provides the whole datagram in `data`. The function puts the // object metadata in `object_metadata`.
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index f48d4c1..35d6555 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -51,9 +51,12 @@ #include "quiche/quic/platform/api/quic_logging.h" #include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_stack_trace.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_mem_slice.h" +#include "quiche/common/quiche_status_utils.h" #include "quiche/common/quiche_weak_ptr.h" +#include "quiche/web_transport/stream_helpers.h" #include "quiche/web_transport/web_transport.h" #define ENDPOINT \ @@ -122,7 +125,7 @@ "while it does not exist"; return; } - control_stream->SendOrBufferMessage(std::move(message)); + control_stream->SendOrBufferMessageOrFatal(std::move(message)); } void MoqtSession::OnSessionReady() { @@ -147,7 +150,7 @@ return; } control_stream_ = control_stream->GetWeakPtr(); - control_stream->set_stream(stream); + control_stream->BindStream(stream); trace_recorder_.RecordControlStreamCreated(stream->GetStreamId()); stream->SetVisitor(std::move(control_stream)); MoqtClientSetup setup; @@ -235,7 +238,7 @@ while (!pending_bidi_streams_.empty() && session_->CanOpenNextOutgoingBidirectionalStream()) { webtransport::Stream* stream = session_->OpenOutgoingBidirectionalStream(); - pending_bidi_streams_.front()->set_stream(stream); + pending_bidi_streams_.front()->BindStream(stream); // TODO(vasilvv): Distinguish between control and and non-control bidi // streams in trace_recorder_. trace_recorder_.RecordControlStreamCreated(stream->GetStreamId()); @@ -301,7 +304,7 @@ } std::unique_ptr<MoqtNamespaceSubscriberStream> state = std::make_unique<MoqtNamespaceSubscriberStream>( - &framer_, next_request_id_, + &framer_, ControlMessageParser(), next_request_id_, [session_weak_ptr = GetWeakPtr(), this, pref = prefix]() { if (!session_weak_ptr.IsValid() || is_closing_) { return; @@ -319,7 +322,7 @@ MoqtNamespaceSubscriberStream* state_ptr = state.get(); if (session_->CanOpenNextOutgoingBidirectionalStream()) { webtransport::Stream* stream = session_->OpenOutgoingBidirectionalStream(); - state->set_stream(stream); + state->BindStream(stream); stream->SetVisitor(std::move(state)); } else { pending_bidi_streams_.push_back(std::move(state)); @@ -329,7 +332,8 @@ message.track_namespace_prefix = prefix; message.subscribe_options = SubscribeNamespaceOption::kNamespace; message.parameters = parameters; - state_ptr->SendOrBufferMessage(framer_.SerializeSubscribeNamespace(message)); + state_ptr->SendOrBufferMessageOrFatal( + framer_.SerializeSubscribeNamespace(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_NAMESPACE message for " << message.track_namespace_prefix; return state_ptr->CreateTask(prefix); @@ -955,11 +959,11 @@ return; } auto control_stream = std::make_unique<ControlStream>(session_); + control_stream->BindStream(std::move(parser_)); // Store a reference to the stream context when the current context is // destroyed below. ControlStream* temp_stream = control_stream.get(); session_->control_stream_ = temp_stream->GetWeakPtr(); - control_stream->set_stream(stream_); // Deletes the UnknownBidiStream object; no class access after this // point. stream_->SetVisitor(std::move(control_stream)); @@ -968,12 +972,13 @@ } case MoqtMessageType::kSubscribeNamespace: { auto namespace_stream = std::make_unique<MoqtNamespacePublisherStream>( - &session_->framer_, stream_, + &session_->framer_, session_->ControlMessageParser(), [session = session_](MoqtError code, absl::string_view reason) { session->Error(code, reason); }, &session_->incoming_subscribe_namespace_, session_->callbacks_.incoming_subscribe_namespace_callback); + namespace_stream->BindStream(std::move(parser_)); MoqtNamespacePublisherStream* temp_stream = namespace_stream.get(); stream_->SetVisitor(std::move(namespace_stream)); // The UnknownBidiStream object is deleted; no class access after this @@ -988,24 +993,21 @@ } } -void MoqtSession::ControlStream::set_stream( - webtransport::Stream* absl_nonnull stream) { - stream->SetPriority( +void MoqtSession::ControlStream::OnStreamBound() { + stream()->SetPriority( webtransport::StreamPriority{/*send_group_id=*/kMoqtSendGroupId, /*send_order=*/kMoqtControlStreamSendOrder}); - if (session_->perspective() == Perspective::IS_SERVER) { - MoqtBidiStreamBase::set_stream(stream, MoqtMessageType::kClientSetup); - } else { - MoqtBidiStreamBase::set_stream(stream, std::nullopt); - } } -void MoqtSession::ControlStream::OnClientSetupMessage( +absl::Status MoqtSession::ControlStream::OnRawControlMessage( + const MoqtRawControlMessage& message) { + return DispatchControlMessage<ControlStream>(message, "control"); +} + +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtClientSetup& message) { if (session_->perspective() == Perspective::IS_CLIENT) { - session_->Error(MoqtError::kProtocolViolation, - "Received CLIENT_SETUP from server"); - return; + return absl::InvalidArgumentError("Received CLIENT_SETUP from server"); } session_->peer_supports_object_ack_ = message.parameters.support_object_acks.value_or( @@ -1015,18 +1017,18 @@ QUICHE_DLOG(INFO) << "Received CLIENT_SETUP"; MoqtServerSetup response; session_->parameters_.ToSetupParameters(response.parameters); - SendOrBufferMessage(session_->framer_.SerializeServerSetup(response)); + QUICHE_RETURN_IF_ERROR( + SendOrBufferMessage(session_->framer_.SerializeServerSetup(response))); QUICHE_DLOG(INFO) << "Sent SERVER_SETUP"; // TODO: handle path. std::move(session_->callbacks_.session_established_callback)(); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnServerSetupMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtServerSetup& message) { if (perspective() == Perspective::IS_SERVER) { - session_->Error(MoqtError::kProtocolViolation, - "Received SERVER_SETUP from client"); - return; + return absl::InvalidArgumentError("Received SERVER_SETUP from client"); } session_->peer_supports_object_ack_ = message.parameters.support_object_acks.value_or( @@ -1036,26 +1038,25 @@ session_->peer_max_request_id_ = message.parameters.max_request_id.value_or(kDefaultMaxRequestId); std::move(session_->callbacks_.session_established_callback)(); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnSubscribeMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtSubscribe& message) { if (!session_->ValidateRequestId(message.request_id)) { - return; + return absl::OkStatus(); } QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for " << message.full_track_name; if (session_->sent_goaway_) { QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE after GOAWAY"; - SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, - std::nullopt, "SUBSCRIBE after GOAWAY"); - return; + return SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, + std::nullopt, "SUBSCRIBE after GOAWAY"); } if (session_->subscribed_track_names_.contains(message.full_track_name)) { - SendRequestError(message.request_id, - RequestErrorCode::kDuplicateSubscription, std::nullopt, - ""); - return; + return SendRequestError(message.request_id, + RequestErrorCode::kDuplicateSubscription, + std::nullopt, ""); } const FullTrackName& track_name = message.full_track_name; std::shared_ptr<MoqtTrackPublisher> track_publisher = @@ -1063,9 +1064,8 @@ if (track_publisher == nullptr) { QUIC_DLOG(INFO) << ENDPOINT << "SUBSCRIBE for " << track_name << " rejected by the application: does not exist"; - SendRequestError(message.request_id, RequestErrorCode::kDoesNotExist, - std::nullopt, "not found"); - return; + return SendRequestError(message.request_id, RequestErrorCode::kDoesNotExist, + std::nullopt, "not found"); } MoqtPublishingMonitorInterface* monitoring = nullptr; @@ -1088,9 +1088,10 @@ QUICHE_NOTREACHED(); // ValidateRequestId() should have caught this. } track_publisher_ptr->AddObjectListener(subscription_ptr); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnSubscribeOkMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtSubscribeOk& message) { RemoteTrack* track = session_->RemoteTrackById(message.request_id); if (track == nullptr) { @@ -1098,12 +1099,10 @@ << "request_id = " << message.request_id << " but no track exists"; // Subscription state might have been destroyed for internal reasons. - return; + return absl::OkStatus(); } if (track->is_fetch()) { - session_->Error(MoqtError::kProtocolViolation, - "Received SUBSCRIBE_OK for a FETCH"); - return; + return absl::InvalidArgumentError("Received SUBSCRIBE_OK for a FETCH"); } if (message.parameters.largest_object.has_value()) { QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for " @@ -1122,7 +1121,7 @@ session_->subscribe_by_alias_.try_emplace(message.track_alias, subscribe); if (!success) { session_->Error(MoqtError::kDuplicateTrackAlias, ""); - return; + return absl::OkStatus(); } subscribe->set_track_alias(message.track_alias); std::optional<SubscriptionFilter> filter = @@ -1146,14 +1145,14 @@ track->full_track_name(), SubscribeOkData{message.parameters, message.extensions}); } + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnRequestOkMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtRequestOk& message) { if (session_->upstream_by_id_.contains(message.request_id)) { - session_->Error(MoqtError::kProtocolViolation, - "Received REQUEST_OK for SUBSCRIBE, FETCH, or PUBLISH"); - return; + return absl::InvalidArgumentError( + "Received REQUEST_OK for SUBSCRIBE, FETCH, or PUBLISH"); } // Response to REQUEST_UPDATE for a subscribe. auto ru_it = session_->pending_subscribe_updates_.find(message.request_id); @@ -1164,23 +1163,22 @@ MoqtRequestErrorInfo{RequestErrorCode::kDoesNotExist, std::nullopt, "subscription does not exist anymore"}); session_->pending_subscribe_updates_.erase(ru_it); - return; + return absl::OkStatus(); } sub_it->second->parameters().Update(ru_it->second.parameters); std::move(ru_it->second.response_callback)(std::nullopt); session_->pending_subscribe_updates_.erase(ru_it); - return; + return absl::OkStatus(); } // Response to PUBLISH_NAMESPACE. auto pn_it = session_->publish_namespace_by_id_.find(message.request_id); if (pn_it != session_->publish_namespace_by_id_.end()) { if (pn_it->second.response_callback == nullptr) { - session_->Error(MoqtError::kProtocolViolation, - "Multiple responses for PUBLISH_NAMESPACE"); - return; + return absl::InvalidArgumentError( + "Multiple responses for PUBLISH_NAMESPACE"); } std::move(pn_it->second.response_callback)(std::nullopt); - return; + return absl::OkStatus(); } // Response to SUBSCRIBE_NAMESPACE is handled in the NamespaceStream. // TRACK_STATUS response would go here, but we don't support upstream @@ -1188,9 +1186,10 @@ // If it doesn't match any state, it might be because the local application // cancelled the request. Do nothing. // TODO(martinduke): Do something with parameters. + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnRequestErrorMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtRequestError& message) { MoqtRequestErrorInfo error_info{message.error_code, message.retry_interval, message.reason_phrase}; @@ -1199,9 +1198,8 @@ if (track != nullptr) { // It's in response to SUBSCRIBE or FETCH. if (!track->ErrorIsAllowed()) { - session_->Error(MoqtError::kProtocolViolation, - "Received REQUEST_ERROR after REQUEST_OK or objects"); - return; + return absl::InvalidArgumentError( + "Received REQUEST_ERROR after REQUEST_OK or objects"); } QUIC_DLOG(INFO) << ENDPOINT << "Received the REQUEST_ERROR for " << "request_id = " << message.request_id << " (" @@ -1228,52 +1226,53 @@ // The visitor might have closed the session. session_->upstream_by_id_.erase(message.request_id); } - return; + return absl::OkStatus(); } // Response to REQUEST_UPDATE for a subscribe. auto ru_it = session_->pending_subscribe_updates_.find(message.request_id); if (ru_it != session_->pending_subscribe_updates_.end()) { std::move(ru_it->second.response_callback)(error_info); session_->pending_subscribe_updates_.erase(ru_it); - return; + return absl::OkStatus(); } // Response to PUBLISH_NAMESPACE. auto pn_it = session_->publish_namespace_by_id_.find(message.request_id); if (pn_it != session_->publish_namespace_by_id_.end()) { if (pn_it->second.response_callback == nullptr) { - session_->Error(MoqtError::kProtocolViolation, - "Multiple responses for PUBLISH_NAMESPACE"); - return; + return absl::InvalidArgumentError( + "Multiple responses for PUBLISH_NAMESPACE"); } std::move(pn_it->second.response_callback)(error_info); session_->publish_namespace_by_namespace_.erase( pn_it->second.track_namespace); session_->publish_namespace_by_id_.erase(pn_it); - return; + return absl::OkStatus(); } // Response to SUBSCRIBE_NAMESPACE is handled in the NamespaceStream. // TRACK_STATUS response would go here, but we don't support upstream // TRACK_STATUS. // If it doesn't match any state, it might be because the local application // cancelled the request. Do nothing. + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnUnsubscribeMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtUnsubscribe& message) { auto it = session_->published_subscriptions_.find(message.request_id); if (it == session_->published_subscriptions_.end()) { - return; + return absl::OkStatus(); } QUIC_DLOG(INFO) << ENDPOINT << "Received an UNSUBSCRIBE for " << it->second->publisher().GetTrackName(); session_->published_subscriptions_.erase(it); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnPublishDoneMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtPublishDone& message) { auto it = session_->upstream_by_id_.find(message.request_id); if (it == session_->upstream_by_id_.end()) { - return; + return absl::OkStatus(); } auto* subscribe = absl::down_cast<SubscribeRemoteTrack*>(it->second.get()); QUIC_DLOG(INFO) << ENDPOINT << "Received a PUBLISH_DONE for " @@ -1283,17 +1282,17 @@ absl::WrapUnique(session_->alarm_factory_->CreateAlarm( new PublishDoneDelegate(session_, subscribe)))); session_->MaybeDestroySubscription(subscribe); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnRequestUpdateMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtRequestUpdate& message) { auto it = session_->published_subscriptions_.find(message.existing_request_id); if (it != session_->published_subscriptions_.end()) { // It's updating SUBSCRIBE. it->second->Update(message.parameters); - SendRequestOk(message.request_id, MessageParameters()); - return; + return SendRequestOk(message.request_id, MessageParameters()); } auto pn_it = session_->publish_namespace_by_id_.find(message.existing_request_id); @@ -1311,33 +1310,32 @@ return; } if (error.has_value()) { - SendRequestError(message.request_id, *error); + CheckStatus(SendRequestError(message.request_id, *error)); session->incoming_publish_namespaces_by_id_.erase( message.request_id); session->incoming_publish_namespaces_by_namespace_.erase( track_namespace); } else { - SendRequestOk(message.request_id, MessageParameters()); + CheckStatus(SendRequestOk(message.request_id, MessageParameters())); } }); - return; + return absl::OkStatus(); } // TODO(martinduke): Check all the request types. // Does not match any known request. - SendRequestError(message.request_id, RequestErrorCode::kNotSupported, - std::nullopt, "No support for update of this type"); + return SendRequestError(message.request_id, RequestErrorCode::kNotSupported, + std::nullopt, "No support for update of this type"); } -void MoqtSession::ControlStream::OnPublishNamespaceMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtPublishNamespace& message) { if (!session_->ValidateRequestId(message.request_id)) { - return; + return absl::OkStatus(); } if (session_->sent_goaway_) { QUIC_DLOG(INFO) << ENDPOINT << "Received a PUBLISH_NAMESPACE after GOAWAY"; - SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, - std::nullopt, "PUBLISH_NAMESPACE after GOAWAY"); - return; + return SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, + std::nullopt, "PUBLISH_NAMESPACE after GOAWAY"); } QUIC_DLOG(INFO) << ENDPOINT << "Received a PUBLISH_NAMESPACE for " << message.track_namespace; @@ -1345,10 +1343,9 @@ session_->incoming_publish_namespaces_by_namespace_.emplace( message.track_namespace, message.request_id); if (!inserted) { - SendRequestError(message.request_id, - RequestErrorCode::kDuplicateSubscription, std::nullopt, - "Duplicate PUBLISH_NAMESPACE"); - return; + return SendRequestError(message.request_id, + RequestErrorCode::kDuplicateSubscription, + std::nullopt, "Duplicate PUBLISH_NAMESPACE"); } quiche::QuicheWeakPtr<MoqtSessionInterface> session_weakptr = session_->GetWeakPtr(); @@ -1363,108 +1360,110 @@ return; } if (error.has_value()) { - SendRequestError(message.request_id, *error); + CheckStatus(SendRequestError(message.request_id, *error)); session->incoming_publish_namespaces_by_id_.erase(message.request_id); session->incoming_publish_namespaces_by_namespace_.erase( message.track_namespace); } else { - SendRequestOk(message.request_id, MessageParameters()); + CheckStatus(SendRequestOk(message.request_id, MessageParameters())); } }); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnPublishNamespaceDoneMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtPublishNamespaceDone& message) { auto it = session_->incoming_publish_namespaces_by_id_.find(message.request_id); if (it == session_->incoming_publish_namespaces_by_id_.end()) { - return; + return absl::OkStatus(); } session_->callbacks_.incoming_publish_namespace_callback( it->second, std::nullopt, nullptr); session_->incoming_publish_namespaces_by_namespace_.erase(it->second); session_->incoming_publish_namespaces_by_id_.erase(it); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnPublishNamespaceCancelMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtPublishNamespaceCancel& message) { auto it = session_->publish_namespace_by_id_.find(message.request_id); if (it == session_->publish_namespace_by_id_.end()) { - return; // State might have been destroyed due to PUBLISH_NAMESPACE_DONE. + return absl::OkStatus(); // State might have been destroyed due to + // PUBLISH_NAMESPACE_DONE. } std::move(it->second.cancel_callback)(MoqtRequestErrorInfo{ message.error_code, std::nullopt, std::string(message.error_reason)}); session_->publish_namespace_by_namespace_.erase(it->second.track_namespace); session_->publish_namespace_by_id_.erase(it); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnTrackStatusMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtTrackStatus& message) { if (!session_->ValidateRequestId(message.request_id)) { - return; + return absl::OkStatus(); } if (session_->sent_goaway_) { QUIC_DLOG(INFO) << ENDPOINT << "Received a TRACK_STATUS_REQUEST after GOAWAY"; - SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, - std::nullopt, "TRACK_STATUS_REQUEST after GOAWAY"); - return; + return SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, + std::nullopt, "TRACK_STATUS_REQUEST after GOAWAY"); } // TODO(martinduke): Handle authentication. std::shared_ptr<MoqtTrackPublisher> track = session_->publisher_->GetTrack(message.full_track_name); if (track == nullptr) { - SendRequestError(message.request_id, RequestErrorCode::kDoesNotExist, - std::nullopt, "Track does not exist"); - return; + return SendRequestError(message.request_id, RequestErrorCode::kDoesNotExist, + std::nullopt, "Track does not exist"); } auto [it, inserted] = session_->incoming_track_status_.emplace( message.request_id, std::make_unique<DownstreamTrackStatus>( message.request_id, session_, track.get())); track->AddObjectListener(it->second.get()); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnGoAwayMessage(const MoqtGoAway& message) { +absl::Status MoqtSession::ControlStream::OnControlMessage( + const MoqtGoAway& message) { if (!message.new_session_uri.empty() && perspective() == quic::Perspective::IS_SERVER) { - session_->Error(MoqtError::kProtocolViolation, - "Received GOAWAY with new_session_uri on the server"); - return; + return absl::InvalidArgumentError( + "Received GOAWAY with new_session_uri on the server"); } if (session_->received_goaway_) { - session_->Error(MoqtError::kProtocolViolation, - "Received multiple GOAWAY messages"); - return; + return absl::InvalidArgumentError("Received multiple GOAWAY messages"); } session_->received_goaway_ = true; if (session_->callbacks_.goaway_received_callback != nullptr) { std::move(session_->callbacks_.goaway_received_callback)( message.new_session_uri); } + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnMaxRequestIdMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtMaxRequestId& message) { if (message.max_request_id < session_->peer_max_request_id_) { QUIC_DLOG(INFO) << ENDPOINT << "Peer sent MAX_REQUEST_ID message with " "lower value than previous"; - session_->Error(MoqtError::kProtocolViolation, - "MAX_REQUEST_ID has lower value than previous"); - return; + return absl::InvalidArgumentError( + "MAX_REQUEST_ID has lower value than previous"); } session_->peer_max_request_id_ = message.max_request_id; + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnFetchMessage(const MoqtFetch& message) { +absl::Status MoqtSession::ControlStream::OnControlMessage( + const MoqtFetch& message) { if (!session_->ValidateRequestId(message.request_id)) { - return; + return absl::OkStatus(); } if (session_->sent_goaway_) { QUIC_DLOG(INFO) << ENDPOINT << "Received a FETCH after GOAWAY"; - SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, - std::nullopt, "FETCH after GOAWAY"); - return; + return SendRequestError(message.request_id, RequestErrorCode::kUnauthorized, + std::nullopt, "FETCH after GOAWAY"); } std::unique_ptr<MoqtFetchTask> fetch; FullTrackName track_name; @@ -1477,9 +1476,9 @@ if (track_publisher == nullptr) { QUIC_DLOG(INFO) << ENDPOINT << "FETCH for " << track_name << " rejected by the application: not found"; - SendRequestError(message.request_id, RequestErrorCode::kDoesNotExist, - std::nullopt, "not found"); - return; + return SendRequestError(message.request_id, + RequestErrorCode::kDoesNotExist, std::nullopt, + "not found"); } QUIC_DLOG(INFO) << ENDPOINT << "Received a StandaloneFETCH for " << track_name; @@ -1500,26 +1499,24 @@ QUIC_DLOG(INFO) << ENDPOINT << "Received a JOINING_FETCH for " << "request_id " << joining_request_id << " that does not exist"; - SendRequestError(message.request_id, - RequestErrorCode::kInvalidJoiningRequestId, std::nullopt, - "Joining Fetch for non-existent request"); - return; + return SendRequestError( + message.request_id, RequestErrorCode::kInvalidJoiningRequestId, + std::nullopt, "Joining Fetch for non-existent request"); } if (!it->second->can_have_joining_fetch()) { QUIC_DLOG(INFO) << ENDPOINT << "Received a JOINING_FETCH for " << "joining_request_id " << joining_request_id << " that is not forwarding"; - session_->Error(MoqtError::kProtocolViolation, - "Joining Fetch for non-forwarding subscribe"); - return; + return absl::InvalidArgumentError( + "Joining Fetch for non-forwarding subscribe"); } track_name = it->second->publisher().GetTrackName(); if (it->second->established()) { if (!it->second->parameters().largest_object.has_value()) { // Nothing to Fetch. - SendRequestError(message.request_id, RequestErrorCode::kDoesNotExist, - std::nullopt, "not found"); - return; + return SendRequestError(message.request_id, + RequestErrorCode::kDoesNotExist, std::nullopt, + "not found"); } const Location largest_location = *it->second->parameters().largest_object; @@ -1536,9 +1533,9 @@ std::get<JoiningFetchAbsolute>(message.fetch); start_group = absolute_fetch.joining_start; if (start_group > largest_location.group) { - SendRequestError(message.request_id, RequestErrorCode::kInvalidRange, - std::nullopt, "invalid range"); - return; + return SendRequestError(message.request_id, + RequestErrorCode::kInvalidRange, std::nullopt, + "invalid range"); } } fetch = it->second->publisher().StandaloneFetch( @@ -1563,9 +1560,8 @@ if (!fetch->GetStatus().ok()) { QUIC_DLOG(INFO) << ENDPOINT << "FETCH for " << track_name << " could not initialize the task"; - SendRequestError(message.request_id, RequestErrorCode::kInvalidRange, - std::nullopt, fetch->GetStatus().message()); - return; + return SendRequestError(message.request_id, RequestErrorCode::kInvalidRange, + std::nullopt, fetch->GetStatus().message()); } auto published_fetch = std::make_unique<PublishedFetch>( message.request_id, session_, std::move(fetch)); @@ -1574,8 +1570,9 @@ if (!result.second) { // Emplace failed. QUIC_DLOG(INFO) << ENDPOINT << "FETCH for " << track_name << " could not be added to the session"; - SendRequestError(message.request_id, RequestErrorCode::kInternalError, - std::nullopt, "Could not initialize FETCH state"); + return SendRequestError(message.request_id, + RequestErrorCode::kInternalError, std::nullopt, + "Could not initialize FETCH state"); } MoqtFetchTask* fetch_task = result.first->second->fetch_task(); fetch_task->SetFetchResponseCallback( @@ -1587,13 +1584,14 @@ if (std::holds_alternative<MoqtFetchOk>(message)) { MoqtFetchOk& fetch_ok = std::get<MoqtFetchOk>(message); fetch_ok.request_id = request_id; - SendOrBufferMessage(session_->framer_.SerializeFetchOk(fetch_ok)); + CheckStatus(SendOrBufferMessage( + session_->framer_.SerializeFetchOk(fetch_ok))); return; } - SendRequestError(request_id, - std::get<MoqtRequestError>(message).error_code, - std::get<MoqtRequestError>(message).retry_interval, - std::get<MoqtRequestError>(message).reason_phrase); + CheckStatus(SendRequestError( + request_id, std::get<MoqtRequestError>(message).error_code, + std::get<MoqtRequestError>(message).retry_interval, + std::get<MoqtRequestError>(message).reason_phrase)); }); // Set a temporary new-object callback that creates a data stream. When // created, the stream visitor will replace this callback. @@ -1618,21 +1616,21 @@ } } }); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnFetchOkMessage(const MoqtFetchOk& message) { +absl::Status MoqtSession::ControlStream::OnControlMessage( + const MoqtFetchOk& message) { RemoteTrack* track = session_->RemoteTrackById(message.request_id); if (track == nullptr) { QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_OK for " << "request_id = " << message.request_id << " but no track exists"; // Subscription state might have been destroyed for internal reasons. - return; + return absl::OkStatus(); } if (!track->is_fetch()) { - session_->Error(MoqtError::kProtocolViolation, - "Received FETCH_OK for a SUBSCRIBE"); - return; + return absl::InvalidArgumentError("Received FETCH_OK for a SUBSCRIBE"); } QUIC_DLOG(INFO) << ENDPOINT << "Received the FETCH_OK for request_id = " << message.request_id << " " << track->full_track_name(); @@ -1640,16 +1638,19 @@ fetch->OnFetchResult( message.end_location, absl::OkStatus(), [=, session = session_]() { session->CancelFetch(message.request_id); }); + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnRequestsBlockedMessage( +absl::Status MoqtSession::ControlStream::OnControlMessage( const MoqtRequestsBlocked& message) { // TODO(martinduke): Derive logic for granting more subscribes. + return absl::OkStatus(); } -void MoqtSession::ControlStream::OnPublishMessage(const MoqtPublish& message) { +absl::Status MoqtSession::ControlStream::OnControlMessage( + const MoqtPublish& message) { if (!session_->ValidateRequestId(message.request_id)) { - return; + return absl::OkStatus(); } RequestErrorCode error_code = session_->sent_goaway_ ? RequestErrorCode::kUnauthorized @@ -1657,7 +1658,8 @@ absl::string_view error_reason = session_->sent_goaway_ ? "Received a PUBLISH after GOAWAY" : "PUBLISH is not supported"; - SendRequestError(message.request_id, error_code, std::nullopt, error_reason); + return SendRequestError(message.request_id, error_code, std::nullopt, + error_reason); } void MoqtSession::IncomingDataStream::OnObjectMessage(const MoqtObject& message, @@ -2015,7 +2017,7 @@ // publisher. default_publisher_priority_ = subscribe_ok.extensions.default_publisher_priority(); - stream->SendOrBufferMessage( + stream->SendOrBufferMessageOrFatal( session_->framer_.SerializeSubscribeOk(subscribe_ok)); // TODO(martinduke): If we buffer objects that arrived previously, the arrival // of the track alias disambiguates what subscription they belong to. Send @@ -2024,7 +2026,9 @@ void MoqtSession::PublishedSubscription::OnSubscribeRejected( MoqtRequestErrorInfo info) { - session_->GetControlStream()->SendRequestError(request_id_, info); + ControlStream* control_stream = session_->GetControlStream(); + control_stream->CheckStatus( + control_stream->SendRequestError(request_id_, info)); session_->published_subscriptions_.erase(request_id_); // No class access below this line! } @@ -2601,7 +2605,7 @@ } MoqtFetchCancel message; message.request_id = request_id; - stream->SendOrBufferMessage(framer_.SerializeFetchCancel(message)); + stream->SendOrBufferMessageOrFatal(framer_.SerializeFetchCancel(message)); // The FETCH_CANCEL will cause a RESET_STREAM to return, which would be the // same as a STOP_SENDING. However, a FETCH_CANCEL works even if the stream // hasn't opened yet.
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 7c8d30e..cb00379 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -18,6 +18,7 @@ #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_alarm.h" #include "quiche/quic/core/quic_alarm_factory.h" @@ -242,7 +243,7 @@ public: explicit ControlStream(MoqtSession* session) : MoqtBidiStreamBase( - &session->framer_, + &session->framer_, session->ControlMessageParser(), // Do nothing on deletion. It threw an error on RESET_STREAM or // FIN, and we're here because the session is being destroyed. []() {}, @@ -255,39 +256,42 @@ }), session_(session), weak_ptr_factory_(this) {} - void set_stream(webtransport::Stream* absl_nonnull stream) override; + + void OnStreamBound() override; + absl::Status OnRawControlMessage( + const MoqtRawControlMessage& message) override; // MoqtControlParserVisitor implementation. - void OnClientSetupMessage(const MoqtClientSetup& message) override; - void OnServerSetupMessage(const MoqtServerSetup& message) override; - void OnRequestOkMessage(const MoqtRequestOk& message) override; - void OnRequestErrorMessage(const MoqtRequestError& message) override; - void OnSubscribeMessage(const MoqtSubscribe& message) override; - void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override; - void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override; - void OnPublishDoneMessage(const MoqtPublishDone& /*message*/) override; - void OnRequestUpdateMessage(const MoqtRequestUpdate& message) override; - void OnPublishNamespaceMessage( - const MoqtPublishNamespace& message) override; - void OnPublishNamespaceDoneMessage( - const MoqtPublishNamespaceDone& /*message*/) override; - void OnPublishNamespaceCancelMessage( - const MoqtPublishNamespaceCancel& message) override; - void OnTrackStatusMessage(const MoqtTrackStatus& message) override; - void OnGoAwayMessage(const MoqtGoAway& /*message*/) override; - void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override; - void OnFetchMessage(const MoqtFetch& message) override; - void OnFetchCancelMessage(const MoqtFetchCancel& /*message*/) override {} - void OnFetchOkMessage(const MoqtFetchOk& message) override; - void OnRequestsBlockedMessage(const MoqtRequestsBlocked& message) override; - void OnPublishMessage(const MoqtPublish& message) override; - void OnObjectAckMessage(const MoqtObjectAck& message) override { + absl::Status OnControlMessage(const MoqtClientSetup& message); + absl::Status OnControlMessage(const MoqtServerSetup& message); + absl::Status OnControlMessage(const MoqtRequestOk& message); + absl::Status OnControlMessage(const MoqtRequestError& message); + absl::Status OnControlMessage(const MoqtSubscribe& message); + absl::Status OnControlMessage(const MoqtSubscribeOk& message); + absl::Status OnControlMessage(const MoqtUnsubscribe& message); + absl::Status OnControlMessage(const MoqtPublishDone& /*message*/); + absl::Status OnControlMessage(const MoqtRequestUpdate& message); + absl::Status OnControlMessage(const MoqtPublishNamespace& message); + absl::Status OnControlMessage(const MoqtPublishNamespaceDone& /*message*/); + absl::Status OnControlMessage(const MoqtPublishNamespaceCancel& message); + absl::Status OnControlMessage(const MoqtTrackStatus& message); + absl::Status OnControlMessage(const MoqtGoAway& /*message*/); + absl::Status OnControlMessage(const MoqtMaxRequestId& message); + absl::Status OnControlMessage(const MoqtFetch& message); + absl::Status OnControlMessage(const MoqtFetchCancel& /*message*/) { + return absl::OkStatus(); + } + absl::Status OnControlMessage(const MoqtFetchOk& message); + absl::Status OnControlMessage(const MoqtRequestsBlocked& message); + absl::Status OnControlMessage(const MoqtPublish& message); + absl::Status OnControlMessage(const MoqtObjectAck& message) { auto subscription_it = session_->published_subscriptions_.find(message.subscribe_id); if (subscription_it == session_->published_subscriptions_.end()) { - return; + return absl::OkStatus(); } subscription_it->second->ProcessObjectAck(message); + return absl::OkStatus(); } // webtransport::StreamVisitor overrides @@ -793,6 +797,11 @@ // underlying WebTransport session to be destroyed. void CleanUpState(); + MoqtControlMessageParser ControlMessageParser() const { + return MoqtControlMessageParser(parameters_.version, + parameters_.using_webtrans); + } + bool is_closing_ = false; webtransport::Session* session_; MoqtSessionParameters parameters_;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index e5336cc..8888581 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -9,7 +9,6 @@ #include <cstring> #include <memory> #include <optional> -#include <queue> #include <string> #include <utility> #include <variant> @@ -185,7 +184,7 @@ // supports. MoqtObjectListener* ReceiveSubscribeSynchronousOk( MockTrackPublisher* publisher, MoqtSubscribe& subscribe, - MoqtControlParserVisitor* control_parser, uint64_t track_alias = 0, + MoqtBidiStreamTestWrapper* control_parser, uint64_t track_alias = 0, TrackExtensions extensions = TrackExtensions()) { MoqtObjectListener* listener_ptr = nullptr; EXPECT_CALL(*publisher, AddObjectListener) @@ -203,7 +202,7 @@ extensions, }; EXPECT_CALL(mock_stream_, Writev(SerializedControlMessage(expected_ok), _)); - control_parser->OnSubscribeMessage(subscribe); + control_parser->ReceiveMessage(subscribe); return listener_ptr; } @@ -304,13 +303,13 @@ session_.OnSessionReady(); // Receive SERVER_SETUP - MoqtControlParserVisitor* stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::FetchParserVisitorFromWebtransportStreamVisitor( - &session_, visitor.get()); + std::move(visitor)); // Handle the server setup MoqtServerSetup setup; // No fields are set. EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); - stream_input->OnServerSetupMessage(setup); + stream_input->ReceiveMessage(setup); } TEST_F(MoqtSessionTest, OnSessionReadyNoControlStream) { @@ -420,12 +419,12 @@ TEST_F(MoqtSessionTest, AddLocalTrack) { MoqtSubscribe request = DefaultSubscribe(); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); // Request for track returns REQUEST_ERROR. EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); // Add the track. Now Subscribe should succeed. MockTrackPublisher* track = CreateTrackPublisher(); @@ -442,18 +441,18 @@ .parameters = MessageParameters(), }; publish.parameters.largest_object = Location(4, 5); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); // Request for track returns REQUEST_ERROR. EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnPublishMessage(publish); + stream_input->ReceiveMessage(publish); } TEST_F(MoqtSessionTest, PublishNamespaceWithOkAndCancel) { testing::MockFunction<void(std::optional<MoqtRequestErrorInfo> error_message)> publish_namespace_response_callback; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL( mock_stream_, @@ -469,14 +468,14 @@ .WillOnce([&](std::optional<MoqtRequestErrorInfo> error) { EXPECT_FALSE(error.has_value()); }); - stream_input->OnRequestOkMessage(ok); + stream_input->ReceiveMessage(ok); MoqtPublishNamespaceCancel cancel = { /*request_id=*/0, RequestErrorCode::kInternalError, /*error_reason=*/"Test error", }; - stream_input->OnPublishNamespaceCancelMessage(cancel); + stream_input->ReceiveMessage(cancel); EXPECT_EQ(cancel_error_info.error_code, RequestErrorCode::kInternalError); EXPECT_EQ(cancel_error_info.reason_phrase, "Test error"); // State is gone. @@ -486,7 +485,7 @@ TEST_F(MoqtSessionTest, PublishNamespaceWithOkAndPublishNamespaceDone) { testing::MockFunction<void(std::optional<MoqtRequestErrorInfo> error_message)> publish_namespace_resolved_callback; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL( mock_stream_, @@ -500,7 +499,7 @@ .WillOnce([&](std::optional<MoqtRequestErrorInfo> error) { EXPECT_FALSE(error.has_value()); }); - stream_input->OnRequestOkMessage(ok); + stream_input->ReceiveMessage(ok); EXPECT_CALL( mock_stream_, @@ -513,7 +512,7 @@ TEST_F(MoqtSessionTest, PublishNamespaceWithError) { testing::MockFunction<void(std::optional<MoqtRequestErrorInfo> error_message)> publish_namespace_resolved_callback; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL( mock_stream_, @@ -530,13 +529,13 @@ EXPECT_EQ(error->error_code, RequestErrorCode::kInternalError); EXPECT_EQ(error->reason_phrase, "Test error"); }); - stream_input->OnRequestErrorMessage(error); + stream_input->ReceiveMessage(error); // State is gone. EXPECT_FALSE(session_.PublishNamespaceDone(TrackNamespace{"foo"})); } TEST_F(MoqtSessionTest, AsynchronousSubscribeReturnsOk) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtSubscribe request = DefaultSubscribe(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -544,7 +543,7 @@ EXPECT_CALL(*track, AddObjectListener) .WillOnce( [&](MoqtObjectListener* listener_ptr) { listener = listener_ptr; }); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribeOk), _)); @@ -554,7 +553,7 @@ } TEST_F(MoqtSessionTest, AsynchronousSubscribeReturnsError) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtSubscribe request = DefaultSubscribe(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -562,7 +561,7 @@ EXPECT_CALL(*track, AddObjectListener) .WillOnce( [&](MoqtObjectListener* listener_ptr) { listener = listener_ptr; }); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); listener->OnSubscribeRejected(MoqtRequestErrorInfo( @@ -572,7 +571,7 @@ } TEST_F(MoqtSessionTest, SynchronousSubscribeReturnsError) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtSubscribe request = DefaultSubscribe(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -585,13 +584,13 @@ listener->OnSubscribeRejected(MoqtRequestErrorInfo( RequestErrorCode::kInternalError, std::nullopt, "Test error")); }); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); EXPECT_EQ(MoqtSessionPeer::GetSubscription(&session_, kDefaultPeerRequestId), nullptr); } TEST_F(MoqtSessionTest, SubscribeForPast) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); SetLargestId(track, Location(10, 20)); @@ -600,7 +599,7 @@ } TEST_F(MoqtSessionTest, SubscribeDoNotForward) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); MoqtSubscribe request = DefaultSubscribe(); @@ -616,7 +615,7 @@ } TEST_F(MoqtSessionTest, SubscribeAbsoluteStartNoDataYet) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); MoqtSubscribe request = DefaultSubscribe(); @@ -630,7 +629,7 @@ } TEST_F(MoqtSessionTest, SubscribeNextGroup) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); MoqtSubscribe request = DefaultSubscribe(); @@ -651,7 +650,7 @@ } TEST_F(MoqtSessionTest, TwoSubscribesForTrack) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); MoqtSubscribe request = DefaultSubscribe(); @@ -661,11 +660,11 @@ request.parameters.subscription_filter.emplace(Location(12, 0)); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); } TEST_F(MoqtSessionTest, UnsubscribeAllowsSecondSubscribe) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); MoqtSubscribe request = DefaultSubscribe(); @@ -675,7 +674,7 @@ MoqtUnsubscribe unsubscribe = { kDefaultPeerRequestId, }; - stream_input->OnUnsubscribeMessage(unsubscribe); + stream_input->ReceiveMessage(unsubscribe); EXPECT_EQ(MoqtSessionPeer::GetSubscription(&session_, 1), nullptr); // Subscribe again, succeeds. @@ -690,12 +689,12 @@ MoqtSubscribe request = DefaultSubscribe(); request.request_id = kDefaultInitialMaxRequestId + 1; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kTooManyRequests), "Received request with too large ID")); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); } TEST_F(MoqtSessionTest, RequestIdWrongLsb) { @@ -704,24 +703,24 @@ TEST_F(MoqtSessionTest, SubscribeIdNotIncreasing) { MoqtSubscribe request = DefaultSubscribe(); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); EXPECT_CALL(*track, AddObjectListener); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); // Second request is a protocol violation. request.full_track_name = FullTrackName({"dead", "beef"}); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kInvalidRequestId), "Duplicate request ID")); - stream_input->OnSubscribeMessage(request); + stream_input->ReceiveMessage(request); } TEST_F(MoqtSessionTest, TooManySubscribes) { MoqtSessionPeer::set_next_request_id(&session_, kDefaultInitialMaxRequestId - 1); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&mock_stream_)); @@ -743,7 +742,7 @@ } TEST_F(MoqtSessionTest, SubscribeDuplicateTrackName) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&mock_stream_)); @@ -757,7 +756,7 @@ } TEST_F(MoqtSessionTest, SubscribeWithOk) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); @@ -778,11 +777,11 @@ EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_TRUE(std::holds_alternative<SubscribeOkData>(response)); }); - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); } TEST_F(MoqtSessionTest, SubscribeNextGroupWithOk) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtSubscribe subscribe = DefaultLocalSubscribe(); subscribe.parameters.subscription_filter.emplace( @@ -804,11 +803,11 @@ EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_TRUE(std::holds_alternative<SubscribeOkData>(response)); }); - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); } TEST_F(MoqtSessionTest, OutgoingSubscribeUpdate) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, GetStreamById) .WillRepeatedly(Return(&mock_stream_)); @@ -825,7 +824,7 @@ TrackExtensions(), }; EXPECT_CALL(remote_track_visitor_, OnReply); - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestUpdate), _)); MessageParameters update_parameters; @@ -836,7 +835,7 @@ EXPECT_TRUE(session_.SubscribeUpdate( FullTrackName("foo", "bar"), update_parameters, [&](std::optional<MoqtRequestErrorInfo> info) { response = info; })); - stream_input->OnRequestOkMessage(MoqtRequestOk{ + stream_input->ReceiveMessage(MoqtRequestOk{ /*request_id=*/2, MessageParameters(), }); @@ -857,7 +856,7 @@ TEST_F(MoqtSessionTest, MaxRequestIdChangesResponse) { MoqtSessionPeer::set_next_request_id(&session_, kDefaultInitialMaxRequestId); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&mock_stream_)); @@ -871,7 +870,7 @@ MoqtMaxRequestId max_request_id = { /*max_request_id=*/kDefaultInitialMaxRequestId + 1, }; - stream_input->OnMaxRequestIdMessage(max_request_id); + stream_input->ReceiveMessage(max_request_id); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); @@ -883,17 +882,17 @@ MoqtMaxRequestId max_request_id = { /*max_request_id=*/kDefaultInitialMaxRequestId - 1, }; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), "MAX_REQUEST_ID has lower value than previous")) .Times(1); - stream_input->OnMaxRequestIdMessage(max_request_id); + stream_input->ReceiveMessage(max_request_id); } TEST_F(MoqtSessionTest, GrantMoreRequests) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kMaxRequestId), _)); @@ -906,7 +905,7 @@ } TEST_F(MoqtSessionTest, SubscribeWithError) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); @@ -931,11 +930,11 @@ std::get<MoqtRequestErrorInfo>(response).reason_phrase == "deadbeef"); }); - stream_input->OnRequestErrorMessage(error); + stream_input->ReceiveMessage(error); } TEST_F(MoqtSessionTest, Unsubscribe) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), /*track_alias=*/2, &remote_track_visitor_); @@ -949,7 +948,7 @@ TEST_F(MoqtSessionTest, ReplyToPublishNamespaceWithOkThenPublishNamespaceDone) { TrackNamespace track_namespace{"foo"}; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MessageParameters parameters; parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, @@ -970,7 +969,7 @@ Writev(SerializedControlMessage(MoqtRequestOk{ kDefaultPeerRequestId, MessageParameters()}), _)); - stream_input->OnPublishNamespaceMessage(publish_namespace); + stream_input->ReceiveMessage(publish_namespace); MoqtPublishNamespaceDone publish_namespace_done = { /*request_id=*/0, }; @@ -979,14 +978,14 @@ .WillOnce( [](const TrackNamespace&, const std::optional<MessageParameters>&, MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); }); - stream_input->OnPublishNamespaceDoneMessage(publish_namespace_done); + stream_input->ReceiveMessage(publish_namespace_done); } TEST_F(MoqtSessionTest, ReplyToPublishNamespaceWithOkThenPublishNamespaceCancel) { TrackNamespace track_namespace{"foo"}; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MessageParameters parameters; parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, @@ -1007,7 +1006,7 @@ Writev(SerializedControlMessage(MoqtRequestOk{ kDefaultPeerRequestId, MessageParameters()}), _)); - stream_input->OnPublishNamespaceMessage(publish_namespace); + stream_input->ReceiveMessage(publish_namespace); EXPECT_CALL(mock_stream_, Writev(SerializedControlMessage(MoqtPublishNamespaceCancel{ kDefaultPeerRequestId, @@ -1020,7 +1019,7 @@ TEST_F(MoqtSessionTest, ReplyToPublishNamespaceWithError) { TrackNamespace track_namespace{"foo"}; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MessageParameters parameters; parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, @@ -1045,7 +1044,7 @@ kDefaultPeerRequestId, error.error_code, error.retry_interval, error.reason_phrase}), _)); - stream_input->OnPublishNamespaceMessage(publish_namespace); + stream_input->ReceiveMessage(publish_namespace); } TEST_F(MoqtSessionTest, SubscribeNamespaceLifeCycle) { @@ -1073,7 +1072,7 @@ EXPECT_FALSE(error.has_value()); }); MoqtRequestOk ok = {kDefaultLocalRequestId, MessageParameters()}; - stream_input->OnRequestOkMessage(ok); + QUICHE_ASSERT_OK(stream_input->OnControlMessage(ok)); EXPECT_TRUE(got_callback); EXPECT_CALL(mock_stream_, ResetWithUserCode); } @@ -1107,7 +1106,7 @@ MoqtRequestError error = {kDefaultLocalRequestId, RequestErrorCode::kInvalidRange, std::nullopt, "deadbeef"}; - stream_input->OnRequestErrorMessage(error); + QUICHE_ASSERT_OK(stream_input->OnControlMessage(error)); EXPECT_TRUE(got_callback); } @@ -1269,10 +1268,10 @@ TrackExtensions(), }; webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); EXPECT_CALL(remote_track_visitor_, OnReply).Times(1); - control_stream->OnSubscribeOkMessage(ok); + control_stream->ReceiveMessage(ok); } TEST_F(MoqtSessionTest, SubscribeOkWithBadTrackAlias) { @@ -1294,12 +1293,12 @@ TrackExtensions(), }; webtransport::test::MockStream mock_control_stream; - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream); EXPECT_CALL( mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kDuplicateTrackAlias), "")); - control_stream->OnSubscribeOkMessage(subscribe_ok); + control_stream->ReceiveMessage(subscribe_ok); } TEST_F(MoqtSessionTest, CreateOutgoingDataStreamAndSend) { @@ -1545,7 +1544,7 @@ }; EXPECT_CALL(mock_stream_, ResetWithUserCode(kResetCodeCancelled)); webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(control_stream, Writev(SerializedControlMessage(expected_publish_done), _)); @@ -1613,7 +1612,7 @@ }; EXPECT_CALL(mock_stream_, ResetWithUserCode(kResetCodeCancelled)); webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(control_stream, Writev(SerializedControlMessage(expected_publish_done), _)); @@ -2038,12 +2037,12 @@ auto track = SetupPublisher(ftn, MoqtForwardingPreference::kSubgroup, Location(4, 2)); MoqtSessionPeer::AddSubscription(&session_, track, 0, 1, 3, 4); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtUnsubscribe unsubscribe = { /*request_id=*/0, }; - stream_input->OnUnsubscribeMessage(unsubscribe); + stream_input->ReceiveMessage(unsubscribe); EXPECT_EQ(MoqtSessionPeer::GetSubscription(&session_, 0), nullptr); } @@ -2119,7 +2118,7 @@ TEST_F(MoqtSessionTest, UsePeerDefaultPriority) { FullTrackName ftn("foo", "bar"); const MoqtPriority kPeerDefaultPriority = 0x20; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kSubscribe), _)); @@ -2131,7 +2130,7 @@ TrackExtensions(std::nullopt, std::nullopt, kPeerDefaultPriority, std::nullopt, std::nullopt, std::nullopt); EXPECT_CALL(remote_track_visitor_, OnReply); - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); // Omit priority from a datagram. char datagram[] = {0x0c, 0x02, 0x05, 0x64, 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66}; @@ -2166,7 +2165,7 @@ TEST_F(MoqtSessionTest, OmitPublisherPriority) { MoqtSubscribe request = DefaultSubscribe(); const MoqtPriority kLocalDefaultPriority = 0x20; - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); // Create the publisher and the SUBSCRIBE with kLocalDefaultPriority. MockTrackPublisher* track = CreateTrackPublisher(); @@ -2580,7 +2579,7 @@ // All callbacks are called asynchronously. TEST_F(MoqtSessionTest, ProcessFetchGetEverythingFromUpstream) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -2590,7 +2589,7 @@ MockFetchTask* fetch_task = fetch_task_ptr.get(); EXPECT_CALL(*track, StandaloneFetch) .WillOnce(Return(std::move(fetch_task_ptr))); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); // Compose and send the FETCH_OK. MoqtFetchOk expected_ok; @@ -2612,7 +2611,7 @@ // All callbacks are called synchronously. All relevant data is cached (or this // is the original publisher). TEST_F(MoqtSessionTest, ProcessFetchWholeRangeIsPresent) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -2635,12 +2634,12 @@ MoqtFetchTask::GetNextObjectResult::kPending); // Everything spins upon message receipt. FetchTask is generating the // necessary callbacks. - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, SendFragmentedFetchObject) { using ::testing::ByMove; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); fetch.request_id = 3; // Use an odd ID for peer request in client session. @@ -2654,7 +2653,7 @@ .WillOnce(Return(ByMove(std::move(fetch_task_ptr)))); // Receive FETCH, send FETCH_OK. - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); // FETCH_OK responding to the request. MoqtFetchOk expected_ok; expected_ok.request_id = fetch.request_id; @@ -2717,7 +2716,7 @@ // The publisher has the first object locally, but has to go upstream to get // the rest. TEST_F(MoqtSessionTest, FetchReturnsObjectBeforeOk) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -2734,7 +2733,7 @@ ExpectSendObject(fetch_task, data_stream, MoqtObjectStatus::kNormal, Location(0, 0), "foo", MoqtFetchTask::GetNextObjectResult::kPending); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); MoqtFetchOk expected_ok; expected_ok.request_id = fetch.request_id; @@ -2745,7 +2744,7 @@ } TEST_F(MoqtSessionTest, FetchReturnsObjectBeforeError) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -2761,7 +2760,7 @@ ExpectSendObject(fetch_task, data_stream, MoqtObjectStatus::kNormal, Location(0, 0), "foo", MoqtFetchTask::GetNextObjectResult::kPending); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); MoqtRequestError expected_error{ fetch.request_id, RequestErrorCode::kDoesNotExist, std::nullopt, "foo"}; @@ -2772,22 +2771,22 @@ TEST_F(MoqtSessionTest, InvalidFetch) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); MockTrackPublisher* track = CreateTrackPublisher(); MoqtFetch fetch = DefaultFetch(); EXPECT_CALL(*track, StandaloneFetch) .WillOnce(Return(std::make_unique<MockFetchTask>())); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); EXPECT_CALL(mock_session_, CloseSession(static_cast<uint64_t>(MoqtError::kInvalidRequestId), "Duplicate request ID")) .Times(1); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, FetchFails) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -2800,11 +2799,11 @@ .WillRepeatedly(Return(absl::Status(absl::StatusCode::kInternal, "foo"))); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, FullFetchDeliveryWithFlowControl) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); MockTrackPublisher* track = CreateTrackPublisher(); @@ -2815,7 +2814,7 @@ EXPECT_CALL(*track, StandaloneFetch) .WillOnce(Return(std::move(fetch_task_ptr))); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) .WillOnce(Return(false)); fetch_task->CallObjectsAvailableCallback(); @@ -2838,7 +2837,7 @@ // Give it the latest object filter. subscribe.parameters.subscription_filter.emplace( MoqtFilterType::kLargestObject); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); SetLargestId(track, Location(4, 10)); @@ -2857,7 +2856,7 @@ fetch.fetch = JoiningFetchRelative(1, 2); EXPECT_CALL(*track, StandaloneFetch(Location(2, 0), Location(4, 10), _)) .WillOnce(Return(std::make_unique<MockFetchTask>())); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, IncomingAbsoluteJoiningFetch) { @@ -2865,7 +2864,7 @@ // Give it the latest object filter. subscribe.parameters.subscription_filter.emplace( MoqtFilterType::kLargestObject); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); SetLargestId(track, Location(4, 10)); @@ -2884,11 +2883,11 @@ fetch.fetch = JoiningFetchAbsolute(1, 2); EXPECT_CALL(*track, StandaloneFetch(Location(2, 0), Location(4, 10), _)) .WillOnce(Return(std::make_unique<MockFetchTask>())); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, IncomingJoiningFetchBadRequestId) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtFetch fetch = DefaultFetch(); fetch.fetch = JoiningFetchRelative(1, 2); @@ -2900,13 +2899,13 @@ }; EXPECT_CALL(mock_stream_, Writev(SerializedControlMessage(expected_error), _)); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, IncomingJoiningFetchForwardZero) { MoqtSubscribe subscribe = DefaultSubscribe(); subscribe.parameters.set_forward(false); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); SetLargestId(track, Location(2, 10)); @@ -2919,11 +2918,11 @@ CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), "Joining Fetch for non-forwarding subscribe")) .Times(1); - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); } TEST_F(MoqtSessionTest, SendJoiningFetch) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&mock_stream_)); @@ -2945,7 +2944,7 @@ } TEST_F(MoqtSessionTest, SendJoiningFetchNoFlowControl) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&mock_stream_)); @@ -2960,9 +2959,9 @@ EXPECT_CALL(remote_track_visitor_, OnReply).Times(1); MessageParameters parameters; parameters.largest_object = Location(2, 0); - stream_input->OnSubscribeOkMessage( + stream_input->ReceiveMessage( MoqtSubscribeOk(0, 2, parameters, TrackExtensions())); - stream_input->OnFetchOkMessage(MoqtFetchOk( + stream_input->ReceiveMessage(MoqtFetchOk( 2, false, Location(2, 0), MessageParameters(), TrackExtensions())); // Packet arrives on FETCH stream. MoqtObject object = { @@ -3116,7 +3115,7 @@ } TEST_F(MoqtSessionTest, FetchThenOkThenCancel) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); std::unique_ptr<MoqtFetchTask> fetch_task; session_.Fetch( @@ -3130,7 +3129,7 @@ /*end_of_track=*/false, Location(3, 25), MessageParameters(), TrackExtensions(), }; - stream_input->OnFetchOkMessage(ok); + stream_input->ReceiveMessage(ok); ASSERT_NE(fetch_task, nullptr); EXPECT_TRUE(fetch_task->GetStatus().ok()); PublishedObject object; @@ -3143,7 +3142,7 @@ } TEST_F(MoqtSessionTest, FetchThenError) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); std::unique_ptr<MoqtFetchTask> fetch_task; session_.Fetch( @@ -3158,7 +3157,7 @@ /*retry_interval=*/std::nullopt, "No username provided", }; - stream_input->OnRequestErrorMessage(error); + stream_input->ReceiveMessage(error); ASSERT_NE(fetch_task, nullptr); EXPECT_TRUE(absl::IsPermissionDenied(fetch_task->GetStatus())); EXPECT_EQ(fetch_task->GetStatus().message(), "No username provided"); @@ -3166,7 +3165,7 @@ // The application takes objects as they arrive. TEST_F(MoqtSessionTest, IncomingFetchObjectsGreedyApp) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); std::unique_ptr<MoqtFetchTask> fetch_task; uint64_t expected_object_id = 0; @@ -3235,7 +3234,7 @@ MessageParameters(), TrackExtensions(), }; - stream_input->OnFetchOkMessage(ok); + stream_input->ReceiveMessage(ok); ASSERT_NE(fetch_task, nullptr); EXPECT_EQ(expected_object_id, 2); @@ -3250,7 +3249,7 @@ } TEST_F(MoqtSessionTest, IncomingFetchObjectsSlowApp) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); std::unique_ptr<MoqtFetchTask> fetch_task; uint64_t expected_object_id = 0; @@ -3306,7 +3305,7 @@ /*end_of_track=*/false, Location(3, 25), MessageParameters(), TrackExtensions(), }; - stream_input->OnFetchOkMessage(ok); + stream_input->ReceiveMessage(ok); ASSERT_NE(fetch_task, nullptr); EXPECT_TRUE(objects_available); @@ -3391,7 +3390,7 @@ TEST_F(MoqtSessionTest, DeliveryTimeoutParameter) { MoqtSubscribe request = DefaultSubscribe(); request.parameters.delivery_timeout = quic::QuicTimeDelta::FromSeconds(1); - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MockTrackPublisher* track = CreateTrackPublisher(); ReceiveSubscribeSynchronousOk(track, request, control_stream.get()); @@ -3649,10 +3648,10 @@ } TEST_F(MoqtSessionTest, ReceiveGoAwayEnforcement) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(session_callbacks_.goaway_received_callback, Call("foo")); - stream_input->OnGoAwayMessage(MoqtGoAway("foo")); + stream_input->ReceiveMessage(MoqtGoAway("foo")); // New requests not allowed. EXPECT_CALL(mock_stream_, Writev).Times(0); MessageParameters parameters = SubscribeForTest(); @@ -3684,11 +3683,11 @@ reported_error = true; EXPECT_EQ(error_message, "Received multiple GOAWAY messages"); }); - stream_input->OnGoAwayMessage(MoqtGoAway("foo")); + stream_input->ReceiveMessage(MoqtGoAway("foo")); } TEST_F(MoqtSessionTest, SendGoAwayEnforcement) { - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); CreateTrackPublisher(); EXPECT_CALL(mock_stream_, @@ -3696,31 +3695,34 @@ session_.GoAway(""); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnSubscribeMessage(DefaultSubscribe()); + stream_input->ReceiveMessage(DefaultSubscribe()); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnPublishNamespaceMessage( + stream_input->ReceiveMessage( MoqtPublishNamespace(3, TrackNamespace({"foo"}), MessageParameters())); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); MoqtFetch fetch = DefaultFetch(); fetch.request_id = 5; - stream_input->OnFetchMessage(fetch); + stream_input->ReceiveMessage(fetch); MoqtFramer framer(true); SessionNamespaceTree tree; MoqtIncomingSubscribeNamespaceCallback callback = DefaultIncomingSubscribeNamespaceCallback; - MoqtNamespacePublisherStream namespace_stream(&framer, &mock_stream_, nullptr, - &tree, callback); + MoqtNamespacePublisherStream namespace_stream( + &framer, MoqtControlMessageParser(kDefaultMoqtVersion, true), nullptr, + &tree, callback); + namespace_stream.BindStream(&mock_stream_); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - namespace_stream.OnSubscribeNamespaceMessage(MoqtSubscribeNamespace(7)); + QUICHE_ASSERT_OK( + namespace_stream.OnControlMessage(MoqtSubscribeNamespace(7))); MoqtTrackStatus track_status = DefaultSubscribe(); track_status.request_id = 7; EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); - stream_input->OnTrackStatusMessage(track_status); + stream_input->ReceiveMessage(track_status); // Block all outgoing SUBSCRIBE, PUBLISH_NAMESPACE, GOAWAY,etc. EXPECT_CALL(mock_stream_, Writev).Times(0); MessageParameters parameters = SubscribeForTest(); @@ -3755,7 +3757,7 @@ TEST_F(MoqtSessionTest, ClientCannotSendNewSessionUri) { // session_ is a client session. - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); // Client GOAWAY not sent. EXPECT_CALL(mock_stream_, Writev).Times(0); @@ -3768,7 +3770,7 @@ MoqtSessionParameters(quic::Perspective::IS_SERVER), std::make_unique<quic::test::TestAlarmFactory>(), session_callbacks_.AsSessionCallbacks()); - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session, &mock_stream_); EXPECT_CALL( mock_session, @@ -3782,13 +3784,13 @@ EXPECT_EQ(error_message, "Received GOAWAY with new_session_uri on the server"); }); - stream_input->OnGoAwayMessage(MoqtGoAway("foo")); + stream_input->ReceiveMessage(MoqtGoAway("foo")); EXPECT_TRUE(reported_error); } TEST_F(MoqtSessionTest, ReceivePublishDoneWithOpenStreams) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&control_stream)); @@ -3805,7 +3807,7 @@ parameters, TrackExtensions(), }; - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); constexpr uint64_t kNumStreams = 3; webtransport::test::MockStream data[kNumStreams]; std::unique_ptr<webtransport::StreamVisitor> data_streams[kNumStreams]; @@ -3832,7 +3834,7 @@ SubscribeRemoteTrack* track = MoqtSessionPeer::remote_track(&session_, 0); ASSERT_NE(track, nullptr); EXPECT_FALSE(track->all_streams_closed()); - stream_input->OnPublishDoneMessage( + stream_input->ReceiveMessage( MoqtPublishDone(0, PublishDoneCode::kTrackEnded, kNumStreams, "foo")); track = MoqtSessionPeer::remote_track(&session_, 0); ASSERT_NE(track, nullptr); @@ -3846,7 +3848,7 @@ TEST_F(MoqtSessionTest, ReceivePublishDoneWithClosedStreams) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&control_stream)); @@ -3863,7 +3865,7 @@ parameters, TrackExtensions(), }; - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); constexpr uint64_t kNumStreams = 3; webtransport::test::MockStream data[kNumStreams]; std::unique_ptr<webtransport::StreamVisitor> data_streams[kNumStreams]; @@ -3894,14 +3896,14 @@ ASSERT_NE(track, nullptr); EXPECT_FALSE(track->all_streams_closed()); EXPECT_CALL(remote_track_visitor_, OnPublishDone(_)); - stream_input->OnPublishDoneMessage( + stream_input->ReceiveMessage( MoqtPublishDone(0, PublishDoneCode::kTrackEnded, kNumStreams, "foo")); EXPECT_EQ(MoqtSessionPeer::remote_track(&session_, 0), nullptr); } TEST_F(MoqtSessionTest, PublishDoneTimeout) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(mock_session_, GetStreamById(_)) .WillRepeatedly(Return(&control_stream)); @@ -3918,7 +3920,7 @@ parameters, TrackExtensions(), }; - stream_input->OnSubscribeOkMessage(ok); + stream_input->ReceiveMessage(ok); constexpr uint64_t kNumStreams = 3; webtransport::test::MockStream data[kNumStreams]; std::unique_ptr<webtransport::StreamVisitor> data_streams[kNumStreams]; @@ -3949,7 +3951,7 @@ ASSERT_NE(track, nullptr); EXPECT_FALSE(track->all_streams_closed()); // stream_count includes a stream that was never sent. - stream_input->OnPublishDoneMessage( + stream_input->ReceiveMessage( MoqtPublishDone(0, PublishDoneCode::kTrackEnded, kNumStreams + 1, "foo")); EXPECT_FALSE(track->all_streams_closed()); auto* publish_done_alarm = @@ -3965,7 +3967,7 @@ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), /*track_alias=*/2, &remote_track_visitor_); webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream( @@ -3997,7 +3999,7 @@ MoqtSessionPeer::CreateRemoteTrack(&session_, DefaultSubscribe(), /*track_alias=*/2, &remote_track_visitor_); webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); std::unique_ptr<MoqtDataParserVisitor> object_stream = MoqtSessionPeer::CreateIncomingDataStream( @@ -4046,7 +4048,7 @@ "bar", true); EXPECT_FALSE(IsInvalidArgument(task->GetStatus())); webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(control_stream, Writev(ControlMessageOfType(MoqtMessageType::kFetchCancel), _)); @@ -4061,7 +4063,7 @@ TEST_F(MoqtSessionTest, IncomingTrackStatusThenSynchronousOk) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); auto* track = CreateTrackPublisher(); @@ -4083,12 +4085,12 @@ EXPECT_CALL(*track, RemoveObjectListener); listener->OnSubscribeAccepted(); }); - stream_input->OnTrackStatusMessage(track_status); + stream_input->ReceiveMessage(track_status); } TEST_F(MoqtSessionTest, IncomingTrackStatusThenAsynchronousOk) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); auto* track = CreateTrackPublisher(); @@ -4096,7 +4098,7 @@ MoqtObjectListener* listener = nullptr; EXPECT_CALL(*track, AddObjectListener) .WillOnce(testing::SaveArg<0>(&listener)); - stream_input->OnTrackStatusMessage(track_status); + stream_input->ReceiveMessage(track_status); ASSERT_NE(listener, nullptr); EXPECT_CALL(*track, expiration) .WillRepeatedly(Return(quic::QuicTimeDelta::FromMilliseconds(10000))); @@ -4112,7 +4114,7 @@ TEST_F(MoqtSessionTest, IncomingTrackStatusThenSynchronousError) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); auto* track = CreateTrackPublisher(); @@ -4128,13 +4130,13 @@ RequestErrorCode::kInternalError, std::nullopt, "Test error")); executed_AddObjectListener = true; }); - stream_input->OnTrackStatusMessage(track_status); + stream_input->ReceiveMessage(track_status); EXPECT_TRUE(executed_AddObjectListener); } TEST_F(MoqtSessionTest, IncomingTrackStatusThenAsynchronousError) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); auto* track = CreateTrackPublisher(); @@ -4142,7 +4144,7 @@ MoqtObjectListener* listener; EXPECT_CALL(*track, AddObjectListener) .WillOnce(testing::SaveArg<0>(&listener)); - stream_input->OnTrackStatusMessage(track_status); + stream_input->ReceiveMessage(track_status); ASSERT_NE(listener, nullptr); EXPECT_CALL(control_stream, Writev(ControlMessageOfType(MoqtMessageType::kRequestError), _)); @@ -4152,7 +4154,7 @@ } TEST_F(MoqtSessionTest, FinReportedToVisitor) { - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &control_stream_); EXPECT_CALL(mock_session_, GetStreamById) .WillRepeatedly(Return(&control_stream_)); @@ -4171,7 +4173,7 @@ EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_TRUE(std::holds_alternative<SubscribeOkData>(response)); }); - control_stream->OnSubscribeOkMessage(ok); + control_stream->ReceiveMessage(ok); MoqtObject object = { /*track_alias=*/2, /*group_id=*/0, @@ -4196,7 +4198,7 @@ } TEST_F(MoqtSessionTest, ResetReportedToVisitor) { - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &control_stream_); EXPECT_CALL(mock_session_, GetStreamById) .WillRepeatedly(Return(&control_stream_)); @@ -4215,7 +4217,7 @@ EXPECT_EQ(ftn, FullTrackName("foo", "bar")); EXPECT_TRUE(std::holds_alternative<SubscribeOkData>(response)); }); - control_stream->OnSubscribeOkMessage(ok); + control_stream->ReceiveMessage(ok); MoqtObject object = { /*track_alias=*/2, /*group_id=*/0, @@ -4242,7 +4244,7 @@ TEST_F(MoqtSessionTest, IncomingPublishNamespaceCleanup) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); // Register two incoming PUBLISH_NAMESPACE. MoqtPublishNamespace publish_namespace{ @@ -4256,7 +4258,7 @@ }); EXPECT_CALL(control_stream, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_input->OnPublishNamespaceMessage(publish_namespace); + stream_input->ReceiveMessage(publish_namespace); publish_namespace = MoqtPublishNamespace( /*request_id=*/3, TrackNamespace{"bar"}, MessageParameters()); @@ -4269,7 +4271,7 @@ }); EXPECT_CALL(control_stream, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - stream_input->OnPublishNamespaceMessage(publish_namespace); + stream_input->ReceiveMessage(publish_namespace); // Revoke "bar" MoqtPublishNamespaceDone done{/*request_id=*/3}; @@ -4279,7 +4281,7 @@ .WillOnce( [](const TrackNamespace&, const std::optional<MessageParameters>&, MoqtResponseCallback callback) { EXPECT_EQ(callback, nullptr); }); - stream_input->OnPublishNamespaceDoneMessage(done); + stream_input->ReceiveMessage(done); // Destroying the session should revoke "foo". EXPECT_CALL( @@ -4310,7 +4312,7 @@ TEST_F(MoqtSessionTest, SubscribeThenRequestOk) { webtransport::test::MockStream control_stream; - std::unique_ptr<MoqtControlParserVisitor> stream_input = + std::unique_ptr<MoqtBidiStreamTestWrapper> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); MessageParameters parameters = SubscribeForTest(); parameters.subscription_filter.emplace(MoqtFilterType::kLargestObject); @@ -4318,33 +4320,34 @@ parameters); EXPECT_CALL(mock_session_, CloseSession); EXPECT_CALL(session_callbacks_.session_terminated_callback, Call); - stream_input->OnRequestOkMessage(MoqtRequestOk{0, MessageParameters()}); + stream_input->ReceiveMessage(MoqtRequestOk{0, MessageParameters()}); } TEST_F(MoqtSessionTest, ClientSetupNotAllowedOnControlStream) { // While technically on the Control stream, when it arrives, it's an // UnknownBidiStream - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, CloseSession); EXPECT_CALL(session_callbacks_.session_terminated_callback, Call); - control_stream->OnClientSetupMessage(MoqtClientSetup()); + control_stream->ReceiveMessage( + MoqtClientSetup(SetupParameters("/", "example.com", 0))); } TEST_F(MoqtSessionTest, NamespaceNotAllowedOnControlStream) { - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, CloseSession); EXPECT_CALL(session_callbacks_.session_terminated_callback, Call); - control_stream->OnNamespaceMessage(MoqtNamespace()); + control_stream->ReceiveMessage(MoqtNamespace()); } TEST_F(MoqtSessionTest, NamespaceDoneNotAllowedOnControlStream) { - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); EXPECT_CALL(mock_session_, CloseSession); EXPECT_CALL(session_callbacks_.session_terminated_callback, Call); - control_stream->OnNamespaceDoneMessage(MoqtNamespaceDone()); + control_stream->ReceiveMessage(MoqtNamespaceDone()); } TEST_F(MoqtSessionTest, IncomingRequestUpdateTruncatesSubscription) { @@ -4374,7 +4377,7 @@ listener->OnNewObjectAvailable(Location(8, 0), std::nullopt, 0x80); - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); // Update the filter to exclude the live edge. The next object is out of // window. @@ -4382,7 +4385,7 @@ parameters.subscription_filter.emplace(Location(4, 0), 7); EXPECT_CALL(mock_stream_, Writev(ControlMessageOfType(MoqtMessageType::kRequestOk), _)); - control_stream->OnRequestUpdateMessage(MoqtRequestUpdate{3, 1, parameters}); + control_stream->ReceiveMessage(MoqtRequestUpdate{3, 1, parameters}); EXPECT_CALL(*mock_publisher, GetCachedObject).Times(0); EXPECT_CALL(mock_session_, SendOrQueueDatagram).Times(0); listener->OnNewObjectAvailable(Location(8, 1), 0, 0x80); @@ -4391,7 +4394,7 @@ TEST_F(MoqtSessionTest, StopSendingBlocksSubgroup) { MoqtSubscribe subscribe = DefaultSubscribe(); MockTrackPublisher* track = CreateTrackPublisher(); - std::unique_ptr<MoqtControlParserVisitor> control_stream = + std::unique_ptr<MoqtBidiStreamTestWrapper> control_stream = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream_); MoqtObjectListener* listener = ReceiveSubscribeSynchronousOk(track, subscribe, control_stream.get(), 0);
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h index 805bba5..3f3a740 100644 --- a/quiche/quic/moqt/test_tools/moqt_session_peer.h +++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -8,23 +8,34 @@ #include <cstdint> #include <memory> #include <optional> +#include <string> #include <utility> #include "absl/base/casts.h" +#include "absl/base/nullability.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "quiche/quic/core/quic_alarm.h" #include "quiche/quic/core/quic_alarm_factory.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/moqt/moqt_bidi_stream.h" #include "quiche/quic/moqt/moqt_fetch_task.h" #include "quiche/quic/moqt/moqt_key_value_pair.h" #include "quiche/quic/moqt/moqt_messages.h" +#include "quiche/quic/moqt/moqt_names.h" #include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_priority.h" #include "quiche/quic/moqt/moqt_publisher.h" #include "quiche/quic/moqt/moqt_session.h" +#include "quiche/quic/moqt/moqt_session_interface.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/moqt_types.h" +#include "quiche/quic/moqt/test_tools/moqt_framer_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/test_tools/quiche_test_utils.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -38,19 +49,45 @@ } }; +// Helper class to interact with MOQT bidi streams in tests. +class MoqtBidiStreamTestWrapper { + public: + explicit MoqtBidiStreamTestWrapper( + std::unique_ptr<MoqtBidiStreamBase> absl_nonnull stream) + : stream_(std::move(stream)) {} + + MoqtBidiStreamBase& stream() { return *stream_; } + + // Simulates receiving the specified control message on the bidi stream. + void ReceiveMessage(const AnyMoqtControlMessage& message) { + std::string serialized = SerializeGenericMessage(message); + quiche::QuicheDataReader reader(serialized); + uint64_t raw_type; + ASSERT_TRUE(reader.ReadVarInt62(&raw_type)); + ASSERT_TRUE(reader.Seek(2)); + absl::Status status = stream_->OnRawControlMessage(MoqtRawControlMessage{ + .type = static_cast<MoqtMessageType>(raw_type), + .payload = std::string(reader.ReadRemainingPayload())}); + stream_->CheckStatus(status); + } + + private: + std::unique_ptr<MoqtBidiStreamBase> absl_nonnull stream_; +}; + class MoqtSessionPeer { public: static constexpr webtransport::StreamId kControlStreamId = 4; - static std::unique_ptr<MoqtControlParserVisitor> CreateControlStream( + static std::unique_ptr<MoqtBidiStreamTestWrapper> CreateControlStream( MoqtSession* session, webtransport::test::MockStream* stream) { auto new_stream = std::make_unique<MoqtSession::ControlStream>(session); session->control_stream_ = new_stream->GetWeakPtr(); - new_stream->set_stream(stream); + new_stream->BindStream(stream); ON_CALL(*stream, visitor()) .WillByDefault(::testing::Return(new_stream.get())); ON_CALL(*stream, CanWrite).WillByDefault(::testing::Return(true)); - return new_stream; + return std::make_unique<MoqtBidiStreamTestWrapper>(std::move(new_stream)); } static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream( @@ -77,10 +114,11 @@ // can inject packets into that stream. // This function is useful for any test that wants to inject packets on a // stream created by the MoqtSession. - static MoqtControlParserVisitor* + static std::unique_ptr<MoqtBidiStreamTestWrapper> FetchParserVisitorFromWebtransportStreamVisitor( - MoqtSession* session, webtransport::StreamVisitor* visitor) { - return static_cast<MoqtSession::ControlStream*>(visitor); + std::unique_ptr<webtransport::StreamVisitor> visitor) { + return std::make_unique<MoqtBidiStreamTestWrapper>(absl::WrapUnique( + absl::down_cast<MoqtSession::ControlStream*>(visitor.release()))); } static void CreateRemoteTrack(MoqtSession* session,