gfe-relnote: Add SessionCache to TlsClientHandshaker, protected by reloadable flag quic_supports_tls_handshake

PiperOrigin-RevId: 279800830
Change-Id: Ib7b49726c14208f63c5b3a8c552cff36cb5d89bf
diff --git a/quic/core/crypto/quic_crypto_client_config.cc b/quic/core/crypto/quic_crypto_client_config.cc
index 6813f7c..6b00c4c 100644
--- a/quic/core/crypto/quic_crypto_client_config.cc
+++ b/quic/core/crypto/quic_crypto_client_config.cc
@@ -61,7 +61,13 @@
 
 QuicCryptoClientConfig::QuicCryptoClientConfig(
     std::unique_ptr<ProofVerifier> proof_verifier)
+    : QuicCryptoClientConfig(std::move(proof_verifier), nullptr) {}
+
+QuicCryptoClientConfig::QuicCryptoClientConfig(
+    std::unique_ptr<ProofVerifier> proof_verifier,
+    std::unique_ptr<SessionCache> session_cache)
     : proof_verifier_(std::move(proof_verifier)),
+      session_cache_(std::move(session_cache)),
       ssl_ctx_(TlsClientConnection::CreateSslCtx()) {
   DCHECK(proof_verifier_.get());
   SetDefaults();
@@ -850,6 +856,10 @@
   return proof_verifier_.get();
 }
 
+SessionCache* QuicCryptoClientConfig::session_cache() const {
+  return session_cache_.get();
+}
+
 SSL_CTX* QuicCryptoClientConfig::ssl_ctx() const {
   return ssl_ctx_.get();
 }
diff --git a/quic/core/crypto/quic_crypto_client_config.h b/quic/core/crypto/quic_crypto_client_config.h
index d3e627d..a3e1bcd 100644
--- a/quic/core/crypto/quic_crypto_client_config.h
+++ b/quic/core/crypto/quic_crypto_client_config.h
@@ -12,8 +12,10 @@
 #include <vector>
 
 #include "third_party/boringssl/src/include/openssl/base.h"
+#include "third_party/boringssl/src/include/openssl/ssl.h"
 #include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h"
 #include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h"
+#include "net/third_party/quiche/src/quic/core/crypto/transport_parameters.h"
 #include "net/third_party/quiche/src/quic/core/quic_packets.h"
 #include "net/third_party/quiche/src/quic/core/quic_server_id.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_export.h"
@@ -27,6 +29,53 @@
 class ProofVerifyDetails;
 class QuicRandom;
 
+// QuicResumptionState stores the state a client needs for performing connection
+// resumption.
+struct QUIC_EXPORT_PRIVATE QuicResumptionState {
+  // |tls_session| holds the cryptographic state necessary for a resumption. It
+  // includes the ALPN negotiated on the connection where the ticket was
+  // received.
+  bssl::UniquePtr<SSL_SESSION> tls_session;
+
+  // If the application using QUIC doesn't support 0-RTT handshakes or the
+  // client didn't receive a 0-RTT capable session ticket from the server,
+  // |transport_params| will be null. Otherwise, it will contain the transport
+  // parameters received from the server on the original connection.
+  std::unique_ptr<TransportParameters> transport_params;
+
+  // If |transport_params| is null, then |application_state| is ignored and
+  // should be empty. |application_state| contains serialized state that the
+  // client received from the server at the application layer that the client
+  // needs to remember when performing a 0-RTT handshake.
+  std::vector<uint8_t> application_state;
+};
+
+// SessionCache is an interface for managing storing and retrieving
+// QuicResumptionState structs.
+class QUIC_EXPORT_PRIVATE SessionCache {
+ public:
+  virtual ~SessionCache() {}
+
+  // Inserts |state| into the cache, keyed by |server_id|. Insert is called
+  // after a session ticket is received. If the session ticket is valid for
+  // 0-RTT, there may be a delay between its receipt and the call to Insert
+  // while waiting for application state for |state|.
+  //
+  // Insert may be called multiple times per connection. SessionCache
+  // implementations should support storing multiple entries per server ID.
+  virtual void Insert(const QuicServerId& server_id,
+                      std::unique_ptr<QuicResumptionState> state) = 0;
+
+  // Lookup is called once at the beginning of each TLS handshake to potentially
+  // provide the saved state both for the TLS handshake and for sending 0-RTT
+  // data (if supported). Lookup may return a nullptr. Implementations should
+  // delete cache entries after returning them in Lookup so that session tickets
+  // are used only once.
+  virtual std::unique_ptr<QuicResumptionState> Lookup(
+      const QuicServerId& server_id,
+      const SSL_CTX* ctx) = 0;
+};
+
 // QuicCryptoClientConfig contains crypto-related configuration settings for a
 // client. Note that this object isn't thread-safe. It's designed to be used on
 // a single thread at a time.
@@ -203,8 +252,11 @@
     virtual bool Matches(const QuicServerId& server_id) const = 0;
   };
 
+  // DEPRECATED: Use the constructor below instead.
   explicit QuicCryptoClientConfig(
       std::unique_ptr<ProofVerifier> proof_verifier);
+  QuicCryptoClientConfig(std::unique_ptr<ProofVerifier> proof_verifier,
+                         std::unique_ptr<SessionCache> session_cache);
   QuicCryptoClientConfig(const QuicCryptoClientConfig&) = delete;
   QuicCryptoClientConfig& operator=(const QuicCryptoClientConfig&) = delete;
   ~QuicCryptoClientConfig();
@@ -309,7 +361,7 @@
       std::string* error_details);
 
   ProofVerifier* proof_verifier() const;
-
+  SessionCache* session_cache() const;
   SSL_CTX* ssl_ctx() const;
 
   // Initialize the CachedState from |canonical_crypto_config| for the
@@ -388,6 +440,7 @@
   std::vector<std::string> canonical_suffixes_;
 
   std::unique_ptr<ProofVerifier> proof_verifier_;
+  std::unique_ptr<SessionCache> session_cache_;
   bssl::UniquePtr<SSL_CTX> ssl_ctx_;
 
   // The |user_agent_id_| passed in QUIC's CHLO message.
diff --git a/quic/core/crypto/tls_client_connection.cc b/quic/core/crypto/tls_client_connection.cc
index f28af66..98aa6e7 100644
--- a/quic/core/crypto/tls_client_connection.cc
+++ b/quic/core/crypto/tls_client_connection.cc
@@ -19,6 +19,11 @@
   // certificate after the connection is complete. We need to re-verify on
   // resumption in case of expiration or revocation/distrust.
   SSL_CTX_set_custom_verify(ssl_ctx.get(), SSL_VERIFY_PEER, &VerifyCallback);
+
+  // Configure session caching.
+  SSL_CTX_set_session_cache_mode(
+      ssl_ctx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL);
+  SSL_CTX_sess_set_new_cb(ssl_ctx.get(), NewSessionCallback);
   return ssl_ctx;
 }
 
@@ -30,4 +35,11 @@
       ->delegate_->VerifyCert(out_alert);
 }
 
+// static
+int TlsClientConnection::NewSessionCallback(SSL* ssl, SSL_SESSION* session) {
+  static_cast<TlsClientConnection*>(ConnectionFromSsl(ssl))
+      ->delegate_->InsertSession(bssl::UniquePtr<SSL_SESSION>(session));
+  return 1;
+}
+
 }  // namespace quic
diff --git a/quic/core/crypto/tls_client_connection.h b/quic/core/crypto/tls_client_connection.h
index 6660343..035f420 100644
--- a/quic/core/crypto/tls_client_connection.h
+++ b/quic/core/crypto/tls_client_connection.h
@@ -26,6 +26,9 @@
     // or ssl_verify_retry if verification is happening asynchronously.
     virtual enum ssl_verify_result_t VerifyCert(uint8_t* out_alert) = 0;
 
+    // Called when a NewSessionTicket is received from the server.
+    virtual void InsertSession(bssl::UniquePtr<SSL_SESSION> session) = 0;
+
     // Provides the delegate for callbacks that are shared between client and
     // server.
     virtual TlsConnection::Delegate* ConnectionDelegate() = 0;
@@ -43,6 +46,10 @@
   // implementation is delegated to Delegate::VerifyCert.
   static enum ssl_verify_result_t VerifyCallback(SSL* ssl, uint8_t* out_alert);
 
+  // Registered as the callback for SSL_CTX_sess_set_new_cb, which calls
+  // Delegate::InsertSession.
+  static int NewSessionCallback(SSL* ssl, SSL_SESSION* session);
+
   Delegate* delegate_;
 };
 
diff --git a/quic/core/crypto/tls_server_connection.cc b/quic/core/crypto/tls_server_connection.cc
index 927c75a..f539a08 100644
--- a/quic/core/crypto/tls_server_connection.cc
+++ b/quic/core/crypto/tls_server_connection.cc
@@ -16,6 +16,7 @@
   SSL_CTX_set_tlsext_servername_callback(ssl_ctx.get(),
                                          &SelectCertificateCallback);
   SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), &SelectAlpnCallback, nullptr);
+  SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_NO_TICKET);
   return ssl_ctx;
 }
 
diff --git a/quic/core/http/quic_spdy_client_session_test.cc b/quic/core/http/quic_spdy_client_session_test.cc
index d5b186c..200024f 100644
--- a/quic/core/http/quic_spdy_client_session_test.cc
+++ b/quic/core/http/quic_spdy_client_session_test.cc
@@ -166,8 +166,10 @@
       config.SetMaxIncomingBidirectionalStreamsToSend(
           server_max_incoming_streams);
     }
+    QuicCryptoServerConfig crypto_config =
+        crypto_test_utils::CryptoServerConfigForTesting();
     crypto_test_utils::HandshakeWithFakeServer(
-        &config, &helper_, &alarm_factory_, connection_, stream,
+        &config, &crypto_config, &helper_, &alarm_factory_, connection_, stream,
         AlpnForVersion(connection_->version()));
   }
 
diff --git a/quic/core/quic_crypto_client_stream.cc b/quic/core/quic_crypto_client_stream.cc
index 354debb..b30cd3a 100644
--- a/quic/core/quic_crypto_client_stream.cc
+++ b/quic/core/quic_crypto_client_stream.cc
@@ -43,9 +43,8 @@
       break;
     case PROTOCOL_TLS1_3:
       handshaker_ = std::make_unique<TlsClientHandshaker>(
-          this, session, server_id, crypto_config->proof_verifier(),
-          crypto_config->ssl_ctx(), std::move(verify_context), proof_handler,
-          crypto_config->user_agent_id());
+          server_id, this, session, std::move(verify_context), crypto_config,
+          proof_handler);
       break;
     case PROTOCOL_UNSUPPORTED:
       QUIC_BUG << "Attempting to create QuicCryptoClientStream for unknown "
diff --git a/quic/core/quic_crypto_client_stream_test.cc b/quic/core/quic_crypto_client_stream_test.cc
index 2622439..675dc20 100644
--- a/quic/core/quic_crypto_client_stream_test.cc
+++ b/quic/core/quic_crypto_client_stream_test.cc
@@ -22,6 +22,7 @@
 #include "net/third_party/quiche/src/quic/test_tools/quic_stream_sequencer_peer.h"
 #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
 #include "net/third_party/quiche/src/quic/test_tools/simple_quic_framer.h"
+#include "net/third_party/quiche/src/quic/test_tools/simple_session_cache.h"
 
 using testing::_;
 
@@ -37,7 +38,10 @@
   QuicCryptoClientStreamTest()
       : supported_versions_(AllSupportedVersions()),
         server_id_(kServerHostname, kServerPort, false),
-        crypto_config_(crypto_test_utils::ProofVerifierForTesting()) {
+        crypto_config_(crypto_test_utils::ProofVerifierForTesting(),
+                       std::make_unique<test::SimpleSessionCache>()),
+        server_crypto_config_(
+            crypto_test_utils::CryptoServerConfigForTesting()) {
     CreateConnection();
   }
 
@@ -56,6 +60,17 @@
             {AlpnForVersion(connection_->version())})));
   }
 
+  void UseTlsHandshake() {
+    SetQuicReloadableFlag(quic_supports_tls_handshake, true);
+    supported_versions_.clear();
+    for (ParsedQuicVersion version : AllSupportedVersions()) {
+      if (version.handshake_protocol != PROTOCOL_TLS1_3) {
+        continue;
+      }
+      supported_versions_.push_back(version);
+    }
+  }
+
   void CompleteCryptoHandshake() {
     if (stream()->handshake_protocol() != PROTOCOL_TLS1_3) {
       EXPECT_CALL(*session_, OnProofValid(testing::_));
@@ -65,8 +80,8 @@
     stream()->CryptoConnect();
     QuicConfig config;
     crypto_test_utils::HandshakeWithFakeServer(
-        &config, &server_helper_, &alarm_factory_, connection_, stream(),
-        AlpnForVersion(connection_->version()));
+        &config, &server_crypto_config_, &server_helper_, &alarm_factory_,
+        connection_, stream(), AlpnForVersion(connection_->version()));
   }
 
   QuicCryptoClientStream* stream() {
@@ -82,6 +97,7 @@
   QuicServerId server_id_;
   CryptoHandshakeMessage message_;
   QuicCryptoClientConfig crypto_config_;
+  QuicCryptoServerConfig server_crypto_config_;
 };
 
 TEST_F(QuicCryptoClientStreamTest, NotInitiallyConected) {
@@ -97,14 +113,7 @@
 }
 
 TEST_F(QuicCryptoClientStreamTest, ConnectedAfterTlsHandshake) {
-  SetQuicReloadableFlag(quic_supports_tls_handshake, true);
-  supported_versions_.clear();
-  for (ParsedQuicVersion version : AllSupportedVersions()) {
-    if (version.handshake_protocol != PROTOCOL_TLS1_3) {
-      continue;
-    }
-    supported_versions_.push_back(version);
-  }
+  UseTlsHandshake();
   CreateConnection();
   CompleteCryptoHandshake();
   EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
@@ -115,27 +124,44 @@
 
 TEST_F(QuicCryptoClientStreamTest,
        ProofVerifyDetailsAvailableAfterTlsHandshake) {
-  SetQuicReloadableFlag(quic_supports_tls_handshake, true);
-  supported_versions_.clear();
-  for (ParsedQuicVersion version : AllSupportedVersions()) {
-    if (version.handshake_protocol != PROTOCOL_TLS1_3) {
-      continue;
-    }
-    supported_versions_.push_back(version);
-  }
+  UseTlsHandshake();
   CreateConnection();
 
   EXPECT_CALL(*session_, OnProofVerifyDetailsAvailable(testing::_));
   stream()->CryptoConnect();
   QuicConfig config;
   crypto_test_utils::HandshakeWithFakeServer(
-      &config, &server_helper_, &alarm_factory_, connection_, stream(),
-      AlpnForVersion(connection_->version()));
+      &config, &server_crypto_config_, &server_helper_, &alarm_factory_,
+      connection_, stream(), AlpnForVersion(connection_->version()));
   EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
   EXPECT_TRUE(stream()->encryption_established());
   EXPECT_TRUE(stream()->handshake_confirmed());
 }
 
+TEST_F(QuicCryptoClientStreamTest, TlsResumption) {
+  UseTlsHandshake();
+  // Enable resumption on the server:
+  SSL_CTX_clear_options(server_crypto_config_.ssl_ctx(), SSL_OP_NO_TICKET);
+  CreateConnection();
+
+  // Finish establishing the first connection:
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->handshake_confirmed());
+  EXPECT_FALSE(stream()->IsResumption());
+
+  // Create a second connection
+  CreateConnection();
+  CompleteCryptoHandshake();
+
+  EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol());
+  EXPECT_TRUE(stream()->encryption_established());
+  EXPECT_TRUE(stream()->handshake_confirmed());
+  EXPECT_TRUE(stream()->IsResumption());
+}
+
 TEST_F(QuicCryptoClientStreamTest, MessageAfterHandshake) {
   CompleteCryptoHandshake();
 
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 2056953..85a278c 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -43,22 +43,21 @@
 }
 
 TlsClientHandshaker::TlsClientHandshaker(
+    const QuicServerId& server_id,
     QuicCryptoStream* stream,
     QuicSession* session,
-    const QuicServerId& server_id,
-    ProofVerifier* proof_verifier,
-    SSL_CTX* ssl_ctx,
     std::unique_ptr<ProofVerifyContext> verify_context,
-    QuicCryptoClientStream::ProofHandler* proof_handler,
-    const std::string& user_agent_id)
-    : TlsHandshaker(stream, session, ssl_ctx),
+    QuicCryptoClientConfig* crypto_config,
+    QuicCryptoClientStream::ProofHandler* proof_handler)
+    : TlsHandshaker(stream, session, crypto_config->ssl_ctx()),
       server_id_(server_id),
-      proof_verifier_(proof_verifier),
+      proof_verifier_(crypto_config->proof_verifier()),
       verify_context_(std::move(verify_context)),
       proof_handler_(proof_handler),
-      user_agent_id_(user_agent_id),
+      session_cache_(crypto_config->session_cache()),
+      user_agent_id_(crypto_config->user_agent_id()),
       crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
-      tls_connection_(ssl_ctx, this) {}
+      tls_connection_(crypto_config->ssl_ctx(), this) {}
 
 TlsClientHandshaker::~TlsClientHandshaker() {
   if (proof_verify_callback_) {
@@ -87,6 +86,15 @@
     return false;
   }
 
+  // Set a session to resume, if there is one.
+  if (session_cache_) {
+    std::unique_ptr<QuicResumptionState> cached_state =
+        session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl()));
+    if (cached_state) {
+      SSL_set_session(ssl(), cached_state->tls_session.get());
+    }
+  }
+
   // Start the handshake.
   AdvanceHandshake();
   return session()->connection()->connected();
@@ -199,8 +207,7 @@
 
 bool TlsClientHandshaker::IsResumption() const {
   QUIC_BUG_IF(!handshake_confirmed_);
-  // We don't support resumption (yet).
-  return false;
+  return SSL_session_reused(ssl()) == 1;
 }
 
 int TlsClientHandshaker::num_scup_messages_received() const {
@@ -246,7 +253,10 @@
     return;
   }
   if (state_ == STATE_HANDSHAKE_COMPLETE) {
-    // TODO(nharper): Handle post-handshake messages.
+    int rv = SSL_process_quic_post_handshake(ssl());
+    if (rv != 1) {
+      CloseConnection(QUIC_HANDSHAKE_FAILED, "Unexpected post-handshake data");
+    }
     return;
   }
 
@@ -394,4 +404,14 @@
   }
 }
 
+void TlsClientHandshaker::InsertSession(bssl::UniquePtr<SSL_SESSION> session) {
+  if (session_cache_ == nullptr) {
+    QUIC_DVLOG(1) << "No session cache, not inserting a session";
+    return;
+  }
+  auto cache_state = std::make_unique<QuicResumptionState>();
+  cache_state->tls_session = std::move(session);
+  session_cache_->Insert(server_id_, std::move(cache_state));
+}
+
 }  // namespace quic
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 17b529e..0b473a3 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -24,14 +24,12 @@
       public QuicCryptoClientStream::HandshakerDelegate,
       public TlsClientConnection::Delegate {
  public:
-  TlsClientHandshaker(QuicCryptoStream* stream,
+  TlsClientHandshaker(const QuicServerId& server_id,
+                      QuicCryptoStream* stream,
                       QuicSession* session,
-                      const QuicServerId& server_id,
-                      ProofVerifier* proof_verifier,
-                      SSL_CTX* ssl_ctx,
                       std::unique_ptr<ProofVerifyContext> verify_context,
-                      QuicCryptoClientStream::ProofHandler* proof_handler,
-                      const std::string& user_agent_id);
+                      QuicCryptoClientConfig* crypto_config,
+                      QuicCryptoClientStream::ProofHandler* proof_handler);
   TlsClientHandshaker(const TlsClientHandshaker&) = delete;
   TlsClientHandshaker& operator=(const TlsClientHandshaker&) = delete;
 
@@ -101,6 +99,8 @@
   bool ProcessTransportParameters(std::string* error_details);
   void FinishHandshake();
 
+  void InsertSession(bssl::UniquePtr<SSL_SESSION> session) override;
+
   QuicServerId server_id_;
 
   // Objects used for verifying the server's certificate chain.
@@ -113,6 +113,10 @@
   // certificate verification.
   QuicCryptoClientStream::ProofHandler* proof_handler_;
 
+  // Used for session resumption. |session_cache_| is owned by the
+  // QuicCryptoClientConfig passed into TlsClientHandshaker's constructor.
+  SessionCache* session_cache_;
+
   std::string user_agent_id_;
 
   // ProofVerifierCallback used for async certificate verification. This object
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index 96691fa..d3463a8 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -223,17 +223,15 @@
  public:
   explicit TestQuicCryptoClientStream(QuicSession* session)
       : TestQuicCryptoStream(session),
-        proof_verifier_(new FakeProofVerifier),
-        ssl_ctx_(TlsClientConnection::CreateSslCtx()),
+        crypto_config_(std::make_unique<FakeProofVerifier>(),
+                       /*session_cache*/ nullptr),
         handshaker_(new TlsClientHandshaker(
+            QuicServerId("test.example.com", 443, false),
             this,
             session,
-            QuicServerId("test.example.com", 443, false),
-            proof_verifier_.get(),
-            ssl_ctx_.get(),
             crypto_test_utils::ProofVerifyContextForTesting(),
-            &proof_handler_,
-            "quic-tester")) {}
+            &crypto_config_,
+            &proof_handler_)) {}
 
   ~TestQuicCryptoClientStream() override = default;
 
@@ -244,13 +242,12 @@
   bool CryptoConnect() { return handshaker_->CryptoConnect(); }
 
   FakeProofVerifier* GetFakeProofVerifier() const {
-    return proof_verifier_.get();
+    return static_cast<FakeProofVerifier*>(crypto_config_.proof_verifier());
   }
 
  private:
-  std::unique_ptr<FakeProofVerifier> proof_verifier_;
   MockProofHandler proof_handler_;
-  bssl::UniquePtr<SSL_CTX> ssl_ctx_;
+  QuicCryptoClientConfig crypto_config_;
   std::unique_ptr<TlsClientHandshaker> handshaker_;
 };