Incoming MoQT SubscribeAnnounces lifecycle. PiperOrigin-RevId: 705588295
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index 0f8c26f..4b189f9 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -158,6 +158,20 @@ kAnnounceNotSupported = 1, }; +enum class QUICHE_EXPORT SubscribeErrorCode : uint64_t { + kInternalError = 0x0, + kInvalidRange = 0x1, + kRetryTrackAlias = 0x2, + kTrackDoesNotExist = 0x3, + kUnauthorized = 0x4, + kTimeout = 0x5, +}; + +struct MoqtSubscribeErrorReason { + SubscribeErrorCode error_code; + std::string reason_phrase; +}; + struct MoqtAnnounceErrorReason { MoqtAnnounceErrorCode error_code; std::string reason_phrase; @@ -397,15 +411,6 @@ MoqtSubscribeParameters parameters; }; -enum class QUICHE_EXPORT SubscribeErrorCode : uint64_t { - kInternalError = 0x0, - kInvalidRange = 0x1, - kRetryTrackAlias = 0x2, - kTrackDoesNotExist = 0x3, - kUnauthorized = 0x4, - kTimeout = 0x5, -}; - struct QUICHE_EXPORT MoqtSubscribeError { uint64_t subscribe_id; SubscribeErrorCode error_code;
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 538f081..9f01c32 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -239,9 +239,10 @@ std::move(callbacks_.session_terminated_callback)(error); } -bool MoqtSession::SubscribeAnnounces(FullTrackName track_namespace, - MoqtSubscribeAnnouncesCallback callback, - MoqtSubscribeParameters parameters) { +bool MoqtSession::SubscribeAnnounces( + FullTrackName track_namespace, + MoqtOutgoingSubscribeAnnouncesCallback callback, + MoqtSubscribeParameters parameters) { if (peer_role_ == MoqtRole::kSubscriber) { std::move(callback)(track_namespace, SubscribeErrorCode::kInternalError, "SUBSCRIBE_ANNOUNCES cannot be sent to subscriber"); @@ -1035,6 +1036,26 @@ session_->outgoing_announces_.erase(it); } +void MoqtSession::ControlStream::OnSubscribeAnnouncesMessage( + const MoqtSubscribeAnnounces& message) { + // TODO(martinduke): Handle authentication. + std::optional<MoqtSubscribeErrorReason> result = + session_->callbacks_.incoming_subscribe_announces_callback( + message.track_namespace, SubscribeType::kSubscribe); + if (result.has_value()) { + MoqtSubscribeAnnouncesError error; + error.track_namespace = message.track_namespace; + error.error_code = result->error_code; + error.reason_phrase = result->reason_phrase; + SendOrBufferMessage( + session_->framer_.SerializeSubscribeAnnouncesError(error)); + return; + } + MoqtSubscribeAnnouncesOk ok; + ok.track_namespace = message.track_namespace; + SendOrBufferMessage(session_->framer_.SerializeSubscribeAnnouncesOk(ok)); +} + void MoqtSession::ControlStream::OnSubscribeAnnouncesOkMessage( const MoqtSubscribeAnnouncesOk& message) { auto it = @@ -1068,6 +1089,14 @@ session_->outgoing_subscribe_announces_.erase(it); } +void MoqtSession::ControlStream::OnUnsubscribeAnnouncesMessage( + const MoqtUnsubscribeAnnounces& message) { + // MoqtSession keeps no state here, so just tell the application. + std::optional<MoqtSubscribeErrorReason> result = + session_->callbacks_.incoming_subscribe_announces_callback( + message.track_namespace, SubscribeType::kUnsubscribe); +} + void MoqtSession::ControlStream::OnMaxSubscribeIdMessage( const MoqtMaxSubscribeId& message) { if (session_->peer_role_ == MoqtRole::kSubscriber) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 4bcba1c..98eb244 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -43,6 +43,8 @@ quiche::SingleUseCallback<void(absl::string_view error_message)>; using MoqtSessionDeletedCallback = quiche::SingleUseCallback<void()>; +enum class SubscribeType { kSubscribe, kUnsubscribe }; + // If |error_message| is nullopt, this is triggered by an ANNOUNCE_OK. // Otherwise, it is triggered by ANNOUNCE_ERROR or ANNOUNCE_CANCEL. For // ERROR or CANCEL, MoqtSession is deleting all ANNOUNCE state immediately @@ -54,9 +56,17 @@ using MoqtIncomingAnnounceCallback = quiche::MultiUseCallback<std::optional<MoqtAnnounceErrorReason>( FullTrackName track_namespace)>; -using MoqtSubscribeAnnouncesCallback = quiche::SingleUseCallback<void( +using MoqtOutgoingSubscribeAnnouncesCallback = quiche::SingleUseCallback<void( FullTrackName track_namespace, std::optional<SubscribeErrorCode> error, absl::string_view reason)>; +// If the return value is nullopt, the Session will respond with +// SUBSCRIBE_ANNOUNCES_OK. Otherwise, it will respond with +// SUBSCRIBE_ANNOUNCES_ERROR. +// If |subscribe_type| is kUnsubscribe, this is an UNSUBSCRIBE_ANNOUNCES message +// and the return value will be ignored. +using MoqtIncomingSubscribeAnnouncesCallback = + quiche::MultiUseCallback<std::optional<MoqtSubscribeErrorReason>( + const FullTrackName& track_namespace, SubscribeType subscribe_type)>; inline std::optional<MoqtAnnounceErrorReason> DefaultIncomingAnnounceCallback( FullTrackName /*track_namespace*/) { @@ -65,6 +75,14 @@ "This endpoint does not accept incoming ANNOUNCE messages"}); }; +inline std::optional<MoqtSubscribeErrorReason> +DefaultIncomingSubscribeAnnouncesCallback(const FullTrackName& track_namespace, + SubscribeType /*subscribe_type*/) { + return MoqtSubscribeErrorReason{ + SubscribeErrorCode::kUnauthorized, + "This endpoint does not support incoming SUBSCRIBE_ANNOUNCES messages"}; +} + // Callbacks for session-level events. struct MoqtSessionCallbacks { MoqtSessionEstablishedCallback session_established_callback = +[] {}; @@ -74,6 +92,8 @@ MoqtIncomingAnnounceCallback incoming_announce_callback = DefaultIncomingAnnounceCallback; + MoqtIncomingSubscribeAnnouncesCallback incoming_subscribe_announces_callback = + DefaultIncomingSubscribeAnnouncesCallback; }; struct SubscriptionWithQueuedStream { @@ -116,7 +136,8 @@ // Returns true if message was sent. bool SubscribeAnnounces( - FullTrackName track_namespace, MoqtSubscribeAnnouncesCallback callback, + FullTrackName track_namespace, + MoqtOutgoingSubscribeAnnouncesCallback callback, MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); bool UnsubscribeAnnounces(FullTrackName track_namespace); @@ -228,13 +249,13 @@ void OnTrackStatusMessage(const MoqtTrackStatus& message) override {} void OnGoAwayMessage(const MoqtGoAway& /*message*/) override {} void OnSubscribeAnnouncesMessage( - const MoqtSubscribeAnnounces& message) override {} + const MoqtSubscribeAnnounces& message) override; void OnSubscribeAnnouncesOkMessage( const MoqtSubscribeAnnouncesOk& message) override; void OnSubscribeAnnouncesErrorMessage( const MoqtSubscribeAnnouncesError& message) override; void OnUnsubscribeAnnouncesMessage( - const MoqtUnsubscribeAnnounces& message) override {} + const MoqtUnsubscribeAnnounces& message) override; void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override; void OnFetchMessage(const MoqtFetch& message) override; void OnFetchCancelMessage(const MoqtFetchCancel& message) override {} @@ -631,7 +652,7 @@ // when sending UNSUBSCRIBE_ANNOUNCES, to make sure the application doesn't // unsubscribe from something that it isn't subscribed to. ANNOUNCEs that // result from this subscription use incoming_announce_callback. - absl::flat_hash_map<FullTrackName, MoqtSubscribeAnnouncesCallback> + absl::flat_hash_map<FullTrackName, MoqtOutgoingSubscribeAnnouncesCallback> outgoing_subscribe_announces_; // The role the peer advertised in its SETUP message. Initialize it to avoid
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 2d91180..44021ca 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -2103,6 +2103,51 @@ EXPECT_EQ(objects_received, 2); } +TEST_F(MoqtSessionTest, IncomingSubscribeAnnounces) { + FullTrackName track_namespace = FullTrackName{"foo"}; + MoqtSubscribeAnnounces announces = { + track_namespace, + /*parameters=*/MoqtSubscribeParameters(), + }; + webtransport::test::MockStream control_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &control_stream); + EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback, + Call(_, SubscribeType::kSubscribe)) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL( + control_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnouncesOk), _)); + stream_input->OnSubscribeAnnouncesMessage(announces); + MoqtUnsubscribeAnnounces unsubscribe_announces = { + /*track_namespace=*/FullTrackName{"foo"}, + }; + EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback, + Call(track_namespace, SubscribeType::kUnsubscribe)) + .WillOnce(Return(std::nullopt)); + stream_input->OnUnsubscribeAnnouncesMessage(unsubscribe_announces); +} + +TEST_F(MoqtSessionTest, IncomingSubscribeAnnouncesWithError) { + FullTrackName track_namespace = FullTrackName{"foo"}; + MoqtSubscribeAnnounces announces = { + track_namespace, + /*parameters=*/MoqtSubscribeParameters(), + }; + webtransport::test::MockStream control_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &control_stream); + EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback, + Call(_, SubscribeType::kSubscribe)) + .WillOnce(Return( + MoqtSubscribeErrorReason{SubscribeErrorCode::kUnauthorized, "foo"})); + EXPECT_CALL( + control_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnouncesError), + _)); + stream_input->OnSubscribeAnnouncesMessage(announces); +} + // TODO: re-enable this test once this behavior is re-implemented. #if 0 TEST_F(MoqtSessionTest, SubscribeUpdateClosesSubscription) {
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h index 7031696..8eb22a7 100644 --- a/quiche/quic/moqt/tools/moqt_mock_visitor.h +++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -30,17 +30,24 @@ testing::MockFunction<void()> session_deleted_callback; testing::MockFunction<std::optional<MoqtAnnounceErrorReason>(FullTrackName)> incoming_announce_callback; + testing::MockFunction<std::optional<MoqtSubscribeErrorReason>(FullTrackName, + SubscribeType)> + incoming_subscribe_announces_callback; MockSessionCallbacks() { ON_CALL(incoming_announce_callback, Call(testing::_)) .WillByDefault(DefaultIncomingAnnounceCallback); + ON_CALL(incoming_subscribe_announces_callback, Call(testing::_, testing::_)) + .WillByDefault(DefaultIncomingSubscribeAnnouncesCallback); } MoqtSessionCallbacks AsSessionCallbacks() { - return MoqtSessionCallbacks{session_established_callback.AsStdFunction(), - session_terminated_callback.AsStdFunction(), - session_deleted_callback.AsStdFunction(), - incoming_announce_callback.AsStdFunction()}; + return MoqtSessionCallbacks{ + session_established_callback.AsStdFunction(), + session_terminated_callback.AsStdFunction(), + session_deleted_callback.AsStdFunction(), + incoming_announce_callback.AsStdFunction(), + incoming_subscribe_announces_callback.AsStdFunction()}; } };