Add basic MoQT OBJECT sending capability. This is the result of a little bit of interop with Alan's chat client. PiperOrigin-RevId: 594005632
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 164a709..b17a40f 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -7,7 +7,6 @@ #include <cstdint> #include <memory> #include <optional> -#include <queue> #include <string> #include <utility> #include <vector> @@ -239,6 +238,70 @@ return true; } +std::optional<webtransport::StreamId> MoqtSession::OpenUnidirectionalStream() { + if (!session_->CanOpenNextOutgoingUnidirectionalStream()) { + return std::nullopt; + } + webtransport::Stream* new_stream = + session_->OpenOutgoingUnidirectionalStream(); + if (new_stream == nullptr) { + return std::nullopt; + } + new_stream->SetVisitor(std::make_unique<Stream>(this, new_stream, false)); + return new_stream->GetStreamId(); +} + +// increment object_sequence or group_sequence depending on |start_new_group| +void MoqtSession::PublishObjectToStream(webtransport::StreamId stream_id, + FullTrackName full_track_name, + bool start_new_group, + absl::string_view payload) { + // TODO: check that the peer is subscribed to the next sequence. + webtransport::Stream* stream = session_->GetStreamById(stream_id); + if (stream == nullptr) { + QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT to nonexistent stream"; + return; + } + auto track_it = local_tracks_.find(full_track_name); + if (track_it == local_tracks_.end()) { + QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT for nonexistent track"; + return; + } + MoqtObject object; + LocalTrack& track = track_it->second; + object.track_id = track.track_alias(); + FullSequence& next_sequence = track.next_sequence_mutable(); + object.group_sequence = next_sequence.group; + if (start_new_group) { + ++object.group_sequence; + object.object_sequence = 0; + } else { + object.object_sequence = next_sequence.object; + } + next_sequence.group = object.group_sequence; + next_sequence.object = object.object_sequence + 1; + if (!track.ShouldSend(object.group_sequence, object.object_sequence)) { + QUICHE_LOG(INFO) << ENDPOINT << "Not sending object " + << full_track_name.track_namespace << ":" + << full_track_name.track_name << " with sequence " + << object.group_sequence << ":" << object.object_sequence + << " because peer is not subscribed"; + return; + } + object.object_send_order = 0; + object.payload_length = payload.size(); + bool success = + stream->Write(framer_.SerializeObject(object, payload).AsStringView()); + if (!success) { + QUICHE_DLOG(ERROR) << ENDPOINT << "Failed to write OBJECT message"; + return; + } + QUICHE_LOG(INFO) << ENDPOINT << "Sending object " + << full_track_name.track_namespace << ":" + << full_track_name.track_name << " with sequence " + << object.group_sequence << ":" << object.object_sequence; +} + void MoqtSession::Stream::OnCanRead() { bool fin = quiche::ProcessAllReadableRegions(*stream_, [&](absl::string_view chunk) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 89b3def..645372b 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -18,7 +18,6 @@ #include "quiche/quic/moqt/moqt_framer.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_parser.h" -#include "quiche/quic/moqt/moqt_subscribe_windows.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_callbacks.h" @@ -112,6 +111,16 @@ RemoteTrack::Visitor* visitor, absl::string_view auth_info = ""); + // Returns the stream ID if successful, nullopt if not. + // TODO: Add a callback if stream creation is delayed. + std::optional<webtransport::StreamId> OpenUnidirectionalStream(); + // Will automatically assign a new sequence number. If |start_new_group|, + // increment group_sequence and set object_sequence to 0. Otherwise, + // increment object_sequence. + void PublishObjectToStream(webtransport::StreamId stream_id, + FullTrackName full_track_name, + bool start_new_group, absl::string_view payload); + private: friend class test::MoqtSessionPeer; class QUICHE_EXPORT Stream : public webtransport::StreamVisitor, @@ -158,6 +167,8 @@ return session_->parameters_.perspective; } + webtransport::Stream* stream() const { return stream_; } + private: friend class test::MoqtSessionPeer; void SendSubscribeError(const MoqtSubscribeRequest& message,
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 12f9c73..4e07a7e 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -18,6 +18,7 @@ #include "quiche/quic/core/quic_types.h" #include "quiche/quic/moqt/moqt_messages.h" #include "quiche/quic/moqt/moqt_parser.h" +#include "quiche/quic/moqt/moqt_subscribe_windows.h" #include "quiche/quic/moqt/moqt_track.h" #include "quiche/quic/moqt/tools/moqt_mock_visitor.h" #include "quiche/quic/platform/api/quic_test.h" @@ -37,6 +38,7 @@ constexpr webtransport::StreamId kControlStreamId = 4; constexpr webtransport::StreamId kIncomingUniStreamId = 15; +constexpr webtransport::StreamId kOutgoingUniStreamId = 14; constexpr MoqtSessionParameters default_parameters = { /*version=*/MoqtVersion::kDraft01, @@ -102,6 +104,12 @@ track.set_track_alias(track_alias); session->tracks_by_alias_.emplace(std::make_pair(track_alias, &track)); } + + static LocalTrack& GetLocalTrack(MoqtSession* session, FullTrackName& name) { + auto it = session->local_tracks_.find(name); + EXPECT_NE(it, session->local_tracks_.end()); + return it->second; + } }; class MoqtSessionTest : public quic::test::QuicTest { @@ -581,7 +589,91 @@ control_stream->OnSubscribeOkMessage(ok); } -// TODO: Cover the error cases in the above +TEST_F(MoqtSessionTest, CreateUniStreamAndSend) { + StrictMock<webtransport::test::MockStream> mock_stream; + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) + .WillOnce(Return(&mock_stream)); + EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1); + EXPECT_CALL(mock_stream, GetStreamId()) + .WillRepeatedly(Return(kOutgoingUniStreamId)); + std::optional<webtransport::StreamId> stream = + session_.OpenUnidirectionalStream(); + EXPECT_TRUE(stream.has_value()); + EXPECT_EQ(stream.value(), kOutgoingUniStreamId); + + // Send on the stream + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillOnce(Return(&mock_stream)); + FullTrackName ftn("foo", "bar"); + MockLocalTrackVisitor track_visitor; + session_.AddLocalTrack(ftn, &track_visitor); + LocalTrack& track = MoqtSessionPeer::GetLocalTrack(&session_, ftn); + FullSequence& next_seq = track.next_sequence_mutable(); + next_seq.group = 4; + next_seq.object = 1; + track.AddWindow(SubscribeWindow(5, 0)); + // No subscription; this is a no-op except for incrementing the sequence + // number. + EXPECT_CALL(mock_stream, Writev(_, _)).Times(0); + session_.PublishObjectToStream(kOutgoingUniStreamId, + FullTrackName("foo", "bar"), + /*start_new_group=*/false, "deadbeef"); + EXPECT_EQ(next_seq, FullSequence(4, 2)); + bool correct_message = false; + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillOnce(Return(&mock_stream)); + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), + MoqtMessageType::kObjectWithPayloadLength); + return absl::OkStatus(); + }); + session_.PublishObjectToStream(kOutgoingUniStreamId, + FullTrackName("foo", "bar"), + /*start_new_group=*/true, "deadbeef"); + EXPECT_TRUE(correct_message); + EXPECT_EQ(next_seq, FullSequence(5, 1)); +} + +// Error cases + +TEST_F(MoqtSessionTest, CannotOpenUniStream) { + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(false)); + std::optional<webtransport::StreamId> stream = + session_.OpenUnidirectionalStream(); + EXPECT_FALSE(stream.has_value()); + + EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream()) + .WillOnce(Return(true)); + EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream()) + .WillOnce(Return(nullptr)); + stream = session_.OpenUnidirectionalStream(); + EXPECT_FALSE(stream.has_value()); +} + +TEST_F(MoqtSessionTest, CannotPublishToStream) { + EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId)) + .WillOnce(Return(nullptr)); + FullTrackName ftn("foo", "bar"); + MockLocalTrackVisitor track_visitor; + session_.AddLocalTrack(ftn, &track_visitor); + LocalTrack& track = MoqtSessionPeer::GetLocalTrack(&session_, ftn); + FullSequence& next_seq = track.next_sequence_mutable(); + next_seq.group = 4; + next_seq.object = 1; + session_.PublishObjectToStream(kOutgoingUniStreamId, ftn, + /*start_new_group=*/false, "deadbeef"); + // Object not sent; no change in sequence number. + EXPECT_EQ(next_seq.group, 4); + EXPECT_EQ(next_seq.object, 1); +} + +// TODO: Cover more error cases in the above } // namespace test
diff --git a/quiche/quic/moqt/moqt_track.h b/quiche/quic/moqt/moqt_track.h index 48ba7bc..4dad086 100644 --- a/quiche/quic/moqt/moqt_track.h +++ b/quiche/quic/moqt/moqt_track.h
@@ -61,6 +61,8 @@ // by one. const FullSequence& next_sequence() const { return next_sequence_; } + FullSequence& next_sequence_mutable() { return next_sequence_; } + bool HasSubscriber() const { return !windows_.IsEmpty(); } private:
diff --git a/quiche/quic/moqt/moqt_track_test.cc b/quiche/quic/moqt/moqt_track_test.cc index f9fb0b4..a4d3e2a 100644 --- a/quiche/quic/moqt/moqt_track_test.cc +++ b/quiche/quic/moqt/moqt_track_test.cc
@@ -31,6 +31,9 @@ EXPECT_EQ(track_.track_alias(), 5); EXPECT_EQ(track_.visitor(), &visitor_); EXPECT_EQ(track_.next_sequence(), FullSequence(4, 1)); + FullSequence& mutable_next = track_.next_sequence_mutable(); + mutable_next.object++; + EXPECT_EQ(track_.next_sequence(), FullSequence(4, 2)); EXPECT_FALSE(track_.HasSubscriber()); }