Always buffer control messages in MoQT. Right now, if we run out of flow control and buffer too many control messages, it will result in a fatal error. PiperOrigin-RevId: 608684846
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 2aba2b2..443dd50 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -12,10 +12,12 @@ #include <utility> #include <vector> + #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_subscribe_windows.h" @@ -33,6 +35,27 @@ using ::quic::Perspective; +MoqtSession::Stream* MoqtSession::GetControlStream() { + if (!control_stream_.has_value()) { + return nullptr; + } + webtransport::Stream* raw_stream = session_->GetStreamById(*control_stream_); + if (raw_stream == nullptr) { + return nullptr; + } + return static_cast<Stream*>(raw_stream->visitor()); +} + +void MoqtSession::SendControlMessage(quiche::QuicheBuffer message) { + Stream* control_stream = GetControlStream(); + if (control_stream == nullptr) { + QUICHE_LOG(DFATAL) << "Trying to send a message on the control stream " + "while it does not exist"; + return; + } + control_stream->SendOrBufferMessage(std::move(message)); +} + void MoqtSession::OnSessionReady() { QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready"; if (parameters_.perspective == Perspective::IS_SERVER) { @@ -55,12 +78,7 @@ if (!parameters_.using_webtrans) { setup.path = parameters_.path; } - quiche::QuicheBuffer serialized_setup = framer_.SerializeClientSetup(setup); - bool success = control_stream->Write(serialized_setup.AsStringView()); - if (!success) { - Error(MoqtError::kGenericError, "Failed to write client SETUP message"); - return; - } + SendControlMessage(framer_.SerializeClientSetup(setup)); QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message"; } @@ -123,12 +141,7 @@ } MoqtAnnounce message; message.track_namespace = track_namespace; - bool success = session_->GetStreamById(*control_stream_) - ->Write(framer_.SerializeAnnounce(message).AsStringView()); - if (!success) { - Error(MoqtError::kGenericError, "Failed to write ANNOUNCE message"); - return; - } + SendControlMessage(framer_.SerializeAnnounce(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for " << message.track_namespace; pending_outgoing_announces_[track_namespace] = std::move(announce_callback); @@ -235,13 +248,7 @@ } else { message.track_alias = next_remote_track_alias_++; } - bool success = - session_->GetStreamById(*control_stream_) - ->Write(framer_.SerializeSubscribe(message).AsStringView()); - if (!success) { - Error(MoqtError::kGenericError, "Failed to write SUBSCRIBE message"); - return false; - } + SendControlMessage(framer_.SerializeSubscribe(message)); QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for " << message.track_namespace << ":" << message.track_name; active_subscribes_.try_emplace(message.subscribe_id, message, visitor); @@ -457,13 +464,7 @@ MoqtServerSetup response; response.selected_version = session_->parameters_.version; response.role = MoqtRole::kBoth; - bool success = stream_->Write( - session_->framer_.SerializeServerSetup(response).AsStringView()); - if (!success) { - session_->Error(MoqtError::kGenericError, - "Failed to write server SETUP message"); - return; - } + SendOrBufferMessage(session_->framer_.SerializeServerSetup(response)); QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message"; } // TODO: handle role and path. @@ -506,13 +507,8 @@ subscribe_error.error_code = error_code; subscribe_error.reason_phrase = reason_phrase; subscribe_error.track_alias = track_alias; - bool success = - stream_->Write(session_->framer_.SerializeSubscribeError(subscribe_error) - .AsStringView()); - if (!success) { - session_->Error(MoqtError::kGenericError, - "Failed to write SUBSCRIBE_ERROR message"); - } + SendOrBufferMessage( + session_->framer_.SerializeSubscribeError(subscribe_error)); } void MoqtSession::Stream::OnSubscribeMessage(const MoqtSubscribe& message) { @@ -573,13 +569,7 @@ } MoqtSubscribeOk subscribe_ok; subscribe_ok.subscribe_id = message.subscribe_id; - bool success = stream_->Write( - session_->framer_.SerializeSubscribeOk(subscribe_ok).AsStringView()); - if (!success) { - session_->Error(MoqtError::kGenericError, - "Failed to write SUBSCRIBE_OK message"); - return; - } + SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok)); QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for " << message.track_namespace << ":" << message.track_name; if (!end.has_value()) { @@ -660,13 +650,7 @@ } MoqtAnnounceOk ok; ok.track_namespace = message.track_namespace; - bool success = - stream_->Write(session_->framer_.SerializeAnnounceOk(ok).AsStringView()); - if (!success) { - session_->Error(MoqtError::kGenericError, - "Failed to write ANNOUNCE_OK message"); - return; - } + SendOrBufferMessage(session_->framer_.SerializeAnnounceOk(ok)); } void MoqtSession::Stream::OnAnnounceOkMessage(const MoqtAnnounceOk& message) { @@ -739,4 +723,17 @@ return sequence; } +void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message, + bool fin) { + quiche::StreamWriteOptions options; + options.set_send_fin(fin); + options.set_buffer_unconditionally(true); + std::array<absl::string_view, 1> write_vector = {message.AsStringView()}; + absl::Status success = stream_->Writev(absl::MakeSpan(write_vector), options); + if (!success.ok()) { + session_->Error(MoqtError::kGenericError, + "Failed to write a control message"); + } +} + } // namespace moqt
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 0bd0fa6..cb9163d 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -19,6 +19,7 @@ #include "quiche/quic/moqt/moqt_parser.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_callbacks.h" #include "quiche/common/simple_buffer_allocator.h" #include "quiche/web_transport/web_transport.h" @@ -120,8 +121,7 @@ bool PublishObject(const FullTrackName& full_track_name, uint64_t group_id, uint64_t object_id, uint64_t object_send_order, MoqtForwardingPreference forwarding_preference, - absl::string_view payload, - bool end_of_stream); + absl::string_view payload, bool end_of_stream); // TODO: Add an API to FIN the stream for a particular track/group/object. // TODO: Add an API to send partial objects. @@ -174,6 +174,10 @@ webtransport::Stream* stream() const { return stream_; } + // Sends a control message, or buffers it if there is insufficient flow + // control credit. + void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false); + private: friend class test::MoqtSessionPeer; void SendSubscribeError(const MoqtSubscribe& message, @@ -191,6 +195,12 @@ std::string partial_object_; }; + // Returns the pointer to the control stream, or nullptr if none is present. + Stream* GetControlStream(); + // Sends a message on the control stream; QUICHE_DCHECKs if no control stream + // is present. + void SendControlMessage(quiche::QuicheBuffer message); + // Returns false if the SUBSCRIBE isn't sent. bool Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor); // converts two MoqtLocations into absolute sequences.
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index cc24b9e..9401a24 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -34,6 +34,7 @@ namespace { using ::testing::_; +using ::testing::AnyNumber; using ::testing::Return; using ::testing::StrictMock; @@ -65,10 +66,13 @@ class MoqtSessionPeer { public: static std::unique_ptr<MoqtParserVisitor> CreateControlStream( - MoqtSession* session, webtransport::Stream* stream) { + MoqtSession* session, webtransport::test::MockStream* stream) { auto new_stream = std::make_unique<MoqtSession::Stream>( session, stream, /*is_control_stream=*/true); session->control_stream_ = kControlStreamId; + EXPECT_CALL(*stream, visitor()) + .Times(AnyNumber()) + .WillRepeatedly(Return(new_stream.get())); return new_stream; } @@ -155,7 +159,9 @@ }); EXPECT_CALL(mock_stream, GetStreamId()) .WillOnce(Return(webtransport::StreamId(4))); + EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); bool correct_message = false; + EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); }); EXPECT_CALL(mock_stream, Writev(_, _)) .WillOnce([&](absl::Span<const absl::string_view> data, const quiche::StreamWriteOptions& options) { @@ -789,7 +795,9 @@ }); EXPECT_CALL(mock_stream, GetStreamId()) .WillOnce(Return(webtransport::StreamId(4))); + EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream)); bool correct_message = false; + EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); }); EXPECT_CALL(mock_stream, Writev(_, _)) .WillOnce([&](absl::Span<const absl::string_view> data, const quiche::StreamWriteOptions& options) {