Refactor TlsServerHandshaker to use ProofSourceHandle for cert selection and signature. Protected by FLAGS_quic_reloadable_flag_quic_tls_use_per_handshaker_proof_source. PiperOrigin-RevId: 350572752 Change-Id: Ifb33c9554c5c91fb69bbe1606a06b913c2fbbfff
diff --git a/quic/core/crypto/proof_source.h b/quic/core/crypto/proof_source.h index 4e36a32..cee94f2 100644 --- a/quic/core/crypto/proof_source.h +++ b/quic/core/crypto/proof_source.h
@@ -19,6 +19,10 @@ namespace quic { +namespace test { +class FakeProofSourceHandle; +} // namespace test + // CryptoBuffers is a RAII class to own a std::vector<CRYPTO_BUFFER*> and the // buffers the elements point to. struct QUIC_EXPORT_PRIVATE CryptoBuffers { @@ -211,6 +215,97 @@ virtual TicketCrypter* GetTicketCrypter() = 0; }; +// ProofSourceHandleCallback is an interface that contains the callbacks when +// the operations in ProofSourceHandle completes. +// TODO(wub): Consider deprecating ProofSource by moving all functionalities of +// ProofSource into ProofSourceHandle. +class QUIC_EXPORT_PRIVATE ProofSourceHandleCallback { + public: + virtual ~ProofSourceHandleCallback() = default; + + // Called when a ProofSourceHandle::SelectCertificate operation completes. + // |ok| indicates whether the operation was successful. + // |is_sync| indicates whether the operation completed synchronously, i.e. + // whether it is completed before ProofSourceHandle::SelectCertificate + // returned. + // |chain| the certificate chain in leaf-first order. + // + // When called asynchronously(is_sync=false), this method will be responsible + // to continue the handshake from where it left off. + virtual void OnSelectCertificateDone(bool ok, + bool is_sync, + const ProofSource::Chain* chain) = 0; + + // Called when a ProofSourceHandle::ComputeSignature operation completes. + virtual void OnComputeSignatureDone( + bool ok, + bool is_sync, + std::string signature, + std::unique_ptr<ProofSource::Details> details) = 0; +}; + +// ProofSourceHandle is an interface by which a TlsServerHandshaker can obtain +// certificate chains and signatures that prove its identity. +// The operations this interface supports are similar to those in ProofSource, +// the main difference is that ProofSourceHandle is per-handshaker, so +// an implementation can have states that are shared by multiple calls on the +// same handle. +// +// A handle object is owned by a TlsServerHandshaker. Since there might be an +// async operation pending when the handle destructs, an implementation must +// ensure when such operations finish, their corresponding callback method won't +// be invoked. +// +// A handle will have at most one async operation pending at a time. +class QUIC_EXPORT_PRIVATE ProofSourceHandle { + public: + virtual ~ProofSourceHandle() = default; + + // Cancel the pending operation, if any. + // Once called, any completion method on |callback()| won't be invoked. + virtual void CancelPendingOperation() = 0; + + // Starts a select certificate operation. If the operation is not cancelled + // when it completes, callback()->OnSelectCertificateDone will be invoked. + // + // If the operation is handled synchronously: + // - QUIC_SUCCESS or QUIC_FAILURE will be returned. + // - callback()->OnSelectCertificateDone should be invoked before the function + // returns. + // + // If the operation is handled asynchronously: + // - QUIC_PENDING will be returned. + // - When the operation is done, callback()->OnSelectCertificateDone should be + // invoked. + virtual QuicAsyncStatus SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view client_hello, + const std::string& alpn, + const std::vector<uint8_t>& quic_transport_params, + const absl::optional<std::vector<uint8_t>>& early_data_context) = 0; + + // Starts a compute signature operation. If the operation is not cancelled + // when it completes, callback()->OnComputeSignatureDone will be invoked. + // + // See the comments of SelectCertificate for sync vs. async operations. + virtual QuicAsyncStatus ComputeSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) = 0; + + protected: + // Returns the object that will be notified when an operation completes. + virtual ProofSourceHandleCallback* callback() = 0; + + private: + friend class test::FakeProofSourceHandle; +}; + } // namespace quic #endif // QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_H_
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h index 41d2820..baf8ee9 100644 --- a/quic/core/quic_flags_list.h +++ b/quic/core/quic_flags_list.h
@@ -62,6 +62,7 @@ QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_false, false) QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_true, true) QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_early_select_cert, false) +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_per_handshaker_proof_source, false) QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_unified_iw_options, false) QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_circular_deque_for_unacked_packets_v2, false) QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_encryption_level_context, false)
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index 383ca83..8fc52f1 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -13,6 +13,8 @@ #include "third_party/boringssl/src/include/openssl/ssl.h" #include "quic/core/crypto/quic_crypto_server_config.h" #include "quic/core/crypto/transport_parameters.h" +#include "quic/core/quic_time.h" +#include "quic/core/quic_types.h" #include "quic/platform/api/quic_flag_utils.h" #include "quic/platform/api/quic_flags.h" #include "quic/platform/api/quic_hostname_utils.h" @@ -21,9 +23,88 @@ namespace quic { +TlsServerHandshaker::DefaultProofSourceHandle::DefaultProofSourceHandle( + TlsServerHandshaker* handshaker, + ProofSource* proof_source) + : handshaker_(handshaker), proof_source_(proof_source) {} + +TlsServerHandshaker::DefaultProofSourceHandle::~DefaultProofSourceHandle() { + CancelPendingOperation(); +} + +void TlsServerHandshaker::DefaultProofSourceHandle::CancelPendingOperation() { + QUIC_DVLOG(1) << "CancelPendingOperation. is_signature_pending=" + << (signature_callback_ != nullptr); + if (signature_callback_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_tls_use_per_handshaker_proof_source, 3, + 3); + signature_callback_->Cancel(); + signature_callback_ = nullptr; + } +} + +QuicAsyncStatus +TlsServerHandshaker::DefaultProofSourceHandle::SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view /*client_hello*/, + const std::string& /*alpn*/, + const std::vector<uint8_t>& /*quic_transport_params*/, + const absl::optional<std::vector<uint8_t>>& /*early_data_context*/) { + if (!handshaker_ || !proof_source_) { + QUIC_BUG << "SelectCertificate called on a detached handle"; + return QUIC_FAILURE; + } + + QuicReferenceCountedPointer<ProofSource::Chain> chain = + proof_source_->GetCertChain(server_address, client_address, hostname); + + handshaker_->OnSelectCertificateDone( + /*ok=*/true, /*is_sync=*/true, chain.get()); + return handshaker_->SelectCertStatus(); +} + +QuicAsyncStatus TlsServerHandshaker::DefaultProofSourceHandle::ComputeSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) { + if (!handshaker_ || !proof_source_) { + QUIC_BUG << "ComputeSignature called on a detached handle"; + return QUIC_FAILURE; + } + + if (signature_callback_) { + QUIC_BUG << "ComputeSignature called while pending"; + return QUIC_FAILURE; + } + + signature_callback_ = new DefaultSignatureCallback(this); + proof_source_->ComputeTlsSignature( + server_address, client_address, hostname, signature_algorithm, in, + std::unique_ptr<DefaultSignatureCallback>(signature_callback_)); + + if (signature_callback_) { + QUIC_DVLOG(1) << "ComputeTlsSignature is pending"; + signature_callback_->set_is_sync(false); + return QUIC_PENDING; + } + + bool success = handshaker_->HasValidSignature(max_signature_size); + QUIC_DVLOG(1) << "ComputeTlsSignature completed synchronously. success:" + << success; + // OnComputeSignatureDone should have been called by signature_callback_->Run. + return success ? QUIC_SUCCESS : QUIC_FAILURE; +} + TlsServerHandshaker::SignatureCallback::SignatureCallback( TlsServerHandshaker* handshaker) - : handshaker_(handshaker) {} + : handshaker_(handshaker) { + DCHECK(!handshaker_->use_proof_source_handle_); +} void TlsServerHandshaker::SignatureCallback::Run( bool ok, @@ -110,6 +191,9 @@ } void TlsServerHandshaker::CancelOutstandingCallbacks() { + if (use_proof_source_handle_ && proof_source_handle_) { + proof_source_handle_->CancelPendingOperation(); + } if (signature_callback_) { signature_callback_->Cancel(); signature_callback_ = nullptr; @@ -120,6 +204,12 @@ } } +std::unique_ptr<ProofSourceHandle> +TlsServerHandshaker::MaybeCreateProofSourceHandle() { + DCHECK(use_proof_source_handle_); + return std::make_unique<DefaultProofSourceHandle>(this, proof_source_); +} + bool TlsServerHandshaker::GetBase64SHA256ClientChannelID( std::string* /*output*/) const { // Channel ID is not supported when TLS is used in QUIC. @@ -477,6 +567,19 @@ size_t max_out, uint16_t sig_alg, absl::string_view in) { + if (use_proof_source_handle_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_tls_use_per_handshaker_proof_source, 2, + 3); + QuicAsyncStatus status = proof_source_handle_->ComputeSignature( + session()->connection()->self_address(), + session()->connection()->peer_address(), hostname_, sig_alg, in, + max_out); + if (status == QUIC_PENDING) { + set_expected_ssl_error(SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); + } + return PrivateKeyComplete(out, out_len, max_out); + } + signature_callback_ = new SignatureCallback(this); proof_source_->ComputeTlsSignature( session()->connection()->self_address(), @@ -496,7 +599,7 @@ if (expected_ssl_error() == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) { return ssl_private_key_retry; } - if (cert_verify_sig_.size() > max_out || cert_verify_sig_.empty()) { + if (!HasValidSignature(max_out)) { return ssl_private_key_failure; } *out_len = cert_verify_sig_.size(); @@ -506,6 +609,32 @@ return ssl_private_key_success; } +void TlsServerHandshaker::OnComputeSignatureDone( + bool ok, + bool is_sync, + std::string signature, + std::unique_ptr<ProofSource::Details> details) { + QUIC_DVLOG(1) << "OnComputeSignatureDone. ok:" << ok + << ", is_sync:" << is_sync + << ", len(signature):" << signature.size(); + DCHECK(use_proof_source_handle_); + if (ok) { + cert_verify_sig_ = std::move(signature); + proof_source_details_ = std::move(details); + } + const int last_expected_ssl_error = expected_ssl_error(); + set_expected_ssl_error(SSL_ERROR_WANT_READ); + if (!is_sync) { + DCHECK_EQ(last_expected_ssl_error, SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); + AdvanceHandshakeFromCallback(); + } +} + +bool TlsServerHandshaker::HasValidSignature(size_t max_signature_size) const { + return !cert_verify_sig_.empty() && + cert_verify_sig_.size() <= max_signature_size; +} + size_t TlsServerHandshaker::SessionTicketMaxOverhead() { DCHECK(proof_source_->GetTicketCrypter()); return proof_source_->GetTicketCrypter()->MaxOverhead(); @@ -579,6 +708,27 @@ QUIC_RELOADABLE_FLAG_COUNT(quic_tls_use_early_select_cert); + // EarlySelectCertCallback can be called twice from BoringSSL: If the first + // call returns ssl_select_cert_retry, when cert selection completes, + // SSL_do_handshake will call it again. + if (use_proof_source_handle_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_tls_use_per_handshaker_proof_source, 1, + 3); + if (select_cert_status_.has_value()) { + // This is the second call, return the result directly. + QUIC_DVLOG(1) << "EarlySelectCertCallback called to continue handshake, " + "returning directly. success:" + << (select_cert_status_.value() == QUIC_SUCCESS); + return (select_cert_status_.value() == QUIC_SUCCESS) + ? ssl_select_cert_success + : ssl_select_cert_error; + } + + // This is the first call. + select_cert_status_ = QUIC_PENDING; + proof_source_handle_ = MaybeCreateProofSourceHandle(); + } + if (!pre_shared_key_.empty()) { // TODO(b/154162689) add PSK support to QUIC+TLS. QUIC_BUG << "QUIC server pre-shared keys not yet supported with TLS"; @@ -599,6 +749,45 @@ QUIC_LOG(INFO) << "No hostname indicated in SNI"; } + if (use_proof_source_handle_) { + std::string error_details; + if (!ProcessTransportParameters(client_hello, &error_details)) { + CloseConnection(QUIC_HANDSHAKE_FAILED, error_details); + return ssl_select_cert_error; + } + OverrideQuicConfigDefaults(session()->config()); + session()->OnConfigNegotiated(); + + auto set_transport_params_result = SetTransportParameters(); + if (!set_transport_params_result.success) { + QUIC_LOG(ERROR) << "Failed to set transport parameters"; + return ssl_select_cert_error; + } + + const QuicAsyncStatus status = proof_source_handle_->SelectCertificate( + session()->connection()->self_address(), + session()->connection()->peer_address(), hostname_, + absl::string_view( + reinterpret_cast<const char*>(client_hello->client_hello), + client_hello->client_hello_len), + AlpnForVersion(session()->version()), + set_transport_params_result.quic_transport_params, + set_transport_params_result.early_data_context); + + DCHECK_EQ(status, SelectCertStatus()); + + if (status == QUIC_PENDING) { + set_expected_ssl_error(SSL_ERROR_PENDING_CERTIFICATE); + return ssl_select_cert_retry; + } + + if (status == QUIC_FAILURE) { + return ssl_select_cert_error; + } + + return ssl_select_cert_success; + } + QuicReferenceCountedPointer<ProofSource::Chain> chain = proof_source_->GetCertChain(session()->connection()->self_address(), session()->connection()->peer_address(), @@ -629,6 +818,38 @@ return ssl_select_cert_success; } +void TlsServerHandshaker::OnSelectCertificateDone( + bool ok, + bool is_sync, + const ProofSource::Chain* chain) { + QUIC_DVLOG(1) << "OnSelectCertificateDone. ok:" << ok + << ", is_sync:" << is_sync; + DCHECK(use_proof_source_handle_); + + select_cert_status_ = QUIC_FAILURE; + if (ok) { + if (chain && !chain->certs.empty()) { + tls_connection_.SetCertChain(chain->ToCryptoBuffers().value); + select_cert_status_ = QUIC_SUCCESS; + } else { + QUIC_LOG(ERROR) << "No certs provided for host '" << hostname_ << "'"; + } + } + + if (!is_sync) { + AdvanceHandshakeFromCallback(); + } +} + +QuicAsyncStatus TlsServerHandshaker::SelectCertStatus() const { + if (!select_cert_status_.has_value()) { + QUIC_BUG << "SelectCertStatus should be called after select cert started"; + return QUIC_PENDING; + } + + return select_cert_status_.value(); +} + bool TlsServerHandshaker::ValidateHostname(const std::string& hostname) const { if (!QuicHostnameUtils::IsValidSNI(hostname)) { // TODO(b/151676147): Include this error string in the CONNECTION_CLOSE
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h index 40db662..9b1cdae 100644 --- a/quic/core/tls_server_handshaker.h +++ b/quic/core/tls_server_handshaker.h
@@ -15,6 +15,7 @@ #include "quic/core/proto/cached_network_parameters_proto.h" #include "quic/core/quic_crypto_server_stream_base.h" #include "quic/core/quic_crypto_stream.h" +#include "quic/core/quic_types.h" #include "quic/core/tls_handshaker.h" #include "quic/platform/api/quic_export.h" @@ -25,6 +26,7 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker : public TlsHandshaker, public TlsServerConnection::Delegate, + public ProofSourceHandleCallback, public QuicCryptoServerStreamBase { public: // |crypto_config| must outlive TlsServerHandshaker. @@ -79,6 +81,12 @@ const std::vector<uint8_t>& write_secret) override; protected: + // Creates a proof source handle for selecting cert and computing signature. + // Only called when |use_proof_source_handle_| is true. + virtual std::unique_ptr<ProofSourceHandle> MaybeCreateProofSourceHandle(); + + bool use_proof_source_handle() const { return use_proof_source_handle_; } + // Hook to allow the server to override parts of the QuicConfig based on SNI // before we generate transport parameters. virtual void OverrideQuicConfigDefaults(QuicConfig* config); @@ -136,6 +144,26 @@ absl::string_view in) override; TlsConnection::Delegate* ConnectionDelegate() override { return this; } + // The status of cert selection. Only called after cert selection started. + QuicAsyncStatus SelectCertStatus() const; + // Whether |cert_verify_sig_| contains a valid signature. + // NOTE: BoringSSL queries the result of a async signature operation using + // PrivateKeyComplete(), a successful PrivateKeyComplete() will clear the + // content of |cert_verify_sig_|, this function should not be called after + // that. + bool HasValidSignature(size_t max_signature_size) const; + + // ProofSourceHandleCallback implementation: + void OnSelectCertificateDone(bool ok, + bool is_sync, + const ProofSource::Chain* chain) override; + + void OnComputeSignatureDone( + bool ok, + bool is_sync, + std::string signature, + std::unique_ptr<ProofSource::Details> details) override; + private: class QUIC_EXPORT_PRIVATE SignatureCallback : public ProofSource::SignatureCallback { @@ -165,6 +193,81 @@ TlsServerHandshaker* handshaker_; }; + // DefaultProofSourceHandle delegates all operations to the shared proof + // source. + class QUIC_EXPORT_PRIVATE DefaultProofSourceHandle + : public ProofSourceHandle { + public: + DefaultProofSourceHandle(TlsServerHandshaker* handshaker, + ProofSource* proof_source); + + ~DefaultProofSourceHandle() override; + + // Cancel the pending signature operation, if any. + void CancelPendingOperation() override; + + // Delegates to proof_source_->GetCertChain. + // Returns QUIC_SUCCESS or QUIC_FAILURE. Never returns QUIC_PENDING. + QuicAsyncStatus SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view client_hello, + const std::string& alpn, + const std::vector<uint8_t>& quic_transport_params, + const absl::optional<std::vector<uint8_t>>& early_data_context) + override; + + // Delegates to proof_source_->ComputeTlsSignature. + // Returns QUIC_SUCCESS, QUIC_FAILURE or QUIC_PENDING. + QuicAsyncStatus ComputeSignature(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) override; + + protected: + ProofSourceHandleCallback* callback() override { return handshaker_; } + + private: + class QUIC_EXPORT_PRIVATE DefaultSignatureCallback + : public ProofSource::SignatureCallback { + public: + explicit DefaultSignatureCallback(DefaultProofSourceHandle* handle) + : handle_(handle) {} + + void Run(bool ok, + std::string signature, + std::unique_ptr<ProofSource::Details> details) override { + if (handle_ == nullptr) { + // Operation has been canceled, or Run has been called. + return; + } + handle_->signature_callback_ = nullptr; + if (handle_->handshaker_ != nullptr) { + handle_->handshaker_->OnComputeSignatureDone( + ok, is_sync_, std::move(signature), std::move(details)); + } + } + + // If called, Cancel causes the pending callback to be a no-op. + void Cancel() { handle_ = nullptr; } + + void set_is_sync(bool is_sync) { is_sync_ = is_sync; } + + private: + DefaultProofSourceHandle* handle_; + // Set to false if handle_->ComputeSignature returns QUIC_PENDING. + bool is_sync_ = true; + }; + + // Not nullptr on construction. Set to nullptr when cancelled. + TlsServerHandshaker* handshaker_; // Not owned. + ProofSource* proof_source_; // Not owned. + DefaultSignatureCallback* signature_callback_ = nullptr; + }; + struct QUIC_NO_EXPORT SetTransportParametersResult { bool success = false; // Empty vector if QUIC transport params are not set successfully. @@ -178,6 +281,7 @@ bool ProcessTransportParameters(const SSL_CLIENT_HELLO* client_hello, std::string* error_details); + std::unique_ptr<ProofSourceHandle> proof_source_handle_; ProofSource* proof_source_; SignatureCallback* signature_callback_ = nullptr; @@ -195,6 +299,9 @@ // indicates that the client attempted a resumption. bool ticket_received_ = false; + // nullopt means select cert hasn't started. + absl::optional<QuicAsyncStatus> select_cert_status_; + std::string hostname_; std::string cert_verify_sig_; std::unique_ptr<ProofSource::Details> proof_source_details_; @@ -212,6 +319,9 @@ TlsServerConnection tls_connection_; const bool use_early_select_cert_ = GetQuicReloadableFlag(quic_tls_use_early_select_cert); + const bool use_proof_source_handle_ = + use_early_select_cert_ && + GetQuicReloadableFlag(quic_tls_use_per_handshaker_proof_source); const QuicCryptoServerConfig* crypto_config_; // Unowned. };
diff --git a/quic/core/tls_server_handshaker_test.cc b/quic/core/tls_server_handshaker_test.cc index 5271bb6..e650a4a 100644 --- a/quic/core/tls_server_handshaker_test.cc +++ b/quic/core/tls_server_handshaker_test.cc
@@ -15,12 +15,14 @@ #include "quic/core/quic_utils.h" #include "quic/core/quic_versions.h" #include "quic/core/tls_client_handshaker.h" +#include "quic/core/tls_server_handshaker.h" #include "quic/platform/api/quic_flags.h" #include "quic/platform/api/quic_logging.h" #include "quic/platform/api/quic_test.h" #include "quic/test_tools/crypto_test_utils.h" #include "quic/test_tools/failing_proof_source.h" #include "quic/test_tools/fake_proof_source.h" +#include "quic/test_tools/fake_proof_source_handle.h" #include "quic/test_tools/quic_test_utils.h" #include "quic/test_tools/simple_session_cache.h" #include "quic/test_tools/test_ticket_crypter.h" @@ -65,6 +67,68 @@ return params; } +class TestTlsServerHandshaker : public TlsServerHandshaker { + public: + TestTlsServerHandshaker(QuicSession* session, + const QuicCryptoServerConfig* crypto_config) + : TlsServerHandshaker(session, crypto_config), + proof_source_(crypto_config->proof_source()) { + ON_CALL(*this, MaybeCreateProofSourceHandle()) + .WillByDefault(testing::Invoke( + this, &TestTlsServerHandshaker::RealMaybeCreateProofSourceHandle)); + } + + MOCK_METHOD(std::unique_ptr<ProofSourceHandle>, + MaybeCreateProofSourceHandle, + (), + (override)); + + void SetupProofSourceHandle( + FakeProofSourceHandle::Action select_cert_action, + FakeProofSourceHandle::Action compute_signature_action) { + EXPECT_CALL(*this, MaybeCreateProofSourceHandle()) + .WillOnce(testing::Invoke( + [this, select_cert_action, compute_signature_action]() { + auto handle = std::make_unique<FakeProofSourceHandle>( + proof_source_, this, select_cert_action, + compute_signature_action); + fake_proof_source_handle_ = handle.get(); + return handle; + })); + } + + FakeProofSourceHandle* fake_proof_source_handle() { + return fake_proof_source_handle_; + } + + private: + std::unique_ptr<ProofSourceHandle> RealMaybeCreateProofSourceHandle() { + return TlsServerHandshaker::MaybeCreateProofSourceHandle(); + } + + // Owned by TlsServerHandshaker. + FakeProofSourceHandle* fake_proof_source_handle_ = nullptr; + ProofSource* proof_source_ = nullptr; +}; + +class TlsServerHandshakerTestSession : public TestQuicSpdyServerSession { + public: + using TestQuicSpdyServerSession::TestQuicSpdyServerSession; + + std::unique_ptr<QuicCryptoServerStreamBase> CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* /*compressed_certs_cache*/) override { + if (connection()->version().handshake_protocol == PROTOCOL_TLS1_3) { + return std::make_unique<NiceMock<TestTlsServerHandshaker>>(this, + crypto_config); + } + + CHECK(false) << "Unsupported handshake protocol: " + << connection()->version().handshake_protocol; + return nullptr; + } +}; + class TlsServerHandshakerTest : public QuicTestWithParam<TestParams> { public: TlsServerHandshakerTest() @@ -109,6 +173,46 @@ std::make_unique<FailingProofSource>(), KeyExchangeSource::Default()); } + void CreateTlsServerHandshakerTestSession(MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory) { + server_connection_ = new PacketSavingConnection( + helper, alarm_factory, Perspective::IS_SERVER, + ParsedVersionOfIndex(supported_versions_, 0)); + + TlsServerHandshakerTestSession* server_session = + new TlsServerHandshakerTestSession( + server_connection_, DefaultQuicConfig(), supported_versions_, + server_crypto_config_.get(), &server_compressed_certs_cache_); + server_session->Initialize(); + + // We advance the clock initially because the default time is zero and the + // strike register worries that we've just overflowed a uint32_t time. + server_connection_->AdvanceTime(QuicTime::Delta::FromSeconds(100000)); + + CHECK(server_session); + server_session_.reset(server_session); + } + + void InitializeServerWithFakeProofSourceHandle() { + helpers_.push_back(std::make_unique<NiceMock<MockQuicConnectionHelper>>()); + alarm_factories_.push_back(std::make_unique<MockAlarmFactory>()); + CreateTlsServerHandshakerTestSession(helpers_.back().get(), + alarm_factories_.back().get()); + server_handshaker_ = static_cast<NiceMock<TestTlsServerHandshaker>*>( + server_session_->GetMutableCryptoStream()); + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillRepeatedly([this](const std::vector<absl::string_view>& alpns) { + return std::find( + alpns.cbegin(), alpns.cend(), + AlpnForVersion(server_session_->connection()->version())); + }); + crypto_test_utils::SetupCryptoServerConfigForTest( + server_connection_->clock(), server_connection_->random_generator(), + server_crypto_config_.get()); + } + // Initializes the crypto server stream state for testing. May be // called multiple times. void InitializeServer() { @@ -122,15 +226,15 @@ &server_connection_, &server_session); CHECK(server_session); server_session_.reset(server_session); + server_handshaker_ = nullptr; EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) .Times(testing::AnyNumber()); EXPECT_CALL(*server_session_, SelectAlpn(_)) - .WillRepeatedly( - [this](const std::vector<absl::string_view>& alpns) { - return std::find( - alpns.cbegin(), alpns.cend(), - AlpnForVersion(server_session_->connection()->version())); - }); + .WillRepeatedly([this](const std::vector<absl::string_view>& alpns) { + return std::find( + alpns.cbegin(), alpns.cend(), + AlpnForVersion(server_session_->connection()->version())); + }); crypto_test_utils::SetupCryptoServerConfigForTest( server_connection_->clock(), server_connection_->random_generator(), server_crypto_config_.get()); @@ -231,8 +335,10 @@ // Server state. PacketSavingConnection* server_connection_; std::unique_ptr<TestQuicSpdyServerSession> server_session_; + // Only set when initialized with InitializeServerWithFakeProofSourceHandle. + NiceMock<TestTlsServerHandshaker>* server_handshaker_ = nullptr; TestTicketCrypter* ticket_crypter_; // owned by proof_source_ - FakeProofSource* proof_source_; // owned by server_crypto_config_ + FakeProofSource* proof_source_; // owned by server_crypto_config_ std::unique_ptr<QuicCryptoServerConfig> server_crypto_config_; QuicCompressedCertsCache server_compressed_certs_cache_; QuicServerId server_id_; @@ -267,7 +373,58 @@ ExpectHandshakeSuccessful(); } -TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncProofSource) { +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSelectCertSuccess) { + if (!(GetQuicReloadableFlag(quic_tls_use_early_select_cert) && + GetQuicReloadableFlag(quic_tls_use_per_handshaker_proof_source))) { + return; + } + + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + + ExpectHandshakeSuccessful(); +} + +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSelectCertFailure) { + if (!(GetQuicReloadableFlag(quic_tls_use_early_select_cert) && + GetQuicReloadableFlag(quic_tls_use_per_handshaker_proof_source))) { + return; + } + + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::FAIL_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + // Check that the server didn't send any handshake messages, because it failed + // to handshake. + EXPECT_EQ(moved_messages_counts_.second, 0u); +} + +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSignature) { EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); // Enable FakeProofSource to capture call to ComputeTlsSignature and run it @@ -285,7 +442,34 @@ ExpectHandshakeSuccessful(); } -TEST_P(TlsServerHandshakerTest, CancelPendingProofSource) { +TEST_P(TlsServerHandshakerTest, CancelPendingSelectCert) { + if (!(GetQuicReloadableFlag(quic_tls_use_early_select_cert) && + GetQuicReloadableFlag(quic_tls_use_per_handshaker_proof_source))) { + return; + } + + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->CancelOutstandingCallbacks(); + ASSERT_FALSE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + // CompletePendingOperation should be noop. + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); +} + +TEST_P(TlsServerHandshakerTest, CancelPendingSignature) { EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); // Enable FakeProofSource to capture call to ComputeTlsSignature and run it
diff --git a/quic/test_tools/fake_proof_source_handle.cc b/quic/test_tools/fake_proof_source_handle.cc new file mode 100644 index 0000000..daa47b5 --- /dev/null +++ b/quic/test_tools/fake_proof_source_handle.cc
@@ -0,0 +1,222 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quic/test_tools/fake_proof_source_handle.h" +#include "quic/core/quic_types.h" +#include "quic/platform/api/quic_bug_tracker.h" + +namespace quic { +namespace test { +namespace { + +struct QUIC_EXPORT_PRIVATE ComputeSignatureResult { + bool ok; + std::string signature; + std::unique_ptr<ProofSource::Details> details; +}; + +class QUIC_EXPORT_PRIVATE ResultSavingSignatureCallback + : public ProofSource::SignatureCallback { + public: + explicit ResultSavingSignatureCallback( + absl::optional<ComputeSignatureResult>* result) + : result_(result) { + DCHECK(!result_->has_value()); + } + void Run(bool ok, + std::string signature, + std::unique_ptr<ProofSource::Details> details) override { + result_->emplace( + ComputeSignatureResult{ok, std::move(signature), std::move(details)}); + } + + private: + absl::optional<ComputeSignatureResult>* result_; +}; + +ComputeSignatureResult ComputeSignatureNow( + ProofSource* delegate, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in) { + absl::optional<ComputeSignatureResult> result; + delegate->ComputeTlsSignature( + server_address, client_address, hostname, signature_algorithm, in, + std::make_unique<ResultSavingSignatureCallback>(&result)); + CHECK(result.has_value()) << "delegate->ComputeTlsSignature must computes a " + "signature immediately"; + return std::move(result.value()); +} +} // namespace + +FakeProofSourceHandle::FakeProofSourceHandle( + ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action select_cert_action, + Action compute_signature_action) + : delegate_(delegate), + callback_(callback), + select_cert_action_(select_cert_action), + compute_signature_action_(compute_signature_action) {} + +void FakeProofSourceHandle::CancelPendingOperation() { + select_cert_op_.reset(); + compute_signature_op_.reset(); +} + +QuicAsyncStatus FakeProofSourceHandle::SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view client_hello, + const std::string& alpn, + const std::vector<uint8_t>& quic_transport_params, + const absl::optional<std::vector<uint8_t>>& early_data_context) { + if (select_cert_action_ == Action::DELEGATE_ASYNC || + select_cert_action_ == Action::FAIL_ASYNC) { + select_cert_op_.emplace(delegate_, callback_, select_cert_action_, + server_address, client_address, hostname, + client_hello, alpn, quic_transport_params, + early_data_context); + return QUIC_PENDING; + } else if (select_cert_action_ == Action::FAIL_SYNC) { + callback()->OnSelectCertificateDone(/*ok=*/false, + /*is_sync=*/true, nullptr); + return QUIC_FAILURE; + } + + DCHECK(select_cert_action_ == Action::DELEGATE_SYNC); + QuicReferenceCountedPointer<ProofSource::Chain> chain = + delegate_->GetCertChain(server_address, client_address, hostname); + + bool ok = chain && !chain->certs.empty(); + callback_->OnSelectCertificateDone(ok, /*is_sync=*/true, chain.get()); + return ok ? QUIC_SUCCESS : QUIC_FAILURE; +} + +QuicAsyncStatus FakeProofSourceHandle::ComputeSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) { + if (compute_signature_action_ == Action::DELEGATE_ASYNC || + compute_signature_action_ == Action::FAIL_ASYNC) { + compute_signature_op_.emplace( + delegate_, callback_, compute_signature_action_, server_address, + client_address, hostname, signature_algorithm, in, max_signature_size); + return QUIC_PENDING; + } else if (compute_signature_action_ == Action::FAIL_SYNC) { + callback()->OnComputeSignatureDone(/*ok=*/false, /*is_sync=*/true, + /*signature=*/"", /*details=*/nullptr); + return QUIC_FAILURE; + } + + DCHECK(compute_signature_action_ == Action::DELEGATE_SYNC); + ComputeSignatureResult result = + ComputeSignatureNow(delegate_, server_address, client_address, hostname, + signature_algorithm, in); + callback_->OnComputeSignatureDone( + result.ok, /*is_sync=*/true, result.signature, std::move(result.details)); + return result.ok ? QUIC_SUCCESS : QUIC_FAILURE; +} + +ProofSourceHandleCallback* FakeProofSourceHandle::callback() { + return callback_; +} + +bool FakeProofSourceHandle::HasPendingOperation() const { + int num_pending_operations = NumPendingOperations(); + return num_pending_operations > 0; +} + +void FakeProofSourceHandle::CompletePendingOperation() { + DCHECK_LE(NumPendingOperations(), 1); + + if (select_cert_op_.has_value()) { + select_cert_op_->Run(); + } else if (compute_signature_op_.has_value()) { + compute_signature_op_->Run(); + } +} + +int FakeProofSourceHandle::NumPendingOperations() const { + return static_cast<int>(select_cert_op_.has_value()) + + static_cast<int>(compute_signature_op_.has_value()); +} + +FakeProofSourceHandle::SelectCertOperation::SelectCertOperation( + ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action action, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view client_hello, + const std::string& alpn, + const std::vector<uint8_t>& quic_transport_params, + const absl::optional<std::vector<uint8_t>>& early_data_context) + : PendingOperation(delegate, callback, action), + server_address_(server_address), + client_address_(client_address), + hostname_(hostname), + client_hello_(client_hello), + alpn_(alpn), + quic_transport_params_(quic_transport_params), + early_data_context_(early_data_context) {} + +void FakeProofSourceHandle::SelectCertOperation::Run() { + if (action_ == Action::FAIL_ASYNC) { + callback_->OnSelectCertificateDone(/*ok=*/false, + /*is_sync=*/false, nullptr); + } else if (action_ == Action::DELEGATE_ASYNC) { + QuicReferenceCountedPointer<ProofSource::Chain> chain = + delegate_->GetCertChain(server_address_, client_address_, hostname_); + bool ok = chain && !chain->certs.empty(); + callback_->OnSelectCertificateDone(ok, /*is_sync=*/false, chain.get()); + } else { + QUIC_BUG << "Unexpected action: " << static_cast<int>(action_); + } +} + +FakeProofSourceHandle::ComputeSignatureOperation::ComputeSignatureOperation( + ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action action, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) + : PendingOperation(delegate, callback, action), + server_address_(server_address), + client_address_(client_address), + hostname_(hostname), + signature_algorithm_(signature_algorithm), + in_(in), + max_signature_size_(max_signature_size) {} + +void FakeProofSourceHandle::ComputeSignatureOperation::Run() { + if (action_ == Action::FAIL_ASYNC) { + callback_->OnComputeSignatureDone( + /*ok=*/false, /*is_sync=*/false, + /*signature=*/"", /*details=*/nullptr); + } else if (action_ == Action::DELEGATE_ASYNC) { + ComputeSignatureResult result = + ComputeSignatureNow(delegate_, server_address_, client_address_, + hostname_, signature_algorithm_, in_); + callback_->OnComputeSignatureDone(result.ok, /*is_sync=*/false, + result.signature, + std::move(result.details)); + } else { + QUIC_BUG << "Unexpected action: " << static_cast<int>(action_); + } +} + +} // namespace test +} // namespace quic
diff --git a/quic/test_tools/fake_proof_source_handle.h b/quic/test_tools/fake_proof_source_handle.h new file mode 100644 index 0000000..e013b90 --- /dev/null +++ b/quic/test_tools/fake_proof_source_handle.h
@@ -0,0 +1,147 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_HANDLE_H_ +#define QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_HANDLE_H_ + +#include "quic/core/crypto/proof_source.h" + +namespace quic { +namespace test { + +// FakeProofSourceHandle allows its behavior to be scripted for testing. +class FakeProofSourceHandle : public ProofSourceHandle { + public: + // What would an operation return when it is called. + enum class Action { + // Delegate the operation to |delegate_| immediately. + DELEGATE_SYNC = 0, + // Handle the operation asynchronously. Delegate the operation to + // |delegate_| when the caller calls CompletePendingOperation(). + DELEGATE_ASYNC, + // Fail the operation immediately. + FAIL_SYNC, + // Handle the operation asynchronously. Fail the operation when the caller + // calls CompletePendingOperation(). + FAIL_ASYNC, + }; + + // |delegate| must do cert selection and signature synchronously. + FakeProofSourceHandle(ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action select_cert_action, + Action compute_signature_action); + + ~FakeProofSourceHandle() override = default; + + void CancelPendingOperation() override; + + QuicAsyncStatus SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view client_hello, + const std::string& alpn, + const std::vector<uint8_t>& quic_transport_params, + const absl::optional<std::vector<uint8_t>>& early_data_context) override; + + QuicAsyncStatus ComputeSignature(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) override; + + ProofSourceHandleCallback* callback() override; + + // Whether there's a pending operation in |this|. + bool HasPendingOperation() const; + void CompletePendingOperation(); + + private: + class PendingOperation { + public: + PendingOperation(ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action action) + : delegate_(delegate), callback_(callback), action_(action) {} + virtual ~PendingOperation() = default; + virtual void Run() = 0; + + protected: + ProofSource* delegate_; + ProofSourceHandleCallback* callback_; + Action action_; + }; + + class SelectCertOperation : public PendingOperation { + public: + SelectCertOperation( + ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action action, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + absl::string_view client_hello, + const std::string& alpn, + const std::vector<uint8_t>& quic_transport_params, + const absl::optional<std::vector<uint8_t>>& early_data_context); + + ~SelectCertOperation() override = default; + + void Run() override; + + private: + QuicSocketAddress server_address_; + QuicSocketAddress client_address_; + std::string hostname_; + std::string client_hello_; + std::string alpn_; + std::vector<uint8_t> quic_transport_params_; + absl::optional<std::vector<uint8_t>> early_data_context_; + }; + + class ComputeSignatureOperation : public PendingOperation { + public: + ComputeSignatureOperation(ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action action, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size); + + ~ComputeSignatureOperation() override = default; + + void Run() override; + + private: + QuicSocketAddress server_address_; + QuicSocketAddress client_address_; + std::string hostname_; + uint16_t signature_algorithm_; + std::string in_; + size_t max_signature_size_; + }; + + private: + int NumPendingOperations() const; + + ProofSource* delegate_; + ProofSourceHandleCallback* callback_; + // Action for the next select cert operation. + Action select_cert_action_ = Action::DELEGATE_SYNC; + // Action for the next compute signature operation. + Action compute_signature_action_ = Action::DELEGATE_SYNC; + absl::optional<SelectCertOperation> select_cert_op_; + absl::optional<ComputeSignatureOperation> compute_signature_op_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_HANDLE_H_