Minor tweaks to TlsServerHandshaker:
- Change TlsServerHandshaker::SelectCertStatus() to TlsServerHandshaker::select_cert_status().
- Set expected_ssl_error to SSL_ERROR_WANT_READ in OnSelectCertificateDone().
- Add a test for async select cert and async signature.

Protected by FLAGS_quic_tls_use_per_handshaker_proof_source.

PiperOrigin-RevId: 351864828
Change-Id: I650c786a0a74bba0df1063be525a028485f5d0dc
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 8fc52f1..2b14657 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -62,7 +62,13 @@
 
   handshaker_->OnSelectCertificateDone(
       /*ok=*/true, /*is_sync=*/true, chain.get());
-  return handshaker_->SelectCertStatus();
+  if (!handshaker_->select_cert_status().has_value()) {
+    QUIC_BUG
+        << "select_cert_status() has no value after a synchronous select cert";
+    // Return success to continue the handshake.
+    return QUIC_SUCCESS;
+  }
+  return handshaker_->select_cert_status().value();
 }
 
 QuicAsyncStatus TlsServerHandshaker::DefaultProofSourceHandle::ComputeSignature(
@@ -567,6 +573,7 @@
     size_t max_out,
     uint16_t sig_alg,
     absl::string_view in) {
+  DCHECK_EQ(expected_ssl_error(), SSL_ERROR_WANT_READ);
   if (use_proof_source_handle_) {
     QUIC_RELOADABLE_FLAG_COUNT_N(quic_tls_use_per_handshaker_proof_source, 2,
                                  3);
@@ -774,7 +781,7 @@
         set_transport_params_result.quic_transport_params,
         set_transport_params_result.early_data_context);
 
-    DCHECK_EQ(status, SelectCertStatus());
+    DCHECK_EQ(status, select_cert_status().value());
 
     if (status == QUIC_PENDING) {
       set_expected_ssl_error(SSL_ERROR_PENDING_CERTIFICATE);
@@ -835,21 +842,14 @@
       QUIC_LOG(ERROR) << "No certs provided for host '" << hostname_ << "'";
     }
   }
-
+  const int last_expected_ssl_error = expected_ssl_error();
+  set_expected_ssl_error(SSL_ERROR_WANT_READ);
   if (!is_sync) {
+    DCHECK_EQ(last_expected_ssl_error, SSL_ERROR_PENDING_CERTIFICATE);
     AdvanceHandshakeFromCallback();
   }
 }
 
-QuicAsyncStatus TlsServerHandshaker::SelectCertStatus() const {
-  if (!select_cert_status_.has_value()) {
-    QUIC_BUG << "SelectCertStatus should be called after select cert started";
-    return QUIC_PENDING;
-  }
-
-  return select_cert_status_.value();
-}
-
 bool TlsServerHandshaker::ValidateHostname(const std::string& hostname) const {
   if (!QuicHostnameUtils::IsValidSNI(hostname)) {
     // TODO(b/151676147): Include this error string in the CONNECTION_CLOSE
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 9b1cdae..2a09b50 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -144,8 +144,10 @@
                                              absl::string_view in) override;
   TlsConnection::Delegate* ConnectionDelegate() override { return this; }
 
-  // The status of cert selection. Only called after cert selection started.
-  QuicAsyncStatus SelectCertStatus() const;
+  // The status of cert selection. nullopt means it hasn't started.
+  const absl::optional<QuicAsyncStatus>& select_cert_status() const {
+    return select_cert_status_;
+  }
   // Whether |cert_verify_sig_| contains a valid signature.
   // NOTE: BoringSSL queries the result of a async signature operation using
   // PrivateKeyComplete(), a successful PrivateKeyComplete() will clear the
diff --git a/quic/core/tls_server_handshaker_test.cc b/quic/core/tls_server_handshaker_test.cc
index e650a4a..c0b1337 100644
--- a/quic/core/tls_server_handshaker_test.cc
+++ b/quic/core/tls_server_handshaker_test.cc
@@ -101,6 +101,8 @@
     return fake_proof_source_handle_;
   }
 
+  using TlsServerHandshaker::expected_ssl_error;
+
  private:
   std::unique_ptr<ProofSourceHandle> RealMaybeCreateProofSourceHandle() {
     return TlsServerHandshaker::MaybeCreateProofSourceHandle();
@@ -424,6 +426,47 @@
   EXPECT_EQ(moved_messages_counts_.second, 0u);
 }
 
+TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSelectCertAndSignature) {
+  if (!(GetQuicReloadableFlag(quic_tls_use_early_select_cert) &&
+        GetQuicReloadableFlag(quic_tls_use_per_handshaker_proof_source))) {
+    return;
+  }
+
+  InitializeServerWithFakeProofSourceHandle();
+  server_handshaker_->SetupProofSourceHandle(
+      /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC,
+      /*compute_signature_action=*/FakeProofSourceHandle::Action::
+          DELEGATE_ASYNC);
+
+  EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0);
+  EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0);
+
+  // Start handshake.
+  AdvanceHandshakeWithFakeClient();
+
+  // A select cert operation is now pending.
+  ASSERT_TRUE(
+      server_handshaker_->fake_proof_source_handle()->HasPendingOperation());
+  EXPECT_EQ(server_handshaker_->expected_ssl_error(),
+            SSL_ERROR_PENDING_CERTIFICATE);
+
+  // Complete the pending select cert. It should advance the handshake to
+  // compute a signature, which will also be saved as a pending operation.
+  server_handshaker_->fake_proof_source_handle()->CompletePendingOperation();
+
+  // A compute signature operation is now pending.
+  ASSERT_TRUE(
+      server_handshaker_->fake_proof_source_handle()->HasPendingOperation());
+  EXPECT_EQ(server_handshaker_->expected_ssl_error(),
+            SSL_ERROR_WANT_PRIVATE_KEY_OPERATION);
+
+  server_handshaker_->fake_proof_source_handle()->CompletePendingOperation();
+
+  CompleteCryptoHandshake();
+
+  ExpectHandshakeSuccessful();
+}
+
 TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSignature) {
   EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0);
   EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0);
diff --git a/quic/test_tools/fake_proof_source_handle.cc b/quic/test_tools/fake_proof_source_handle.cc
index daa47b5..437a913 100644
--- a/quic/test_tools/fake_proof_source_handle.cc
+++ b/quic/test_tools/fake_proof_source_handle.cc
@@ -139,8 +139,10 @@
 
   if (select_cert_op_.has_value()) {
     select_cert_op_->Run();
+    select_cert_op_.reset();
   } else if (compute_signature_op_.has_value()) {
     compute_signature_op_->Run();
+    compute_signature_op_.reset();
   }
 }