Always buffer control messages in MoQT.

Right now, if we run out of flow control and buffer too many control messages, it will result in a fatal error.

PiperOrigin-RevId: 608684846
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 2aba2b2..443dd50 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -12,10 +12,12 @@
 #include <utility>
 #include <vector>
 
+
 #include "absl/algorithm/container.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/moqt/moqt_messages.h"
 #include "quiche/quic/moqt/moqt_subscribe_windows.h"
@@ -33,6 +35,27 @@
 
 using ::quic::Perspective;
 
+MoqtSession::Stream* MoqtSession::GetControlStream() {
+  if (!control_stream_.has_value()) {
+    return nullptr;
+  }
+  webtransport::Stream* raw_stream = session_->GetStreamById(*control_stream_);
+  if (raw_stream == nullptr) {
+    return nullptr;
+  }
+  return static_cast<Stream*>(raw_stream->visitor());
+}
+
+void MoqtSession::SendControlMessage(quiche::QuicheBuffer message) {
+  Stream* control_stream = GetControlStream();
+  if (control_stream == nullptr) {
+    QUICHE_LOG(DFATAL) << "Trying to send a message on the control stream "
+                          "while it does not exist";
+    return;
+  }
+  control_stream->SendOrBufferMessage(std::move(message));
+}
+
 void MoqtSession::OnSessionReady() {
   QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready";
   if (parameters_.perspective == Perspective::IS_SERVER) {
@@ -55,12 +78,7 @@
   if (!parameters_.using_webtrans) {
     setup.path = parameters_.path;
   }
-  quiche::QuicheBuffer serialized_setup = framer_.SerializeClientSetup(setup);
-  bool success = control_stream->Write(serialized_setup.AsStringView());
-  if (!success) {
-    Error(MoqtError::kGenericError, "Failed to write client SETUP message");
-    return;
-  }
+  SendControlMessage(framer_.SerializeClientSetup(setup));
   QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message";
 }
 
@@ -123,12 +141,7 @@
   }
   MoqtAnnounce message;
   message.track_namespace = track_namespace;
-  bool success = session_->GetStreamById(*control_stream_)
-                     ->Write(framer_.SerializeAnnounce(message).AsStringView());
-  if (!success) {
-    Error(MoqtError::kGenericError, "Failed to write ANNOUNCE message");
-    return;
-  }
+  SendControlMessage(framer_.SerializeAnnounce(message));
   QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for "
                   << message.track_namespace;
   pending_outgoing_announces_[track_namespace] = std::move(announce_callback);
@@ -235,13 +248,7 @@
   } else {
     message.track_alias = next_remote_track_alias_++;
   }
-  bool success =
-      session_->GetStreamById(*control_stream_)
-          ->Write(framer_.SerializeSubscribe(message).AsStringView());
-  if (!success) {
-    Error(MoqtError::kGenericError, "Failed to write SUBSCRIBE message");
-    return false;
-  }
+  SendControlMessage(framer_.SerializeSubscribe(message));
   QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
                   << message.track_namespace << ":" << message.track_name;
   active_subscribes_.try_emplace(message.subscribe_id, message, visitor);
@@ -457,13 +464,7 @@
     MoqtServerSetup response;
     response.selected_version = session_->parameters_.version;
     response.role = MoqtRole::kBoth;
-    bool success = stream_->Write(
-        session_->framer_.SerializeServerSetup(response).AsStringView());
-    if (!success) {
-      session_->Error(MoqtError::kGenericError,
-                      "Failed to write server SETUP message");
-      return;
-    }
+    SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
     QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
   }
   // TODO: handle role and path.
@@ -506,13 +507,8 @@
   subscribe_error.error_code = error_code;
   subscribe_error.reason_phrase = reason_phrase;
   subscribe_error.track_alias = track_alias;
-  bool success =
-      stream_->Write(session_->framer_.SerializeSubscribeError(subscribe_error)
-                         .AsStringView());
-  if (!success) {
-    session_->Error(MoqtError::kGenericError,
-                    "Failed to write SUBSCRIBE_ERROR message");
-  }
+  SendOrBufferMessage(
+      session_->framer_.SerializeSubscribeError(subscribe_error));
 }
 
 void MoqtSession::Stream::OnSubscribeMessage(const MoqtSubscribe& message) {
@@ -573,13 +569,7 @@
   }
   MoqtSubscribeOk subscribe_ok;
   subscribe_ok.subscribe_id = message.subscribe_id;
-  bool success = stream_->Write(
-      session_->framer_.SerializeSubscribeOk(subscribe_ok).AsStringView());
-  if (!success) {
-    session_->Error(MoqtError::kGenericError,
-                    "Failed to write SUBSCRIBE_OK message");
-    return;
-  }
+  SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok));
   QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
                   << message.track_namespace << ":" << message.track_name;
   if (!end.has_value()) {
@@ -660,13 +650,7 @@
   }
   MoqtAnnounceOk ok;
   ok.track_namespace = message.track_namespace;
-  bool success =
-      stream_->Write(session_->framer_.SerializeAnnounceOk(ok).AsStringView());
-  if (!success) {
-    session_->Error(MoqtError::kGenericError,
-                    "Failed to write ANNOUNCE_OK message");
-    return;
-  }
+  SendOrBufferMessage(session_->framer_.SerializeAnnounceOk(ok));
 }
 
 void MoqtSession::Stream::OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
@@ -739,4 +723,17 @@
   return sequence;
 }
 
+void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message,
+                                              bool fin) {
+  quiche::StreamWriteOptions options;
+  options.set_send_fin(fin);
+  options.set_buffer_unconditionally(true);
+  std::array<absl::string_view, 1> write_vector = {message.AsStringView()};
+  absl::Status success = stream_->Writev(absl::MakeSpan(write_vector), options);
+  if (!success.ok()) {
+    session_->Error(MoqtError::kGenericError,
+                    "Failed to write a control message");
+  }
+}
+
 }  // namespace moqt
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 0bd0fa6..cb9163d 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -19,6 +19,7 @@
 #include "quiche/quic/moqt/moqt_parser.h"
 #include "quiche/quic/moqt/moqt_track.h"
 #include "quiche/common/platform/api/quiche_export.h"
+#include "quiche/common/quiche_buffer_allocator.h"
 #include "quiche/common/quiche_callbacks.h"
 #include "quiche/common/simple_buffer_allocator.h"
 #include "quiche/web_transport/web_transport.h"
@@ -120,8 +121,7 @@
   bool PublishObject(const FullTrackName& full_track_name, uint64_t group_id,
                      uint64_t object_id, uint64_t object_send_order,
                      MoqtForwardingPreference forwarding_preference,
-                     absl::string_view payload,
-                     bool end_of_stream);
+                     absl::string_view payload, bool end_of_stream);
   // TODO: Add an API to FIN the stream for a particular track/group/object.
   // TODO: Add an API to send partial objects.
 
@@ -174,6 +174,10 @@
 
     webtransport::Stream* stream() const { return stream_; }
 
+    // Sends a control message, or buffers it if there is insufficient flow
+    // control credit.
+    void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false);
+
    private:
     friend class test::MoqtSessionPeer;
     void SendSubscribeError(const MoqtSubscribe& message,
@@ -191,6 +195,12 @@
     std::string partial_object_;
   };
 
+  // Returns the pointer to the control stream, or nullptr if none is present.
+  Stream* GetControlStream();
+  // Sends a message on the control stream; QUICHE_DCHECKs if no control stream
+  // is present.
+  void SendControlMessage(quiche::QuicheBuffer message);
+
   // Returns false if the SUBSCRIBE isn't sent.
   bool Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor);
   // converts two MoqtLocations into absolute sequences.
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index cc24b9e..9401a24 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -34,6 +34,7 @@
 namespace {
 
 using ::testing::_;
+using ::testing::AnyNumber;
 using ::testing::Return;
 using ::testing::StrictMock;
 
@@ -65,10 +66,13 @@
 class MoqtSessionPeer {
  public:
   static std::unique_ptr<MoqtParserVisitor> CreateControlStream(
-      MoqtSession* session, webtransport::Stream* stream) {
+      MoqtSession* session, webtransport::test::MockStream* stream) {
     auto new_stream = std::make_unique<MoqtSession::Stream>(
         session, stream, /*is_control_stream=*/true);
     session->control_stream_ = kControlStreamId;
+    EXPECT_CALL(*stream, visitor())
+        .Times(AnyNumber())
+        .WillRepeatedly(Return(new_stream.get()));
     return new_stream;
   }
 
@@ -155,7 +159,9 @@
       });
   EXPECT_CALL(mock_stream, GetStreamId())
       .WillOnce(Return(webtransport::StreamId(4)));
+  EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
   bool correct_message = false;
+  EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
   EXPECT_CALL(mock_stream, Writev(_, _))
       .WillOnce([&](absl::Span<const absl::string_view> data,
                     const quiche::StreamWriteOptions& options) {
@@ -789,7 +795,9 @@
       });
   EXPECT_CALL(mock_stream, GetStreamId())
       .WillOnce(Return(webtransport::StreamId(4)));
+  EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
   bool correct_message = false;
+  EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
   EXPECT_CALL(mock_stream, Writev(_, _))
       .WillOnce([&](absl::Span<const absl::string_view> data,
                     const quiche::StreamWriteOptions& options) {