In TlsHandshaker::AdvanceHandshake, retry SSL_do_handshake once if it "succeeded" when entering early data.

Protected by FLAGS_quic_reloadable_flag_quic_tls_retry_handshake_on_early_data.

This commit is a cherry-pick of 48b1681499b66762eb2ed7d9f68ccab668b0affa.
Its parent is the commit used by Chrome m91 DEPS.
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 9443349..910db54 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -64,6 +64,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_true, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_normalized_sni_for_cert_selectioon, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_per_handshaker_proof_source, true)
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_retry_handshake_on_early_data, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_unified_iw_options, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_unify_stop_sending, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_encryption_level_context, true)
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index b157500..f8e3216 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -446,26 +446,23 @@
 }
 
 void TlsClientHandshaker::FinishHandshake() {
-  // Fill crypto_negotiated_params_:
-  const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
-  if (cipher) {
-    crypto_negotiated_params_->cipher_suite =
-        SSL_CIPHER_get_protocol_id(cipher);
-  }
-  crypto_negotiated_params_->key_exchange_group = SSL_get_curve_id(ssl());
-  crypto_negotiated_params_->peer_signature_algorithm =
-      SSL_get_peer_signature_algorithm(ssl());
-  if (SSL_in_early_data(ssl())) {
-    // SSL_do_handshake returns after sending the ClientHello if the session is
-    // 0-RTT-capable, which means that FinishHandshake will get called twice -
-    // the first time after sending the ClientHello, and the second time after
-    // the handshake is complete. If we're in the first time FinishHandshake is
-    // called, we can't do any end-of-handshake processing.
+  FillNegotiatedParams();
 
-    // If we're attempting a 0-RTT handshake, then we need to let the transport
-    // and application know what state to apply to early data.
-    PrepareZeroRttConfig(cached_state_.get());
-    return;
+  if (retry_handshake_on_early_data_) {
+    QUICHE_CHECK(!SSL_in_early_data(ssl()));
+  } else {
+    if (SSL_in_early_data(ssl())) {
+      // SSL_do_handshake returns after sending the ClientHello if the session
+      // is 0-RTT-capable, which means that FinishHandshake will get called
+      // twice - the first time after sending the ClientHello, and the second
+      // time after the handshake is complete. If we're in the first time
+      // FinishHandshake is called, we can't do any end-of-handshake processing.
+
+      // If we're attempting a 0-RTT handshake, then we need to let the
+      // transport and application know what state to apply to early data.
+      PrepareZeroRttConfig(cached_state_.get());
+      return;
+    }
   }
   QUIC_LOG(INFO) << "Client: handshake finished";
 
@@ -526,6 +523,30 @@
   handshaker_delegate()->OnTlsHandshakeComplete();
 }
 
+void TlsClientHandshaker::OnEnterEarlyData() {
+  QUICHE_DCHECK(retry_handshake_on_early_data_);
+  QUICHE_DCHECK(SSL_in_early_data(ssl()));
+
+  // TODO(wub): It might be unnecessary to FillNegotiatedParams() at this time,
+  // because we fill it again when handshake completes.
+  FillNegotiatedParams();
+
+  // If we're attempting a 0-RTT handshake, then we need to let the transport
+  // and application know what state to apply to early data.
+  PrepareZeroRttConfig(cached_state_.get());
+}
+
+void TlsClientHandshaker::FillNegotiatedParams() {
+  const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
+  if (cipher) {
+    crypto_negotiated_params_->cipher_suite =
+        SSL_CIPHER_get_protocol_id(cipher);
+  }
+  crypto_negotiated_params_->key_exchange_group = SSL_get_curve_id(ssl());
+  crypto_negotiated_params_->peer_signature_algorithm =
+      SSL_get_peer_signature_algorithm(ssl());
+}
+
 void TlsClientHandshaker::ProcessPostHandshakeMessage() {
   int rv = SSL_process_quic_post_handshake(ssl());
   if (rv != 1) {
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 42d587f..8fb6da6 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -90,6 +90,8 @@
   }
 
   void FinishHandshake() override;
+  void OnEnterEarlyData() override;
+  void FillNegotiatedParams();
   void ProcessPostHandshakeMessage() override;
   bool ShouldCloseConnectionOnUnexpectedError(int ssl_error) override;
   QuicAsyncStatus VerifyCertChain(
diff --git a/quic/core/tls_client_handshaker_test.cc b/quic/core/tls_client_handshaker_test.cc
index 807f49e..e498acb 100644
--- a/quic/core/tls_client_handshaker_test.cc
+++ b/quic/core/tls_client_handshaker_test.cc
@@ -423,6 +423,62 @@
   EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_accepted);
 }
 
+// Regression test for b/186438140.
+TEST_P(TlsClientHandshakerTest, ZeroRttResumptionWithAyncProofVerifier) {
+  // Finish establishing the first connection, so the second connection can
+  // resume.
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_FALSE(stream()->IsResumption());
+
+  // Create a second connection.
+  CreateConnection();
+  InitializeFakeServer();
+  EXPECT_CALL(*session_, OnConfigNegotiated());
+  EXPECT_CALL(*connection_, SendCryptoData(_, _, _))
+      .Times(testing::AnyNumber());
+  // Enable TestProofVerifier to capture the call to VerifyCertChain and run it
+  // asynchronously.
+  TestProofVerifier* proof_verifier =
+      static_cast<TestProofVerifier*>(crypto_config_->proof_verifier());
+  proof_verifier->Activate();
+  // Start the second handshake.
+  stream()->CryptoConnect();
+
+  ASSERT_EQ(proof_verifier->NumPendingCallbacks(), 1u);
+
+  // Advance the handshake with the server. Since cert verification has not
+  // finished yet, client cannot derive HANDSHAKE and 1-RTT keys.
+  crypto_test_utils::AdvanceHandshake(connection_, stream(), 0,
+                                      server_connection_, server_stream(), 0);
+
+  EXPECT_FALSE(stream()->one_rtt_keys_available());
+  EXPECT_FALSE(server_stream()->one_rtt_keys_available());
+
+  // Finish cert verification after receiving packets from server.
+  proof_verifier->InvokePendingCallback(0);
+
+  QuicFramer* framer = QuicConnectionPeer::GetFramer(connection_);
+  if (!GetQuicReloadableFlag(quic_tls_retry_handshake_on_early_data)) {
+    // Client does not have HANDSHAKE key due to b/186438140.
+    EXPECT_EQ(nullptr,
+              QuicFramerPeer::GetEncrypter(framer, ENCRYPTION_HANDSHAKE));
+    return;
+  }
+
+  // Verify client has derived HANDSHAKE key.
+  EXPECT_NE(nullptr,
+            QuicFramerPeer::GetEncrypter(framer, ENCRYPTION_HANDSHAKE));
+
+  // Ideally, we should also verify that the process_undecryptable_packets_alarm
+  // is set and processing the undecryptable packets can advance the handshake
+  // to completion. Unfortunately, the test facilities used in this test does
+  // not support queuing and processing undecryptable packets.
+}
+
 TEST_P(TlsClientHandshakerTest, ZeroRttRejection) {
   // Finish establishing the first connection:
   CompleteCryptoHandshake();
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index 56110c7..bc8e0b4 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -12,9 +12,12 @@
 #include "quic/core/quic_crypto_stream.h"
 #include "quic/core/tls_client_handshaker.h"
 #include "quic/platform/api/quic_bug_tracker.h"
+#include "quic/platform/api/quic_stack_trace.h"
 
 namespace quic {
 
+#define ENDPOINT (SSL_is_server(ssl()) ? "TlsServer: " : "TlsClient: ")
+
 TlsHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl(
     TlsHandshaker* parent)
     : parent_(parent) {}
@@ -93,8 +96,38 @@
     return;
   }
 
-  QUIC_VLOG(1) << "TlsHandshaker: continuing handshake";
+  QUIC_VLOG(1) << ENDPOINT << "Continuing handshake";
   int rv = SSL_do_handshake(ssl());
+
+  // If SSL_do_handshake return success(1) and we are in early data, it is
+  // possible that we have provided ServerHello to BoringSSL but it hasn't been
+  // processed. Retry SSL_do_handshake once will advance the handshake more in
+  // that case. If there are no unprocessed ServerHello, the retry will return a
+  // non-positive number.
+  if (retry_handshake_on_early_data_ && rv == 1 && SSL_in_early_data(ssl())) {
+    OnEnterEarlyData();
+    rv = SSL_do_handshake(ssl());
+    QUIC_VLOG(1) << ENDPOINT
+                 << "SSL_do_handshake returned when entering early data. After "
+                 << "retry, rv=" << rv
+                 << ", SSL_in_early_data=" << SSL_in_early_data(ssl());
+    // The retry should either
+    // - Return <= 0 if the handshake is still pending, likely still in early
+    //   data.
+    // - Return 1 if the handshake has _actually_ finished. i.e.
+    //   SSL_in_early_data should be false.
+    //
+    // In either case, it should not both return 1 and stay in early data.
+    if (rv == 1 && SSL_in_early_data(ssl()) && !is_connection_closed_) {
+      QUIC_BUG(quic_handshaker_stay_in_early_data)
+          << "The original and the retry of SSL_do_handshake both returned "
+             "success and in early data";
+      CloseConnection(QUIC_HANDSHAKE_FAILED,
+                      "TLS handshake failed: Still in early data after retry");
+      return;
+    }
+  }
+
   if (rv == 1) {
     FinishHandshake();
     return;
@@ -200,7 +233,7 @@
 void TlsHandshaker::SetWriteSecret(EncryptionLevel level,
                                    const SSL_CIPHER* cipher,
                                    const std::vector<uint8_t>& write_secret) {
-  QUIC_DVLOG(1) << "SetWriteSecret level=" << level;
+  QUIC_DVLOG(1) << ENDPOINT << "SetWriteSecret level=" << level;
   std::unique_ptr<QuicEncrypter> encrypter =
       QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
   const EVP_MD* prf = Prf(cipher);
@@ -223,7 +256,7 @@
 bool TlsHandshaker::SetReadSecret(EncryptionLevel level,
                                   const SSL_CIPHER* cipher,
                                   const std::vector<uint8_t>& read_secret) {
-  QUIC_DVLOG(1) << "SetReadSecret level=" << level;
+  QUIC_DVLOG(1) << ENDPOINT << "SetReadSecret level=" << level;
   std::unique_ptr<QuicDecrypter> decrypter =
       QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
   const EVP_MD* prf = Prf(cipher);
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h
index 4f8bf47..6bf48b5 100644
--- a/quic/core/tls_handshaker.h
+++ b/quic/core/tls_handshaker.h
@@ -73,8 +73,20 @@
   // finished. Note that due to 0-RTT, the handshake may "finish" twice;
   // |SSL_in_early_data| can be used to determine whether the handshake is truly
   // done.
+  // TODO(wub): When --quic_tls_retry_handshake_on_early_data is true, this
+  // function will only be called once when the handshake actually finishes.
+  // Update comment when deprecating the flag.
   virtual void FinishHandshake() = 0;
 
+  // Called when |SSL_do_handshake| returns 1 and the connection is in early
+  // data. In that case, |AdvanceHandshake| will call |OnEnterEarlyData| and
+  // retry |SSL_do_handshake| once.
+  virtual void OnEnterEarlyData() {
+    // By default, do nothing but check the preconditions.
+    QUICHE_DCHECK(retry_handshake_on_early_data_);
+    QUICHE_DCHECK(SSL_in_early_data(ssl()));
+  }
+
   // Called when a handshake message is received after the handshake is
   // complete.
   virtual void ProcessPostHandshakeMessage() = 0;
@@ -155,6 +167,9 @@
   // error code corresponding to the TLS alert description |desc|.
   void SendAlert(EncryptionLevel level, uint8_t desc) override;
 
+  const bool retry_handshake_on_early_data_ =
+      GetQuicReloadableFlag(quic_tls_retry_handshake_on_early_data);
+
  private:
   // ProofVerifierCallbackImpl handles the result of an asynchronous certificate
   // verification operation.
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index cc536e7..b628db1 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -540,15 +540,19 @@
 }
 
 void TlsServerHandshaker::FinishHandshake() {
-  if (SSL_in_early_data(ssl())) {
-    // If the server accepts early data, SSL_do_handshake returns success twice:
-    // once after processing the ClientHello and sending the server's first
-    // flight, and then again after the handshake is complete. This results in
-    // FinishHandshake getting called twice. On the first call to
-    // FinishHandshake, we don't have any confirmation that the client is live,
-    // so all end of handshake processing is deferred until the handshake is
-    // actually complete.
-    return;
+  if (retry_handshake_on_early_data_) {
+    QUICHE_DCHECK(!SSL_in_early_data(ssl()));
+  } else {
+    if (SSL_in_early_data(ssl())) {
+      // If the server accepts early data, SSL_do_handshake returns success
+      // twice: once after processing the ClientHello and sending the server's
+      // first flight, and then again after the handshake is complete. This
+      // results in FinishHandshake getting called twice. On the first call to
+      // FinishHandshake, we don't have any confirmation that the client is
+      // live, so all end of handshake processing is deferred until the
+      // handshake is actually complete.
+      return;
+    }
   }
   if (!valid_alpn_received_) {
     QUIC_DLOG(ERROR)