Actually set MoqtSession::control_stream_ on the server side.

PiperOrigin-RevId: 612530962
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc
index 013fcc1..5ff0f79 100644
--- a/quiche/quic/moqt/moqt_integration_test.cc
+++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -214,6 +214,33 @@
   EXPECT_TRUE(success);
 }
 
+TEST_F(MoqtIntegrationTest, AnnounceSuccessSubscribeInResponse) {
+  EstablishSession();
+  EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call("foo"))
+      .WillOnce(Return(std::nullopt));
+  MockRemoteTrackVisitor server_visitor;
+  testing::MockFunction<void(
+      absl::string_view track_namespace,
+      std::optional<MoqtAnnounceErrorReason> error_message)>
+      announce_callback;
+  client_->session()->Announce("foo", announce_callback.AsStdFunction());
+  bool matches = false;
+  EXPECT_CALL(announce_callback, Call(_, _))
+      .WillOnce([&](absl::string_view track_namespace,
+                    std::optional<MoqtAnnounceErrorReason> error) {
+        EXPECT_EQ(track_namespace, "foo");
+        EXPECT_FALSE(error.has_value());
+        server_->session()->SubscribeCurrentGroup(track_namespace, "/catalog",
+                                                  &server_visitor);
+      });
+  EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() {
+    matches = true;
+  });
+  bool success =
+      test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; });
+  EXPECT_TRUE(success);
+}
+
 TEST_F(MoqtIntegrationTest, AnnounceFailure) {
   EstablishSession();
   testing::MockFunction<void(
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index 81d65de..2340ab3 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -449,6 +449,7 @@
   } else {
     is_control_stream_ = true;
   }
+  session_->control_stream_ = stream_->GetStreamId();
   if (perspective() == Perspective::IS_CLIENT) {
     session_->Error(MoqtError::kProtocolViolation,
                     "Received CLIENT_SETUP from server");
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index a9a57c2..87c1909 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -211,6 +211,7 @@
         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup);
         return absl::OkStatus();
       });
+  EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0));
   EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
   stream_input->OnClientSetupMessage(setup);
 }
@@ -858,6 +859,7 @@
         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup);
         return absl::OkStatus();
       });
+  EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0));
   EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
   stream_input->OnClientSetupMessage(setup);