MOQT MAX_SUBSCRIBE_ID implementation Currently sets max_subscribe_id to UINT62_MAX if it's an earlier draft version. This can be removed soon, when we delete support for draft-05. PiperOrigin-RevId: 678364064
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc index 32d2a5f..f4ddd7a 100644 --- a/quiche/quic/moqt/moqt_framer.cc +++ b/quiche/quic/moqt/moqt_framer.cc
@@ -263,6 +263,10 @@ int_parameters.push_back( IntParameter(MoqtSetupParameter::kRole, *message.role)); } + if (message.max_subscribe_id.has_value()) { + int_parameters.push_back(IntParameter(MoqtSetupParameter::kMaxSubscribeId, + *message.max_subscribe_id)); + } if (message.supports_object_ack) { int_parameters.push_back( IntParameter(MoqtSetupParameter::kSupportObjectAcks, 1u)); @@ -287,6 +291,10 @@ int_parameters.push_back( IntParameter(MoqtSetupParameter::kRole, *message.role)); } + if (message.max_subscribe_id.has_value()) { + int_parameters.push_back(IntParameter(MoqtSetupParameter::kMaxSubscribeId, + *message.max_subscribe_id)); + } if (message.supports_object_ack) { int_parameters.push_back( IntParameter(MoqtSetupParameter::kSupportObjectAcks, 1u)); @@ -499,6 +507,12 @@ WireStringWithVarInt62Length(message.new_session_uri)); } +quiche::QuicheBuffer MoqtFramer::SerializeMaxSubscribeId( + const MoqtMaxSubscribeId& message) { + return Serialize(WireVarInt62(MoqtMessageType::kMaxSubscribeId), + WireVarInt62(message.max_subscribe_id)); +} + quiche::QuicheBuffer MoqtFramer::SerializeObjectAck( const MoqtObjectAck& message) { return Serialize(WireVarInt62(MoqtMessageType::kObjectAck),
diff --git a/quiche/quic/moqt/moqt_framer.h b/quiche/quic/moqt/moqt_framer.h index 188715a..40b4dbf 100644 --- a/quiche/quic/moqt/moqt_framer.h +++ b/quiche/quic/moqt/moqt_framer.h
@@ -55,6 +55,8 @@ quiche::QuicheBuffer SerializeUnannounce(const MoqtUnannounce& message); quiche::QuicheBuffer SerializeTrackStatus(const MoqtTrackStatus& message); quiche::QuicheBuffer SerializeGoAway(const MoqtGoAway& message); + quiche::QuicheBuffer SerializeMaxSubscribeId( + const MoqtMaxSubscribeId& message); quiche::QuicheBuffer SerializeObjectAck(const MoqtObjectAck& message); private:
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc index cb89551..bd19955 100644 --- a/quiche/quic/moqt/moqt_framer_test.cc +++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -48,6 +48,7 @@ MoqtMessageType::kAnnounceError, MoqtMessageType::kUnannounce, MoqtMessageType::kGoAway, + MoqtMessageType::kMaxSubscribeId, MoqtMessageType::kObjectAck, MoqtMessageType::kClientSetup, MoqtMessageType::kServerSetup, @@ -157,6 +158,10 @@ auto data = std::get<MoqtGoAway>(structured_data); return framer_.SerializeGoAway(data); } + case moqt::MoqtMessageType::kMaxSubscribeId: { + auto data = std::get<MoqtMaxSubscribeId>(structured_data); + return framer_.SerializeMaxSubscribeId(data); + } case moqt::MoqtMessageType::kObjectAck: { auto data = std::get<MoqtObjectAck>(structured_data); return framer_.SerializeObjectAck(data);
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc index 2ecdbf6..df357fe 100644 --- a/quiche/quic/moqt/moqt_messages.cc +++ b/quiche/quic/moqt/moqt_messages.cc
@@ -88,6 +88,8 @@ return "UNANNOUNCE"; case MoqtMessageType::kGoAway: return "GOAWAY"; + case MoqtMessageType::kMaxSubscribeId: + return "MAX_SUBSCRIBE_ID"; case MoqtMessageType::kObjectAck: return "OBJECT_ACK"; }
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index 0a6bc55..aaf151d 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -34,6 +34,7 @@ }; inline constexpr MoqtVersion kDefaultMoqtVersion = MoqtVersion::kDraft05; +inline constexpr uint64_t kDefaultInitialMaxSubscribeId = 100; struct QUICHE_EXPORT MoqtSessionParameters { // TODO: support multiple versions. @@ -50,6 +51,7 @@ quic::Perspective perspective; bool using_webtrans; std::string path; + uint64_t max_subscribe_id = kDefaultInitialMaxSubscribeId; bool deliver_partial_objects = false; bool support_object_acks = false; }; @@ -84,6 +86,7 @@ kTrackStatusRequest = 0x0d, kTrackStatus = 0x0e, kGoAway = 0x10, + kMaxSubscribeId = 0x15, kClientSetup = 0x40, kServerSetup = 0x41, @@ -101,6 +104,7 @@ kProtocolViolation = 0x3, kDuplicateTrackAlias = 0x4, kParameterLengthMismatch = 0x5, + kTooManySubscribes = 0x6, kGoawayTimeout = 0x10, }; @@ -121,6 +125,7 @@ enum class QUICHE_EXPORT MoqtSetupParameter : uint64_t { kRole = 0x0, kPath = 0x1, + kMaxSubscribeId = 0x2, // QUICHE-specific extensions. // Indicates support for OACK messages. @@ -220,12 +225,14 @@ std::vector<MoqtVersion> supported_versions; std::optional<MoqtRole> role; std::optional<std::string> path; + std::optional<uint64_t> max_subscribe_id; bool supports_object_ack = false; }; struct QUICHE_EXPORT MoqtServerSetup { MoqtVersion selected_version; std::optional<MoqtRole> role; + std::optional<uint64_t> max_subscribe_id; bool supports_object_ack = false; }; @@ -423,6 +430,10 @@ std::string new_session_uri; }; +struct QUICHE_EXPORT MoqtMaxSubscribeId { + uint64_t max_subscribe_id; +}; + // All of the four values in this message are encoded as varints. // `delta_from_deadline` is encoded as an absolute value, with the lowest bit // indicating the sign (0 if positive).
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 923013d..87c45b9 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -224,6 +224,8 @@ return ProcessTrackStatus(reader); case MoqtMessageType::kGoAway: return ProcessGoAway(reader); + case MoqtMessageType::kMaxSubscribeId: + return ProcessMaxSubscribeId(reader); case moqt::MoqtMessageType::kObjectAck: return ProcessObjectAck(reader); } @@ -284,6 +286,18 @@ } setup.path = value; break; + case MoqtSetupParameter::kMaxSubscribeId: + if (setup.max_subscribe_id.has_value()) { + ParseError("MAX_SUBSCRIBE_ID parameter appears twice in SETUP"); + return 0; + } + uint64_t max_id; + if (!StringViewToVarInt(value, max_id)) { + ParseError("MAX_SUBSCRIBE_ID parameter is not a valid varint"); + return 0; + } + setup.max_subscribe_id = max_id; + break; case MoqtSetupParameter::kSupportObjectAcks: uint64_t flag; if (!StringViewToVarInt(value, flag) || flag > 1) { @@ -347,6 +361,18 @@ case MoqtSetupParameter::kPath: ParseError("PATH parameter in SERVER_SETUP"); return 0; + case MoqtSetupParameter::kMaxSubscribeId: + if (setup.max_subscribe_id.has_value()) { + ParseError("MAX_SUBSCRIBE_ID parameter appears twice in SETUP"); + return 0; + } + uint64_t max_id; + if (!StringViewToVarInt(value, max_id)) { + ParseError("MAX_SUBSCRIBE_ID parameter is not a valid varint"); + return 0; + } + setup.max_subscribe_id = max_id; + break; case MoqtSetupParameter::kSupportObjectAcks: uint64_t flag; if (!StringViewToVarInt(value, flag) || flag > 1) { @@ -723,6 +749,15 @@ return reader.PreviouslyReadPayload().length(); } +size_t MoqtControlParser::ProcessMaxSubscribeId(quic::QuicDataReader& reader) { + MoqtMaxSubscribeId max_subscribe_id; + if (!reader.ReadVarInt62(&max_subscribe_id.max_subscribe_id)) { + return 0; + } + visitor_.OnMaxSubscribeIdMessage(max_subscribe_id); + return reader.PreviouslyReadPayload().length(); +} + size_t MoqtControlParser::ProcessObjectAck(quic::QuicDataReader& reader) { MoqtObjectAck object_ack; uint64_t raw_delta;
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h index f129f72..f19a0a6 100644 --- a/quiche/quic/moqt/moqt_parser.h +++ b/quiche/quic/moqt/moqt_parser.h
@@ -43,6 +43,7 @@ virtual void OnUnannounceMessage(const MoqtUnannounce& message) = 0; virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) = 0; virtual void OnGoAwayMessage(const MoqtGoAway& message) = 0; + virtual void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) = 0; virtual void OnObjectAckMessage(const MoqtObjectAck& message) = 0; virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0; @@ -107,6 +108,7 @@ size_t ProcessUnannounce(quic::QuicDataReader& reader); size_t ProcessTrackStatus(quic::QuicDataReader& reader); size_t ProcessGoAway(quic::QuicDataReader& reader); + size_t ProcessMaxSubscribeId(quic::QuicDataReader& reader); size_t ProcessObjectAck(quic::QuicDataReader& reader); // If |error| is not provided, assumes kProtocolViolation.
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc index 53bbdf8..f69ef96 100644 --- a/quiche/quic/moqt/moqt_parser_test.cc +++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -39,7 +39,7 @@ MoqtMessageType::kAnnounceOk, MoqtMessageType::kAnnounceError, MoqtMessageType::kUnannounce, MoqtMessageType::kClientSetup, MoqtMessageType::kServerSetup, MoqtMessageType::kGoAway, - MoqtMessageType::kObjectAck, + MoqtMessageType::kMaxSubscribeId, MoqtMessageType::kObjectAck, }; constexpr std::array kDataStreamTypes{MoqtDataStreamType::kObjectStream, MoqtDataStreamType::kStreamHeaderTrack, @@ -162,6 +162,9 @@ void OnGoAwayMessage(const MoqtGoAway& message) override { OnControlMessage(message); } + void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override { + OnControlMessage(message); + } void OnObjectAckMessage(const MoqtObjectAck& message) override { OnControlMessage(message); } @@ -541,6 +544,24 @@ EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } +TEST_F(MoqtMessageSpecificTest, ClientSetupMaxSubscribeIdAppearsTwice) { + MoqtControlParser parser(kRawQuic, visitor_); + char setup[] = { + 0x40, 0x40, 0x02, 0x01, 0x02, // versions + 0x04, // 4 params + 0x00, 0x01, 0x03, // role = PubSub + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + 0x02, 0x01, 0x32, // max_subscribe_id = 50 + 0x02, 0x01, 0x32, // max_subscribe_id = 50 + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "MAX_SUBSCRIBE_ID parameter appears twice in SETUP"); + EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); +} + TEST_F(MoqtMessageSpecificTest, ServerSetupRoleIsMissing) { MoqtControlParser parser(kRawQuic, visitor_); char setup[] = { @@ -635,6 +656,24 @@ EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); } +TEST_F(MoqtMessageSpecificTest, ServerSetupMaxSubscribeIdAppearsTwice) { + MoqtControlParser parser(kRawQuic, visitor_); + char setup[] = { + 0x40, 0x40, 0x02, 0x01, 0x02, // versions = 1, 2 + 0x04, // 4 params + 0x00, 0x01, 0x03, // role = PubSub + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" + 0x02, 0x01, 0x32, // max_subscribe_id = 50 + 0x02, 0x01, 0x32, // max_subscribe_id = 50 + }; + parser.ProcessData(absl::string_view(setup, sizeof(setup)), false); + EXPECT_EQ(visitor_.messages_received_, 0); + EXPECT_TRUE(visitor_.parsing_error_.has_value()); + EXPECT_EQ(*visitor_.parsing_error_, + "MAX_SUBSCRIBE_ID parameter appears twice in SETUP"); + EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation); +} + TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationInfoTwice) { MoqtControlParser parser(kWebTrans, visitor_); char subscribe[] = {
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index afcee57..d11f9d2 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -105,6 +105,7 @@ callbacks_(std::move(callbacks)), framer_(quiche::SimpleBufferAllocator::Get(), parameters.using_webtrans), publisher_(DefaultPublisher::GetInstance()), + local_max_subscribe_id_(parameters.max_subscribe_id), liveness_token_(std::make_shared<Empty>()) {} MoqtSession::ControlStream* MoqtSession::GetControlStream() { @@ -146,6 +147,7 @@ MoqtClientSetup setup = MoqtClientSetup{ .supported_versions = std::vector<MoqtVersion>{parameters_.version}, .role = MoqtRole::kPubSub, + .max_subscribe_id = parameters_.max_subscribe_id, .supports_object_ack = parameters_.support_object_acks, }; if (!parameters_.using_webtrans) { @@ -389,6 +391,13 @@ return false; } // TODO(martinduke): support authorization info + if (next_subscribe_id_ > peer_max_subscribe_id_) { + QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE with ID " + << message.subscribe_id + << " which is greater than the maximum ID " + << peer_max_subscribe_id_; + return false; + } message.subscribe_id = next_subscribe_id_++; FullTrackName ftn(std::string(message.track_namespace), std::string(message.track_name)); @@ -492,6 +501,13 @@ } } +void MoqtSession::GrantMoreSubscribes(uint64_t num_subscribes) { + local_max_subscribe_id_ += num_subscribes; + MoqtMaxSubscribeId message; + message.max_subscribe_id = local_max_subscribe_id_; + SendControlMessage(framer_.SerializeMaxSubscribeId(message)); +} + std::pair<FullTrackName, RemoteTrack::Visitor*> MoqtSession::TrackPropertiesFromAlias(const MoqtObject& message) { auto it = remote_tracks_.find(message.track_alias); @@ -596,11 +612,18 @@ MoqtServerSetup response; response.selected_version = session_->parameters_.version; response.role = MoqtRole::kPubSub; + response.max_subscribe_id = session_->parameters_.max_subscribe_id; response.supports_object_ack = session_->parameters_.support_object_acks; SendOrBufferMessage(session_->framer_.SerializeServerSetup(response)); QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message"; } // TODO: handle role and path. + if (message.max_subscribe_id.has_value()) { + session_->peer_max_subscribe_id_ = *message.max_subscribe_id; + } else if (session_->parameters_.version == MoqtVersion::kDraft05) { + // TODO (martinduke): Delete this when we roll the version number. + session_->peer_max_subscribe_id_ = UINT64_MAX >> 2; + } std::move(session_->callbacks_.session_established_callback)(); session_->peer_role_ = *message.role; } @@ -622,6 +645,12 @@ session_->peer_supports_object_ack_ = message.supports_object_ack; QUIC_DLOG(INFO) << ENDPOINT << "Received the SETUP message"; // TODO: handle role and path. + if (message.max_subscribe_id.has_value()) { + session_->peer_max_subscribe_id_ = *message.max_subscribe_id; + } else if (session_->parameters_.version == MoqtVersion::kDraft05) { + // TODO (martinduke): Delete this when we roll the version number. + session_->peer_max_subscribe_id_ = UINT64_MAX >> 2; + } std::move(session_->callbacks_.session_established_callback)(); session_->peer_role_ = *message.role; } @@ -646,6 +675,12 @@ "Received SUBSCRIBE from publisher"); return; } + if (message.subscribe_id > session_->local_max_subscribe_id_) { + QUIC_DLOG(INFO) << ENDPOINT << "Received SUBSCRIBE with too large ID"; + session_->Error(MoqtError::kTooManySubscribes, + "Received SUBSCRIBE with too large ID"); + return; + } QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for " << message.track_namespace << ":" << message.track_name; @@ -837,6 +872,25 @@ // TODO: notify the application about this. } +void MoqtSession::ControlStream::OnMaxSubscribeIdMessage( + const MoqtMaxSubscribeId& message) { + if (session_->peer_role_ == MoqtRole::kSubscriber) { + QUIC_DLOG(INFO) << ENDPOINT << "Subscriber peer sent MAX_SUBSCRIBE_ID"; + session_->Error(MoqtError::kProtocolViolation, + "Received MAX_SUBSCRIBE_ID from Subscriber"); + return; + } + if (message.max_subscribe_id < session_->peer_max_subscribe_id_) { + QUIC_DLOG(INFO) << ENDPOINT + << "Peer sent MAX_SUBSCRIBE_ID message with " + "lower value than previous"; + session_->Error(MoqtError::kProtocolViolation, + "MAX_SUBSCRIBE_ID message has lower value than previous"); + return; + } + session_->peer_max_subscribe_id_ = message.max_subscribe_id; +} + void MoqtSession::ControlStream::OnParsingError(MoqtError error_code, absl::string_view reason) { session_->Error(error_code, absl::StrCat("Parse error: ", reason));
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index d437dc3..79104c6 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -172,6 +172,8 @@ std::optional<webtransport::SendOrder> old_send_order, std::optional<webtransport::SendOrder> new_send_order); + void GrantMoreSubscribes(uint64_t num_subscribes); + private: friend class test::MoqtSessionPeer; @@ -207,6 +209,7 @@ void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override {} void OnTrackStatusMessage(const MoqtTrackStatus& message) override {} void OnGoAwayMessage(const MoqtGoAway& /*message*/) override {} + void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override; void OnObjectAckMessage(const MoqtObjectAck& message) override { auto subscription_it = session_->published_subscriptions_.find(message.subscribe_id); @@ -521,6 +524,11 @@ // parameter, and other checks have changed/been disabled. MoqtRole peer_role_ = MoqtRole::kPubSub; + // The maximum subscribe ID that the local endpoint can send. + uint64_t peer_max_subscribe_id_ = 0; + // The maximum subscribe ID sent to the peer. + uint64_t local_max_subscribe_id_ = 0; + // Must be last. Token used to make sure that the streams do not call into // the session when the session has already been destroyed. struct Empty {};
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index cd7d49c..25783a3 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -10,6 +10,7 @@ #include <optional> #include <string> #include <utility> +#include <vector> #include "absl/status/status.h" #include "absl/strings/match.h" @@ -157,15 +158,25 @@ static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) { return session->remote_tracks_.find(track_alias)->second; } + + static void set_next_subscribe_id(MoqtSession* session, uint64_t id) { + session->next_subscribe_id_ = id; + } + + static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) { + session->peer_max_subscribe_id_ = id; + } }; class MoqtSessionTest : public quic::test::QuicTest { public: MoqtSessionTest() : session_(&mock_session_, - MoqtSessionParameters(quic::Perspective::IS_CLIENT), + MoqtSessionParameters(quic::Perspective::IS_CLIENT, ""), session_callbacks_.AsSessionCallbacks()) { session_.set_publisher(&publisher_); + MoqtSessionPeer::set_peer_max_subscribe_id(&session_, + kDefaultInitialMaxSubscribeId); } ~MoqtSessionTest() { EXPECT_CALL(session_callbacks_.session_deleted_callback, Call()); @@ -464,6 +475,53 @@ EXPECT_TRUE(correct_message); } +TEST_F(MoqtSessionTest, SubscribeIdTooHigh) { + // Peer subscribes to (0, 0) + MoqtSubscribe request = { + /*subscribe_id=*/kDefaultInitialMaxSubscribeId + 1, + /*track_alias=*/2, + /*track_namespace=*/"foo", + /*track_name=*/"bar", + /*subscriber_priority=*/0x80, + /*group_order=*/std::nullopt, + /*start_group=*/0, + /*start_object=*/0, + /*end_group=*/std::nullopt, + /*end_object=*/std::nullopt, + /*parameters=*/MoqtSubscribeParameters(), + }; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + EXPECT_CALL(mock_session_, + CloseSession(static_cast<uint64_t>(MoqtError::kTooManySubscribes), + "Received SUBSCRIBE with too large ID")) + .Times(1); + stream_input->OnSubscribeMessage(request); +} + +TEST_F(MoqtSessionTest, TooManySubscribes) { + MoqtSessionPeer::set_next_subscribe_id(&session_, + kDefaultInitialMaxSubscribeId); + MockRemoteTrackVisitor remote_track_visitor; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); + bool correct_message = true; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); + return absl::OkStatus(); + }); + EXPECT_TRUE( + session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor)); + EXPECT_FALSE( + session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor)); +} + TEST_F(MoqtSessionTest, SubscribeWithOk) { webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = @@ -496,6 +554,102 @@ EXPECT_TRUE(correct_message); } +TEST_F(MoqtSessionTest, MaxSubscribeIdChangesResponse) { + MoqtSessionPeer::set_next_subscribe_id(&session_, + kDefaultInitialMaxSubscribeId + 1); + MockRemoteTrackVisitor remote_track_visitor; + EXPECT_FALSE( + session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor)); + MoqtMaxSubscribeId max_subscribe_id = { + /*max_subscribe_id=*/kDefaultInitialMaxSubscribeId + 1, + }; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + stream_input->OnMaxSubscribeIdMessage(max_subscribe_id); + EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); + bool correct_message = true; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe); + return absl::OkStatus(); + }); + EXPECT_TRUE( + session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor)); + EXPECT_TRUE(correct_message); +} + +TEST_F(MoqtSessionTest, LowerMaxSubscribeIdIsAnError) { + MoqtMaxSubscribeId max_subscribe_id = { + /*max_subscribe_id=*/kDefaultInitialMaxSubscribeId - 1, + }; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + EXPECT_CALL( + mock_session_, + CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation), + "MAX_SUBSCRIBE_ID message has lower value than previous")) + .Times(1); + stream_input->OnMaxSubscribeIdMessage(max_subscribe_id); +} + +TEST_F(MoqtSessionTest, GrantMoreSubscribes) { + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream)); + bool correct_message = true; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), + MoqtMessageType::kMaxSubscribeId); + return absl::OkStatus(); + }); + session_.GrantMoreSubscribes(1); + EXPECT_TRUE(correct_message); + // Peer subscribes to (0, 0) + MoqtSubscribe request = { + /*subscribe_id=*/kDefaultInitialMaxSubscribeId + 1, + /*track_alias=*/2, + /*track_namespace=*/"foo", + /*track_name=*/"bar", + /*subscriber_priority=*/0x80, + /*group_order=*/std::nullopt, + /*start_group=*/0, + /*start_object=*/0, + /*end_group=*/std::nullopt, + /*end_object=*/std::nullopt, + /*parameters=*/MoqtSubscribeParameters(), + }; + correct_message = false; + FullTrackName ftn("foo", "bar"); + auto track = std::make_shared<MockTrackPublisher>(ftn); + EXPECT_CALL(*track, GetTrackStatus()) + .WillRepeatedly(Return(MoqtTrackStatusCode::kInProgress)); + EXPECT_CALL(*track, GetCachedObject(_)).WillRepeatedly([] { + return std::optional<PublishedObject>(); + }); + EXPECT_CALL(*track, GetCachedObjectsInRange(_, _)) + .WillRepeatedly(Return(std::vector<FullSequence>())); + EXPECT_CALL(*track, GetLargestSequence()) + .WillRepeatedly(Return(FullSequence(10, 20))); + publisher_.Add(track); + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk); + return absl::OkStatus(); + }); + stream_input->OnSubscribeMessage(request); + EXPECT_TRUE(correct_message); +} + TEST_F(MoqtSessionTest, SubscribeWithError) { webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input =
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 9633ef3..813d4e8 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -38,7 +38,7 @@ MoqtSubscribeDone, MoqtSubscribeUpdate, MoqtAnnounce, MoqtAnnounceOk, MoqtAnnounceError, MoqtAnnounceCancel, MoqtTrackStatusRequest, MoqtUnannounce, MoqtTrackStatus, - MoqtGoAway, MoqtObjectAck>; + MoqtGoAway, MoqtMaxSubscribeId, MoqtObjectAck>; // The total actual size of the message. size_t total_message_size() const { return wire_image_size_; } @@ -282,7 +282,7 @@ if (webtrans) { // Should not send PATH. client_setup_.path = std::nullopt; - raw_packet_[5] = 0x01; // only one parameter + raw_packet_[5] = 0x02; // only two parameters SetWireImage(raw_packet_, sizeof(raw_packet_) - 5); } else { SetWireImage(raw_packet_, sizeof(raw_packet_)); @@ -311,16 +311,20 @@ QUIC_LOG(INFO) << "CLIENT_SETUP path mismatch"; return false; } + if (cast.max_subscribe_id != client_setup_.max_subscribe_id) { + QUIC_LOG(INFO) << "CLIENT_SETUP max_subscribe_id mismatch"; + return false; + } return true; } void ExpandVarints() override { if (client_setup_.path.has_value()) { - ExpandVarintsImpl("--vvvvvv-vv---"); + ExpandVarintsImpl("--vvvvvv-vv-vv---"); // first two bytes are already a 2B varint. Also, don't expand parameter // varints because that messes up the parameter length field. } else { - ExpandVarintsImpl("--vvvvvv-"); + ExpandVarintsImpl("--vvvvvv-vv-"); } } @@ -329,18 +333,20 @@ } private: - uint8_t raw_packet_[14] = { - 0x40, 0x40, // type - 0x02, 0x01, 0x02, // versions - 0x02, // 2 parameters - 0x00, 0x01, 0x03, // role = PubSub - 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo" + uint8_t raw_packet_[17] = { + 0x40, 0x40, // type + 0x02, 0x01, 0x02, // versions + 0x03, // 3 parameters + 0x00, 0x01, 0x03, // role = PubSub + 0x02, 0x01, 0x32, // max_subscribe_id = 50 + 0x01, 0x03, 0x66, 0x6f, 0x6f, // path = "foo" }; MoqtClientSetup client_setup_ = { /*supported_versions=*/std::vector<MoqtVersion>( {static_cast<MoqtVersion>(1), static_cast<MoqtVersion>(2)}), /*role=*/MoqtRole::kPubSub, /*path=*/"foo", + /*max_subscribe_id=*/50, }; }; @@ -360,11 +366,15 @@ QUIC_LOG(INFO) << "SERVER_SETUP role mismatch"; return false; } + if (cast.max_subscribe_id != server_setup_.max_subscribe_id) { + QUIC_LOG(INFO) << "SERVER_SETUP max_subscribe_id mismatch"; + return false; + } return true; } void ExpandVarints() override { - ExpandVarintsImpl("--vvvv-"); // first two bytes are already a 2b varint + ExpandVarintsImpl("--vvvv-vv-"); // first two bytes are already a 2b varint } MessageStructuredData structured_data() const override { @@ -372,14 +382,16 @@ } private: - uint8_t raw_packet_[7] = { + uint8_t raw_packet_[10] = { 0x40, 0x41, // type - 0x01, 0x01, // version, one param + 0x01, 0x02, // version, two parameters 0x00, 0x01, 0x03, // role = PubSub + 0x02, 0x01, 0x32, // max_subscribe_id = 50 }; MoqtServerSetup server_setup_ = { /*selected_version=*/static_cast<MoqtVersion>(1), /*role=*/MoqtRole::kPubSub, + /*max_subscribe_id=*/50, }; }; @@ -1031,6 +1043,38 @@ }; }; +class QUICHE_NO_EXPORT MaxSubscribeIdMessage : public TestMessageBase { + public: + MaxSubscribeIdMessage() : TestMessageBase() { + SetWireImage(raw_packet_, sizeof(raw_packet_)); + } + + bool EqualFieldValues(MessageStructuredData& values) const override { + auto cast = std::get<MoqtMaxSubscribeId>(values); + if (cast.max_subscribe_id != max_subscribe_id_.max_subscribe_id) { + QUIC_LOG(INFO) << "MAX_SUBSCRIBE_ID mismatch"; + return false; + } + return true; + } + + void ExpandVarints() override { ExpandVarintsImpl("vv"); } + + MessageStructuredData structured_data() const override { + return TestMessageBase::MessageStructuredData(max_subscribe_id_); + } + + private: + uint8_t raw_packet_[2] = { + 0x15, + 0x0b, + }; + + MoqtMaxSubscribeId max_subscribe_id_ = { + /*max_subscribe_id =*/11, + }; +}; + class QUICHE_NO_EXPORT ObjectAckMessage : public TestMessageBase { public: ObjectAckMessage() : TestMessageBase() { @@ -1111,6 +1155,8 @@ return std::make_unique<TrackStatusMessage>(); case MoqtMessageType::kGoAway: return std::make_unique<GoAwayMessage>(); + case MoqtMessageType::kMaxSubscribeId: + return std::make_unique<MaxSubscribeIdMessage>(); case MoqtMessageType::kObjectAck: return std::make_unique<ObjectAckMessage>(); case MoqtMessageType::kClientSetup: