Change ClientProofSource::GetCertAndKey() to return shared_ptr<CertAndKey> instead of a raw pointer of CertAndKey.
This is to future prevent the caller from accessing the return value after the proof source implementation has destroyed it.
Interface change only.
PiperOrigin-RevId: 615222541
diff --git a/quiche/quic/core/crypto/client_proof_source.cc b/quiche/quic/core/crypto/client_proof_source.cc
index 2d477db..94c6490 100644
--- a/quiche/quic/core/crypto/client_proof_source.cc
+++ b/quiche/quic/core/crypto/client_proof_source.cc
@@ -26,9 +26,9 @@
return true;
}
-const ClientProofSource::CertAndKey* DefaultClientProofSource::GetCertAndKey(
- absl::string_view hostname) const {
- if (const CertAndKey* const result = LookupExact(hostname);
+std::shared_ptr<const ClientProofSource::CertAndKey>
+DefaultClientProofSource::GetCertAndKey(absl::string_view hostname) const {
+ if (std::shared_ptr<const CertAndKey> result = LookupExact(hostname);
result || hostname == "*") {
return result;
}
@@ -39,7 +39,7 @@
auto dot_pos = hostname.find('.');
if (dot_pos != std::string::npos) {
std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos));
- const CertAndKey* const result = LookupExact(wildcard);
+ std::shared_ptr<const CertAndKey> result = LookupExact(wildcard);
if (result != nullptr) {
return result;
}
@@ -50,13 +50,13 @@
return LookupExact("*");
}
-const ClientProofSource::CertAndKey* DefaultClientProofSource::LookupExact(
- absl::string_view map_key) const {
+std::shared_ptr<const ClientProofSource::CertAndKey>
+DefaultClientProofSource::LookupExact(absl::string_view map_key) const {
const auto it = cert_and_keys_.find(map_key);
QUIC_DVLOG(1) << "LookupExact(" << map_key
<< ") found:" << (it != cert_and_keys_.end());
if (it != cert_and_keys_.end()) {
- return it->second.get();
+ return it->second;
}
return nullptr;
}
diff --git a/quiche/quic/core/crypto/client_proof_source.h b/quiche/quic/core/crypto/client_proof_source.h
index f2ceb12..589d1be 100644
--- a/quiche/quic/core/crypto/client_proof_source.h
+++ b/quiche/quic/core/crypto/client_proof_source.h
@@ -37,7 +37,7 @@
// |server_hostname| is typically a full domain name(www.foo.com), but it
// could also be a wildcard domain(*.foo.com), or a "*" which will return the
// default cert.
- virtual const CertAndKey* GetCertAndKey(
+ virtual std::shared_ptr<const CertAndKey> GetCertAndKey(
absl::string_view server_hostname) const = 0;
};
@@ -58,10 +58,12 @@
CertificatePrivateKey private_key);
// ClientProofSource implementation
- const CertAndKey* GetCertAndKey(absl::string_view hostname) const override;
+ std::shared_ptr<const CertAndKey> GetCertAndKey(
+ absl::string_view hostname) const override;
private:
- const CertAndKey* LookupExact(absl::string_view map_key) const;
+ std::shared_ptr<const CertAndKey> LookupExact(
+ absl::string_view map_key) const;
absl::flat_hash_map<std::string, std::shared_ptr<CertAndKey>> cert_and_keys_;
};
diff --git a/quiche/quic/core/crypto/client_proof_source_test.cc b/quiche/quic/core/crypto/client_proof_source_test.cc
index 0104ef3..36b9dff 100644
--- a/quiche/quic/core/crypto/client_proof_source_test.cc
+++ b/quiche/quic/core/crypto/client_proof_source_test.cc
@@ -56,7 +56,7 @@
#define VERIFY_CERT_AND_KEY_MATCHES(lhs, rhs) \
do { \
SCOPED_TRACE(testing::Message()); \
- VerifyCertAndKeyMatches(lhs, rhs); \
+ VerifyCertAndKeyMatches(lhs.get(), rhs); \
} while (0)
void VerifyCertAndKeyMatches(const ClientProofSource::CertAndKey* lhs,
diff --git a/quiche/quic/core/tls_client_handshaker.cc b/quiche/quic/core/tls_client_handshaker.cc
index 30831e7..6fff29d 100644
--- a/quiche/quic/core/tls_client_handshaker.cc
+++ b/quiche/quic/core/tls_client_handshaker.cc
@@ -44,7 +44,7 @@
crypto_config->tls_signature_algorithms()->c_str());
}
if (crypto_config->proof_source() != nullptr) {
- const ClientProofSource::CertAndKey* cert_and_key =
+ std::shared_ptr<const ClientProofSource::CertAndKey> cert_and_key =
crypto_config->proof_source()->GetCertAndKey(server_id.host());
if (cert_and_key != nullptr) {
QUIC_DVLOG(1) << "Setting client cert and key for " << server_id.host();