Handle 0-RTT in TlsClientHandshaker

Tls-in-quic 0-rtt change, protected by disabled flag quic_enable_zero_rtt_for_tls

PiperOrigin-RevId: 312540775
Change-Id: I3d26ee14db86a7b81d0886f9951c41acb2d469b1
diff --git a/quic/core/crypto/quic_crypto_client_config.cc b/quic/core/crypto/quic_crypto_client_config.cc
index 7f8edc2..fd1d41c 100644
--- a/quic/core/crypto/quic_crypto_client_config.cc
+++ b/quic/core/crypto/quic_crypto_client_config.cc
@@ -67,7 +67,9 @@
     std::unique_ptr<SessionCache> session_cache)
     : proof_verifier_(std::move(proof_verifier)),
       session_cache_(std::move(session_cache)),
-      ssl_ctx_(TlsClientConnection::CreateSslCtx()) {
+      enable_zero_rtt_for_tls_(
+          GetQuicReloadableFlag(quic_enable_zero_rtt_for_tls)),
+      ssl_ctx_(TlsClientConnection::CreateSslCtx(enable_zero_rtt_for_tls_)) {
   DCHECK(proof_verifier_.get());
   SetDefaults();
 }
diff --git a/quic/core/crypto/quic_crypto_client_config.h b/quic/core/crypto/quic_crypto_client_config.h
index e3867c8..9a87556 100644
--- a/quic/core/crypto/quic_crypto_client_config.h
+++ b/quic/core/crypto/quic_crypto_client_config.h
@@ -368,6 +368,8 @@
   void set_proof_source(std::unique_ptr<ProofSource> proof_source);
   SSL_CTX* ssl_ctx() const;
 
+  bool early_data_enabled_for_tls() const { return enable_zero_rtt_for_tls_; }
+
   // Initialize the CachedState from |canonical_crypto_config| for the
   // |canonical_server_id| as the initial CachedState for |server_id|. We will
   // copy config data only if |canonical_crypto_config| has valid proof.
@@ -450,6 +452,9 @@
   std::unique_ptr<ProofVerifier> proof_verifier_;
   std::unique_ptr<SessionCache> session_cache_;
   std::unique_ptr<ProofSource> proof_source_;
+
+  // Latched value of reloadable flag quic_enable_zero_rtt_for_tls
+  bool enable_zero_rtt_for_tls_;
   bssl::UniquePtr<SSL_CTX> ssl_ctx_;
 
   // The |user_agent_id_| passed in QUIC's CHLO message.
diff --git a/quic/core/crypto/tls_client_connection.cc b/quic/core/crypto/tls_client_connection.cc
index 7d11224..7908847 100644
--- a/quic/core/crypto/tls_client_connection.cc
+++ b/quic/core/crypto/tls_client_connection.cc
@@ -11,7 +11,8 @@
       delegate_(delegate) {}
 
 // static
-bssl::UniquePtr<SSL_CTX> TlsClientConnection::CreateSslCtx() {
+bssl::UniquePtr<SSL_CTX> TlsClientConnection::CreateSslCtx(
+    bool enable_early_data) {
   bssl::UniquePtr<SSL_CTX> ssl_ctx = TlsConnection::CreateSslCtx();
   // Configure certificate verification.
   SSL_CTX_set_custom_verify(ssl_ctx.get(), SSL_VERIFY_PEER, &VerifyCallback);
@@ -22,6 +23,8 @@
   SSL_CTX_set_session_cache_mode(
       ssl_ctx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL);
   SSL_CTX_sess_set_new_cb(ssl_ctx.get(), NewSessionCallback);
+
+  SSL_CTX_set_early_data_enabled(ssl_ctx.get(), enable_early_data);
   return ssl_ctx;
 }
 
diff --git a/quic/core/crypto/tls_client_connection.h b/quic/core/crypto/tls_client_connection.h
index 035f420..a7ef209 100644
--- a/quic/core/crypto/tls_client_connection.h
+++ b/quic/core/crypto/tls_client_connection.h
@@ -39,7 +39,7 @@
   TlsClientConnection(SSL_CTX* ssl_ctx, Delegate* delegate);
 
   // Creates and configures an SSL_CTX that is appropriate for clients to use.
-  static bssl::UniquePtr<SSL_CTX> CreateSslCtx();
+  static bssl::UniquePtr<SSL_CTX> CreateSslCtx(bool enable_early_data);
 
  private:
   // Registered as the callback for SSL_CTX_set_custom_verify. The
diff --git a/quic/core/handshaker_delegate_interface.h b/quic/core/handshaker_delegate_interface.h
index 0e8e045..c157e08 100644
--- a/quic/core/handshaker_delegate_interface.h
+++ b/quic/core/handshaker_delegate_interface.h
@@ -54,6 +54,10 @@
   // encryption level and 2) a server successfully processes a forward secure
   // packet.
   virtual void NeuterHandshakeData() = 0;
+
+  // Called when 0-RTT data is rejected by the server. This is only called in
+  // TLS handshakes and only called on clients.
+  virtual void OnZeroRttRejected() = 0;
 };
 
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_client_session_test.cc b/quic/core/http/quic_spdy_client_session_test.cc
index 7d4ae81..be80a13 100644
--- a/quic/core/http/quic_spdy_client_session_test.cc
+++ b/quic/core/http/quic_spdy_client_session_test.cc
@@ -93,6 +93,8 @@
             QuicUtils::GetInvalidStreamId(GetParam().transport_version)) {
     auto client_cache = std::make_unique<test::SimpleSessionCache>();
     client_session_cache_ = client_cache.get();
+    SetQuicReloadableFlag(quic_enable_tls_resumption, true);
+    SetQuicReloadableFlag(quic_enable_zero_rtt_for_tls, true);
     crypto_config_ = std::make_unique<QuicCryptoClientConfig>(
         crypto_test_utils::ProofVerifierForTesting(), std::move(client_cache));
     Initialize();
@@ -171,8 +173,6 @@
     } else {
       config.SetMaxBidirectionalStreamsToSend(server_max_incoming_streams);
     }
-    SetQuicReloadableFlag(quic_enable_tls_resumption, true);
-    SetQuicReloadableFlag(quic_enable_zero_rtt_for_tls, true);
     std::unique_ptr<QuicCryptoServerConfig> crypto_config =
         crypto_test_utils::CryptoServerConfigForTesting();
     crypto_test_utils::HandshakeWithFakeServer(
@@ -1028,7 +1028,6 @@
       kInitialSessionFlowControlWindowForTest + 1);
   config.SetMaxBidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection + 1);
   config.SetMaxUnidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection + 1);
-  SetQuicReloadableFlag(quic_enable_tls_resumption, true);
   std::unique_ptr<QuicCryptoServerConfig> crypto_config =
       crypto_test_utils::CryptoServerConfigForTesting();
   crypto_test_utils::HandshakeWithFakeServer(
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index 32ada39..bfcd169 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -1609,6 +1609,10 @@
   connection()->OnHandshakeComplete();
 }
 
+void QuicSession::OnZeroRttRejected() {
+  // TODO(b/153726130): Handle early data rejection.
+}
+
 void QuicSession::OnCryptoHandshakeMessageSent(
     const CryptoHandshakeMessage& /*message*/) {}
 
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index 490f4d0..0aa3ae4 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -255,6 +255,7 @@
   void DiscardOldEncryptionKey(EncryptionLevel level) override;
   void NeuterUnencryptedData() override;
   void NeuterHandshakeData() override;
+  void OnZeroRttRejected() override;
 
   // Implement StreamDelegateInterface.
   void OnStreamError(QuicErrorCode error_code,
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 150c997..b3cfc0a 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -67,6 +67,7 @@
       pre_shared_key_(crypto_config->pre_shared_key()),
       crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
       has_application_state_(has_application_state),
+      attempting_zero_rtt_(crypto_config->early_data_enabled_for_tls()),
       tls_connection_(crypto_config->ssl_ctx(), this) {}
 
 TlsClientHandshaker::~TlsClientHandshaker() {
@@ -114,16 +115,16 @@
   }
 
   // Set a session to resume, if there is one.
+  std::unique_ptr<QuicResumptionState> cached_state;
   if (session_cache_) {
-    std::unique_ptr<QuicResumptionState> cached_state =
-        session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
-    if (cached_state) {
-      SSL_set_session(ssl(), cached_state->tls_session.get());
-      if (GetQuicReloadableFlag(quic_enable_zero_rtt_for_tls) &&
-          SSL_SESSION_early_data_capable(cached_state->tls_session.get())) {
-        if (!PrepareZeroRttConfig(cached_state.get())) {
-          return false;
-        }
+    cached_state = session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
+  }
+  if (cached_state) {
+    SSL_set_session(ssl(), cached_state->tls_session.get());
+    if (attempting_zero_rtt_ &&
+        SSL_SESSION_early_data_capable(cached_state->tls_session.get())) {
+      if (!PrepareZeroRttConfig(cached_state.get())) {
+        return false;
       }
     }
   }
@@ -427,6 +428,10 @@
   }
   int ssl_error = SSL_get_error(ssl(), rv);
   bool should_close = true;
+  if (ssl_error == SSL_ERROR_EARLY_DATA_REJECTED) {
+    HandleZeroRttReject();
+    return;
+  }
   switch (state_) {
     // TODO(b/153726130): handle the case where the server rejects early data.
     case STATE_HANDSHAKE_RUNNING:
@@ -455,6 +460,15 @@
 }
 
 void TlsClientHandshaker::FinishHandshake() {
+  if (SSL_in_early_data(ssl())) {
+    // SSL_do_handshake returns after sending the ClientHello if the session is
+    // 0-RTT-capable, which means that FinishHandshake will get called twice -
+    // the first time after sending the ClientHello, and the second time after
+    // the handshake is complete. If we're in the first time FinishHandshake is
+    // called, we can't do any end-of-handshake processing, so we return early
+    // from this function.
+    return;
+  }
   QUIC_LOG(INFO) << "Client: handshake finished";
   state_ = STATE_HANDSHAKE_COMPLETE;
   // Fill crypto_negotiated_params_:
@@ -504,6 +518,13 @@
   handshaker_delegate()->OnOneRttKeysAvailable();
 }
 
+void TlsClientHandshaker::HandleZeroRttReject() {
+  QUIC_LOG(INFO) << "0-RTT handshake attempted but was rejected by the server";
+  handshaker_delegate()->OnZeroRttRejected();
+  SSL_reset_early_data_reject(ssl());
+  AdvanceHandshake();
+}
+
 enum ssl_verify_result_t TlsClientHandshaker::VerifyCert(uint8_t* out_alert) {
   if (verify_result_ != ssl_verify_retry ||
       state_ == STATE_CERT_VERIFY_PENDING) {
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index cc60121..ad35ba1 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -124,6 +124,7 @@
   bool SetTransportParameters();
   bool ProcessTransportParameters(std::string* error_details);
   void FinishHandshake();
+  void HandleZeroRttReject();
 
   // Called when server completes handshake (i.e., either handshake done is
   // received or 1-RTT packet gets acknowledged).
@@ -175,6 +176,7 @@
   bool allow_invalid_sni_for_tests_ = false;
 
   const bool has_application_state_;
+  bool attempting_zero_rtt_;
 
   TlsClientConnection tls_connection_;
 
diff --git a/quic/core/tls_client_handshaker_test.cc b/quic/core/tls_client_handshaker_test.cc
index 85492b8..fff5135 100644
--- a/quic/core/tls_client_handshaker_test.cc
+++ b/quic/core/tls_client_handshaker_test.cc
@@ -163,13 +163,13 @@
   TlsClientHandshakerTest()
       : supported_versions_({GetParam()}),
         server_id_(kServerHostname, kServerPort, false),
-        crypto_config_(std::make_unique<QuicCryptoClientConfig>(
-            std::make_unique<TestProofVerifier>(),
-            std::make_unique<test::SimpleSessionCache>())),
         server_compressed_certs_cache_(
             QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) {
     SetQuicReloadableFlag(quic_enable_tls_resumption, true);
     SetQuicReloadableFlag(quic_enable_zero_rtt_for_tls, true);
+    crypto_config_ = std::make_unique<QuicCryptoClientConfig>(
+        std::make_unique<TestProofVerifier>(),
+        std::make_unique<test::SimpleSessionCache>());
     server_crypto_config_ = crypto_test_utils::CryptoServerConfigForTesting();
     CreateConnection();
   }
@@ -338,6 +338,59 @@
   EXPECT_TRUE(stream()->IsResumption());
 }
 
+// TODO(b/152551499): This test is currently broken because the logic to reject
+// 0-RTT is overzealous. It currently requires a byte-for-byte match of the
+// Transport Parameters (between the ones that the server sent on the connection
+// where it issued a ticket, and the ones that the server is sending on the
+// connection where it is potentially accepting early data). This is broken
+// because the stateless reset token in the Transport Parameters necessarily
+// must be different between those two connections. Once that check is relaxed,
+// this test can be enabled.
+TEST_P(TlsClientHandshakerTest, DISABLED_ZeroRttResumption) {
+  // Finish establishing the first connection:
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_FALSE(stream()->IsResumption());
+
+  // Create a second connection
+  CreateConnection();
+  CompleteCryptoHandshake();
+
+  // TODO(b/152551499): Add a test that checks we have keys after calling
+  // stream()->CryptoConnect().
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_TRUE(stream()->IsResumption());
+  EXPECT_TRUE(stream()->EarlyDataAccepted());
+}
+
+// TODO(b/152551499): Also test resumption getting rejected.
+TEST_P(TlsClientHandshakerTest, ZeroRttRejection) {
+  // Finish establishing the first connection:
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_FALSE(stream()->IsResumption());
+
+  // Create a second connection, but disable 0-RTT on the server.
+  SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false);
+  CreateConnection();
+  EXPECT_CALL(*session_, OnZeroRttRejected());
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_TRUE(stream()->IsResumption());
+  EXPECT_FALSE(stream()->EarlyDataAccepted());
+}
+
 TEST_P(TlsClientHandshakerTest, ClientSendsNoSNI) {
   // Reconfigure client to sent an empty server hostname. The crypto config also
   // needs to be recreated to use a FakeProofVerifier since the server's cert
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 203726a..012cb11 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -353,6 +353,16 @@
 }
 
 void TlsServerHandshaker::FinishHandshake() {
+  if (SSL_in_early_data(ssl())) {
+    // If the server accepts early data, SSL_do_handshake returns success twice:
+    // once after processing the ClientHello and sending the server's first
+    // flight, and then again after the handshake is complete. This results in
+    // FinishHandshake getting called twice. On the first call to
+    // FinishHandshake, we don't have any confirmation that the client is live,
+    // so all end of handshake processing is deferred until the handshake is
+    // actually complete.
+    return;
+  }
   if (!valid_alpn_received_) {
     QUIC_DLOG(ERROR)
         << "Server: handshake finished without receiving a known ALPN";
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc
index 4864441..643ae59 100644
--- a/quic/test_tools/crypto_test_utils.cc
+++ b/quic/test_tools/crypto_test_utils.cc
@@ -574,7 +574,7 @@
                      std::string label) {
   if (encrypter == nullptr || decrypter == nullptr) {
     ADD_FAILURE() << "Expected non-null crypters; have " << encrypter << " and "
-                  << decrypter;
+                  << decrypter << " for " << label;
     return;
   }
   quiche::QuicheStringPiece encrypter_key = encrypter->GetKey();
@@ -605,7 +605,8 @@
     const QuicDecrypter* server_decrypter(
         QuicFramerPeer::GetDecrypter(server_framer, level));
     if (level == ENCRYPTION_FORWARD_SECURE ||
-        !((level == ENCRYPTION_HANDSHAKE || client_encrypter == nullptr) &&
+        !((level == ENCRYPTION_HANDSHAKE || level == ENCRYPTION_ZERO_RTT ||
+           client_encrypter == nullptr) &&
           server_decrypter == nullptr)) {
       CompareCrypters(client_encrypter, server_decrypter,
                       "client " + EncryptionLevelString(level) + " write");
@@ -616,7 +617,8 @@
         QuicFramerPeer::GetDecrypter(client_framer, level));
     if (level == ENCRYPTION_FORWARD_SECURE ||
         !(server_encrypter == nullptr &&
-          (level == ENCRYPTION_HANDSHAKE || client_decrypter == nullptr))) {
+          (level == ENCRYPTION_HANDSHAKE || level == ENCRYPTION_ZERO_RTT ||
+           client_decrypter == nullptr))) {
       CompareCrypters(server_encrypter, client_decrypter,
                       "server " + EncryptionLevelString(level) + " write");
     }
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 92b1791..491b828 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -1139,6 +1139,7 @@
   MOCK_METHOD(bool, ShouldCreateOutgoingUnidirectionalStream, (), (override));
   MOCK_METHOD(std::vector<std::string>, GetAlpnsToOffer, (), (const, override));
   MOCK_METHOD(void, OnAlpnSelected, (quiche::QuicheStringPiece), (override));
+  MOCK_METHOD(void, OnZeroRttRejected, (), (override));
 
   QuicCryptoClientStream* GetMutableCryptoStream() override;
   const QuicCryptoClientStream* GetCryptoStream() const override;