Handle 0-RTT in TlsClientHandshaker

Tls-in-quic 0-rtt change, protected by disabled flag quic_enable_zero_rtt_for_tls

PiperOrigin-RevId: 312540775
Change-Id: I3d26ee14db86a7b81d0886f9951c41acb2d469b1
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 150c997..b3cfc0a 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -67,6 +67,7 @@
       pre_shared_key_(crypto_config->pre_shared_key()),
       crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
       has_application_state_(has_application_state),
+      attempting_zero_rtt_(crypto_config->early_data_enabled_for_tls()),
       tls_connection_(crypto_config->ssl_ctx(), this) {}
 
 TlsClientHandshaker::~TlsClientHandshaker() {
@@ -114,16 +115,16 @@
   }
 
   // Set a session to resume, if there is one.
+  std::unique_ptr<QuicResumptionState> cached_state;
   if (session_cache_) {
-    std::unique_ptr<QuicResumptionState> cached_state =
-        session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
-    if (cached_state) {
-      SSL_set_session(ssl(), cached_state->tls_session.get());
-      if (GetQuicReloadableFlag(quic_enable_zero_rtt_for_tls) &&
-          SSL_SESSION_early_data_capable(cached_state->tls_session.get())) {
-        if (!PrepareZeroRttConfig(cached_state.get())) {
-          return false;
-        }
+    cached_state = session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
+  }
+  if (cached_state) {
+    SSL_set_session(ssl(), cached_state->tls_session.get());
+    if (attempting_zero_rtt_ &&
+        SSL_SESSION_early_data_capable(cached_state->tls_session.get())) {
+      if (!PrepareZeroRttConfig(cached_state.get())) {
+        return false;
       }
     }
   }
@@ -427,6 +428,10 @@
   }
   int ssl_error = SSL_get_error(ssl(), rv);
   bool should_close = true;
+  if (ssl_error == SSL_ERROR_EARLY_DATA_REJECTED) {
+    HandleZeroRttReject();
+    return;
+  }
   switch (state_) {
     // TODO(b/153726130): handle the case where the server rejects early data.
     case STATE_HANDSHAKE_RUNNING:
@@ -455,6 +460,15 @@
 }
 
 void TlsClientHandshaker::FinishHandshake() {
+  if (SSL_in_early_data(ssl())) {
+    // SSL_do_handshake returns after sending the ClientHello if the session is
+    // 0-RTT-capable, which means that FinishHandshake will get called twice -
+    // the first time after sending the ClientHello, and the second time after
+    // the handshake is complete. If we're in the first time FinishHandshake is
+    // called, we can't do any end-of-handshake processing, so we return early
+    // from this function.
+    return;
+  }
   QUIC_LOG(INFO) << "Client: handshake finished";
   state_ = STATE_HANDSHAKE_COMPLETE;
   // Fill crypto_negotiated_params_:
@@ -504,6 +518,13 @@
   handshaker_delegate()->OnOneRttKeysAvailable();
 }
 
+void TlsClientHandshaker::HandleZeroRttReject() {
+  QUIC_LOG(INFO) << "0-RTT handshake attempted but was rejected by the server";
+  handshaker_delegate()->OnZeroRttRejected();
+  SSL_reset_early_data_reject(ssl());
+  AdvanceHandshake();
+}
+
 enum ssl_verify_result_t TlsClientHandshaker::VerifyCert(uint8_t* out_alert) {
   if (verify_result_ != ssl_verify_retry ||
       state_ == STATE_CERT_VERIFY_PENDING) {