Only set QUIC TLS 0-RTT client state if a 0-RTT handshake was attempted

Even if an SSL_SESSION is early data capable, it is still possible (for
various reasons) that BoringSSL will decide not to do a 0-RTT handshake.
TlsClientHandshaker should wait for a signal from BoringSSL that it is
attempting a 0-RTT handshake before it sets saved transport and application
state for early data.

Client-side only quic behavior change, not flag protected.

PiperOrigin-RevId: 320646436
Change-Id: Ib1cfe2640d3cd62e23344ede852b91756a44f687
diff --git a/quic/core/crypto/quic_crypto_client_config.cc b/quic/core/crypto/quic_crypto_client_config.cc
index 9314e99..4680785 100644
--- a/quic/core/crypto/quic_crypto_client_config.cc
+++ b/quic/core/crypto/quic_crypto_client_config.cc
@@ -67,9 +67,8 @@
     std::unique_ptr<SessionCache> session_cache)
     : proof_verifier_(std::move(proof_verifier)),
       session_cache_(std::move(session_cache)),
-      enable_zero_rtt_for_tls_(
-          GetQuicReloadableFlag(quic_enable_zero_rtt_for_tls)),
-      ssl_ctx_(TlsClientConnection::CreateSslCtx(enable_zero_rtt_for_tls_)),
+      ssl_ctx_(TlsClientConnection::CreateSslCtx(
+          GetQuicReloadableFlag(quic_enable_zero_rtt_for_tls))),
       disable_chlo_padding_(GetQuicReloadableFlag(quic_dont_pad_chlo)) {
   DCHECK(proof_verifier_.get());
   SetDefaults();
diff --git a/quic/core/crypto/quic_crypto_client_config.h b/quic/core/crypto/quic_crypto_client_config.h
index ce0efe5..e355c5e 100644
--- a/quic/core/crypto/quic_crypto_client_config.h
+++ b/quic/core/crypto/quic_crypto_client_config.h
@@ -354,8 +354,6 @@
   void set_proof_source(std::unique_ptr<ProofSource> proof_source);
   SSL_CTX* ssl_ctx() const;
 
-  bool early_data_enabled_for_tls() const { return enable_zero_rtt_for_tls_; }
-
   // Initialize the CachedState from |canonical_crypto_config| for the
   // |canonical_server_id| as the initial CachedState for |server_id|. We will
   // copy config data only if |canonical_crypto_config| has valid proof.
@@ -443,8 +441,6 @@
   std::unique_ptr<SessionCache> session_cache_;
   std::unique_ptr<ProofSource> proof_source_;
 
-  // Latched value of reloadable flag quic_enable_zero_rtt_for_tls
-  bool enable_zero_rtt_for_tls_;
   bssl::UniquePtr<SSL_CTX> ssl_ctx_;
 
   // The |user_agent_id_| passed in QUIC's CHLO message.
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 7d000d9..76f9b82 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -68,7 +68,6 @@
       pre_shared_key_(crypto_config->pre_shared_key()),
       crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
       has_application_state_(has_application_state),
-      attempting_zero_rtt_(crypto_config->early_data_enabled_for_tls()),
       tls_connection_(crypto_config->ssl_ctx(), this) {}
 
 TlsClientHandshaker::~TlsClientHandshaker() {
@@ -116,18 +115,11 @@
   }
 
   // Set a session to resume, if there is one.
-  std::unique_ptr<QuicResumptionState> cached_state;
   if (session_cache_) {
-    cached_state = session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
+    cached_state_ = session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
   }
-  if (cached_state) {
-    SSL_set_session(ssl(), cached_state->tls_session.get());
-    if (attempting_zero_rtt_ &&
-        SSL_SESSION_early_data_capable(cached_state->tls_session.get())) {
-      if (!PrepareZeroRttConfig(cached_state.get())) {
-        return false;
-      }
-    }
+  if (cached_state_) {
+    SSL_set_session(ssl(), cached_state_->tls_session.get());
   }
 
   // Start the handshake.
@@ -467,8 +459,11 @@
     // 0-RTT-capable, which means that FinishHandshake will get called twice -
     // the first time after sending the ClientHello, and the second time after
     // the handshake is complete. If we're in the first time FinishHandshake is
-    // called, we can't do any end-of-handshake processing, so we return early
-    // from this function.
+    // called, we can't do any end-of-handshake processing.
+
+    // If we're attempting a 0-RTT handshake, then we need to let the transport
+    // and application know what state to apply to early data.
+    PrepareZeroRttConfig(cached_state_.get());
     return;
   }
   QUIC_LOG(INFO) << "Client: handshake finished";
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 573c055..bf05ca8 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -176,7 +176,9 @@
   bool allow_invalid_sni_for_tests_ = false;
 
   const bool has_application_state_;
-  bool attempting_zero_rtt_;
+  // Contains the state for performing a resumption, if one is attempted. This
+  // will always be non-null if a 0-RTT resumption is attempted.
+  std::unique_ptr<QuicResumptionState> cached_state_;
 
   TlsClientConnection tls_connection_;
 
diff --git a/quic/core/tls_client_handshaker_test.cc b/quic/core/tls_client_handshaker_test.cc
index f586028..1832d90 100644
--- a/quic/core/tls_client_handshaker_test.cc
+++ b/quic/core/tls_client_handshaker_test.cc
@@ -202,13 +202,18 @@
   }
 
   void CompleteCryptoHandshake() {
+    CompleteCryptoHandshakeWithServerALPN(
+        AlpnForVersion(connection_->version()));
+  }
+
+  void CompleteCryptoHandshakeWithServerALPN(const std::string& alpn) {
     EXPECT_CALL(*connection_, SendCryptoData(_, _, _))
         .Times(testing::AnyNumber());
     stream()->CryptoConnect();
     QuicConfig config;
     crypto_test_utils::HandshakeWithFakeServer(
         &config, server_crypto_config_.get(), &server_helper_, &alarm_factory_,
-        connection_, stream(), AlpnForVersion(connection_->version()));
+        connection_, stream(), alpn);
   }
 
   QuicCryptoClientStream* stream() {
@@ -358,6 +363,10 @@
 
   // Create a second connection
   CreateConnection();
+  // OnConfigNegotiated should be called twice - once when processing saved
+  // 0-RTT transport parameters, and then again when receiving transport
+  // parameters from the server.
+  EXPECT_CALL(*session_, OnConfigNegotiated()).Times(2);
   CompleteCryptoHandshake();
 
   // TODO(b/152551499): Add a test that checks we have keys after calling
@@ -383,6 +392,11 @@
   SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false);
   CreateConnection();
 
+  // OnConfigNegotiated should be called twice - once when processing saved
+  // 0-RTT transport parameters, and then again when receiving transport
+  // parameters from the server.
+  EXPECT_CALL(*session_, OnConfigNegotiated()).Times(2);
+
   // 4 packets will be sent in this connection: initial handshake packet, 0-RTT
   // packet containing SETTINGS, handshake packet upon 0-RTT rejection, 0-RTT
   // packet retransmission.
@@ -470,6 +484,32 @@
   EXPECT_TRUE(server_stream()->encryption_established());
 }
 
+TEST_P(TlsClientHandshakerTest, ZeroRTTNotAttemptedOnALPNChange) {
+  // Finish establishing the first connection:
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_FALSE(stream()->IsResumption());
+
+  // Create a second connection
+  CreateConnection();
+  // Override the ALPN to send on the second connection.
+  const std::string kTestAlpn = "Test ALPN";
+  EXPECT_CALL(*session_, GetAlpnsToOffer())
+      .WillRepeatedly(testing::Return(std::vector<std::string>({kTestAlpn})));
+  // OnConfigNegotiated should only be called once: when transport parameters
+  // are received from the server.
+  EXPECT_CALL(*session_, OnConfigNegotiated()).Times(1);
+
+  CompleteCryptoHandshakeWithServerALPN(kTestAlpn);
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->one_rtt_keys_available());
+  EXPECT_FALSE(stream()->EarlyDataAccepted());
+}
+
 TEST_P(TlsClientHandshakerTest, InvalidSNI) {
   // Test that a client will skip sending SNI if configured to send an invalid
   // hostname. In this case, the inclusion of '!' is invalid.
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc
index 32a8450..7cf2dab 100644
--- a/quic/test_tools/quic_test_utils.cc
+++ b/quic/test_tools/quic_test_utils.cc
@@ -776,6 +776,9 @@
       server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
       crypto_config, this, /*has_application_state = */ false);
   Initialize();
+  ON_CALL(*this, OnConfigNegotiated())
+      .WillByDefault(
+          Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
 }
 
 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
@@ -793,6 +796,10 @@
   return crypto_stream_.get();
 }
 
+void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
+  QuicSpdyClientSessionBase::OnConfigNegotiated();
+}
+
 TestPushPromiseDelegate::TestPushPromiseDelegate(bool match)
     : match_(match), rendezvous_fired_(false), rendezvous_stream_(nullptr) {}
 
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 7be7edf..870f99d 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -1163,6 +1163,7 @@
   MOCK_METHOD(bool, ShouldCreateOutgoingUnidirectionalStream, (), (override));
   MOCK_METHOD(std::vector<std::string>, GetAlpnsToOffer, (), (const, override));
   MOCK_METHOD(void, OnAlpnSelected, (quiche::QuicheStringPiece), (override));
+  MOCK_METHOD(void, OnConfigNegotiated, (), (override));
 
   QuicCryptoClientStream* GetMutableCryptoStream() override;
   const QuicCryptoClientStream* GetCryptoStream() const override;
@@ -1179,6 +1180,10 @@
   }
 
  private:
+  // Calls the parent class's OnConfigNegotiated method. Used to set the default
+  // mock behavior for OnConfigNegotiated.
+  void RealOnConfigNegotiated();
+
   std::unique_ptr<QuicCryptoClientStream> crypto_stream_;
   QuicClientPushPromiseIndex push_promise_index_;
   std::vector<CryptoHandshakeMessage> sent_crypto_handshake_messages_;