Internal change

PiperOrigin-RevId: 395525828
diff --git a/quic/core/crypto/proof_source.h b/quic/core/crypto/proof_source.h
index b0a2446..c4a1f2c 100644
--- a/quic/core/crypto/proof_source.h
+++ b/quic/core/crypto/proof_source.h
@@ -145,10 +145,13 @@
                         std::unique_ptr<Callback> callback) = 0;
 
   // Returns the certificate chain for |hostname| in leaf-first order.
+  //
+  // Sets *cert_matched_sni to true if the certificate matched the given
+  // hostname, false if a default cert not matching the hostname was used.
   virtual QuicReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& server_address,
-      const QuicSocketAddress& client_address,
-      const std::string& hostname) = 0;
+      const QuicSocketAddress& client_address, const std::string& hostname,
+      bool* cert_matched_sni) = 0;
 
   // Computes a signature using the private key of the certificate for
   // |hostname|. The value in |in| is signed using the algorithm specified by
@@ -247,13 +250,16 @@
   //      SSL_set_handshake_hints.
   // |ticket_encryption_key| (optional) encryption key to be used for minting
   //      TLS resumption tickets.
+  // |cert_matched_sni| is true if the certificate matched the SNI hostname,
+  //      false if a non-matching default cert was used.
   //
   // When called asynchronously(is_sync=false), this method will be responsible
   // to continue the handshake from where it left off.
-  virtual void OnSelectCertificateDone(
-      bool ok, bool is_sync, const ProofSource::Chain* chain,
-      absl::string_view handshake_hints,
-      absl::string_view ticket_encryption_key) = 0;
+  virtual void OnSelectCertificateDone(bool ok, bool is_sync,
+                                       const ProofSource::Chain* chain,
+                                       absl::string_view handshake_hints,
+                                       absl::string_view ticket_encryption_key,
+                                       bool cert_matched_sni) = 0;
 
   // Called when a ProofSourceHandle::ComputeSignature operation completes.
   virtual void OnComputeSignatureDone(
diff --git a/quic/core/crypto/proof_source_x509.cc b/quic/core/crypto/proof_source_x509.cc
index c7acd00..986ee32 100644
--- a/quic/core/crypto/proof_source_x509.cc
+++ b/quic/core/crypto/proof_source_x509.cc
@@ -53,7 +53,7 @@
     return;
   }
 
-  Certificate* certificate = GetCertificate(hostname);
+  Certificate* certificate = GetCertificate(hostname, &proof.cert_matched_sni);
   proof.signature =
       certificate->key.Sign(absl::string_view(payload.get(), payload_size),
                             SSL_SIGN_RSA_PSS_RSAE_SHA256);
@@ -63,9 +63,9 @@
 
 QuicReferenceCountedPointer<ProofSource::Chain> ProofSourceX509::GetCertChain(
     const QuicSocketAddress& /*server_address*/,
-    const QuicSocketAddress& /*client_address*/,
-    const std::string& hostname) {
-  return GetCertificate(hostname)->chain;
+    const QuicSocketAddress& /*client_address*/, const std::string& hostname,
+    bool* cert_matched_sni) {
+  return GetCertificate(hostname, cert_matched_sni)->chain;
 }
 
 void ProofSourceX509::ComputeTlsSignature(
@@ -75,8 +75,9 @@
     uint16_t signature_algorithm,
     absl::string_view in,
     std::unique_ptr<ProofSource::SignatureCallback> callback) {
-  std::string signature =
-      GetCertificate(hostname)->key.Sign(in, signature_algorithm);
+  bool cert_matched_sni;
+  std::string signature = GetCertificate(hostname, &cert_matched_sni)
+                              ->key.Sign(in, signature_algorithm);
   callback->Run(/*ok=*/!signature.empty(), signature, nullptr);
 }
 
@@ -125,9 +126,10 @@
 }
 
 ProofSourceX509::Certificate* ProofSourceX509::GetCertificate(
-    const std::string& hostname) const {
+    const std::string& hostname, bool* cert_matched_sni) const {
   auto it = certificate_map_.find(hostname);
   if (it != certificate_map_.end()) {
+    *cert_matched_sni = true;
     return it->second;
   }
   auto dot_pos = hostname.find('.');
@@ -135,9 +137,11 @@
     std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos));
     it = certificate_map_.find(wildcard);
     if (it != certificate_map_.end()) {
+      *cert_matched_sni = true;
       return it->second;
     }
   }
+  *cert_matched_sni = false;
   return default_certificate_;
 }
 
diff --git a/quic/core/crypto/proof_source_x509.h b/quic/core/crypto/proof_source_x509.h
index 9ac6769..4b7ae51 100644
--- a/quic/core/crypto/proof_source_x509.h
+++ b/quic/core/crypto/proof_source_x509.h
@@ -37,8 +37,8 @@
                 std::unique_ptr<Callback> callback) override;
   QuicReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& server_address,
-      const QuicSocketAddress& client_address,
-      const std::string& hostname) override;
+      const QuicSocketAddress& client_address, const std::string& hostname,
+      bool* cert_matched_sni) override;
   void ComputeTlsSignature(
       const QuicSocketAddress& server_address,
       const QuicSocketAddress& client_address, const std::string& hostname,
@@ -65,7 +65,8 @@
 
   // Looks up certficiate for hostname, returns the default if no certificate is
   // found.
-  Certificate* GetCertificate(const std::string& hostname) const;
+  Certificate* GetCertificate(const std::string& hostname,
+                              bool* cert_matched_sni) const;
 
   std::forward_list<Certificate> certificates_;
   Certificate* default_certificate_;
diff --git a/quic/core/crypto/proof_source_x509_test.cc b/quic/core/crypto/proof_source_x509_test.cc
index 0f5f4f2..1a9462f 100644
--- a/quic/core/crypto/proof_source_x509_test.cc
+++ b/quic/core/crypto/proof_source_x509_test.cc
@@ -71,40 +71,47 @@
                                                 std::move(*wildcard_key_)));
 
   // Default certificate.
+  bool cert_matched_sni;
   EXPECT_EQ(proof_source
                 ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(),
-                               "unknown.test")
+                               "unknown.test", &cert_matched_sni)
                 ->certs[0],
             kTestCertificate);
+  EXPECT_FALSE(cert_matched_sni);
   // mail.example.org is explicitly a SubjectAltName in kTestCertificate.
   EXPECT_EQ(proof_source
                 ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(),
-                               "mail.example.org")
+                               "mail.example.org", &cert_matched_sni)
                 ->certs[0],
             kTestCertificate);
+  EXPECT_TRUE(cert_matched_sni);
   // www.foo.test is in kWildcardCertificate.
   EXPECT_EQ(proof_source
                 ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(),
-                               "www.foo.test")
+                               "www.foo.test", &cert_matched_sni)
                 ->certs[0],
             kWildcardCertificate);
+  EXPECT_TRUE(cert_matched_sni);
   // *.wildcard.test is in kWildcardCertificate.
   EXPECT_EQ(proof_source
                 ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(),
-                               "www.wildcard.test")
+                               "www.wildcard.test", &cert_matched_sni)
                 ->certs[0],
             kWildcardCertificate);
+  EXPECT_TRUE(cert_matched_sni);
   EXPECT_EQ(proof_source
                 ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(),
-                               "etc.wildcard.test")
+                               "etc.wildcard.test", &cert_matched_sni)
                 ->certs[0],
             kWildcardCertificate);
+  EXPECT_TRUE(cert_matched_sni);
   // wildcard.test itself is not in kWildcardCertificate.
   EXPECT_EQ(proof_source
                 ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(),
-                               "wildcard.test")
+                               "wildcard.test", &cert_matched_sni)
                 ->certs[0],
             kTestCertificate);
+  EXPECT_FALSE(cert_matched_sni);
 }
 
 TEST_F(ProofSourceX509Test, TlsSignature) {
diff --git a/quic/core/crypto/quic_crypto_proof.cc b/quic/core/crypto/quic_crypto_proof.cc
index 033ef32..1e5f3c4 100644
--- a/quic/core/crypto/quic_crypto_proof.cc
+++ b/quic/core/crypto/quic_crypto_proof.cc
@@ -6,6 +6,7 @@
 
 namespace quic {
 
-QuicCryptoProof::QuicCryptoProof() : send_expect_ct_header(false) {}
+QuicCryptoProof::QuicCryptoProof()
+    : send_expect_ct_header(false), cert_matched_sni(false) {}
 
 }  // namespace quic
diff --git a/quic/core/crypto/quic_crypto_proof.h b/quic/core/crypto/quic_crypto_proof.h
index 6ca87fb..53e0961 100644
--- a/quic/core/crypto/quic_crypto_proof.h
+++ b/quic/core/crypto/quic_crypto_proof.h
@@ -22,6 +22,9 @@
   // Should the Expect-CT header be sent on the connection where the
   // certificate is used.
   bool send_expect_ct_header;
+  // Did the selected leaf certificate contain a SubjectAltName that included
+  // the requested SNI.
+  bool cert_matched_sni;
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_crypto_client_handshaker_test.cc b/quic/core/quic_crypto_client_handshaker_test.cc
index ed5da6c..120e5da 100644
--- a/quic/core/quic_crypto_client_handshaker_test.cc
+++ b/quic/core/quic_crypto_client_handshaker_test.cc
@@ -77,18 +77,20 @@
                 QuicTransportVersion /*transport_version*/,
                 absl::string_view /*chlo_hash*/,
                 std::unique_ptr<Callback> callback) override {
-    QuicReferenceCountedPointer<ProofSource::Chain> chain =
-        GetCertChain(server_address, client_address, hostname);
+    bool cert_matched_sni;
+    QuicReferenceCountedPointer<ProofSource::Chain> chain = GetCertChain(
+        server_address, client_address, hostname, &cert_matched_sni);
     QuicCryptoProof proof;
     proof.signature = "Dummy signature";
     proof.leaf_cert_scts = "Dummy timestamp";
+    proof.cert_matched_sni = cert_matched_sni;
     callback->Run(true, chain, proof, /*details=*/nullptr);
   }
 
   QuicReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& /*server_address*/,
       const QuicSocketAddress& /*client_address*/,
-      const std::string& /*hostname*/) override {
+      const std::string& /*hostname*/, bool* /*cert_matched_sni*/) override {
     std::vector<std::string> certs;
     certs.push_back("Dummy cert");
     return QuicReferenceCountedPointer<ProofSource::Chain>(
diff --git a/quic/core/quic_crypto_server_stream.cc b/quic/core/quic_crypto_server_stream.cc
index 9bd9c7b..461c58f 100644
--- a/quic/core/quic_crypto_server_stream.cc
+++ b/quic/core/quic_crypto_server_stream.cc
@@ -352,6 +352,10 @@
   return signed_config_->proof.send_expect_ct_header;
 }
 
+bool QuicCryptoServerStream::DidCertMatchSni() const {
+  return signed_config_->proof.cert_matched_sni;
+}
+
 const ProofSource::Details* QuicCryptoServerStream::ProofSourceDetails() const {
   return proof_source_details_.get();
 }
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h
index 911ca7e..690e872 100644
--- a/quic/core/quic_crypto_server_stream.h
+++ b/quic/core/quic_crypto_server_stream.h
@@ -51,6 +51,7 @@
   std::string GetAddressToken() const override;
   bool ValidateAddressToken(absl::string_view token) const override;
   bool ShouldSendExpectCTHeader() const override;
+  bool DidCertMatchSni() const override;
   const ProofSource::Details* ProofSourceDetails() const override;
 
   // From QuicCryptoStream
diff --git a/quic/core/quic_crypto_server_stream_base.h b/quic/core/quic_crypto_server_stream_base.h
index 3ff9c37..d2b9f15 100644
--- a/quic/core/quic_crypto_server_stream_base.h
+++ b/quic/core/quic_crypto_server_stream_base.h
@@ -85,6 +85,9 @@
   // configuration for the certificate used in the connection is accessible.
   virtual bool ShouldSendExpectCTHeader() const = 0;
 
+  // Return true if a cert was picked that matched the SNI hostname.
+  virtual bool DidCertMatchSni() const = 0;
+
   // Returns the Details from the latest call to ProofSource::GetProof or
   // ProofSource::ComputeTlsSignature. Returns nullptr if no such call has been
   // made. The Details are owned by the QuicCryptoServerStreamBase and the
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 514c462..ba7aedb 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -77,13 +77,15 @@
     return QUIC_FAILURE;
   }
 
+  bool cert_matched_sni;
   QuicReferenceCountedPointer<ProofSource::Chain> chain =
-      proof_source_->GetCertChain(server_address, client_address, hostname);
+      proof_source_->GetCertChain(server_address, client_address, hostname,
+                                  &cert_matched_sni);
 
   handshaker_->OnSelectCertificateDone(
       /*ok=*/true, /*is_sync=*/true, chain.get(),
       /*handshake_hints=*/absl::string_view(),
-      /*ticket_encryption_key=*/absl::string_view());
+      /*ticket_encryption_key=*/absl::string_view(), cert_matched_sni);
   if (!handshaker_->select_cert_status().has_value()) {
     QUIC_BUG(quic_bug_12423_1)
         << "select_cert_status() has no value after a synchronous select cert";
@@ -357,6 +359,8 @@
   return false;
 }
 
+bool TlsServerHandshaker::DidCertMatchSni() const { return cert_matched_sni_; }
+
 const ProofSource::Details* TlsServerHandshaker::ProofSourceDetails() const {
   return proof_source_details_.get();
 }
@@ -956,8 +960,8 @@
 
 void TlsServerHandshaker::OnSelectCertificateDone(
     bool ok, bool is_sync, const ProofSource::Chain* chain,
-    absl::string_view handshake_hints,
-    absl::string_view ticket_encryption_key) {
+    absl::string_view handshake_hints, absl::string_view ticket_encryption_key,
+    bool cert_matched_sni) {
   QUIC_DVLOG(1) << "OnSelectCertificateDone. ok:" << ok
                 << ", is_sync:" << is_sync
                 << ", len(handshake_hints):" << handshake_hints.size()
@@ -979,6 +983,7 @@
   }
   ticket_encryption_key_ = std::string(ticket_encryption_key);
   select_cert_status_ = QUIC_FAILURE;
+  cert_matched_sni_ = cert_matched_sni;
   if (ok) {
     if (chain && !chain->certs.empty()) {
       tls_connection_.SetCertChain(chain->ToCryptoBuffers().value);
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 948a29a..f105676 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -62,6 +62,7 @@
   bool ValidateAddressToken(absl::string_view token) const override;
   void OnNewTokenReceived(absl::string_view token) override;
   bool ShouldSendExpectCTHeader() const override;
+  bool DidCertMatchSni() const override;
   const ProofSource::Details* ProofSourceDetails() const override;
 
   // From QuicCryptoServerStreamBase and TlsHandshaker
@@ -169,10 +170,11 @@
   bool HasValidSignature(size_t max_signature_size) const;
 
   // ProofSourceHandleCallback implementation:
-  void OnSelectCertificateDone(
-      bool ok, bool is_sync, const ProofSource::Chain* chain,
-      absl::string_view handshake_hints,
-      absl::string_view ticket_encryption_key) override;
+  void OnSelectCertificateDone(bool ok, bool is_sync,
+                               const ProofSource::Chain* chain,
+                               absl::string_view handshake_hints,
+                               absl::string_view ticket_encryption_key,
+                               bool cert_matched_sni) override;
 
   void OnComputeSignatureDone(
       bool ok,
@@ -364,6 +366,8 @@
   const QuicCryptoServerConfig* crypto_config_;  // Unowned.
   const bool restore_connection_context_in_callbacks_ =
       GetQuicReloadableFlag(quic_tls_restore_connection_context_in_callbacks);
+
+  bool cert_matched_sni_ = false;
 };
 
 }  // namespace quic
diff --git a/quic/qbone/qbone_session_test.cc b/quic/qbone/qbone_session_test.cc
index 9887f79..c3ae3be 100644
--- a/quic/qbone/qbone_session_test.cc
+++ b/quic/qbone/qbone_session_test.cc
@@ -79,9 +79,9 @@
                 absl::string_view chlo_hash,
                 std::unique_ptr<Callback> callback) override {
     if (!proof_source_) {
-      QuicReferenceCountedPointer<ProofSource::Chain> chain =
-          GetCertChain(server_address, client_address, hostname);
       QuicCryptoProof proof;
+      QuicReferenceCountedPointer<ProofSource::Chain> chain = GetCertChain(
+          server_address, client_address, hostname, &proof.cert_matched_sni);
       callback->Run(/*ok=*/false, chain, proof, /*details=*/nullptr);
       return;
     }
@@ -92,13 +92,13 @@
 
   QuicReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& server_address,
-      const QuicSocketAddress& client_address,
-      const std::string& hostname) override {
+      const QuicSocketAddress& client_address, const std::string& hostname,
+      bool* cert_matched_sni) override {
     if (!proof_source_) {
       return QuicReferenceCountedPointer<Chain>();
     }
-    return proof_source_->GetCertChain(server_address, client_address,
-                                       hostname);
+    return proof_source_->GetCertChain(server_address, client_address, hostname,
+                                       cert_matched_sni);
   }
 
   void ComputeTlsSignature(
diff --git a/quic/test_tools/failing_proof_source.cc b/quic/test_tools/failing_proof_source.cc
index fad128d..645bef9 100644
--- a/quic/test_tools/failing_proof_source.cc
+++ b/quic/test_tools/failing_proof_source.cc
@@ -22,7 +22,8 @@
 QuicReferenceCountedPointer<ProofSource::Chain>
 FailingProofSource::GetCertChain(const QuicSocketAddress& /*server_address*/,
                                  const QuicSocketAddress& /*client_address*/,
-                                 const std::string& /*hostname*/) {
+                                 const std::string& /*hostname*/,
+                                 bool* /*cert_matched_sni*/) {
   return QuicReferenceCountedPointer<Chain>();
 }
 
diff --git a/quic/test_tools/failing_proof_source.h b/quic/test_tools/failing_proof_source.h
index 447b770..db5a197 100644
--- a/quic/test_tools/failing_proof_source.h
+++ b/quic/test_tools/failing_proof_source.h
@@ -23,8 +23,8 @@
 
   QuicReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& server_address,
-      const QuicSocketAddress& client_address,
-      const std::string& hostname) override;
+      const QuicSocketAddress& client_address, const std::string& hostname,
+      bool* cert_matched_sni) override;
 
   void ComputeTlsSignature(
       const QuicSocketAddress& server_address,
diff --git a/quic/test_tools/fake_proof_source.cc b/quic/test_tools/fake_proof_source.cc
index 1109d65..0f7cb19 100644
--- a/quic/test_tools/fake_proof_source.cc
+++ b/quic/test_tools/fake_proof_source.cc
@@ -96,9 +96,10 @@
 
 QuicReferenceCountedPointer<ProofSource::Chain> FakeProofSource::GetCertChain(
     const QuicSocketAddress& server_address,
-    const QuicSocketAddress& client_address,
-    const std::string& hostname) {
-  return delegate_->GetCertChain(server_address, client_address, hostname);
+    const QuicSocketAddress& client_address, const std::string& hostname,
+    bool* cert_matched_sni) {
+  return delegate_->GetCertChain(server_address, client_address, hostname,
+                                 cert_matched_sni);
 }
 
 void FakeProofSource::ComputeTlsSignature(
diff --git a/quic/test_tools/fake_proof_source.h b/quic/test_tools/fake_proof_source.h
index c088d43..b135129 100644
--- a/quic/test_tools/fake_proof_source.h
+++ b/quic/test_tools/fake_proof_source.h
@@ -43,8 +43,8 @@
                 std::unique_ptr<ProofSource::Callback> callback) override;
   QuicReferenceCountedPointer<Chain> GetCertChain(
       const QuicSocketAddress& server_address,
-      const QuicSocketAddress& client_address,
-      const std::string& hostname) override;
+      const QuicSocketAddress& client_address, const std::string& hostname,
+      bool* cert_matched_sni) override;
   void ComputeTlsSignature(
       const QuicSocketAddress& server_address,
       const QuicSocketAddress& client_address,
diff --git a/quic/test_tools/fake_proof_source_handle.cc b/quic/test_tools/fake_proof_source_handle.cc
index f34247e..8c1e7f8 100644
--- a/quic/test_tools/fake_proof_source_handle.cc
+++ b/quic/test_tools/fake_proof_source_handle.cc
@@ -94,19 +94,23 @@
     callback()->OnSelectCertificateDone(
         /*ok=*/false,
         /*is_sync=*/true, nullptr, /*handshake_hints=*/absl::string_view(),
-        /*ticket_encryption_key=*/absl::string_view());
+        /*ticket_encryption_key=*/absl::string_view(),
+        /*cert_matched_sni=*/false);
     return QUIC_FAILURE;
   }
 
   QUICHE_DCHECK(select_cert_action_ == Action::DELEGATE_SYNC);
+  bool cert_matched_sni;
   QuicReferenceCountedPointer<ProofSource::Chain> chain =
-      delegate_->GetCertChain(server_address, client_address, hostname);
+      delegate_->GetCertChain(server_address, client_address, hostname,
+                              &cert_matched_sni);
 
   bool ok = chain && !chain->certs.empty();
   callback_->OnSelectCertificateDone(
       ok, /*is_sync=*/true, chain.get(),
       /*handshake_hints=*/absl::string_view(),
-      /*ticket_encryption_key=*/absl::string_view());
+      /*ticket_encryption_key=*/absl::string_view(),
+      /*cert_matched_sni=*/cert_matched_sni);
   return ok ? QUIC_SUCCESS : QUIC_FAILURE;
 }
 
@@ -182,16 +186,19 @@
         /*ok=*/false,
         /*is_sync=*/false, nullptr,
         /*handshake_hints=*/absl::string_view(),
-        /*ticket_encryption_key=*/absl::string_view());
+        /*ticket_encryption_key=*/absl::string_view(),
+        /*cert_matched_sni=*/false);
   } else if (action_ == Action::DELEGATE_ASYNC) {
+    bool cert_matched_sni;
     QuicReferenceCountedPointer<ProofSource::Chain> chain =
         delegate_->GetCertChain(args_.server_address, args_.client_address,
-                                args_.hostname);
+                                args_.hostname, &cert_matched_sni);
     bool ok = chain && !chain->certs.empty();
     callback_->OnSelectCertificateDone(
         ok, /*is_sync=*/false, chain.get(),
         /*handshake_hints=*/absl::string_view(),
-        /*ticket_encryption_key=*/absl::string_view());
+        /*ticket_encryption_key=*/absl::string_view(),
+        /*cert_matched_sni=*/cert_matched_sni);
   } else {
     QUIC_BUG(quic_bug_10139_1)
         << "Unexpected action: " << static_cast<int>(action_);