Actually set MoqtSession::control_stream_ on the server side. PiperOrigin-RevId: 612530962
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index 013fcc1..5ff0f79 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -214,6 +214,33 @@ EXPECT_TRUE(success); } +TEST_F(MoqtIntegrationTest, AnnounceSuccessSubscribeInResponse) { + EstablishSession(); + EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call("foo")) + .WillOnce(Return(std::nullopt)); + MockRemoteTrackVisitor server_visitor; + testing::MockFunction<void( + absl::string_view track_namespace, + std::optional<MoqtAnnounceErrorReason> error_message)> + announce_callback; + client_->session()->Announce("foo", announce_callback.AsStdFunction()); + bool matches = false; + EXPECT_CALL(announce_callback, Call(_, _)) + .WillOnce([&](absl::string_view track_namespace, + std::optional<MoqtAnnounceErrorReason> error) { + EXPECT_EQ(track_namespace, "foo"); + EXPECT_FALSE(error.has_value()); + server_->session()->SubscribeCurrentGroup(track_namespace, "/catalog", + &server_visitor); + }); + EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() { + matches = true; + }); + bool success = + test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; }); + EXPECT_TRUE(success); +} + TEST_F(MoqtIntegrationTest, AnnounceFailure) { EstablishSession(); testing::MockFunction<void(
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index 81d65de..2340ab3 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -449,6 +449,7 @@ } else { is_control_stream_ = true; } + session_->control_stream_ = stream_->GetStreamId(); if (perspective() == Perspective::IS_CLIENT) { session_->Error(MoqtError::kProtocolViolation, "Received CLIENT_SETUP from server");
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index a9a57c2..87c1909 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -211,6 +211,7 @@ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); return absl::OkStatus(); }); + EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0)); EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); stream_input->OnClientSetupMessage(setup); } @@ -858,6 +859,7 @@ EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); return absl::OkStatus(); }); + EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0)); EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); stream_input->OnClientSetupMessage(setup);