Change ClientProofSource::GetCertAndKey() to return shared_ptr<CertAndKey> instead of a raw pointer of CertAndKey. This is to future prevent the caller from accessing the return value after the proof source implementation has destroyed it. Interface change only. PiperOrigin-RevId: 615222541
diff --git a/quiche/quic/core/crypto/client_proof_source.cc b/quiche/quic/core/crypto/client_proof_source.cc index 2d477db..94c6490 100644 --- a/quiche/quic/core/crypto/client_proof_source.cc +++ b/quiche/quic/core/crypto/client_proof_source.cc
@@ -26,9 +26,9 @@ return true; } -const ClientProofSource::CertAndKey* DefaultClientProofSource::GetCertAndKey( - absl::string_view hostname) const { - if (const CertAndKey* const result = LookupExact(hostname); +std::shared_ptr<const ClientProofSource::CertAndKey> +DefaultClientProofSource::GetCertAndKey(absl::string_view hostname) const { + if (std::shared_ptr<const CertAndKey> result = LookupExact(hostname); result || hostname == "*") { return result; } @@ -39,7 +39,7 @@ auto dot_pos = hostname.find('.'); if (dot_pos != std::string::npos) { std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos)); - const CertAndKey* const result = LookupExact(wildcard); + std::shared_ptr<const CertAndKey> result = LookupExact(wildcard); if (result != nullptr) { return result; } @@ -50,13 +50,13 @@ return LookupExact("*"); } -const ClientProofSource::CertAndKey* DefaultClientProofSource::LookupExact( - absl::string_view map_key) const { +std::shared_ptr<const ClientProofSource::CertAndKey> +DefaultClientProofSource::LookupExact(absl::string_view map_key) const { const auto it = cert_and_keys_.find(map_key); QUIC_DVLOG(1) << "LookupExact(" << map_key << ") found:" << (it != cert_and_keys_.end()); if (it != cert_and_keys_.end()) { - return it->second.get(); + return it->second; } return nullptr; }
diff --git a/quiche/quic/core/crypto/client_proof_source.h b/quiche/quic/core/crypto/client_proof_source.h index f2ceb12..589d1be 100644 --- a/quiche/quic/core/crypto/client_proof_source.h +++ b/quiche/quic/core/crypto/client_proof_source.h
@@ -37,7 +37,7 @@ // |server_hostname| is typically a full domain name(www.foo.com), but it // could also be a wildcard domain(*.foo.com), or a "*" which will return the // default cert. - virtual const CertAndKey* GetCertAndKey( + virtual std::shared_ptr<const CertAndKey> GetCertAndKey( absl::string_view server_hostname) const = 0; }; @@ -58,10 +58,12 @@ CertificatePrivateKey private_key); // ClientProofSource implementation - const CertAndKey* GetCertAndKey(absl::string_view hostname) const override; + std::shared_ptr<const CertAndKey> GetCertAndKey( + absl::string_view hostname) const override; private: - const CertAndKey* LookupExact(absl::string_view map_key) const; + std::shared_ptr<const CertAndKey> LookupExact( + absl::string_view map_key) const; absl::flat_hash_map<std::string, std::shared_ptr<CertAndKey>> cert_and_keys_; };
diff --git a/quiche/quic/core/crypto/client_proof_source_test.cc b/quiche/quic/core/crypto/client_proof_source_test.cc index 0104ef3..36b9dff 100644 --- a/quiche/quic/core/crypto/client_proof_source_test.cc +++ b/quiche/quic/core/crypto/client_proof_source_test.cc
@@ -56,7 +56,7 @@ #define VERIFY_CERT_AND_KEY_MATCHES(lhs, rhs) \ do { \ SCOPED_TRACE(testing::Message()); \ - VerifyCertAndKeyMatches(lhs, rhs); \ + VerifyCertAndKeyMatches(lhs.get(), rhs); \ } while (0) void VerifyCertAndKeyMatches(const ClientProofSource::CertAndKey* lhs,
diff --git a/quiche/quic/core/tls_client_handshaker.cc b/quiche/quic/core/tls_client_handshaker.cc index 30831e7..6fff29d 100644 --- a/quiche/quic/core/tls_client_handshaker.cc +++ b/quiche/quic/core/tls_client_handshaker.cc
@@ -44,7 +44,7 @@ crypto_config->tls_signature_algorithms()->c_str()); } if (crypto_config->proof_source() != nullptr) { - const ClientProofSource::CertAndKey* cert_and_key = + std::shared_ptr<const ClientProofSource::CertAndKey> cert_and_key = crypto_config->proof_source()->GetCertAndKey(server_id.host()); if (cert_and_key != nullptr) { QUIC_DVLOG(1) << "Setting client cert and key for " << server_id.host();