Add basic MoQT OBJECT sending capability. This is the result of a little bit of interop with Alan's chat client.
PiperOrigin-RevId: 594005632
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 164a709..b17a40f 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -7,7 +7,6 @@
#include <cstdint>
#include <memory>
#include <optional>
-#include <queue>
#include <string>
#include <utility>
#include <vector>
@@ -239,6 +238,70 @@
return true;
}
+std::optional<webtransport::StreamId> MoqtSession::OpenUnidirectionalStream() {
+ if (!session_->CanOpenNextOutgoingUnidirectionalStream()) {
+ return std::nullopt;
+ }
+ webtransport::Stream* new_stream =
+ session_->OpenOutgoingUnidirectionalStream();
+ if (new_stream == nullptr) {
+ return std::nullopt;
+ }
+ new_stream->SetVisitor(std::make_unique<Stream>(this, new_stream, false));
+ return new_stream->GetStreamId();
+}
+
+// increment object_sequence or group_sequence depending on |start_new_group|
+void MoqtSession::PublishObjectToStream(webtransport::StreamId stream_id,
+ FullTrackName full_track_name,
+ bool start_new_group,
+ absl::string_view payload) {
+ // TODO: check that the peer is subscribed to the next sequence.
+ webtransport::Stream* stream = session_->GetStreamById(stream_id);
+ if (stream == nullptr) {
+ QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT to nonexistent stream";
+ return;
+ }
+ auto track_it = local_tracks_.find(full_track_name);
+ if (track_it == local_tracks_.end()) {
+ QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT for nonexistent track";
+ return;
+ }
+ MoqtObject object;
+ LocalTrack& track = track_it->second;
+ object.track_id = track.track_alias();
+ FullSequence& next_sequence = track.next_sequence_mutable();
+ object.group_sequence = next_sequence.group;
+ if (start_new_group) {
+ ++object.group_sequence;
+ object.object_sequence = 0;
+ } else {
+ object.object_sequence = next_sequence.object;
+ }
+ next_sequence.group = object.group_sequence;
+ next_sequence.object = object.object_sequence + 1;
+ if (!track.ShouldSend(object.group_sequence, object.object_sequence)) {
+ QUICHE_LOG(INFO) << ENDPOINT << "Not sending object "
+ << full_track_name.track_namespace << ":"
+ << full_track_name.track_name << " with sequence "
+ << object.group_sequence << ":" << object.object_sequence
+ << " because peer is not subscribed";
+ return;
+ }
+ object.object_send_order = 0;
+ object.payload_length = payload.size();
+ bool success =
+ stream->Write(framer_.SerializeObject(object, payload).AsStringView());
+ if (!success) {
+ QUICHE_DLOG(ERROR) << ENDPOINT << "Failed to write OBJECT message";
+ return;
+ }
+ QUICHE_LOG(INFO) << ENDPOINT << "Sending object "
+ << full_track_name.track_namespace << ":"
+ << full_track_name.track_name << " with sequence "
+ << object.group_sequence << ":" << object.object_sequence;
+}
+
void MoqtSession::Stream::OnCanRead() {
bool fin =
quiche::ProcessAllReadableRegions(*stream_, [&](absl::string_view chunk) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 89b3def..645372b 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -18,7 +18,6 @@
#include "quiche/quic/moqt/moqt_framer.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_parser.h"
-#include "quiche/quic/moqt/moqt_subscribe_windows.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/common/platform/api/quiche_export.h"
#include "quiche/common/quiche_callbacks.h"
@@ -112,6 +111,16 @@
RemoteTrack::Visitor* visitor,
absl::string_view auth_info = "");
+ // Returns the stream ID if successful, nullopt if not.
+ // TODO: Add a callback if stream creation is delayed.
+ std::optional<webtransport::StreamId> OpenUnidirectionalStream();
+ // Will automatically assign a new sequence number. If |start_new_group|,
+ // increment group_sequence and set object_sequence to 0. Otherwise,
+ // increment object_sequence.
+ void PublishObjectToStream(webtransport::StreamId stream_id,
+ FullTrackName full_track_name,
+ bool start_new_group, absl::string_view payload);
+
private:
friend class test::MoqtSessionPeer;
class QUICHE_EXPORT Stream : public webtransport::StreamVisitor,
@@ -158,6 +167,8 @@
return session_->parameters_.perspective;
}
+ webtransport::Stream* stream() const { return stream_; }
+
private:
friend class test::MoqtSessionPeer;
void SendSubscribeError(const MoqtSubscribeRequest& message,
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 12f9c73..4e07a7e 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -18,6 +18,7 @@
#include "quiche/quic/core/quic_types.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_parser.h"
+#include "quiche/quic/moqt/moqt_subscribe_windows.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
#include "quiche/quic/platform/api/quic_test.h"
@@ -37,6 +38,7 @@
constexpr webtransport::StreamId kControlStreamId = 4;
constexpr webtransport::StreamId kIncomingUniStreamId = 15;
+constexpr webtransport::StreamId kOutgoingUniStreamId = 14;
constexpr MoqtSessionParameters default_parameters = {
/*version=*/MoqtVersion::kDraft01,
@@ -102,6 +104,12 @@
track.set_track_alias(track_alias);
session->tracks_by_alias_.emplace(std::make_pair(track_alias, &track));
}
+
+ static LocalTrack& GetLocalTrack(MoqtSession* session, FullTrackName& name) {
+ auto it = session->local_tracks_.find(name);
+ EXPECT_NE(it, session->local_tracks_.end());
+ return it->second;
+ }
};
class MoqtSessionTest : public quic::test::QuicTest {
@@ -581,7 +589,91 @@
control_stream->OnSubscribeOkMessage(ok);
}
-// TODO: Cover the error cases in the above
+TEST_F(MoqtSessionTest, CreateUniStreamAndSend) {
+ StrictMock<webtransport::test::MockStream> mock_stream;
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(true));
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(&mock_stream));
+ EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
+ EXPECT_CALL(mock_stream, GetStreamId())
+ .WillRepeatedly(Return(kOutgoingUniStreamId));
+ std::optional<webtransport::StreamId> stream =
+ session_.OpenUnidirectionalStream();
+ EXPECT_TRUE(stream.has_value());
+ EXPECT_EQ(stream.value(), kOutgoingUniStreamId);
+
+ // Send on the stream
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillOnce(Return(&mock_stream));
+ FullTrackName ftn("foo", "bar");
+ MockLocalTrackVisitor track_visitor;
+ session_.AddLocalTrack(ftn, &track_visitor);
+ LocalTrack& track = MoqtSessionPeer::GetLocalTrack(&session_, ftn);
+ FullSequence& next_seq = track.next_sequence_mutable();
+ next_seq.group = 4;
+ next_seq.object = 1;
+ track.AddWindow(SubscribeWindow(5, 0));
+ // No subscription; this is a no-op except for incrementing the sequence
+ // number.
+ EXPECT_CALL(mock_stream, Writev(_, _)).Times(0);
+ session_.PublishObjectToStream(kOutgoingUniStreamId,
+ FullTrackName("foo", "bar"),
+ /*start_new_group=*/false, "deadbeef");
+ EXPECT_EQ(next_seq, FullSequence(4, 2));
+ bool correct_message = false;
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillOnce(Return(&mock_stream));
+ EXPECT_CALL(mock_stream, Writev(_, _))
+ .WillOnce([&](absl::Span<const absl::string_view> data,
+ const quiche::StreamWriteOptions& options) {
+ correct_message = true;
+ EXPECT_EQ(*ExtractMessageType(data[0]),
+ MoqtMessageType::kObjectWithPayloadLength);
+ return absl::OkStatus();
+ });
+ session_.PublishObjectToStream(kOutgoingUniStreamId,
+ FullTrackName("foo", "bar"),
+ /*start_new_group=*/true, "deadbeef");
+ EXPECT_TRUE(correct_message);
+ EXPECT_EQ(next_seq, FullSequence(5, 1));
+}
+
+// Error cases
+
+TEST_F(MoqtSessionTest, CannotOpenUniStream) {
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(false));
+ std::optional<webtransport::StreamId> stream =
+ session_.OpenUnidirectionalStream();
+ EXPECT_FALSE(stream.has_value());
+
+ EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
+ .WillOnce(Return(true));
+ EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
+ .WillOnce(Return(nullptr));
+ stream = session_.OpenUnidirectionalStream();
+ EXPECT_FALSE(stream.has_value());
+}
+
+TEST_F(MoqtSessionTest, CannotPublishToStream) {
+ EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
+ .WillOnce(Return(nullptr));
+ FullTrackName ftn("foo", "bar");
+ MockLocalTrackVisitor track_visitor;
+ session_.AddLocalTrack(ftn, &track_visitor);
+ LocalTrack& track = MoqtSessionPeer::GetLocalTrack(&session_, ftn);
+ FullSequence& next_seq = track.next_sequence_mutable();
+ next_seq.group = 4;
+ next_seq.object = 1;
+ session_.PublishObjectToStream(kOutgoingUniStreamId, ftn,
+ /*start_new_group=*/false, "deadbeef");
+ // Object not sent; no change in sequence number.
+ EXPECT_EQ(next_seq.group, 4);
+ EXPECT_EQ(next_seq.object, 1);
+}
+
+// TODO: Cover more error cases in the above
} // namespace test
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h
index 48ba7bc..4dad086 100644
--- a/quiche/quic/moqt/moqt_track.h
+++ b/quiche/quic/moqt/moqt_track.h
@@ -61,6 +61,8 @@
// by one.
const FullSequence& next_sequence() const { return next_sequence_; }
+ FullSequence& next_sequence_mutable() { return next_sequence_; }
+
bool HasSubscriber() const { return !windows_.IsEmpty(); }
private:
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc
index f9fb0b4..a4d3e2a 100644
--- a/quiche/quic/moqt/moqt_track_test.cc
+++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -31,6 +31,9 @@
EXPECT_EQ(track_.track_alias(), 5);
EXPECT_EQ(track_.visitor(), &visitor_);
EXPECT_EQ(track_.next_sequence(), FullSequence(4, 1));
+ FullSequence& mutable_next = track_.next_sequence_mutable();
+ mutable_next.object++;
+ EXPECT_EQ(track_.next_sequence(), FullSequence(4, 2));
EXPECT_FALSE(track_.HasSubscriber());
}