Always buffer control messages in MoQT.
Right now, if we run out of flow control and buffer too many control messages, it will result in a fatal error.
PiperOrigin-RevId: 608684846
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 2aba2b2..443dd50 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -12,10 +12,12 @@
#include <utility>
#include <vector>
+
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "quiche/quic/core/quic_types.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_subscribe_windows.h"
@@ -33,6 +35,27 @@
using ::quic::Perspective;
+MoqtSession::Stream* MoqtSession::GetControlStream() {
+ if (!control_stream_.has_value()) {
+ return nullptr;
+ }
+ webtransport::Stream* raw_stream = session_->GetStreamById(*control_stream_);
+ if (raw_stream == nullptr) {
+ return nullptr;
+ }
+ return static_cast<Stream*>(raw_stream->visitor());
+}
+
+void MoqtSession::SendControlMessage(quiche::QuicheBuffer message) {
+ Stream* control_stream = GetControlStream();
+ if (control_stream == nullptr) {
+ QUICHE_LOG(DFATAL) << "Trying to send a message on the control stream "
+ "while it does not exist";
+ return;
+ }
+ control_stream->SendOrBufferMessage(std::move(message));
+}
+
void MoqtSession::OnSessionReady() {
QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready";
if (parameters_.perspective == Perspective::IS_SERVER) {
@@ -55,12 +78,7 @@
if (!parameters_.using_webtrans) {
setup.path = parameters_.path;
}
- quiche::QuicheBuffer serialized_setup = framer_.SerializeClientSetup(setup);
- bool success = control_stream->Write(serialized_setup.AsStringView());
- if (!success) {
- Error(MoqtError::kGenericError, "Failed to write client SETUP message");
- return;
- }
+ SendControlMessage(framer_.SerializeClientSetup(setup));
QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message";
}
@@ -123,12 +141,7 @@
}
MoqtAnnounce message;
message.track_namespace = track_namespace;
- bool success = session_->GetStreamById(*control_stream_)
- ->Write(framer_.SerializeAnnounce(message).AsStringView());
- if (!success) {
- Error(MoqtError::kGenericError, "Failed to write ANNOUNCE message");
- return;
- }
+ SendControlMessage(framer_.SerializeAnnounce(message));
QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for "
<< message.track_namespace;
pending_outgoing_announces_[track_namespace] = std::move(announce_callback);
@@ -235,13 +248,7 @@
} else {
message.track_alias = next_remote_track_alias_++;
}
- bool success =
- session_->GetStreamById(*control_stream_)
- ->Write(framer_.SerializeSubscribe(message).AsStringView());
- if (!success) {
- Error(MoqtError::kGenericError, "Failed to write SUBSCRIBE message");
- return false;
- }
+ SendControlMessage(framer_.SerializeSubscribe(message));
QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
<< message.track_namespace << ":" << message.track_name;
active_subscribes_.try_emplace(message.subscribe_id, message, visitor);
@@ -457,13 +464,7 @@
MoqtServerSetup response;
response.selected_version = session_->parameters_.version;
response.role = MoqtRole::kBoth;
- bool success = stream_->Write(
- session_->framer_.SerializeServerSetup(response).AsStringView());
- if (!success) {
- session_->Error(MoqtError::kGenericError,
- "Failed to write server SETUP message");
- return;
- }
+ SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
}
// TODO: handle role and path.
@@ -506,13 +507,8 @@
subscribe_error.error_code = error_code;
subscribe_error.reason_phrase = reason_phrase;
subscribe_error.track_alias = track_alias;
- bool success =
- stream_->Write(session_->framer_.SerializeSubscribeError(subscribe_error)
- .AsStringView());
- if (!success) {
- session_->Error(MoqtError::kGenericError,
- "Failed to write SUBSCRIBE_ERROR message");
- }
+ SendOrBufferMessage(
+ session_->framer_.SerializeSubscribeError(subscribe_error));
}
void MoqtSession::Stream::OnSubscribeMessage(const MoqtSubscribe& message) {
@@ -573,13 +569,7 @@
}
MoqtSubscribeOk subscribe_ok;
subscribe_ok.subscribe_id = message.subscribe_id;
- bool success = stream_->Write(
- session_->framer_.SerializeSubscribeOk(subscribe_ok).AsStringView());
- if (!success) {
- session_->Error(MoqtError::kGenericError,
- "Failed to write SUBSCRIBE_OK message");
- return;
- }
+ SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok));
QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
<< message.track_namespace << ":" << message.track_name;
if (!end.has_value()) {
@@ -660,13 +650,7 @@
}
MoqtAnnounceOk ok;
ok.track_namespace = message.track_namespace;
- bool success =
- stream_->Write(session_->framer_.SerializeAnnounceOk(ok).AsStringView());
- if (!success) {
- session_->Error(MoqtError::kGenericError,
- "Failed to write ANNOUNCE_OK message");
- return;
- }
+ SendOrBufferMessage(session_->framer_.SerializeAnnounceOk(ok));
}
void MoqtSession::Stream::OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
@@ -739,4 +723,17 @@
return sequence;
}
+void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message,
+ bool fin) {
+ quiche::StreamWriteOptions options;
+ options.set_send_fin(fin);
+ options.set_buffer_unconditionally(true);
+ std::array<absl::string_view, 1> write_vector = {message.AsStringView()};
+ absl::Status success = stream_->Writev(absl::MakeSpan(write_vector), options);
+ if (!success.ok()) {
+ session_->Error(MoqtError::kGenericError,
+ "Failed to write a control message");
+ }
+}
+
} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 0bd0fa6..cb9163d 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -19,6 +19,7 @@
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/common/platform/api/quiche_export.h"
+#include "quiche/common/quiche_buffer_allocator.h"
#include "quiche/common/quiche_callbacks.h"
#include "quiche/common/simple_buffer_allocator.h"
#include "quiche/web_transport/web_transport.h"
@@ -120,8 +121,7 @@
bool PublishObject(const FullTrackName& full_track_name, uint64_t group_id,
uint64_t object_id, uint64_t object_send_order,
MoqtForwardingPreference forwarding_preference,
- absl::string_view payload,
- bool end_of_stream);
+ absl::string_view payload, bool end_of_stream);
// TODO: Add an API to FIN the stream for a particular track/group/object.
// TODO: Add an API to send partial objects.
@@ -174,6 +174,10 @@
webtransport::Stream* stream() const { return stream_; }
+ // Sends a control message, or buffers it if there is insufficient flow
+ // control credit.
+ void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false);
+
private:
friend class test::MoqtSessionPeer;
void SendSubscribeError(const MoqtSubscribe& message,
@@ -191,6 +195,12 @@
std::string partial_object_;
};
+ // Returns the pointer to the control stream, or nullptr if none is present.
+ Stream* GetControlStream();
+ // Sends a message on the control stream; QUICHE_DCHECKs if no control stream
+ // is present.
+ void SendControlMessage(quiche::QuicheBuffer message);
+
// Returns false if the SUBSCRIBE isn't sent.
bool Subscribe(MoqtSubscribe& message, RemoteTrack::Visitor* visitor);
// converts two MoqtLocations into absolute sequences.
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index cc24b9e..9401a24 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -34,6 +34,7 @@
namespace {
using ::testing::_;
+using ::testing::AnyNumber;
using ::testing::Return;
using ::testing::StrictMock;
@@ -65,10 +66,13 @@
class MoqtSessionPeer {
public:
static std::unique_ptr<MoqtParserVisitor> CreateControlStream(
- MoqtSession* session, webtransport::Stream* stream) {
+ MoqtSession* session, webtransport::test::MockStream* stream) {
auto new_stream = std::make_unique<MoqtSession::Stream>(
session, stream, /*is_control_stream=*/true);
session->control_stream_ = kControlStreamId;
+ EXPECT_CALL(*stream, visitor())
+ .Times(AnyNumber())
+ .WillRepeatedly(Return(new_stream.get()));
return new_stream;
}
@@ -155,7 +159,9 @@
});
EXPECT_CALL(mock_stream, GetStreamId())
.WillOnce(Return(webtransport::StreamId(4)));
+ EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
bool correct_message = false;
+ EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
EXPECT_CALL(mock_stream, Writev(_, _))
.WillOnce([&](absl::Span<const absl::string_view> data,
const quiche::StreamWriteOptions& options) {
@@ -789,7 +795,9 @@
});
EXPECT_CALL(mock_stream, GetStreamId())
.WillOnce(Return(webtransport::StreamId(4)));
+ EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
bool correct_message = false;
+ EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
EXPECT_CALL(mock_stream, Writev(_, _))
.WillOnce([&](absl::Span<const absl::string_view> data,
const quiche::StreamWriteOptions& options) {