Finish MoQT incoming ANNOUNCE life cycle (UNANNOUNCE). PiperOrigin-RevId: 705631696
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index dc1820b..1c968ee 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -120,10 +120,10 @@ EXPECT_TRUE(success); } -TEST_F(MoqtIntegrationTest, AnnounceSuccess) { +TEST_F(MoqtIntegrationTest, AnnounceSuccessThenUnannounce) { EstablishSession(); EXPECT_CALL(server_callbacks_.incoming_announce_callback, - Call(FullTrackName{"foo"})) + Call(FullTrackName{"foo"}, AnnounceEvent::kAnnounce)) .WillOnce(Return(std::nullopt)); testing::MockFunction<void( FullTrackName track_namespace, @@ -142,12 +142,62 @@ bool success = test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; }); EXPECT_TRUE(success); + matches = false; + EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call(_, _)) + .WillOnce([&](FullTrackName name, AnnounceEvent event) { + matches = true; + EXPECT_EQ(name, FullTrackName{"foo"}); + EXPECT_EQ(event, AnnounceEvent::kUnannounce); + return std::nullopt; + }); + client_->session()->Unannounce(FullTrackName{"foo"}); + success = test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; }); + EXPECT_TRUE(success); +} + +TEST_F(MoqtIntegrationTest, AnnounceSuccessThenCancel) { + EstablishSession(); + EXPECT_CALL(server_callbacks_.incoming_announce_callback, + Call(FullTrackName{"foo"}, AnnounceEvent::kAnnounce)) + .WillOnce(Return(std::nullopt)); + testing::MockFunction<void( + FullTrackName track_namespace, + std::optional<MoqtAnnounceErrorReason> error_message)> + announce_callback; + client_->session()->Announce(FullTrackName{"foo"}, + announce_callback.AsStdFunction()); + bool matches = false; + EXPECT_CALL(announce_callback, Call(_, _)) + .WillOnce([&](FullTrackName track_namespace, + std::optional<MoqtAnnounceErrorReason> error) { + matches = true; + EXPECT_EQ(track_namespace, FullTrackName{"foo"}); + EXPECT_FALSE(error.has_value()); + }); + bool success = + test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; }); + EXPECT_TRUE(success); + matches = false; + EXPECT_CALL(announce_callback, Call(_, _)) + .WillOnce([&](FullTrackName track_namespace, + std::optional<MoqtAnnounceErrorReason> error) { + matches = true; + EXPECT_EQ(track_namespace, FullTrackName{"foo"}); + ASSERT_TRUE(error.has_value()); + EXPECT_EQ(error->error_code, MoqtAnnounceErrorCode::kInternalError); + EXPECT_EQ(error->reason_phrase, "internal error"); + }); + server_->session()->CancelAnnounce(FullTrackName{"foo"}, + MoqtAnnounceErrorCode::kInternalError, + "internal error"); + success = test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; }); + EXPECT_TRUE(success); } TEST_F(MoqtIntegrationTest, AnnounceSuccessSubscribeInResponse) { EstablishSession(); EXPECT_CALL(server_callbacks_.incoming_announce_callback, - Call(FullTrackName{"foo"})) + Call(FullTrackName{"foo"}, AnnounceEvent::kAnnounce)) .WillOnce(Return(std::nullopt)); MockSubscribeRemoteTrackVisitor server_visitor; testing::MockFunction<void( @@ -180,8 +230,10 @@ // Set up the server to subscribe to "data" track for the namespace announce // it receives. MockSubscribeRemoteTrackVisitor server_visitor; - EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call(_)) - .WillOnce([&](FullTrackName track_namespace) { + EXPECT_CALL(server_callbacks_.incoming_announce_callback, + Call(_, AnnounceEvent::kAnnounce)) + .WillOnce([&](const FullTrackName& track_namespace, + AnnounceEvent /*announce*/) { FullTrackName track_name = track_namespace; track_name.AddElement("data"); server_->session()->SubscribeAbsolute(
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 9f01c32..c181578 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -310,6 +310,19 @@ return true; } +void MoqtSession::CancelAnnounce(FullTrackName track_namespace, + MoqtAnnounceErrorCode code, + absl::string_view reason) { + if (peer_role_ == MoqtRole::kSubscriber) { + return; + } + MoqtAnnounceCancel message{track_namespace, code, std::string(reason)}; + + SendControlMessage(framer_.SerializeAnnounceCancel(message)); + QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE_CANCEL message for " + << message.track_namespace << " with reason " << reason; +} + bool MoqtSession::SubscribeAbsolute(const FullTrackName& name, uint64_t start_group, uint64_t start_object, SubscribeRemoteTrack::Visitor* visitor, @@ -981,7 +994,8 @@ return; } std::optional<MoqtAnnounceErrorReason> error = - session_->callbacks_.incoming_announce_callback(message.track_namespace); + session_->callbacks_.incoming_announce_callback(message.track_namespace, + AnnounceEvent::kAnnounce); if (error.has_value()) { MoqtAnnounceError reply; reply.track_namespace = message.track_namespace; @@ -1036,12 +1050,18 @@ session_->outgoing_announces_.erase(it); } +void MoqtSession::ControlStream::OnUnannounceMessage( + const MoqtUnannounce& message) { + session_->callbacks_.incoming_announce_callback(message.track_namespace, + AnnounceEvent::kUnannounce); +} + void MoqtSession::ControlStream::OnSubscribeAnnouncesMessage( const MoqtSubscribeAnnounces& message) { // TODO(martinduke): Handle authentication. std::optional<MoqtSubscribeErrorReason> result = session_->callbacks_.incoming_subscribe_announces_callback( - message.track_namespace, SubscribeType::kSubscribe); + message.track_namespace, SubscribeEvent::kSubscribe); if (result.has_value()) { MoqtSubscribeAnnouncesError error; error.track_namespace = message.track_namespace; @@ -1094,7 +1114,7 @@ // MoqtSession keeps no state here, so just tell the application. std::optional<MoqtSubscribeErrorReason> result = session_->callbacks_.incoming_subscribe_announces_callback( - message.track_namespace, SubscribeType::kUnsubscribe); + message.track_namespace, SubscribeEvent::kUnsubscribe); } void MoqtSession::ControlStream::OnMaxSubscribeIdMessage(
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 98eb244..ca8e80c 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -43,7 +43,8 @@ quiche::SingleUseCallback<void(absl::string_view error_message)>; using MoqtSessionDeletedCallback = quiche::SingleUseCallback<void()>; -enum class SubscribeType { kSubscribe, kUnsubscribe }; +enum class SubscribeEvent { kSubscribe, kUnsubscribe }; +enum class AnnounceEvent { kAnnounce, kUnannounce }; // If |error_message| is nullopt, this is triggered by an ANNOUNCE_OK. // Otherwise, it is triggered by ANNOUNCE_ERROR or ANNOUNCE_CANCEL. For @@ -55,7 +56,7 @@ std::optional<MoqtAnnounceErrorReason> error)>; using MoqtIncomingAnnounceCallback = quiche::MultiUseCallback<std::optional<MoqtAnnounceErrorReason>( - FullTrackName track_namespace)>; + const FullTrackName& track_namespace, AnnounceEvent announce_type)>; using MoqtOutgoingSubscribeAnnouncesCallback = quiche::SingleUseCallback<void( FullTrackName track_namespace, std::optional<SubscribeErrorCode> error, absl::string_view reason)>; @@ -66,10 +67,10 @@ // and the return value will be ignored. using MoqtIncomingSubscribeAnnouncesCallback = quiche::MultiUseCallback<std::optional<MoqtSubscribeErrorReason>( - const FullTrackName& track_namespace, SubscribeType subscribe_type)>; + const FullTrackName& track_namespace, SubscribeEvent subscribe_type)>; inline std::optional<MoqtAnnounceErrorReason> DefaultIncomingAnnounceCallback( - FullTrackName /*track_namespace*/) { + const FullTrackName& /*track_namespace*/, AnnounceEvent /*announce*/) { return std::optional(MoqtAnnounceErrorReason{ MoqtAnnounceErrorCode::kAnnounceNotSupported, "This endpoint does not accept incoming ANNOUNCE messages"}); @@ -77,7 +78,7 @@ inline std::optional<MoqtSubscribeErrorReason> DefaultIncomingSubscribeAnnouncesCallback(const FullTrackName& track_namespace, - SubscribeType /*subscribe_type*/) { + SubscribeEvent /*subscribe_type*/) { return MoqtSubscribeErrorReason{ SubscribeErrorCode::kUnauthorized, "This endpoint does not support incoming SUBSCRIBE_ANNOUNCES messages"}; @@ -148,6 +149,10 @@ MoqtOutgoingAnnounceCallback announce_callback); // Returns true if message was sent, false if there is no ANNOUNCE to cancel. bool Unannounce(FullTrackName track_namespace); + // Allows the subscriber to declare it will not subscribe to |track_namespace| + // anymore. + void CancelAnnounce(FullTrackName track_namespace, MoqtAnnounceErrorCode code, + absl::string_view reason_phrase); // Returns true if SUBSCRIBE was sent. If there is already a subscription to // the track, the message will still be sent. However, the visitor will be @@ -245,7 +250,7 @@ void OnAnnounceCancelMessage(const MoqtAnnounceCancel& message) override; void OnTrackStatusRequestMessage( const MoqtTrackStatusRequest& message) override {}; - void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override {} + void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override; void OnTrackStatusMessage(const MoqtTrackStatus& message) override {} void OnGoAwayMessage(const MoqtGoAway& /*message*/) override {} void OnSubscribeAnnouncesMessage(
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 44021ca..1c845b6 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -674,19 +674,73 @@ EXPECT_EQ(MoqtSessionPeer::remote_track(&session_, 2), nullptr); } -TEST_F(MoqtSessionTest, ReplyToAnnounce) { +TEST_F(MoqtSessionTest, ReplyToAnnounceWithOkThenUnannounce) { + FullTrackName track_namespace{"foo"}; webtransport::test::MockStream mock_stream; std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); MoqtAnnounce announce = { - /*track_namespace=*/FullTrackName{"foo"}, + track_namespace, }; EXPECT_CALL(session_callbacks_.incoming_announce_callback, - Call(FullTrackName{"foo"})) + Call(track_namespace, AnnounceEvent::kAnnounce)) .WillOnce(Return(std::nullopt)); EXPECT_CALL( mock_stream, - Writev(SerializedControlMessage(MoqtAnnounceOk{FullTrackName{"foo"}}), + Writev(SerializedControlMessage(MoqtAnnounceOk{track_namespace}), _)); + stream_input->OnAnnounceMessage(announce); + MoqtUnannounce unannounce = { + track_namespace, + }; + EXPECT_CALL(session_callbacks_.incoming_announce_callback, + Call(track_namespace, AnnounceEvent::kUnannounce)) + .WillOnce(Return(std::nullopt)); + stream_input->OnUnannounceMessage(unannounce); +} + +TEST_F(MoqtSessionTest, ReplyToAnnounceWithOkThenAnnounceCancel) { + FullTrackName track_namespace{"foo"}; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + MoqtAnnounce announce = { + track_namespace, + }; + EXPECT_CALL(session_callbacks_.incoming_announce_callback, + Call(track_namespace, AnnounceEvent::kAnnounce)) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL( + mock_stream, + Writev(SerializedControlMessage(MoqtAnnounceOk{track_namespace}), _)); + stream_input->OnAnnounceMessage(announce); + EXPECT_CALL(mock_stream, + Writev(SerializedControlMessage(MoqtAnnounceCancel{ + track_namespace, MoqtAnnounceErrorCode::kInternalError, + "deadbeef"}), + _)); + session_.CancelAnnounce(track_namespace, + MoqtAnnounceErrorCode::kInternalError, "deadbeef"); +} + +TEST_F(MoqtSessionTest, ReplyToAnnounceWithError) { + FullTrackName track_namespace{"foo"}; + webtransport::test::MockStream mock_stream; + std::unique_ptr<MoqtControlParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&session_, &mock_stream); + MoqtAnnounce announce = { + track_namespace, + }; + MoqtAnnounceErrorReason error = { + MoqtAnnounceErrorCode::kAnnounceNotSupported, + "deadbeef", + }; + EXPECT_CALL(session_callbacks_.incoming_announce_callback, + Call(track_namespace, AnnounceEvent::kAnnounce)) + .WillOnce(Return(error)); + EXPECT_CALL( + mock_stream, + Writev(SerializedControlMessage(MoqtAnnounceError{ + track_namespace, error.error_code, error.reason_phrase}), _)); stream_input->OnAnnounceMessage(announce); } @@ -2113,7 +2167,7 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback, - Call(_, SubscribeType::kSubscribe)) + Call(_, SubscribeEvent::kSubscribe)) .WillOnce(Return(std::nullopt)); EXPECT_CALL( control_stream, @@ -2123,7 +2177,7 @@ /*track_namespace=*/FullTrackName{"foo"}, }; EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback, - Call(track_namespace, SubscribeType::kUnsubscribe)) + Call(track_namespace, SubscribeEvent::kUnsubscribe)) .WillOnce(Return(std::nullopt)); stream_input->OnUnsubscribeAnnouncesMessage(unsubscribe_announces); } @@ -2138,7 +2192,7 @@ std::unique_ptr<MoqtControlParserVisitor> stream_input = MoqtSessionPeer::CreateControlStream(&session_, &control_stream); EXPECT_CALL(session_callbacks_.incoming_subscribe_announces_callback, - Call(_, SubscribeType::kSubscribe)) + Call(_, SubscribeEvent::kSubscribe)) .WillOnce(Return( MoqtSubscribeErrorReason{SubscribeErrorCode::kUnauthorized, "foo"})); EXPECT_CALL(
diff --git a/quiche/quic/moqt/tools/chat_server.cc b/quiche/quic/moqt/tools/chat_server.cc index 3f31a47..243d13b 100644 --- a/quiche/quic/moqt/tools/chat_server.cc +++ b/quiche/quic/moqt/tools/chat_server.cc
@@ -32,11 +32,15 @@ MoqtSession* session, ChatServer* server) : session_(session), server_(server) { session_->callbacks().incoming_announce_callback = - [&](FullTrackName track_namespace) { + [&](FullTrackName track_namespace, AnnounceEvent announce_type) { FullTrackName track_name = track_namespace; track_name.AddElement(""); - std::cout << "Received ANNOUNCE for " << track_namespace.ToString() - << "\n"; + if (announce_type == AnnounceEvent::kAnnounce) { + std::cout << "Received ANNOUNCE for "; + } else { + std::cout << "Received UNANNOUNCE for "; + } + std::cout << track_namespace.ToString() << "\n"; username_ = server_->strings().GetUsernameFromFullTrackName(track_name); if (username_->empty()) { std::cout << "Malformed ANNOUNCE namespace\n";
diff --git a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc index 4b7aa2d..1f0cfc3 100644 --- a/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc +++ b/quiche/quic/moqt/tools/moqt_ingestion_server_bin.cc
@@ -115,8 +115,9 @@ absl::bind_front(&MoqtIngestionHandler::OnAnnounceReceived, this); } + // TODO(martinduke): Handle when |announce| is false (UNANNOUNCE). std::optional<MoqtAnnounceErrorReason> OnAnnounceReceived( - FullTrackName track_namespace) { + FullTrackName track_namespace, AnnounceEvent /*announce*/) { if (!IsValidTrackNamespace(track_namespace) && !quiche::GetQuicheCommandLineFlag( FLAGS_allow_invalid_track_namespaces)) {
diff --git a/quiche/quic/moqt/tools/moqt_mock_visitor.h b/quiche/quic/moqt/tools/moqt_mock_visitor.h index 8eb22a7..a685b2c 100644 --- a/quiche/quic/moqt/tools/moqt_mock_visitor.h +++ b/quiche/quic/moqt/tools/moqt_mock_visitor.h
@@ -28,14 +28,15 @@ testing::MockFunction<void()> session_established_callback; testing::MockFunction<void(absl::string_view)> session_terminated_callback; testing::MockFunction<void()> session_deleted_callback; - testing::MockFunction<std::optional<MoqtAnnounceErrorReason>(FullTrackName)> + testing::MockFunction<std::optional<MoqtAnnounceErrorReason>( + const FullTrackName&, AnnounceEvent)> incoming_announce_callback; testing::MockFunction<std::optional<MoqtSubscribeErrorReason>(FullTrackName, - SubscribeType)> + SubscribeEvent)> incoming_subscribe_announces_callback; MockSessionCallbacks() { - ON_CALL(incoming_announce_callback, Call(testing::_)) + ON_CALL(incoming_announce_callback, Call(testing::_, testing::_)) .WillByDefault(DefaultIncomingAnnounceCallback); ON_CALL(incoming_subscribe_announces_callback, Call(testing::_, testing::_)) .WillByDefault(DefaultIncomingSubscribeAnnouncesCallback);