Call GetCertChains() in TlsServerHandshaker
Protected by quic_reloadable_flag_quic_use_proof_source_get_cert_chains.
PiperOrigin-RevId: 823650043
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h
index de45c4a..54c174d 100755
--- a/quiche/common/quiche_feature_flags_list.h
+++ b/quiche/common/quiche_feature_flags_list.h
@@ -60,6 +60,7 @@
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_testonly_default_false, false, false, "A testonly reloadable flag that will always default to false.")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_testonly_default_true, true, true, "A testonly reloadable flag that will always default to true.")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_inlining_send_buffer2, true, true, "Uses an inlining version of QuicSendStreamBuffer.")
+QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_proof_source_get_cert_chains, false, false, "When true, quic::TlsServerHandshaker will use ProofSource::GetCertChains() instead of ProofSource::GetCertChain()")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_received_client_addresses_cache, true, true, "If true, use a LRU cache to record client addresses of packets received on server's original address.")
QUICHE_FLAG(bool, quiche_restart_flag_quic_support_release_time_for_gso, false, false, "If true, QuicGsoBatchWriter will support release time if it is available and the process has the permission to do so.")
QUICHE_FLAG(bool, quiche_restart_flag_quic_testonly_default_false, false, false, "A testonly restart flag that will always default to false.")
diff --git a/quiche/quic/core/crypto/proof_source.h b/quiche/quic/core/crypto/proof_source.h
index 17d2f59..ac4db37 100644
--- a/quiche/quic/core/crypto/proof_source.h
+++ b/quiche/quic/core/crypto/proof_source.h
@@ -10,12 +10,14 @@
#include <memory>
#include <optional>
#include <string>
+#include <utility>
#include <variant>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "openssl/base.h"
#include "openssl/pool.h"
#include "openssl/ssl.h"
@@ -291,11 +293,22 @@
// Configuration to use for configuring the SSL object when handshaking
// locally.
struct LocalSSLConfig {
+ using ReferencedCountedChain =
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>;
+
+ // TODO: b/451645567 - Remove this constructor.
LocalSSLConfig(const ProofSource::Chain* absl_nullable chain,
QuicDelayedSSLConfig delayed_ssl_config)
: chain(chain), delayed_ssl_config(delayed_ssl_config) {}
+ LocalSSLConfig(absl::Span<const ReferencedCountedChain> chains,
+ QuicDelayedSSLConfig delayed_ssl_config)
+ : chains(std::move(chains)), delayed_ssl_config(delayed_ssl_config) {}
+
+ // TODO: b/451645567 - Once we remove `ProofSource::GetCertChain()`, we can
+ // delete the `chain` field.
const ProofSource::Chain* absl_nullable chain = nullptr;
+ absl::Span<const ReferencedCountedChain absl_nonnull> chains;
QuicDelayedSSLConfig delayed_ssl_config;
};
@@ -331,11 +344,20 @@
//
// When called asynchronously(is_sync=false), this method will be responsible
// to continue the handshake from where it left off.
+ //
+ // Callers that pass a `LocalSSLConfig` in `ssl_config` must use the result of
+ // `DoesOnSelectCertificateDoneExpectChains()` to decide which fields to
+ // populate.
virtual void OnSelectCertificateDone(bool ok, bool is_sync,
SSLConfig ssl_config,
absl::string_view ticket_encryption_key,
bool cert_matched_sni) = 0;
+ // Returns true when `OnSelectCertificateDone()` reads the
+ // `LocalSSLConfig::chains` field. Otherwise, it may read
+ // `LocalSSLConfig::chain`.
+ virtual bool DoesOnSelectCertificateDoneExpectChains() const = 0;
+
// Called when a ProofSourceHandle::ComputeSignature operation completes.
virtual void OnComputeSignatureDone(
bool ok, bool is_sync, std::string signature,
diff --git a/quiche/quic/core/crypto/proof_source_x509.h b/quiche/quic/core/crypto/proof_source_x509.h
index 37b3246..4f6a706 100644
--- a/quiche/quic/core/crypto/proof_source_x509.h
+++ b/quiche/quic/core/crypto/proof_source_x509.h
@@ -34,6 +34,7 @@
QuicTransportVersion transport_version,
absl::string_view chlo_hash,
std::unique_ptr<Callback> callback) override;
+ // TODO: b/451645567 - Define `GetCertChains()` instead of `GetCertChain()`.
quiche::QuicheReferenceCountedPointer<Chain> GetCertChain(
const QuicSocketAddress& server_address,
const QuicSocketAddress& client_address, const std::string& hostname,
diff --git a/quiche/quic/core/quic_crypto_client_handshaker_test.cc b/quiche/quic/core/quic_crypto_client_handshaker_test.cc
index 69530ac..fc42e01 100644
--- a/quiche/quic/core/quic_crypto_client_handshaker_test.cc
+++ b/quiche/quic/core/quic_crypto_client_handshaker_test.cc
@@ -84,6 +84,7 @@
callback->Run(true, chain, proof, /*details=*/nullptr);
}
+ // TODO: b/451645567 - Define `GetCertChains()` instead of `GetCertChain()`.
quiche::QuicheReferenceCountedPointer<Chain> GetCertChain(
const QuicSocketAddress& /*server_address*/,
const QuicSocketAddress& /*client_address*/,
diff --git a/quiche/quic/core/quic_types.h b/quiche/quic/core/quic_types.h
index f621307..8b554d0 100644
--- a/quiche/quic/core/quic_types.h
+++ b/quiche/quic/core/quic_types.h
@@ -898,6 +898,8 @@
std::optional<ClientCertMode> client_cert_mode;
// QUIC transport parameters as serialized by ProofSourceHandle.
std::optional<std::vector<uint8_t>> quic_transport_parameters;
+
+ bool operator==(const QuicDelayedSSLConfig& other) const = default;
};
// ParsedClientHello contains client hello information extracted from a fully
diff --git a/quiche/quic/core/tls_server_handshaker.cc b/quiche/quic/core/tls_server_handshaker.cc
index e8ebdde..fc9cc4c 100644
--- a/quiche/quic/core/tls_server_handshaker.cc
+++ b/quiche/quic/core/tls_server_handshaker.cc
@@ -14,6 +14,7 @@
#include <variant>
#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
@@ -113,6 +114,26 @@
return QUIC_FAILURE;
}
+ if (handshaker_->DoesOnSelectCertificateDoneExpectChains()) {
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_use_proof_source_get_cert_chains, 1, 2);
+
+ ProofSource::CertChainsResult cert_chains_result =
+ proof_source_->GetCertChains(server_address, client_address, hostname);
+
+ handshaker_->OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/true,
+ ProofSourceHandleCallback::LocalSSLConfig(cert_chains_result.chains,
+ QuicDelayedSSLConfig()),
+ /*ticket_encryption_key=*/absl::string_view(),
+ /*cert_matched_sni=*/cert_chains_result.chains_match_sni);
+ if (!handshaker_->select_cert_status().has_value()) {
+ QUIC_BUG(select_cert_status_valueless_after_sync_select_cert);
+ // Return success to continue the handshake.
+ return QUIC_SUCCESS;
+ }
+ return *handshaker_->select_cert_status();
+ }
+
bool cert_matched_sni;
quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
proof_source_->GetCertChain(server_address, client_address, hostname,
@@ -221,6 +242,8 @@
QuicCryptoServerStreamBase(session),
proof_source_(crypto_config->proof_source()),
proof_verifier_(crypto_config->proof_verifier()),
+ use_proof_source_get_cert_chains_(
+ GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)),
pre_shared_key_(crypto_config->pre_shared_key()),
crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
tls_connection_(crypto_config->ssl_ctx(), this, session->GetSSLConfig()),
@@ -1119,11 +1142,31 @@
if (ok) {
if (auto* local_config = std::get_if<LocalSSLConfig>(&ssl_config);
local_config != nullptr) {
- if (local_config->chain && !local_config->chain->certs.empty()) {
+ if (!use_proof_source_get_cert_chains_ && local_config->chain &&
+ !local_config->chain->certs.empty()) {
tls_connection_.AddCertChain(
local_config->chain->ToCryptoBuffers().value,
local_config->chain->trust_anchor_id);
select_cert_status_ = QUIC_SUCCESS;
+ } else if (use_proof_source_get_cert_chains_ &&
+ !local_config->chains.empty() &&
+ // Cert selection fails when there are no chains with certs.
+ absl::c_any_of(local_config->chains,
+ [](const quiche::QuicheReferenceCountedPointer<
+ ProofSource::Chain> absl_nonnull& chain) {
+ return !chain->certs.empty();
+ })) {
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_use_proof_source_get_cert_chains, 2,
+ 2);
+ for (const quiche::QuicheReferenceCountedPointer<
+ ProofSource::Chain> absl_nonnull& chain :
+ local_config->chains) {
+ if (!chain->certs.empty()) {
+ tls_connection_.AddCertChain(chain->ToCryptoBuffers().value,
+ chain->trust_anchor_id);
+ }
+ }
+ select_cert_status_ = QUIC_SUCCESS;
} else {
QUIC_DLOG(ERROR) << "No certs provided for host '"
<< crypto_negotiated_params_->sni
diff --git a/quiche/quic/core/tls_server_handshaker.h b/quiche/quic/core/tls_server_handshaker.h
index 806fd20..f26be04 100644
--- a/quiche/quic/core/tls_server_handshaker.h
+++ b/quiche/quic/core/tls_server_handshaker.h
@@ -205,6 +205,9 @@
void OnSelectCertificateDone(bool ok, bool is_sync, SSLConfig ssl_config,
absl::string_view ticket_encryption_key,
bool cert_matched_sni) override;
+ bool DoesOnSelectCertificateDoneExpectChains() const override {
+ return use_proof_source_get_cert_chains_;
+ }
void OnComputeSignatureDone(
bool ok, bool is_sync, std::string signature,
@@ -250,8 +253,11 @@
// Close the handle. Cancel the pending signature operation, if any.
void CloseHandle() override;
- // Delegates to proof_source_->GetCertChain.
- // Returns QUIC_SUCCESS or QUIC_FAILURE. Never returns QUIC_PENDING.
+ // Delegates to `proof_source_->GetCertChains()` when
+ // `handshaker_->use_proof_source_get_cert_chains()` is true. Otherwise,
+ // delegates to `proof_source_->GetCertChain()`.
+ //
+ // Returns `QUIC_SUCCESS` or `QUIC_FAILURE`. Never returns `QUIC_PENDING`.
QuicAsyncStatus SelectCertificate(
const QuicSocketAddress& server_address,
const QuicSocketAddress& client_address,
@@ -383,6 +389,10 @@
// True if new ALPS codepoint in the ClientHello.
bool alps_new_codepoint_received_ = false;
+ // The value of the reloadable flag `quic_use_proof_source_get_cert_chains` at
+ // the time of construction.
+ const bool use_proof_source_get_cert_chains_;
+
// nullopt means select cert hasn't started.
std::optional<QuicAsyncStatus> select_cert_status_;
diff --git a/quiche/quic/core/tls_server_handshaker_test.cc b/quiche/quic/core/tls_server_handshaker_test.cc
index be182a8..4a22af7 100644
--- a/quiche/quic/core/tls_server_handshaker_test.cc
+++ b/quiche/quic/core/tls_server_handshaker_test.cc
@@ -43,6 +43,7 @@
#include "quiche/quic/test_tools/quic_config_peer.h"
#include "quiche/quic/test_tools/quic_test_utils.h"
#include "quiche/quic/test_tools/simple_session_cache.h"
+#include "quiche/quic/test_tools/test_certificates.h"
#include "quiche/quic/test_tools/test_ticket_crypter.h"
#include "quiche/common/platform/api/quiche_logging.h"
#include "quiche/common/platform/api/quiche_reference_counted.h"
@@ -53,9 +54,15 @@
} // namespace quic
using testing::_;
+using testing::AllOf;
+using testing::ElementsAre;
+using testing::Eq;
+using testing::Field;
using testing::HasSubstr;
+using testing::IsNull;
using testing::NiceMock;
using testing::Return;
+using testing::VariantWith;
namespace quic {
namespace test {
@@ -68,13 +75,15 @@
struct TestParams {
ParsedQuicVersion version;
bool disable_resumption;
+ bool enable_get_cert_chains;
};
ABSL_ATTRIBUTE_UNUSED // Used by ::testing::PrintToStringParamName().
std::string PrintToString(const TestParams& p) {
- return absl::StrCat(
- ParsedQuicVersionToString(p.version), "_",
- (p.disable_resumption ? "ResumptionDisabled" : "ResumptionEnabled"));
+ return absl::StrCat(ParsedQuicVersionToString(p.version), "AndResumption",
+ (p.disable_resumption ? "Disabled" : "Enabled"),
+ "AndGetCertChains",
+ (p.enable_get_cert_chains ? "Enabled" : "Disabled"));
}
// Constructs test permutations.
@@ -82,7 +91,10 @@
std::vector<TestParams> params;
for (const auto& version : AllSupportedVersionsWithTls()) {
for (bool disable_resumption : {false, true}) {
- params.push_back(TestParams{version, disable_resumption});
+ for (bool enable_get_cert_chains : {false, true}) {
+ params.push_back(
+ TestParams{version, disable_resumption, enable_get_cert_chains});
+ }
}
}
return params;
@@ -131,6 +143,10 @@
ON_CALL(*this, OverrideQuicConfigDefaults(_))
.WillByDefault(testing::Invoke(
this, &TestTlsServerHandshaker::RealOverrideQuicConfigDefaults));
+
+ ON_CALL(*this, OnSelectCertificateDone)
+ .WillByDefault(testing::Invoke(
+ this, &TestTlsServerHandshaker::RealOnSelectCertificateDone));
}
MOCK_METHOD(std::unique_ptr<ProofSourceHandle>, MaybeCreateProofSourceHandle,
@@ -139,6 +155,11 @@
MOCK_METHOD(void, OverrideQuicConfigDefaults, (QuicConfig * config),
(override));
+ MOCK_METHOD(void, OnSelectCertificateDone,
+ (bool, bool, SSLConfig, absl::string_view, bool), (override));
+
+ // Makes the next call to `MaybeCreateProofSourceHandle()` return a
+ // `FakeProofSourceHandle` instead of a real `ProofSourceHandle`.
void SetupProofSourceHandle(
FakeProofSourceHandle::Action select_cert_action,
FakeProofSourceHandle::Action compute_signature_action,
@@ -192,6 +213,14 @@
return TlsServerHandshaker::OverrideQuicConfigDefaults(config);
}
+ void RealOnSelectCertificateDone(bool ok, bool is_sync, SSLConfig ssl_config,
+ absl::string_view ticket_encryption_key,
+ bool cert_matched_sni) {
+ return TlsServerHandshaker::OnSelectCertificateDone(
+ ok, is_sync, std::move(ssl_config), ticket_encryption_key,
+ cert_matched_sni);
+ }
+
// Owned by TlsServerHandshaker.
FakeProofSourceHandle* fake_proof_source_handle_ = nullptr;
ProofSource* proof_source_ = nullptr;
@@ -226,6 +255,9 @@
supported_versions_({GetParam().version}) {
SetQuicFlag(quic_disable_server_tls_resumption,
GetParam().disable_resumption);
+ SetQuicReloadableFlag(quic_use_proof_source_get_cert_chains,
+ GetParam().enable_get_cert_chains);
+
client_crypto_config_ = std::make_unique<QuicCryptoClientConfig>(
crypto_test_utils::ProofVerifierForTesting(),
std::make_unique<test::SimpleSessionCache>());
@@ -725,6 +757,307 @@
EXPECT_FALSE(last_select_cert_args().ssl_config.early_data_enabled);
}
+TEST_P(TlsServerHandshakerTest, SelectCertificateCallsGetCertChains) {
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname));
+ // Although the default implementation of `ProofSource::GetCertChains()`
+ // uses `ProofSource::GetCertChain()`, we're using a `FakeProofSource` that
+ // delegates to a `TestProofSource`, which does not have this behavior.
+ EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+ } else {
+ EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+ EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+ }
+
+ CompleteCryptoHandshake();
+ ExpectHandshakeSuccessful();
+}
+
+TEST_P(TlsServerHandshakerTest, ZeroCertChainsFailsHandshake) {
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+ .WillOnce(Return(ProofSource::CertChainsResult{
+ .chains_match_sni = false,
+ .chains = {},
+ }));
+ EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+ } else {
+ EXPECT_CALL(*proof_source_, GetCertChain)
+ .WillOnce([&](const QuicSocketAddress&, const QuicSocketAddress&,
+ const std::string&, bool* cert_matched_sni)
+ -> quiche::QuicheReferenceCountedPointer<
+ ProofSource::Chain> {
+ *cert_matched_sni = false;
+ return // quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ nullptr;
+ });
+ EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+ }
+
+ AdvanceHandshakeWithFakeClient();
+
+ EXPECT_FALSE(client_stream()->one_rtt_keys_available());
+ EXPECT_FALSE(client_stream()->encryption_established());
+ EXPECT_FALSE(server_stream()->one_rtt_keys_available());
+ EXPECT_FALSE(server_stream()->encryption_established());
+}
+
+// Test that `DefaultProofSourceHandle::SelectCertificate()` uses chains that do
+// not match the SNI when that's all that's available.
+TEST_P(TlsServerHandshakerTest, MultipleCertChainsNotMatchingSni) {
+ InitializeServerWithFakeProofSourceHandle();
+ ASSERT_TRUE(server_handshaker_ != nullptr);
+
+ // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+ // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+ // `DefaultProofSourceHandle`.
+
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ // Return chains that claim not to match SNI. (Whether the certificate in
+ // `quic::test::kTestCertificate` actually matches `kServerHostname` is
+ // irrelevant.)
+ EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+ .WillOnce(Return(ProofSource::CertChainsResult{
+ .chains_match_sni = false,
+ .chains =
+ std::vector<
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>{
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(
+ quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"")),
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(
+ quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"")),
+ },
+ }));
+
+ EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+
+ EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ /*ssl_config=*/_,
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/false));
+ } else {
+ EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+ EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+ EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone);
+ }
+
+ CompleteCryptoHandshake();
+}
+
+// Test that `DefaultProofSourceHandle::SelectCertificate()` passes along chains
+// with and without Trust Anchor IDs .
+TEST_P(TlsServerHandshakerTest, SelectCertificateChainMatchesSni) {
+ // Set the client-advertised Trust Anchor IDs. Each ID is preceded by an 8-bit
+ // length prefix.
+ client_crypto_config_->ssl_config().trust_anchor_ids = "\x03\x11\x22\x33";
+ InitializeFakeClient();
+
+ InitializeServerWithFakeProofSourceHandle();
+ ASSERT_TRUE(server_handshaker_ != nullptr);
+
+ // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+ // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+ // `DefaultProofSourceHandle`.
+
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ using LocalSSLConfig = ProofSourceHandleCallback::LocalSSLConfig;
+
+ std::vector<quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>
+ chains = {
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"\x11\x22\x33")),
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"")),
+ };
+ EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+ .WillOnce(Return(ProofSource::CertChainsResult{
+ .chains_match_sni = true,
+ .chains = chains,
+ }));
+ EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+ // `DefaultProofSourceHandle::SelectCertificate()` should pick only the
+ // chains that claim to match SNI.
+ EXPECT_CALL(*server_handshaker_,
+ OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ VariantWith<LocalSSLConfig>(
+ AllOf(Field(&LocalSSLConfig::chain, IsNull()),
+ Field(&LocalSSLConfig::chains,
+ ElementsAre(chains[0], chains[1])),
+ Field(&LocalSSLConfig::delayed_ssl_config,
+ Eq(QuicDelayedSSLConfig())))),
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/true));
+ } else {
+ EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+ EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+ EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ /*ssl_config=*/_,
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/false));
+ }
+
+ CompleteCryptoHandshake();
+ ExpectHandshakeSuccessful();
+
+ EXPECT_EQ(client_stream()->MatchedTrustAnchorIdForTesting(),
+ GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains));
+}
+
+// Test that `DefaultProofSourceHandle::SelectCertificate()` passes through
+// chains with Trust Anchor IDs when that's all that's available. The handshake
+// will succeed because one of the chains matches the client-advertised Trust
+// Anchor IDs.
+TEST_P(TlsServerHandshakerTest,
+ SelectCertificateNotFirstChainMatchesSniAndTrustAnchorId) {
+ // Set the client-advertised Trust Anchor IDs. Each ID is preceded by an 8-bit
+ // length prefix.
+ client_crypto_config_->ssl_config().trust_anchor_ids = "\x03\x11\x22\x33";
+ InitializeFakeClient();
+
+ InitializeServerWithFakeProofSourceHandle();
+ ASSERT_TRUE(server_handshaker_ != nullptr);
+
+ // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+ // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+ // `DefaultProofSourceHandle`.
+
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ using LocalSSLConfig = ProofSourceHandleCallback::LocalSSLConfig;
+
+ std::vector<quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>
+ chains = {
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"\x07\x08")),
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"\x11\x22\x33")),
+ };
+ EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+ .WillOnce(Return(ProofSource::CertChainsResult{
+ .chains_match_sni = true,
+ .chains = chains,
+ }));
+ EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+ // `DefaultProofSourceHandle::SelectCertificate()` should pick only the
+ // chains that claim to match SNI.
+ EXPECT_CALL(*server_handshaker_,
+ OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ VariantWith<LocalSSLConfig>(
+ AllOf(Field(&LocalSSLConfig::chain, IsNull()),
+ Field(&LocalSSLConfig::chains,
+ ElementsAre(chains[0], chains[1])),
+ Field(&LocalSSLConfig::delayed_ssl_config,
+ Eq(QuicDelayedSSLConfig())))),
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/true));
+ } else {
+ EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+ EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+ EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ /*ssl_config=*/_,
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/false));
+ }
+
+ CompleteCryptoHandshake();
+ ExpectHandshakeSuccessful();
+
+ EXPECT_EQ(client_stream()->MatchedTrustAnchorIdForTesting(),
+ GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains));
+}
+
+// Test that the handshake fails when all chains returned by
+// `ProofSource::GetCertChains()` have Trust Anchor IDs, but none match the
+// client-advertised IDs. This pitfall can be avoided by returning at least one
+// (complete) certificate chain that is not associated with a Trust Anchor ID.
+TEST_P(TlsServerHandshakerTest,
+ SelectCertificateFailsWhenAllChainsHaveMismatchedTrustAnchorIds) {
+ // Set the client-advertised Trust Anchor IDs. Each ID is preceded by an 8-bit
+ // length prefix.
+ client_crypto_config_->ssl_config().trust_anchor_ids = "\x03\x11\x22\x33";
+ InitializeFakeClient();
+
+ InitializeServerWithFakeProofSourceHandle();
+ ASSERT_TRUE(server_handshaker_ != nullptr);
+
+ // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+ // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+ // `DefaultProofSourceHandle`.
+
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ using LocalSSLConfig = ProofSourceHandleCallback::LocalSSLConfig;
+
+ std::vector<quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>
+ chains = {
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"\x07\x08")),
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+ new ProofSource::Chain(
+ /*certs=*/{std::string(quic::test::kTestCertificate)},
+ /*trust_anchor_id=*/"\x42")),
+ };
+ EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+ .WillOnce(Return(ProofSource::CertChainsResult{
+ .chains_match_sni = true,
+ .chains = chains,
+ }));
+ EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+ // `DefaultProofSourceHandle::SelectCertificate()` should pick only the
+ // chains that claim to match SNI.
+ EXPECT_CALL(*server_handshaker_,
+ OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ VariantWith<LocalSSLConfig>(
+ AllOf(Field(&LocalSSLConfig::chain, IsNull()),
+ Field(&LocalSSLConfig::chains,
+ ElementsAre(chains[0], chains[1])),
+ Field(&LocalSSLConfig::delayed_ssl_config,
+ Eq(QuicDelayedSSLConfig())))),
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/true));
+ AdvanceHandshakeWithFakeClient();
+
+ EXPECT_FALSE(client_stream()->one_rtt_keys_available());
+ EXPECT_FALSE(client_stream()->encryption_established());
+ EXPECT_FALSE(server_stream()->one_rtt_keys_available());
+ EXPECT_FALSE(server_stream()->encryption_established());
+
+ } else {
+ EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+ EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+ EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+ /*ok=*/true, /*is_sync=*/_,
+ /*ssl_config=*/_,
+ /*ticket_encryption_key=*/_,
+ /*cert_matched_sni=*/false));
+
+ CompleteCryptoHandshake();
+ ExpectHandshakeSuccessful();
+ }
+
+ EXPECT_FALSE(client_stream()->MatchedTrustAnchorIdForTesting());
+}
+
TEST_P(TlsServerHandshakerTest, ConnectionClosedOnTlsError) {
EXPECT_CALL(*server_connection_,
CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _));
diff --git a/quiche/quic/test_tools/fake_proof_source.cc b/quiche/quic/test_tools/fake_proof_source.cc
index 5ee9032..22262ba 100644
--- a/quiche/quic/test_tools/fake_proof_source.cc
+++ b/quiche/quic/test_tools/fake_proof_source.cc
@@ -4,19 +4,32 @@
#include "quiche/quic/test_tools/fake_proof_source.h"
+#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include "absl/strings/string_view.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/core/quic_versions.h"
#include "quiche/quic/platform/api/quic_logging.h"
+#include "quiche/quic/platform/api/quic_socket_address.h"
+#include "quiche/quic/platform/api/quic_test.h"
#include "quiche/quic/test_tools/crypto_test_utils.h"
+#include "quiche/common/platform/api/quiche_logging.h"
namespace quic {
namespace test {
-FakeProofSource::FakeProofSource()
- : delegate_(crypto_test_utils::ProofSourceForTesting()) {}
+FakeProofSource::FakeProofSource(const std::string& trust_anchor_id)
+ : delegate_(crypto_test_utils::ProofSourceForTesting(trust_anchor_id)) {
+ ON_CALL(*this, GetCertChain)
+ .WillByDefault(
+ testing::Invoke(delegate_.get(), &ProofSource::GetCertChain));
+ ON_CALL(*this, GetCertChains)
+ .WillByDefault(
+ testing::Invoke(delegate_.get(), &ProofSource::GetCertChains));
+}
FakeProofSource::~FakeProofSource() {}
@@ -87,15 +100,6 @@
delegate_.get()));
}
-quiche::QuicheReferenceCountedPointer<ProofSource::Chain>
-FakeProofSource::GetCertChain(const QuicSocketAddress& server_address,
- const QuicSocketAddress& client_address,
- const std::string& hostname,
- bool* cert_matched_sni) {
- return delegate_->GetCertChain(server_address, client_address, hostname,
- cert_matched_sni);
-}
-
void FakeProofSource::ComputeTlsSignature(
const QuicSocketAddress& server_address,
const QuicSocketAddress& client_address, const std::string& hostname,
diff --git a/quiche/quic/test_tools/fake_proof_source.h b/quiche/quic/test_tools/fake_proof_source.h
index d5e9837..de0b79e 100644
--- a/quiche/quic/test_tools/fake_proof_source.h
+++ b/quiche/quic/test_tools/fake_proof_source.h
@@ -5,12 +5,18 @@
#ifndef QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_
#define QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_
+#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "quiche/quic/core/crypto/proof_source.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/core/quic_versions.h"
+#include "quiche/quic/platform/api/quic_socket_address.h"
+#include "quiche/quic/platform/api/quic_test.h"
+#include "quiche/common/platform/api/quiche_reference_counted.h"
namespace quic {
namespace test {
@@ -24,7 +30,7 @@
// FakeProofSource::GetTicketCrypter.
class FakeProofSource : public ProofSource {
public:
- FakeProofSource();
+ explicit FakeProofSource(const std::string& trust_anchor_id = "");
~FakeProofSource() override;
// Before this object is "active", all calls to GetProof will be delegated
@@ -40,10 +46,16 @@
QuicTransportVersion transport_version,
absl::string_view chlo_hash,
std::unique_ptr<ProofSource::Callback> callback) override;
- quiche::QuicheReferenceCountedPointer<Chain> GetCertChain(
- const QuicSocketAddress& server_address,
- const QuicSocketAddress& client_address, const std::string& hostname,
- bool* cert_matched_sni) override;
+ MOCK_METHOD(quiche::QuicheReferenceCountedPointer<Chain>, GetCertChain,
+ (const QuicSocketAddress& server_address,
+ const QuicSocketAddress& client_address,
+ const std::string& hostname, bool* cert_matched_sni),
+ (override));
+ MOCK_METHOD(ProofSource::CertChainsResult, GetCertChains,
+ (const QuicSocketAddress& server_address,
+ const QuicSocketAddress& client_address,
+ const std::string& hostname),
+ (override));
void ComputeTlsSignature(
const QuicSocketAddress& server_address,
const QuicSocketAddress& client_address, const std::string& hostname,
diff --git a/quiche/quic/test_tools/fake_proof_source_handle.cc b/quiche/quic/test_tools/fake_proof_source_handle.cc
index 3ae09f5..0d99eaf 100644
--- a/quiche/quic/test_tools/fake_proof_source_handle.cc
+++ b/quiche/quic/test_tools/fake_proof_source_handle.cc
@@ -14,11 +14,13 @@
#include "absl/base/nullability.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "openssl/base.h"
#include "quiche/quic/core/crypto/proof_source.h"
#include "quiche/quic/core/quic_connection_id.h"
#include "quiche/quic/core/quic_types.h"
#include "quiche/quic/platform/api/quic_bug_tracker.h"
+#include "quiche/quic/platform/api/quic_flags.h"
#include "quiche/quic/platform/api/quic_socket_address.h"
#include "quiche/common/platform/api/quiche_logging.h"
#include "quiche/common/platform/api/quiche_reference_counted.h"
@@ -119,6 +121,19 @@
}
QUICHE_DCHECK(select_cert_action_ == Action::DELEGATE_SYNC);
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ ProofSource::CertChainsResult chains_result =
+ delegate_->GetCertChains(server_address, client_address, hostname);
+ const bool ok = !chains_result.chains.empty();
+ callback_->OnSelectCertificateDone(
+ ok, /*is_sync=*/true,
+ ProofSourceHandleCallback::LocalSSLConfig(
+ absl::MakeConstSpan(chains_result.chains), delayed_ssl_config_),
+ /*ticket_encryption_key=*/absl::string_view(),
+ /*cert_matched_sni=*/chains_result.chains_match_sni);
+ return ok ? QUIC_SUCCESS : QUIC_FAILURE;
+ }
+
bool cert_matched_sni;
quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
delegate_->GetCertChain(server_address, client_address, hostname,
@@ -207,21 +222,37 @@
callback_->OnSelectCertificateDone(
/*ok=*/false,
/*is_sync=*/false,
- ProofSourceHandleCallback::LocalSSLConfig{nullptr, delayed_ssl_config_},
+ callback_->DoesOnSelectCertificateDoneExpectChains()
+ ? ProofSourceHandleCallback::LocalSSLConfig(
+ /*chains=*/{}, delayed_ssl_config_)
+ : ProofSourceHandleCallback::LocalSSLConfig(/*chain=*/nullptr,
+ delayed_ssl_config_),
/*ticket_encryption_key=*/absl::string_view(),
/*cert_matched_sni=*/false);
} else if (action_ == Action::DELEGATE_ASYNC) {
- bool cert_matched_sni;
- quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
- delegate_->GetCertChain(args_.server_address, args_.client_address,
- args_.hostname, &cert_matched_sni);
- bool ok = chain && !chain->certs.empty();
- callback_->OnSelectCertificateDone(
- ok, /*is_sync=*/false,
- ProofSourceHandleCallback::LocalSSLConfig{chain.get(),
- delayed_ssl_config_},
- /*ticket_encryption_key=*/absl::string_view(),
- /*cert_matched_sni=*/cert_matched_sni);
+ if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+ ProofSource::CertChainsResult chains_result = delegate_->GetCertChains(
+ args_.server_address, args_.client_address, args_.hostname);
+ const bool ok = !chains_result.chains.empty();
+ callback_->OnSelectCertificateDone(
+ ok, /*is_sync=*/false,
+ ProofSourceHandleCallback::LocalSSLConfig(
+ absl::MakeConstSpan(chains_result.chains), delayed_ssl_config_),
+ /*ticket_encryption_key=*/absl::string_view(),
+ /*cert_matched_sni=*/chains_result.chains_match_sni);
+ } else {
+ bool cert_matched_sni;
+ quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
+ delegate_->GetCertChain(args_.server_address, args_.client_address,
+ args_.hostname, &cert_matched_sni);
+ bool ok = chain && !chain->certs.empty();
+ callback_->OnSelectCertificateDone(
+ ok, /*is_sync=*/false,
+ ProofSourceHandleCallback::LocalSSLConfig{chain.get(),
+ delayed_ssl_config_},
+ /*ticket_encryption_key=*/absl::string_view(),
+ /*cert_matched_sni=*/cert_matched_sni);
+ }
} else {
QUIC_BUG(quic_bug_10139_1)
<< "Unexpected action: " << static_cast<int>(action_);
diff --git a/quiche/quic/test_tools/quic_test_utils.cc b/quiche/quic/test_tools/quic_test_utils.cc
index 3d1eefe..d47c05e 100644
--- a/quiche/quic/test_tools/quic_test_utils.cc
+++ b/quiche/quic/test_tools/quic_test_utils.cc
@@ -772,7 +772,7 @@
std::optional<QuicSSLConfig> ssl_config)
: QuicSpdyClientSessionBase(connection, nullptr, config,
supported_versions),
- ssl_config_(std::move(ssl_config)) {
+ ssl_config_(std::move(ssl_config).value_or(crypto_config->ssl_config())) {
// TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
// calls in tests and set |has_application_state| to true.
crypto_stream_ = std::make_unique<QuicCryptoClientStream>(