Call GetCertChains() in TlsServerHandshaker

Protected by quic_reloadable_flag_quic_use_proof_source_get_cert_chains.

PiperOrigin-RevId: 823650043
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h
index de45c4a..54c174d 100755
--- a/quiche/common/quiche_feature_flags_list.h
+++ b/quiche/common/quiche_feature_flags_list.h
@@ -60,6 +60,7 @@
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_testonly_default_false, false, false, "A testonly reloadable flag that will always default to false.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_testonly_default_true, true, true, "A testonly reloadable flag that will always default to true.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_inlining_send_buffer2, true, true, "Uses an inlining version of QuicSendStreamBuffer.")
+QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_proof_source_get_cert_chains, false, false, "When true, quic::TlsServerHandshaker will use ProofSource::GetCertChains() instead of ProofSource::GetCertChain()")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_received_client_addresses_cache, true, true, "If true, use a LRU cache to record client addresses of packets received on server's original address.")
 QUICHE_FLAG(bool, quiche_restart_flag_quic_support_release_time_for_gso, false, false, "If true, QuicGsoBatchWriter will support release time if it is available and the process has the permission to do so.")
 QUICHE_FLAG(bool, quiche_restart_flag_quic_testonly_default_false, false, false, "A testonly restart flag that will always default to false.")
diff --git a/quiche/quic/core/crypto/proof_source.h b/quiche/quic/core/crypto/proof_source.h
index 17d2f59..ac4db37 100644
--- a/quiche/quic/core/crypto/proof_source.h
+++ b/quiche/quic/core/crypto/proof_source.h
@@ -10,12 +10,14 @@
 #include <memory>
 #include <optional>
 #include <string>
+#include <utility>
 #include <variant>
 #include <vector>
 
 #include "absl/base/nullability.h"
 #include "absl/status/status.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "openssl/base.h"
 #include "openssl/pool.h"
 #include "openssl/ssl.h"
@@ -291,11 +293,22 @@
   // Configuration to use for configuring the SSL object when handshaking
   // locally.
   struct LocalSSLConfig {
+    using ReferencedCountedChain =
+        quiche::QuicheReferenceCountedPointer<ProofSource::Chain>;
+
+    // TODO: b/451645567 - Remove this constructor.
     LocalSSLConfig(const ProofSource::Chain* absl_nullable chain,
                    QuicDelayedSSLConfig delayed_ssl_config)
         : chain(chain), delayed_ssl_config(delayed_ssl_config) {}
 
+    LocalSSLConfig(absl::Span<const ReferencedCountedChain> chains,
+                   QuicDelayedSSLConfig delayed_ssl_config)
+        : chains(std::move(chains)), delayed_ssl_config(delayed_ssl_config) {}
+
+    // TODO: b/451645567 - Once we remove `ProofSource::GetCertChain()`, we can
+    // delete the `chain` field.
     const ProofSource::Chain* absl_nullable chain = nullptr;
+    absl::Span<const ReferencedCountedChain absl_nonnull> chains;
     QuicDelayedSSLConfig delayed_ssl_config;
   };
 
@@ -331,11 +344,20 @@
   //
   // When called asynchronously(is_sync=false), this method will be responsible
   // to continue the handshake from where it left off.
+  //
+  // Callers that pass a `LocalSSLConfig` in `ssl_config` must use the result of
+  // `DoesOnSelectCertificateDoneExpectChains()` to decide which fields to
+  // populate.
   virtual void OnSelectCertificateDone(bool ok, bool is_sync,
                                        SSLConfig ssl_config,
                                        absl::string_view ticket_encryption_key,
                                        bool cert_matched_sni) = 0;
 
+  // Returns true when `OnSelectCertificateDone()` reads the
+  // `LocalSSLConfig::chains` field. Otherwise, it may read
+  // `LocalSSLConfig::chain`.
+  virtual bool DoesOnSelectCertificateDoneExpectChains() const = 0;
+
   // Called when a ProofSourceHandle::ComputeSignature operation completes.
   virtual void OnComputeSignatureDone(
       bool ok, bool is_sync, std::string signature,
diff --git a/quiche/quic/core/crypto/proof_source_x509.h b/quiche/quic/core/crypto/proof_source_x509.h
index 37b3246..4f6a706 100644
--- a/quiche/quic/core/crypto/proof_source_x509.h
+++ b/quiche/quic/core/crypto/proof_source_x509.h
@@ -34,6 +34,7 @@
                 QuicTransportVersion transport_version,
                 absl::string_view chlo_hash,
                 std::unique_ptr<Callback> callback) override;
+  // TODO: b/451645567 - Define `GetCertChains()` instead of `GetCertChain()`.
   quiche::QuicheReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& server_address,
       const QuicSocketAddress& client_address, const std::string& hostname,
diff --git a/quiche/quic/core/quic_crypto_client_handshaker_test.cc b/quiche/quic/core/quic_crypto_client_handshaker_test.cc
index 69530ac..fc42e01 100644
--- a/quiche/quic/core/quic_crypto_client_handshaker_test.cc
+++ b/quiche/quic/core/quic_crypto_client_handshaker_test.cc
@@ -84,6 +84,7 @@
     callback->Run(true, chain, proof, /*details=*/nullptr);
   }
 
+  // TODO: b/451645567 - Define `GetCertChains()` instead of `GetCertChain()`.
   quiche::QuicheReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& /*server_address*/,
       const QuicSocketAddress& /*client_address*/,
diff --git a/quiche/quic/core/quic_types.h b/quiche/quic/core/quic_types.h
index f621307..8b554d0 100644
--- a/quiche/quic/core/quic_types.h
+++ b/quiche/quic/core/quic_types.h
@@ -898,6 +898,8 @@
   std::optional<ClientCertMode> client_cert_mode;
   // QUIC transport parameters as serialized by ProofSourceHandle.
   std::optional<std::vector<uint8_t>> quic_transport_parameters;
+
+  bool operator==(const QuicDelayedSSLConfig& other) const = default;
 };
 
 // ParsedClientHello contains client hello information extracted from a fully
diff --git a/quiche/quic/core/tls_server_handshaker.cc b/quiche/quic/core/tls_server_handshaker.cc
index e8ebdde..fc9cc4c 100644
--- a/quiche/quic/core/tls_server_handshaker.cc
+++ b/quiche/quic/core/tls_server_handshaker.cc
@@ -14,6 +14,7 @@
 #include <variant>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/base/nullability.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
@@ -113,6 +114,26 @@
     return QUIC_FAILURE;
   }
 
+  if (handshaker_->DoesOnSelectCertificateDoneExpectChains()) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_use_proof_source_get_cert_chains, 1, 2);
+
+    ProofSource::CertChainsResult cert_chains_result =
+        proof_source_->GetCertChains(server_address, client_address, hostname);
+
+    handshaker_->OnSelectCertificateDone(
+        /*ok=*/true, /*is_sync=*/true,
+        ProofSourceHandleCallback::LocalSSLConfig(cert_chains_result.chains,
+                                                  QuicDelayedSSLConfig()),
+        /*ticket_encryption_key=*/absl::string_view(),
+        /*cert_matched_sni=*/cert_chains_result.chains_match_sni);
+    if (!handshaker_->select_cert_status().has_value()) {
+      QUIC_BUG(select_cert_status_valueless_after_sync_select_cert);
+      // Return success to continue the handshake.
+      return QUIC_SUCCESS;
+    }
+    return *handshaker_->select_cert_status();
+  }
+
   bool cert_matched_sni;
   quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
       proof_source_->GetCertChain(server_address, client_address, hostname,
@@ -221,6 +242,8 @@
       QuicCryptoServerStreamBase(session),
       proof_source_(crypto_config->proof_source()),
       proof_verifier_(crypto_config->proof_verifier()),
+      use_proof_source_get_cert_chains_(
+          GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)),
       pre_shared_key_(crypto_config->pre_shared_key()),
       crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
       tls_connection_(crypto_config->ssl_ctx(), this, session->GetSSLConfig()),
@@ -1119,11 +1142,31 @@
   if (ok) {
     if (auto* local_config = std::get_if<LocalSSLConfig>(&ssl_config);
         local_config != nullptr) {
-      if (local_config->chain && !local_config->chain->certs.empty()) {
+      if (!use_proof_source_get_cert_chains_ && local_config->chain &&
+          !local_config->chain->certs.empty()) {
         tls_connection_.AddCertChain(
             local_config->chain->ToCryptoBuffers().value,
             local_config->chain->trust_anchor_id);
         select_cert_status_ = QUIC_SUCCESS;
+      } else if (use_proof_source_get_cert_chains_ &&
+                 !local_config->chains.empty() &&
+                 // Cert selection fails when there are no chains with certs.
+                 absl::c_any_of(local_config->chains,
+                                [](const quiche::QuicheReferenceCountedPointer<
+                                    ProofSource::Chain> absl_nonnull& chain) {
+                                  return !chain->certs.empty();
+                                })) {
+        QUIC_RELOADABLE_FLAG_COUNT_N(quic_use_proof_source_get_cert_chains, 2,
+                                     2);
+        for (const quiche::QuicheReferenceCountedPointer<
+                 ProofSource::Chain> absl_nonnull& chain :
+             local_config->chains) {
+          if (!chain->certs.empty()) {
+            tls_connection_.AddCertChain(chain->ToCryptoBuffers().value,
+                                         chain->trust_anchor_id);
+          }
+        }
+        select_cert_status_ = QUIC_SUCCESS;
       } else {
         QUIC_DLOG(ERROR) << "No certs provided for host '"
                          << crypto_negotiated_params_->sni
diff --git a/quiche/quic/core/tls_server_handshaker.h b/quiche/quic/core/tls_server_handshaker.h
index 806fd20..f26be04 100644
--- a/quiche/quic/core/tls_server_handshaker.h
+++ b/quiche/quic/core/tls_server_handshaker.h
@@ -205,6 +205,9 @@
   void OnSelectCertificateDone(bool ok, bool is_sync, SSLConfig ssl_config,
                                absl::string_view ticket_encryption_key,
                                bool cert_matched_sni) override;
+  bool DoesOnSelectCertificateDoneExpectChains() const override {
+    return use_proof_source_get_cert_chains_;
+  }
 
   void OnComputeSignatureDone(
       bool ok, bool is_sync, std::string signature,
@@ -250,8 +253,11 @@
     // Close the handle. Cancel the pending signature operation, if any.
     void CloseHandle() override;
 
-    // Delegates to proof_source_->GetCertChain.
-    // Returns QUIC_SUCCESS or QUIC_FAILURE. Never returns QUIC_PENDING.
+    // Delegates to `proof_source_->GetCertChains()` when
+    // `handshaker_->use_proof_source_get_cert_chains()` is true. Otherwise,
+    // delegates to `proof_source_->GetCertChain()`.
+    //
+    // Returns `QUIC_SUCCESS` or `QUIC_FAILURE`. Never returns `QUIC_PENDING`.
     QuicAsyncStatus SelectCertificate(
         const QuicSocketAddress& server_address,
         const QuicSocketAddress& client_address,
@@ -383,6 +389,10 @@
   // True if new ALPS codepoint in the ClientHello.
   bool alps_new_codepoint_received_ = false;
 
+  // The value of the reloadable flag `quic_use_proof_source_get_cert_chains` at
+  // the time of construction.
+  const bool use_proof_source_get_cert_chains_;
+
   // nullopt means select cert hasn't started.
   std::optional<QuicAsyncStatus> select_cert_status_;
 
diff --git a/quiche/quic/core/tls_server_handshaker_test.cc b/quiche/quic/core/tls_server_handshaker_test.cc
index be182a8..4a22af7 100644
--- a/quiche/quic/core/tls_server_handshaker_test.cc
+++ b/quiche/quic/core/tls_server_handshaker_test.cc
@@ -43,6 +43,7 @@
 #include "quiche/quic/test_tools/quic_config_peer.h"
 #include "quiche/quic/test_tools/quic_test_utils.h"
 #include "quiche/quic/test_tools/simple_session_cache.h"
+#include "quiche/quic/test_tools/test_certificates.h"
 #include "quiche/quic/test_tools/test_ticket_crypter.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/common/platform/api/quiche_reference_counted.h"
@@ -53,9 +54,15 @@
 }  // namespace quic
 
 using testing::_;
+using testing::AllOf;
+using testing::ElementsAre;
+using testing::Eq;
+using testing::Field;
 using testing::HasSubstr;
+using testing::IsNull;
 using testing::NiceMock;
 using testing::Return;
+using testing::VariantWith;
 
 namespace quic {
 namespace test {
@@ -68,13 +75,15 @@
 struct TestParams {
   ParsedQuicVersion version;
   bool disable_resumption;
+  bool enable_get_cert_chains;
 };
 
 ABSL_ATTRIBUTE_UNUSED  // Used by ::testing::PrintToStringParamName().
     std::string PrintToString(const TestParams& p) {
-  return absl::StrCat(
-      ParsedQuicVersionToString(p.version), "_",
-      (p.disable_resumption ? "ResumptionDisabled" : "ResumptionEnabled"));
+  return absl::StrCat(ParsedQuicVersionToString(p.version), "AndResumption",
+                      (p.disable_resumption ? "Disabled" : "Enabled"),
+                      "AndGetCertChains",
+                      (p.enable_get_cert_chains ? "Enabled" : "Disabled"));
 }
 
 // Constructs test permutations.
@@ -82,7 +91,10 @@
   std::vector<TestParams> params;
   for (const auto& version : AllSupportedVersionsWithTls()) {
     for (bool disable_resumption : {false, true}) {
-      params.push_back(TestParams{version, disable_resumption});
+      for (bool enable_get_cert_chains : {false, true}) {
+        params.push_back(
+            TestParams{version, disable_resumption, enable_get_cert_chains});
+      }
     }
   }
   return params;
@@ -131,6 +143,10 @@
     ON_CALL(*this, OverrideQuicConfigDefaults(_))
         .WillByDefault(testing::Invoke(
             this, &TestTlsServerHandshaker::RealOverrideQuicConfigDefaults));
+
+    ON_CALL(*this, OnSelectCertificateDone)
+        .WillByDefault(testing::Invoke(
+            this, &TestTlsServerHandshaker::RealOnSelectCertificateDone));
   }
 
   MOCK_METHOD(std::unique_ptr<ProofSourceHandle>, MaybeCreateProofSourceHandle,
@@ -139,6 +155,11 @@
   MOCK_METHOD(void, OverrideQuicConfigDefaults, (QuicConfig * config),
               (override));
 
+  MOCK_METHOD(void, OnSelectCertificateDone,
+              (bool, bool, SSLConfig, absl::string_view, bool), (override));
+
+  // Makes the next call to `MaybeCreateProofSourceHandle()` return a
+  // `FakeProofSourceHandle` instead of a real `ProofSourceHandle`.
   void SetupProofSourceHandle(
       FakeProofSourceHandle::Action select_cert_action,
       FakeProofSourceHandle::Action compute_signature_action,
@@ -192,6 +213,14 @@
     return TlsServerHandshaker::OverrideQuicConfigDefaults(config);
   }
 
+  void RealOnSelectCertificateDone(bool ok, bool is_sync, SSLConfig ssl_config,
+                                   absl::string_view ticket_encryption_key,
+                                   bool cert_matched_sni) {
+    return TlsServerHandshaker::OnSelectCertificateDone(
+        ok, is_sync, std::move(ssl_config), ticket_encryption_key,
+        cert_matched_sni);
+  }
+
   // Owned by TlsServerHandshaker.
   FakeProofSourceHandle* fake_proof_source_handle_ = nullptr;
   ProofSource* proof_source_ = nullptr;
@@ -226,6 +255,9 @@
         supported_versions_({GetParam().version}) {
     SetQuicFlag(quic_disable_server_tls_resumption,
                 GetParam().disable_resumption);
+    SetQuicReloadableFlag(quic_use_proof_source_get_cert_chains,
+                          GetParam().enable_get_cert_chains);
+
     client_crypto_config_ = std::make_unique<QuicCryptoClientConfig>(
         crypto_test_utils::ProofVerifierForTesting(),
         std::make_unique<test::SimpleSessionCache>());
@@ -725,6 +757,307 @@
   EXPECT_FALSE(last_select_cert_args().ssl_config.early_data_enabled);
 }
 
+TEST_P(TlsServerHandshakerTest, SelectCertificateCallsGetCertChains) {
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname));
+    // Although the default implementation of `ProofSource::GetCertChains()`
+    // uses `ProofSource::GetCertChain()`, we're using a `FakeProofSource` that
+    // delegates to a `TestProofSource`, which does not have this behavior.
+    EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+  } else {
+    EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+    EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+  }
+
+  CompleteCryptoHandshake();
+  ExpectHandshakeSuccessful();
+}
+
+TEST_P(TlsServerHandshakerTest, ZeroCertChainsFailsHandshake) {
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+        .WillOnce(Return(ProofSource::CertChainsResult{
+            .chains_match_sni = false,
+            .chains = {},
+        }));
+    EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+  } else {
+    EXPECT_CALL(*proof_source_, GetCertChain)
+        .WillOnce([&](const QuicSocketAddress&, const QuicSocketAddress&,
+                      const std::string&, bool* cert_matched_sni)
+                      -> quiche::QuicheReferenceCountedPointer<
+                          ProofSource::Chain> {
+          *cert_matched_sni = false;
+          return  // quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+              nullptr;
+        });
+    EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+  }
+
+  AdvanceHandshakeWithFakeClient();
+
+  EXPECT_FALSE(client_stream()->one_rtt_keys_available());
+  EXPECT_FALSE(client_stream()->encryption_established());
+  EXPECT_FALSE(server_stream()->one_rtt_keys_available());
+  EXPECT_FALSE(server_stream()->encryption_established());
+}
+
+// Test that `DefaultProofSourceHandle::SelectCertificate()` uses chains that do
+// not match the SNI when that's all that's available.
+TEST_P(TlsServerHandshakerTest, MultipleCertChainsNotMatchingSni) {
+  InitializeServerWithFakeProofSourceHandle();
+  ASSERT_TRUE(server_handshaker_ != nullptr);
+
+  // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+  // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+  // `DefaultProofSourceHandle`.
+
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    // Return chains that claim not to match SNI. (Whether the certificate in
+    // `quic::test::kTestCertificate` actually matches `kServerHostname` is
+    // irrelevant.)
+    EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+        .WillOnce(Return(ProofSource::CertChainsResult{
+            .chains_match_sni = false,
+            .chains =
+                std::vector<
+                    quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>{
+                    quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                        new ProofSource::Chain(
+                            /*certs=*/{std::string(
+                                quic::test::kTestCertificate)},
+                            /*trust_anchor_id=*/"")),
+                    quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                        new ProofSource::Chain(
+                            /*certs=*/{std::string(
+                                quic::test::kTestCertificate)},
+                            /*trust_anchor_id=*/"")),
+                },
+        }));
+
+    EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+
+    EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+                                         /*ok=*/true, /*is_sync=*/_,
+                                         /*ssl_config=*/_,
+                                         /*ticket_encryption_key=*/_,
+                                         /*cert_matched_sni=*/false));
+  } else {
+    EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+    EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+    EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone);
+  }
+
+  CompleteCryptoHandshake();
+}
+
+// Test that `DefaultProofSourceHandle::SelectCertificate()` passes along chains
+// with and without Trust Anchor IDs .
+TEST_P(TlsServerHandshakerTest, SelectCertificateChainMatchesSni) {
+  // Set the client-advertised Trust Anchor IDs. Each ID is preceded by an 8-bit
+  // length prefix.
+  client_crypto_config_->ssl_config().trust_anchor_ids = "\x03\x11\x22\x33";
+  InitializeFakeClient();
+
+  InitializeServerWithFakeProofSourceHandle();
+  ASSERT_TRUE(server_handshaker_ != nullptr);
+
+  // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+  // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+  // `DefaultProofSourceHandle`.
+
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    using LocalSSLConfig = ProofSourceHandleCallback::LocalSSLConfig;
+
+    std::vector<quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>
+        chains = {
+            quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                new ProofSource::Chain(
+                    /*certs=*/{std::string(quic::test::kTestCertificate)},
+                    /*trust_anchor_id=*/"\x11\x22\x33")),
+            quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                new ProofSource::Chain(
+                    /*certs=*/{std::string(quic::test::kTestCertificate)},
+                    /*trust_anchor_id=*/"")),
+        };
+    EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+        .WillOnce(Return(ProofSource::CertChainsResult{
+            .chains_match_sni = true,
+            .chains = chains,
+        }));
+    EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+    // `DefaultProofSourceHandle::SelectCertificate()` should pick only the
+    // chains that claim to match SNI.
+    EXPECT_CALL(*server_handshaker_,
+                OnSelectCertificateDone(
+                    /*ok=*/true, /*is_sync=*/_,
+                    VariantWith<LocalSSLConfig>(
+                        AllOf(Field(&LocalSSLConfig::chain, IsNull()),
+                              Field(&LocalSSLConfig::chains,
+                                    ElementsAre(chains[0], chains[1])),
+                              Field(&LocalSSLConfig::delayed_ssl_config,
+                                    Eq(QuicDelayedSSLConfig())))),
+                    /*ticket_encryption_key=*/_,
+                    /*cert_matched_sni=*/true));
+  } else {
+    EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+    EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+    EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+                                         /*ok=*/true, /*is_sync=*/_,
+                                         /*ssl_config=*/_,
+                                         /*ticket_encryption_key=*/_,
+                                         /*cert_matched_sni=*/false));
+  }
+
+  CompleteCryptoHandshake();
+  ExpectHandshakeSuccessful();
+
+  EXPECT_EQ(client_stream()->MatchedTrustAnchorIdForTesting(),
+            GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains));
+}
+
+// Test that `DefaultProofSourceHandle::SelectCertificate()` passes through
+// chains with Trust Anchor IDs when that's all that's available. The handshake
+// will succeed because one of the chains matches the client-advertised Trust
+// Anchor IDs.
+TEST_P(TlsServerHandshakerTest,
+       SelectCertificateNotFirstChainMatchesSniAndTrustAnchorId) {
+  // Set the client-advertised Trust Anchor IDs. Each ID is preceded by an 8-bit
+  // length prefix.
+  client_crypto_config_->ssl_config().trust_anchor_ids = "\x03\x11\x22\x33";
+  InitializeFakeClient();
+
+  InitializeServerWithFakeProofSourceHandle();
+  ASSERT_TRUE(server_handshaker_ != nullptr);
+
+  // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+  // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+  // `DefaultProofSourceHandle`.
+
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    using LocalSSLConfig = ProofSourceHandleCallback::LocalSSLConfig;
+
+    std::vector<quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>
+        chains = {
+            quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                new ProofSource::Chain(
+                    /*certs=*/{std::string(quic::test::kTestCertificate)},
+                    /*trust_anchor_id=*/"\x07\x08")),
+            quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                new ProofSource::Chain(
+                    /*certs=*/{std::string(quic::test::kTestCertificate)},
+                    /*trust_anchor_id=*/"\x11\x22\x33")),
+        };
+    EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+        .WillOnce(Return(ProofSource::CertChainsResult{
+            .chains_match_sni = true,
+            .chains = chains,
+        }));
+    EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+    // `DefaultProofSourceHandle::SelectCertificate()` should pick only the
+    // chains that claim to match SNI.
+    EXPECT_CALL(*server_handshaker_,
+                OnSelectCertificateDone(
+                    /*ok=*/true, /*is_sync=*/_,
+                    VariantWith<LocalSSLConfig>(
+                        AllOf(Field(&LocalSSLConfig::chain, IsNull()),
+                              Field(&LocalSSLConfig::chains,
+                                    ElementsAre(chains[0], chains[1])),
+                              Field(&LocalSSLConfig::delayed_ssl_config,
+                                    Eq(QuicDelayedSSLConfig())))),
+                    /*ticket_encryption_key=*/_,
+                    /*cert_matched_sni=*/true));
+  } else {
+    EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+    EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+    EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+                                         /*ok=*/true, /*is_sync=*/_,
+                                         /*ssl_config=*/_,
+                                         /*ticket_encryption_key=*/_,
+                                         /*cert_matched_sni=*/false));
+  }
+
+  CompleteCryptoHandshake();
+  ExpectHandshakeSuccessful();
+
+  EXPECT_EQ(client_stream()->MatchedTrustAnchorIdForTesting(),
+            GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains));
+}
+
+// Test that the handshake fails when all chains returned by
+// `ProofSource::GetCertChains()` have Trust Anchor IDs, but none match the
+// client-advertised IDs. This pitfall can be avoided by returning at least one
+// (complete) certificate chain that is not associated with a Trust Anchor ID.
+TEST_P(TlsServerHandshakerTest,
+       SelectCertificateFailsWhenAllChainsHaveMismatchedTrustAnchorIds) {
+  // Set the client-advertised Trust Anchor IDs. Each ID is preceded by an 8-bit
+  // length prefix.
+  client_crypto_config_->ssl_config().trust_anchor_ids = "\x03\x11\x22\x33";
+  InitializeFakeClient();
+
+  InitializeServerWithFakeProofSourceHandle();
+  ASSERT_TRUE(server_handshaker_ != nullptr);
+
+  // Note that if we called `server_handshaker_->SetupProofSourceHandle()` here,
+  // this test would incorrectly exercise `FakeProofSourceHandle` instead of
+  // `DefaultProofSourceHandle`.
+
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    using LocalSSLConfig = ProofSourceHandleCallback::LocalSSLConfig;
+
+    std::vector<quiche::QuicheReferenceCountedPointer<ProofSource::Chain>>
+        chains = {
+            quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                new ProofSource::Chain(
+                    /*certs=*/{std::string(quic::test::kTestCertificate)},
+                    /*trust_anchor_id=*/"\x07\x08")),
+            quiche::QuicheReferenceCountedPointer<ProofSource::Chain>(
+                new ProofSource::Chain(
+                    /*certs=*/{std::string(quic::test::kTestCertificate)},
+                    /*trust_anchor_id=*/"\x42")),
+        };
+    EXPECT_CALL(*proof_source_, GetCertChains(_, _, kServerHostname))
+        .WillOnce(Return(ProofSource::CertChainsResult{
+            .chains_match_sni = true,
+            .chains = chains,
+        }));
+    EXPECT_CALL(*proof_source_, GetCertChain).Times(0);
+    // `DefaultProofSourceHandle::SelectCertificate()` should pick only the
+    // chains that claim to match SNI.
+    EXPECT_CALL(*server_handshaker_,
+                OnSelectCertificateDone(
+                    /*ok=*/true, /*is_sync=*/_,
+                    VariantWith<LocalSSLConfig>(
+                        AllOf(Field(&LocalSSLConfig::chain, IsNull()),
+                              Field(&LocalSSLConfig::chains,
+                                    ElementsAre(chains[0], chains[1])),
+                              Field(&LocalSSLConfig::delayed_ssl_config,
+                                    Eq(QuicDelayedSSLConfig())))),
+                    /*ticket_encryption_key=*/_,
+                    /*cert_matched_sni=*/true));
+    AdvanceHandshakeWithFakeClient();
+
+    EXPECT_FALSE(client_stream()->one_rtt_keys_available());
+    EXPECT_FALSE(client_stream()->encryption_established());
+    EXPECT_FALSE(server_stream()->one_rtt_keys_available());
+    EXPECT_FALSE(server_stream()->encryption_established());
+
+  } else {
+    EXPECT_CALL(*proof_source_, GetCertChains).Times(0);
+    EXPECT_CALL(*proof_source_, GetCertChain(_, _, kServerHostname, _));
+    EXPECT_CALL(*server_handshaker_, OnSelectCertificateDone(
+                                         /*ok=*/true, /*is_sync=*/_,
+                                         /*ssl_config=*/_,
+                                         /*ticket_encryption_key=*/_,
+                                         /*cert_matched_sni=*/false));
+
+    CompleteCryptoHandshake();
+    ExpectHandshakeSuccessful();
+  }
+
+  EXPECT_FALSE(client_stream()->MatchedTrustAnchorIdForTesting());
+}
+
 TEST_P(TlsServerHandshakerTest, ConnectionClosedOnTlsError) {
   EXPECT_CALL(*server_connection_,
               CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _));
diff --git a/quiche/quic/test_tools/fake_proof_source.cc b/quiche/quic/test_tools/fake_proof_source.cc
index 5ee9032..22262ba 100644
--- a/quiche/quic/test_tools/fake_proof_source.cc
+++ b/quiche/quic/test_tools/fake_proof_source.cc
@@ -4,19 +4,32 @@
 
 #include "quiche/quic/test_tools/fake_proof_source.h"
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <utility>
 
 #include "absl/strings/string_view.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/core/quic_versions.h"
 #include "quiche/quic/platform/api/quic_logging.h"
+#include "quiche/quic/platform/api/quic_socket_address.h"
+#include "quiche/quic/platform/api/quic_test.h"
 #include "quiche/quic/test_tools/crypto_test_utils.h"
+#include "quiche/common/platform/api/quiche_logging.h"
 
 namespace quic {
 namespace test {
 
-FakeProofSource::FakeProofSource()
-    : delegate_(crypto_test_utils::ProofSourceForTesting()) {}
+FakeProofSource::FakeProofSource(const std::string& trust_anchor_id)
+    : delegate_(crypto_test_utils::ProofSourceForTesting(trust_anchor_id)) {
+  ON_CALL(*this, GetCertChain)
+      .WillByDefault(
+          testing::Invoke(delegate_.get(), &ProofSource::GetCertChain));
+  ON_CALL(*this, GetCertChains)
+      .WillByDefault(
+          testing::Invoke(delegate_.get(), &ProofSource::GetCertChains));
+}
 
 FakeProofSource::~FakeProofSource() {}
 
@@ -87,15 +100,6 @@
       delegate_.get()));
 }
 
-quiche::QuicheReferenceCountedPointer<ProofSource::Chain>
-FakeProofSource::GetCertChain(const QuicSocketAddress& server_address,
-                              const QuicSocketAddress& client_address,
-                              const std::string& hostname,
-                              bool* cert_matched_sni) {
-  return delegate_->GetCertChain(server_address, client_address, hostname,
-                                 cert_matched_sni);
-}
-
 void FakeProofSource::ComputeTlsSignature(
     const QuicSocketAddress& server_address,
     const QuicSocketAddress& client_address, const std::string& hostname,
diff --git a/quiche/quic/test_tools/fake_proof_source.h b/quiche/quic/test_tools/fake_proof_source.h
index d5e9837..de0b79e 100644
--- a/quiche/quic/test_tools/fake_proof_source.h
+++ b/quiche/quic/test_tools/fake_proof_source.h
@@ -5,12 +5,18 @@
 #ifndef QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_
 #define QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <vector>
 
 #include "absl/strings/string_view.h"
 #include "quiche/quic/core/crypto/proof_source.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/core/quic_versions.h"
+#include "quiche/quic/platform/api/quic_socket_address.h"
+#include "quiche/quic/platform/api/quic_test.h"
+#include "quiche/common/platform/api/quiche_reference_counted.h"
 
 namespace quic {
 namespace test {
@@ -24,7 +30,7 @@
 // FakeProofSource::GetTicketCrypter.
 class FakeProofSource : public ProofSource {
  public:
-  FakeProofSource();
+  explicit FakeProofSource(const std::string& trust_anchor_id = "");
   ~FakeProofSource() override;
 
   // Before this object is "active", all calls to GetProof will be delegated
@@ -40,10 +46,16 @@
                 QuicTransportVersion transport_version,
                 absl::string_view chlo_hash,
                 std::unique_ptr<ProofSource::Callback> callback) override;
-  quiche::QuicheReferenceCountedPointer<Chain> GetCertChain(
-      const QuicSocketAddress& server_address,
-      const QuicSocketAddress& client_address, const std::string& hostname,
-      bool* cert_matched_sni) override;
+  MOCK_METHOD(quiche::QuicheReferenceCountedPointer<Chain>, GetCertChain,
+              (const QuicSocketAddress& server_address,
+               const QuicSocketAddress& client_address,
+               const std::string& hostname, bool* cert_matched_sni),
+              (override));
+  MOCK_METHOD(ProofSource::CertChainsResult, GetCertChains,
+              (const QuicSocketAddress& server_address,
+               const QuicSocketAddress& client_address,
+               const std::string& hostname),
+              (override));
   void ComputeTlsSignature(
       const QuicSocketAddress& server_address,
       const QuicSocketAddress& client_address, const std::string& hostname,
diff --git a/quiche/quic/test_tools/fake_proof_source_handle.cc b/quiche/quic/test_tools/fake_proof_source_handle.cc
index 3ae09f5..0d99eaf 100644
--- a/quiche/quic/test_tools/fake_proof_source_handle.cc
+++ b/quiche/quic/test_tools/fake_proof_source_handle.cc
@@ -14,11 +14,13 @@
 
 #include "absl/base/nullability.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "openssl/base.h"
 #include "quiche/quic/core/crypto/proof_source.h"
 #include "quiche/quic/core/quic_connection_id.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/platform/api/quic_bug_tracker.h"
+#include "quiche/quic/platform/api/quic_flags.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/common/platform/api/quiche_reference_counted.h"
@@ -119,6 +121,19 @@
   }
 
   QUICHE_DCHECK(select_cert_action_ == Action::DELEGATE_SYNC);
+  if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+    ProofSource::CertChainsResult chains_result =
+        delegate_->GetCertChains(server_address, client_address, hostname);
+    const bool ok = !chains_result.chains.empty();
+    callback_->OnSelectCertificateDone(
+        ok, /*is_sync=*/true,
+        ProofSourceHandleCallback::LocalSSLConfig(
+            absl::MakeConstSpan(chains_result.chains), delayed_ssl_config_),
+        /*ticket_encryption_key=*/absl::string_view(),
+        /*cert_matched_sni=*/chains_result.chains_match_sni);
+    return ok ? QUIC_SUCCESS : QUIC_FAILURE;
+  }
+
   bool cert_matched_sni;
   quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
       delegate_->GetCertChain(server_address, client_address, hostname,
@@ -207,21 +222,37 @@
     callback_->OnSelectCertificateDone(
         /*ok=*/false,
         /*is_sync=*/false,
-        ProofSourceHandleCallback::LocalSSLConfig{nullptr, delayed_ssl_config_},
+        callback_->DoesOnSelectCertificateDoneExpectChains()
+            ? ProofSourceHandleCallback::LocalSSLConfig(
+                  /*chains=*/{}, delayed_ssl_config_)
+            : ProofSourceHandleCallback::LocalSSLConfig(/*chain=*/nullptr,
+                                                        delayed_ssl_config_),
         /*ticket_encryption_key=*/absl::string_view(),
         /*cert_matched_sni=*/false);
   } else if (action_ == Action::DELEGATE_ASYNC) {
-    bool cert_matched_sni;
-    quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
-        delegate_->GetCertChain(args_.server_address, args_.client_address,
-                                args_.hostname, &cert_matched_sni);
-    bool ok = chain && !chain->certs.empty();
-    callback_->OnSelectCertificateDone(
-        ok, /*is_sync=*/false,
-        ProofSourceHandleCallback::LocalSSLConfig{chain.get(),
-                                                  delayed_ssl_config_},
-        /*ticket_encryption_key=*/absl::string_view(),
-        /*cert_matched_sni=*/cert_matched_sni);
+    if (GetQuicReloadableFlag(quic_use_proof_source_get_cert_chains)) {
+      ProofSource::CertChainsResult chains_result = delegate_->GetCertChains(
+          args_.server_address, args_.client_address, args_.hostname);
+      const bool ok = !chains_result.chains.empty();
+      callback_->OnSelectCertificateDone(
+          ok, /*is_sync=*/false,
+          ProofSourceHandleCallback::LocalSSLConfig(
+              absl::MakeConstSpan(chains_result.chains), delayed_ssl_config_),
+          /*ticket_encryption_key=*/absl::string_view(),
+          /*cert_matched_sni=*/chains_result.chains_match_sni);
+    } else {
+      bool cert_matched_sni;
+      quiche::QuicheReferenceCountedPointer<ProofSource::Chain> chain =
+          delegate_->GetCertChain(args_.server_address, args_.client_address,
+                                  args_.hostname, &cert_matched_sni);
+      bool ok = chain && !chain->certs.empty();
+      callback_->OnSelectCertificateDone(
+          ok, /*is_sync=*/false,
+          ProofSourceHandleCallback::LocalSSLConfig{chain.get(),
+                                                    delayed_ssl_config_},
+          /*ticket_encryption_key=*/absl::string_view(),
+          /*cert_matched_sni=*/cert_matched_sni);
+    }
   } else {
     QUIC_BUG(quic_bug_10139_1)
         << "Unexpected action: " << static_cast<int>(action_);
diff --git a/quiche/quic/test_tools/quic_test_utils.cc b/quiche/quic/test_tools/quic_test_utils.cc
index 3d1eefe..d47c05e 100644
--- a/quiche/quic/test_tools/quic_test_utils.cc
+++ b/quiche/quic/test_tools/quic_test_utils.cc
@@ -772,7 +772,7 @@
     std::optional<QuicSSLConfig> ssl_config)
     : QuicSpdyClientSessionBase(connection, nullptr, config,
                                 supported_versions),
-      ssl_config_(std::move(ssl_config)) {
+      ssl_config_(std::move(ssl_config).value_or(crypto_config->ssl_config())) {
   // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
   // calls in tests and set |has_application_state| to true.
   crypto_stream_ = std::make_unique<QuicCryptoClientStream>(