Outgoing MoQT SUBSCRIBE_ANNOUNCES life cycle. Does not include any actual ANNOUNCE messages, which are actually not related to the state of the SUBSCRIBE_ANNOUNCE. PiperOrigin-RevId: 704446656
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h index 6239299..c7bb493 100644 --- a/quiche/quic/moqt/moqt_messages.h +++ b/quiche/quic/moqt/moqt_messages.h
@@ -517,7 +517,7 @@ struct QUICHE_EXPORT MoqtSubscribeAnnouncesError { FullTrackName track_namespace; - MoqtAnnounceErrorCode error_code; + SubscribeErrorCode error_code; std::string reason_phrase; };
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc index 800f9eb..7de1149 100644 --- a/quiche/quic/moqt/moqt_parser.cc +++ b/quiche/quic/moqt/moqt_parser.cc
@@ -734,7 +734,7 @@ return 0; } subscribe_namespace_error.error_code = - static_cast<MoqtAnnounceErrorCode>(error_code); + static_cast<SubscribeErrorCode>(error_code); visitor_.OnSubscribeAnnouncesErrorMessage(subscribe_namespace_error); return reader.PreviouslyReadPayload().length(); }
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 6a06e00..81fb1b7 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -239,6 +239,37 @@ std::move(callbacks_.session_terminated_callback)(error); } +bool MoqtSession::SubscribeAnnounces(FullTrackName track_namespace, + MoqtSubscribeAnnouncesCallback callback, + MoqtSubscribeParameters parameters) { + if (peer_role_ == MoqtRole::kSubscriber) { + std::move(callback)(track_namespace, SubscribeErrorCode::kInternalError, + "SUBSCRIBE_ANNOUNCES cannot be sent to subscriber"); + return false; + } + MoqtSubscribeAnnounces message; + message.track_namespace = track_namespace; + message.parameters = std::move(parameters); + SendControlMessage(framer_.SerializeSubscribeAnnounces(message)); + QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_ANNOUNCES message for " + << message.track_namespace; + outgoing_subscribe_announces_[track_namespace] = std::move(callback); + return true; +} + +bool MoqtSession::UnsubscribeAnnounces(FullTrackName track_namespace) { + if (!outgoing_subscribe_announces_.contains(track_namespace)) { + return false; + } + MoqtUnsubscribeAnnounces message; + message.track_namespace = track_namespace; + SendControlMessage(framer_.SerializeUnsubscribeAnnounces(message)); + QUIC_DLOG(INFO) << ENDPOINT << "Sent UNSUBSCRIBE_ANNOUNCES message for " + << message.track_namespace; + outgoing_subscribe_announces_.erase(track_namespace); + return true; +} + // TODO: Create state that allows ANNOUNCE_OK/ERROR on spurious namespaces to // trigger session errors. void MoqtSession::Announce(FullTrackName track_namespace, @@ -983,6 +1014,39 @@ // TODO: notify the application about this. } +void MoqtSession::ControlStream::OnSubscribeAnnouncesOkMessage( + const MoqtSubscribeAnnouncesOk& message) { + auto it = + session_->outgoing_subscribe_announces_.find(message.track_namespace); + if (it == session_->outgoing_subscribe_announces_.end()) { + return; // UNSUBSCRIBE_ANNOUNCES may already have deleted the entry. + } + if (it->second == nullptr) { + session_->Error(MoqtError::kProtocolViolation, + "Two responses to SUBSCRIBE_ANNOUNCES"); + return; + } + std::move(it->second)(message.track_namespace, std::nullopt, ""); + it->second = nullptr; +} + +void MoqtSession::ControlStream::OnSubscribeAnnouncesErrorMessage( + const MoqtSubscribeAnnouncesError& message) { + auto it = + session_->outgoing_subscribe_announces_.find(message.track_namespace); + if (it == session_->outgoing_subscribe_announces_.end()) { + return; // UNSUBSCRIBE_ANNOUNCES may already have deleted the entry. + } + if (it->second == nullptr) { + session_->Error(MoqtError::kProtocolViolation, + "Two responses to SUBSCRIBE_ANNOUNCES"); + return; + } + std::move(it->second)(message.track_namespace, message.error_code, + absl::string_view(message.reason_phrase)); + session_->outgoing_subscribe_announces_.erase(it); +} + void MoqtSession::ControlStream::OnMaxSubscribeIdMessage( const MoqtMaxSubscribeId& message) { if (session_->peer_role_ == MoqtRole::kSubscriber) {
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 7908c9e..97eb5e5 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -49,6 +49,9 @@ using MoqtIncomingAnnounceCallback = quiche::MultiUseCallback<std::optional<MoqtAnnounceErrorReason>( FullTrackName track_namespace)>; +using MoqtSubscribeAnnouncesCallback = quiche::SingleUseCallback<void( + FullTrackName track_namespace, std::optional<SubscribeErrorCode> error, + absl::string_view reason)>; inline std::optional<MoqtAnnounceErrorReason> DefaultIncomingAnnounceCallback( FullTrackName /*track_namespace*/) { @@ -106,6 +109,12 @@ quic::Perspective perspective() const { return parameters_.perspective; } + // Returns true if message was sent. + bool SubscribeAnnounces( + FullTrackName track_namespace, MoqtSubscribeAnnouncesCallback callback, + MoqtSubscribeParameters parameters = MoqtSubscribeParameters()); + bool UnsubscribeAnnounces(FullTrackName track_namespace); + // Send an ANNOUNCE message for |track_namespace|, and call // |announce_callback| when the response arrives. Will fail immediately if // there is already an unresolved ANNOUNCE for that namespace. @@ -214,9 +223,9 @@ void OnSubscribeAnnouncesMessage( const MoqtSubscribeAnnounces& message) override {} void OnSubscribeAnnouncesOkMessage( - const MoqtSubscribeAnnouncesOk& message) override {} + const MoqtSubscribeAnnouncesOk& message) override; void OnSubscribeAnnouncesErrorMessage( - const MoqtSubscribeAnnouncesError& message) override {} + const MoqtSubscribeAnnouncesError& message) override; void OnUnsubscribeAnnouncesMessage( const MoqtUnsubscribeAnnounces& message) override {} void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override; @@ -609,6 +618,12 @@ // Indexed by track namespace. absl::flat_hash_map<FullTrackName, MoqtOutgoingAnnounceCallback> pending_outgoing_announces_; + // The value is nullptr after OK or ERROR is received. The entry is deleted + // when sending UNSUBSCRIBE_ANNOUNCES, to make sure the application doesn't + // unsubscribe from something that it isn't subscribed to. ANNOUNCEs that + // result from this subscription use incoming_announce_callback. + absl::flat_hash_map<FullTrackName, MoqtSubscribeAnnouncesCallback> + outgoing_subscribe_announces_; // The role the peer advertised in its SETUP message. Initialize it to avoid // an uninitialized value if no SETUP arrives or it arrives with no Role
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 349c147..c2fb556 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -639,6 +639,68 @@ stream_input->OnAnnounceMessage(announce); } +TEST_F(MoqtSessionTest, SubscribeAnnouncesLifeCycle) { + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + FullTrackName track_namespace("foo", "bar"); + track_namespace.NameToNamespace(); + bool got_callback = false; + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnounces), _)); + session_.SubscribeAnnounces( + track_namespace, + [&](const FullTrackName& ftn, std::optional<SubscribeErrorCode> error, + absl::string_view reason) { + got_callback = true; + EXPECT_EQ(track_namespace, ftn); + EXPECT_FALSE(error.has_value()); + EXPECT_EQ(reason, ""); + }); + MoqtSubscribeAnnouncesOk ok = { + /*track_namespace=*/track_namespace, + }; + stream_input->OnSubscribeAnnouncesOkMessage(ok); + EXPECT_TRUE(got_callback); + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kUnsubscribeAnnounces), _)); + EXPECT_TRUE(session_.UnsubscribeAnnounces(track_namespace)); + EXPECT_FALSE(session_.UnsubscribeAnnounces(track_namespace)); +} + +TEST_F(MoqtSessionTest, SubscribeAnnouncesError) { + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + FullTrackName track_namespace("foo", "bar"); + track_namespace.NameToNamespace(); + bool got_callback = false; + EXPECT_CALL( + mock_stream, + Writev(ControlMessageOfType(MoqtMessageType::kSubscribeAnnounces), _)); + session_.SubscribeAnnounces( + track_namespace, + [&](const FullTrackName& ftn, std::optional<SubscribeErrorCode> error, + absl::string_view reason) { + got_callback = true; + EXPECT_EQ(track_namespace, ftn); + ASSERT_TRUE(error.has_value()); + EXPECT_EQ(*error, SubscribeErrorCode::kInvalidRange); + EXPECT_EQ(reason, "deadbeef"); + }); + MoqtSubscribeAnnouncesError error = { + /*track_namespace=*/track_namespace, + /*error_code=*/SubscribeErrorCode::kInvalidRange, + /*reason_phrase=*/"deadbeef", + }; + stream_input->OnSubscribeAnnouncesErrorMessage(error); + EXPECT_TRUE(got_callback); + // Entry is immediately gone. + EXPECT_FALSE(session_.UnsubscribeAnnounces(track_namespace)); +} + TEST_F(MoqtSessionTest, IncomingObject) { MockSubscribeRemoteTrackVisitor visitor_; FullTrackName ftn("foo", "bar");
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h index 9488f63..5015bf7 100644 --- a/quiche/quic/moqt/test_tools/moqt_test_message.h +++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -1212,13 +1212,13 @@ private: uint8_t raw_packet_[12] = { 0x13, 0x0a, 0x01, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo" - 0x01, // error_code = 1 + 0x04, // error_code = 4 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar" }; MoqtSubscribeAnnouncesError subscribe_namespace_error_ = { /*track_namespace=*/FullTrackName{"foo"}, - /*error_code=*/MoqtAnnounceErrorCode::kAnnounceNotSupported, + /*error_code=*/SubscribeErrorCode::kUnauthorized, /*reason_phrase=*/"bar", }; };