Internal change PiperOrigin-RevId: 395525828
diff --git a/quic/core/crypto/proof_source.h b/quic/core/crypto/proof_source.h index b0a2446..c4a1f2c 100644 --- a/quic/core/crypto/proof_source.h +++ b/quic/core/crypto/proof_source.h
@@ -145,10 +145,13 @@ std::unique_ptr<Callback> callback) = 0; // Returns the certificate chain for |hostname| in leaf-first order. + // + // Sets *cert_matched_sni to true if the certificate matched the given + // hostname, false if a default cert not matching the hostname was used. virtual QuicReferenceCountedPointer<Chain> GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) = 0; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) = 0; // Computes a signature using the private key of the certificate for // |hostname|. The value in |in| is signed using the algorithm specified by @@ -247,13 +250,16 @@ // SSL_set_handshake_hints. // |ticket_encryption_key| (optional) encryption key to be used for minting // TLS resumption tickets. + // |cert_matched_sni| is true if the certificate matched the SNI hostname, + // false if a non-matching default cert was used. // // When called asynchronously(is_sync=false), this method will be responsible // to continue the handshake from where it left off. - virtual void OnSelectCertificateDone( - bool ok, bool is_sync, const ProofSource::Chain* chain, - absl::string_view handshake_hints, - absl::string_view ticket_encryption_key) = 0; + virtual void OnSelectCertificateDone(bool ok, bool is_sync, + const ProofSource::Chain* chain, + absl::string_view handshake_hints, + absl::string_view ticket_encryption_key, + bool cert_matched_sni) = 0; // Called when a ProofSourceHandle::ComputeSignature operation completes. virtual void OnComputeSignatureDone(
diff --git a/quic/core/crypto/proof_source_x509.cc b/quic/core/crypto/proof_source_x509.cc index c7acd00..986ee32 100644 --- a/quic/core/crypto/proof_source_x509.cc +++ b/quic/core/crypto/proof_source_x509.cc
@@ -53,7 +53,7 @@ return; } - Certificate* certificate = GetCertificate(hostname); + Certificate* certificate = GetCertificate(hostname, &proof.cert_matched_sni); proof.signature = certificate->key.Sign(absl::string_view(payload.get(), payload_size), SSL_SIGN_RSA_PSS_RSAE_SHA256); @@ -63,9 +63,9 @@ QuicReferenceCountedPointer<ProofSource::Chain> ProofSourceX509::GetCertChain( const QuicSocketAddress& /*server_address*/, - const QuicSocketAddress& /*client_address*/, - const std::string& hostname) { - return GetCertificate(hostname)->chain; + const QuicSocketAddress& /*client_address*/, const std::string& hostname, + bool* cert_matched_sni) { + return GetCertificate(hostname, cert_matched_sni)->chain; } void ProofSourceX509::ComputeTlsSignature( @@ -75,8 +75,9 @@ uint16_t signature_algorithm, absl::string_view in, std::unique_ptr<ProofSource::SignatureCallback> callback) { - std::string signature = - GetCertificate(hostname)->key.Sign(in, signature_algorithm); + bool cert_matched_sni; + std::string signature = GetCertificate(hostname, &cert_matched_sni) + ->key.Sign(in, signature_algorithm); callback->Run(/*ok=*/!signature.empty(), signature, nullptr); } @@ -125,9 +126,10 @@ } ProofSourceX509::Certificate* ProofSourceX509::GetCertificate( - const std::string& hostname) const { + const std::string& hostname, bool* cert_matched_sni) const { auto it = certificate_map_.find(hostname); if (it != certificate_map_.end()) { + *cert_matched_sni = true; return it->second; } auto dot_pos = hostname.find('.'); @@ -135,9 +137,11 @@ std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos)); it = certificate_map_.find(wildcard); if (it != certificate_map_.end()) { + *cert_matched_sni = true; return it->second; } } + *cert_matched_sni = false; return default_certificate_; }
diff --git a/quic/core/crypto/proof_source_x509.h b/quic/core/crypto/proof_source_x509.h index 9ac6769..4b7ae51 100644 --- a/quic/core/crypto/proof_source_x509.h +++ b/quic/core/crypto/proof_source_x509.h
@@ -37,8 +37,8 @@ std::unique_ptr<Callback> callback) override; QuicReferenceCountedPointer<Chain> GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; void ComputeTlsSignature( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, const std::string& hostname, @@ -65,7 +65,8 @@ // Looks up certficiate for hostname, returns the default if no certificate is // found. - Certificate* GetCertificate(const std::string& hostname) const; + Certificate* GetCertificate(const std::string& hostname, + bool* cert_matched_sni) const; std::forward_list<Certificate> certificates_; Certificate* default_certificate_;
diff --git a/quic/core/crypto/proof_source_x509_test.cc b/quic/core/crypto/proof_source_x509_test.cc index 0f5f4f2..1a9462f 100644 --- a/quic/core/crypto/proof_source_x509_test.cc +++ b/quic/core/crypto/proof_source_x509_test.cc
@@ -71,40 +71,47 @@ std::move(*wildcard_key_))); // Default certificate. + bool cert_matched_sni; EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "unknown.test") + "unknown.test", &cert_matched_sni) ->certs[0], kTestCertificate); + EXPECT_FALSE(cert_matched_sni); // mail.example.org is explicitly a SubjectAltName in kTestCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "mail.example.org") + "mail.example.org", &cert_matched_sni) ->certs[0], kTestCertificate); + EXPECT_TRUE(cert_matched_sni); // www.foo.test is in kWildcardCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "www.foo.test") + "www.foo.test", &cert_matched_sni) ->certs[0], kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); // *.wildcard.test is in kWildcardCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "www.wildcard.test") + "www.wildcard.test", &cert_matched_sni) ->certs[0], kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "etc.wildcard.test") + "etc.wildcard.test", &cert_matched_sni) ->certs[0], kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); // wildcard.test itself is not in kWildcardCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "wildcard.test") + "wildcard.test", &cert_matched_sni) ->certs[0], kTestCertificate); + EXPECT_FALSE(cert_matched_sni); } TEST_F(ProofSourceX509Test, TlsSignature) {
diff --git a/quic/core/crypto/quic_crypto_proof.cc b/quic/core/crypto/quic_crypto_proof.cc index 033ef32..1e5f3c4 100644 --- a/quic/core/crypto/quic_crypto_proof.cc +++ b/quic/core/crypto/quic_crypto_proof.cc
@@ -6,6 +6,7 @@ namespace quic { -QuicCryptoProof::QuicCryptoProof() : send_expect_ct_header(false) {} +QuicCryptoProof::QuicCryptoProof() + : send_expect_ct_header(false), cert_matched_sni(false) {} } // namespace quic
diff --git a/quic/core/crypto/quic_crypto_proof.h b/quic/core/crypto/quic_crypto_proof.h index 6ca87fb..53e0961 100644 --- a/quic/core/crypto/quic_crypto_proof.h +++ b/quic/core/crypto/quic_crypto_proof.h
@@ -22,6 +22,9 @@ // Should the Expect-CT header be sent on the connection where the // certificate is used. bool send_expect_ct_header; + // Did the selected leaf certificate contain a SubjectAltName that included + // the requested SNI. + bool cert_matched_sni; }; } // namespace quic
diff --git a/quic/core/quic_crypto_client_handshaker_test.cc b/quic/core/quic_crypto_client_handshaker_test.cc index ed5da6c..120e5da 100644 --- a/quic/core/quic_crypto_client_handshaker_test.cc +++ b/quic/core/quic_crypto_client_handshaker_test.cc
@@ -77,18 +77,20 @@ QuicTransportVersion /*transport_version*/, absl::string_view /*chlo_hash*/, std::unique_ptr<Callback> callback) override { - QuicReferenceCountedPointer<ProofSource::Chain> chain = - GetCertChain(server_address, client_address, hostname); + bool cert_matched_sni; + QuicReferenceCountedPointer<ProofSource::Chain> chain = GetCertChain( + server_address, client_address, hostname, &cert_matched_sni); QuicCryptoProof proof; proof.signature = "Dummy signature"; proof.leaf_cert_scts = "Dummy timestamp"; + proof.cert_matched_sni = cert_matched_sni; callback->Run(true, chain, proof, /*details=*/nullptr); } QuicReferenceCountedPointer<Chain> GetCertChain( const QuicSocketAddress& /*server_address*/, const QuicSocketAddress& /*client_address*/, - const std::string& /*hostname*/) override { + const std::string& /*hostname*/, bool* /*cert_matched_sni*/) override { std::vector<std::string> certs; certs.push_back("Dummy cert"); return QuicReferenceCountedPointer<ProofSource::Chain>(
diff --git a/quic/core/quic_crypto_server_stream.cc b/quic/core/quic_crypto_server_stream.cc index 9bd9c7b..461c58f 100644 --- a/quic/core/quic_crypto_server_stream.cc +++ b/quic/core/quic_crypto_server_stream.cc
@@ -352,6 +352,10 @@ return signed_config_->proof.send_expect_ct_header; } +bool QuicCryptoServerStream::DidCertMatchSni() const { + return signed_config_->proof.cert_matched_sni; +} + const ProofSource::Details* QuicCryptoServerStream::ProofSourceDetails() const { return proof_source_details_.get(); }
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h index 911ca7e..690e872 100644 --- a/quic/core/quic_crypto_server_stream.h +++ b/quic/core/quic_crypto_server_stream.h
@@ -51,6 +51,7 @@ std::string GetAddressToken() const override; bool ValidateAddressToken(absl::string_view token) const override; bool ShouldSendExpectCTHeader() const override; + bool DidCertMatchSni() const override; const ProofSource::Details* ProofSourceDetails() const override; // From QuicCryptoStream
diff --git a/quic/core/quic_crypto_server_stream_base.h b/quic/core/quic_crypto_server_stream_base.h index 3ff9c37..d2b9f15 100644 --- a/quic/core/quic_crypto_server_stream_base.h +++ b/quic/core/quic_crypto_server_stream_base.h
@@ -85,6 +85,9 @@ // configuration for the certificate used in the connection is accessible. virtual bool ShouldSendExpectCTHeader() const = 0; + // Return true if a cert was picked that matched the SNI hostname. + virtual bool DidCertMatchSni() const = 0; + // Returns the Details from the latest call to ProofSource::GetProof or // ProofSource::ComputeTlsSignature. Returns nullptr if no such call has been // made. The Details are owned by the QuicCryptoServerStreamBase and the
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index 514c462..ba7aedb 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -77,13 +77,15 @@ return QUIC_FAILURE; } + bool cert_matched_sni; QuicReferenceCountedPointer<ProofSource::Chain> chain = - proof_source_->GetCertChain(server_address, client_address, hostname); + proof_source_->GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); handshaker_->OnSelectCertificateDone( /*ok=*/true, /*is_sync=*/true, chain.get(), /*handshake_hints=*/absl::string_view(), - /*ticket_encryption_key=*/absl::string_view()); + /*ticket_encryption_key=*/absl::string_view(), cert_matched_sni); if (!handshaker_->select_cert_status().has_value()) { QUIC_BUG(quic_bug_12423_1) << "select_cert_status() has no value after a synchronous select cert"; @@ -357,6 +359,8 @@ return false; } +bool TlsServerHandshaker::DidCertMatchSni() const { return cert_matched_sni_; } + const ProofSource::Details* TlsServerHandshaker::ProofSourceDetails() const { return proof_source_details_.get(); } @@ -956,8 +960,8 @@ void TlsServerHandshaker::OnSelectCertificateDone( bool ok, bool is_sync, const ProofSource::Chain* chain, - absl::string_view handshake_hints, - absl::string_view ticket_encryption_key) { + absl::string_view handshake_hints, absl::string_view ticket_encryption_key, + bool cert_matched_sni) { QUIC_DVLOG(1) << "OnSelectCertificateDone. ok:" << ok << ", is_sync:" << is_sync << ", len(handshake_hints):" << handshake_hints.size() @@ -979,6 +983,7 @@ } ticket_encryption_key_ = std::string(ticket_encryption_key); select_cert_status_ = QUIC_FAILURE; + cert_matched_sni_ = cert_matched_sni; if (ok) { if (chain && !chain->certs.empty()) { tls_connection_.SetCertChain(chain->ToCryptoBuffers().value);
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h index 948a29a..f105676 100644 --- a/quic/core/tls_server_handshaker.h +++ b/quic/core/tls_server_handshaker.h
@@ -62,6 +62,7 @@ bool ValidateAddressToken(absl::string_view token) const override; void OnNewTokenReceived(absl::string_view token) override; bool ShouldSendExpectCTHeader() const override; + bool DidCertMatchSni() const override; const ProofSource::Details* ProofSourceDetails() const override; // From QuicCryptoServerStreamBase and TlsHandshaker @@ -169,10 +170,11 @@ bool HasValidSignature(size_t max_signature_size) const; // ProofSourceHandleCallback implementation: - void OnSelectCertificateDone( - bool ok, bool is_sync, const ProofSource::Chain* chain, - absl::string_view handshake_hints, - absl::string_view ticket_encryption_key) override; + void OnSelectCertificateDone(bool ok, bool is_sync, + const ProofSource::Chain* chain, + absl::string_view handshake_hints, + absl::string_view ticket_encryption_key, + bool cert_matched_sni) override; void OnComputeSignatureDone( bool ok, @@ -364,6 +366,8 @@ const QuicCryptoServerConfig* crypto_config_; // Unowned. const bool restore_connection_context_in_callbacks_ = GetQuicReloadableFlag(quic_tls_restore_connection_context_in_callbacks); + + bool cert_matched_sni_ = false; }; } // namespace quic
diff --git a/quic/qbone/qbone_session_test.cc b/quic/qbone/qbone_session_test.cc index 9887f79..c3ae3be 100644 --- a/quic/qbone/qbone_session_test.cc +++ b/quic/qbone/qbone_session_test.cc
@@ -79,9 +79,9 @@ absl::string_view chlo_hash, std::unique_ptr<Callback> callback) override { if (!proof_source_) { - QuicReferenceCountedPointer<ProofSource::Chain> chain = - GetCertChain(server_address, client_address, hostname); QuicCryptoProof proof; + QuicReferenceCountedPointer<ProofSource::Chain> chain = GetCertChain( + server_address, client_address, hostname, &proof.cert_matched_sni); callback->Run(/*ok=*/false, chain, proof, /*details=*/nullptr); return; } @@ -92,13 +92,13 @@ QuicReferenceCountedPointer<Chain> GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override { + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override { if (!proof_source_) { return QuicReferenceCountedPointer<Chain>(); } - return proof_source_->GetCertChain(server_address, client_address, - hostname); + return proof_source_->GetCertChain(server_address, client_address, hostname, + cert_matched_sni); } void ComputeTlsSignature(
diff --git a/quic/test_tools/failing_proof_source.cc b/quic/test_tools/failing_proof_source.cc index fad128d..645bef9 100644 --- a/quic/test_tools/failing_proof_source.cc +++ b/quic/test_tools/failing_proof_source.cc
@@ -22,7 +22,8 @@ QuicReferenceCountedPointer<ProofSource::Chain> FailingProofSource::GetCertChain(const QuicSocketAddress& /*server_address*/, const QuicSocketAddress& /*client_address*/, - const std::string& /*hostname*/) { + const std::string& /*hostname*/, + bool* /*cert_matched_sni*/) { return QuicReferenceCountedPointer<Chain>(); }
diff --git a/quic/test_tools/failing_proof_source.h b/quic/test_tools/failing_proof_source.h index 447b770..db5a197 100644 --- a/quic/test_tools/failing_proof_source.h +++ b/quic/test_tools/failing_proof_source.h
@@ -23,8 +23,8 @@ QuicReferenceCountedPointer<Chain> GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; void ComputeTlsSignature( const QuicSocketAddress& server_address,
diff --git a/quic/test_tools/fake_proof_source.cc b/quic/test_tools/fake_proof_source.cc index 1109d65..0f7cb19 100644 --- a/quic/test_tools/fake_proof_source.cc +++ b/quic/test_tools/fake_proof_source.cc
@@ -96,9 +96,10 @@ QuicReferenceCountedPointer<ProofSource::Chain> FakeProofSource::GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) { - return delegate_->GetCertChain(server_address, client_address, hostname); + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) { + return delegate_->GetCertChain(server_address, client_address, hostname, + cert_matched_sni); } void FakeProofSource::ComputeTlsSignature(
diff --git a/quic/test_tools/fake_proof_source.h b/quic/test_tools/fake_proof_source.h index c088d43..b135129 100644 --- a/quic/test_tools/fake_proof_source.h +++ b/quic/test_tools/fake_proof_source.h
@@ -43,8 +43,8 @@ std::unique_ptr<ProofSource::Callback> callback) override; QuicReferenceCountedPointer<Chain> GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; void ComputeTlsSignature( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address,
diff --git a/quic/test_tools/fake_proof_source_handle.cc b/quic/test_tools/fake_proof_source_handle.cc index f34247e..8c1e7f8 100644 --- a/quic/test_tools/fake_proof_source_handle.cc +++ b/quic/test_tools/fake_proof_source_handle.cc
@@ -94,19 +94,23 @@ callback()->OnSelectCertificateDone( /*ok=*/false, /*is_sync=*/true, nullptr, /*handshake_hints=*/absl::string_view(), - /*ticket_encryption_key=*/absl::string_view()); + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/false); return QUIC_FAILURE; } QUICHE_DCHECK(select_cert_action_ == Action::DELEGATE_SYNC); + bool cert_matched_sni; QuicReferenceCountedPointer<ProofSource::Chain> chain = - delegate_->GetCertChain(server_address, client_address, hostname); + delegate_->GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); bool ok = chain && !chain->certs.empty(); callback_->OnSelectCertificateDone( ok, /*is_sync=*/true, chain.get(), /*handshake_hints=*/absl::string_view(), - /*ticket_encryption_key=*/absl::string_view()); + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/cert_matched_sni); return ok ? QUIC_SUCCESS : QUIC_FAILURE; } @@ -182,16 +186,19 @@ /*ok=*/false, /*is_sync=*/false, nullptr, /*handshake_hints=*/absl::string_view(), - /*ticket_encryption_key=*/absl::string_view()); + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/false); } else if (action_ == Action::DELEGATE_ASYNC) { + bool cert_matched_sni; QuicReferenceCountedPointer<ProofSource::Chain> chain = delegate_->GetCertChain(args_.server_address, args_.client_address, - args_.hostname); + args_.hostname, &cert_matched_sni); bool ok = chain && !chain->certs.empty(); callback_->OnSelectCertificateDone( ok, /*is_sync=*/false, chain.get(), /*handshake_hints=*/absl::string_view(), - /*ticket_encryption_key=*/absl::string_view()); + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/cert_matched_sni); } else { QUIC_BUG(quic_bug_10139_1) << "Unexpected action: " << static_cast<int>(action_);