Fix MSAN failures in MoQT integration tests. PiperOrigin-RevId: 642664057
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc index 1e2de16..9690e4d 100644 --- a/quiche/quic/moqt/moqt_integration_test.cc +++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -29,41 +29,21 @@ using ::testing::Assign; using ::testing::Return; -class ClientEndpoint : public MoqtClientEndpoint { - public: - ClientEndpoint(Simulator* simulator, const std::string& name, - const std::string& peer_name, MoqtVersion version) - : MoqtClientEndpoint(simulator, name, peer_name, version) { - session()->callbacks() = callbacks_.AsSessionCallbacks(); - } - MockSessionCallbacks& callbacks() { return callbacks_; } - - private: - MockSessionCallbacks callbacks_; -}; -class ServerEndpoint : public MoqtServerEndpoint { - public: - ServerEndpoint(Simulator* simulator, const std::string& name, - const std::string& peer_name, MoqtVersion version) - : MoqtServerEndpoint(simulator, name, peer_name, version) { - session()->callbacks() = callbacks_.AsSessionCallbacks(); - } - MockSessionCallbacks& callbacks() { return callbacks_; } - - private: - MockSessionCallbacks callbacks_; -}; - class MoqtIntegrationTest : public quiche::test::QuicheTest { public: void CreateDefaultEndpoints() { - client_ = std::make_unique<ClientEndpoint>( + client_ = std::make_unique<MoqtClientEndpoint>( &test_harness_.simulator(), "Client", "Server", MoqtVersion::kDraft04); - server_ = std::make_unique<ServerEndpoint>( + server_ = std::make_unique<MoqtServerEndpoint>( &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft04); + SetupCallbacks(); test_harness_.set_client(client_.get()); test_harness_.set_server(server_.get()); } + void SetupCallbacks() { + client_->session()->callbacks() = client_callbacks_.AsSessionCallbacks(); + server_->session()->callbacks() = server_callbacks_.AsSessionCallbacks(); + } void WireUpEndpoints() { test_harness_.WireUpEndpoints(); } @@ -74,9 +54,9 @@ client_->quic_session()->CryptoConnect(); bool client_established = false; bool server_established = false; - EXPECT_CALL(client_->callbacks().session_established_callback, Call()) + EXPECT_CALL(client_callbacks_.session_established_callback, Call()) .WillOnce(Assign(&client_established, true)); - EXPECT_CALL(server_->callbacks().session_established_callback, Call()) + EXPECT_CALL(server_callbacks_.session_established_callback, Call()) .WillOnce(Assign(&server_established, true)); bool success = test_harness_.RunUntilWithDefaultTimeout( [&]() { return client_established && server_established; }); @@ -86,8 +66,10 @@ protected: quic::simulator::TestHarness test_harness_; - std::unique_ptr<ClientEndpoint> client_; - std::unique_ptr<ServerEndpoint> server_; + MockSessionCallbacks client_callbacks_; + MockSessionCallbacks server_callbacks_; + std::unique_ptr<MoqtClientEndpoint> client_; + std::unique_ptr<MoqtServerEndpoint> server_; }; TEST_F(MoqtIntegrationTest, Handshake) { @@ -97,9 +79,9 @@ client_->quic_session()->CryptoConnect(); bool client_established = false; bool server_established = false; - EXPECT_CALL(client_->callbacks().session_established_callback, Call()) + EXPECT_CALL(client_callbacks_.session_established_callback, Call()) .WillOnce(Assign(&client_established, true)); - EXPECT_CALL(server_->callbacks().session_established_callback, Call()) + EXPECT_CALL(server_callbacks_.session_established_callback, Call()) .WillOnce(Assign(&server_established, true)); bool success = test_harness_.RunUntilWithDefaultTimeout( [&]() { return client_established && server_established; }); @@ -107,11 +89,12 @@ } TEST_F(MoqtIntegrationTest, VersionMismatch) { - client_ = std::make_unique<ClientEndpoint>( + client_ = std::make_unique<MoqtClientEndpoint>( &test_harness_.simulator(), "Client", "Server", MoqtVersion::kUnrecognizedVersionForTests); - server_ = std::make_unique<ServerEndpoint>( + server_ = std::make_unique<MoqtServerEndpoint>( &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft04); + SetupCallbacks(); test_harness_.set_client(client_.get()); test_harness_.set_server(server_.get()); WireUpEndpoints(); @@ -119,13 +102,11 @@ client_->quic_session()->CryptoConnect(); bool client_terminated = false; bool server_terminated = false; - EXPECT_CALL(client_->callbacks().session_established_callback, Call()) - .Times(0); - EXPECT_CALL(server_->callbacks().session_established_callback, Call()) - .Times(0); - EXPECT_CALL(client_->callbacks().session_terminated_callback, Call(_)) + EXPECT_CALL(client_callbacks_.session_established_callback, Call()).Times(0); + EXPECT_CALL(server_callbacks_.session_established_callback, Call()).Times(0); + EXPECT_CALL(client_callbacks_.session_terminated_callback, Call(_)) .WillOnce(Assign(&client_terminated, true)); - EXPECT_CALL(server_->callbacks().session_terminated_callback, Call(_)) + EXPECT_CALL(server_callbacks_.session_terminated_callback, Call(_)) .WillOnce(Assign(&server_terminated, true)); bool success = test_harness_.RunUntilWithDefaultTimeout( [&]() { return client_terminated && server_terminated; }); @@ -134,7 +115,7 @@ TEST_F(MoqtIntegrationTest, AnnounceSuccess) { EstablishSession(); - EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call("foo")) + EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call("foo")) .WillOnce(Return(std::nullopt)); testing::MockFunction<void( absl::string_view track_namespace, @@ -156,7 +137,7 @@ TEST_F(MoqtIntegrationTest, AnnounceSuccessSubscribeInResponse) { EstablishSession(); - EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call("foo")) + EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call("foo")) .WillOnce(Return(std::nullopt)); MockRemoteTrackVisitor server_visitor; testing::MockFunction<void( @@ -187,7 +168,7 @@ // Set up the server to subscribe to "data" track for the namespace announce // it receives. MockRemoteTrackVisitor server_visitor; - EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call(_)) + EXPECT_CALL(server_callbacks_.incoming_announce_callback, Call(_)) .WillOnce([&](absl::string_view track_namespace) { server_->session()->SubscribeAbsolute( track_namespace, "data", /*start_group=*/0,