Add retry_token, resumption_attempted and early_data_attempted to quic::ParsedClientHello.

This is a follow up to cl/406887508.

PiperOrigin-RevId: 409242611
diff --git a/quic/core/quic_buffered_packet_store.cc b/quic/core/quic_buffered_packet_store.cc
index 291434f..f597ce0 100644
--- a/quic/core/quic_buffered_packet_store.cc
+++ b/quic/core/quic_buffered_packet_store.cc
@@ -255,11 +255,10 @@
 }
 
 bool QuicBufferedPacketStore::IngestPacketForTlsChloExtraction(
-    const QuicConnectionId& connection_id,
-    const ParsedQuicVersion& version,
-    const QuicReceivedPacket& packet,
-    std::vector<std::string>* out_alpns,
-    std::string* out_sni) {
+    const QuicConnectionId& connection_id, const ParsedQuicVersion& version,
+    const QuicReceivedPacket& packet, std::vector<std::string>* out_alpns,
+    std::string* out_sni, bool* out_resumption_attempted,
+    bool* out_early_data_attempted) {
   QUICHE_DCHECK_NE(out_alpns, nullptr);
   QUICHE_DCHECK_NE(out_sni, nullptr);
   QUICHE_DCHECK_EQ(version.handshake_protocol, PROTOCOL_TLS1_3);
@@ -273,8 +272,11 @@
   if (!it->second.tls_chlo_extractor.HasParsedFullChlo()) {
     return false;
   }
-  *out_alpns = it->second.tls_chlo_extractor.alpns();
-  *out_sni = it->second.tls_chlo_extractor.server_name();
+  const TlsChloExtractor& tls_chlo_extractor = it->second.tls_chlo_extractor;
+  *out_alpns = tls_chlo_extractor.alpns();
+  *out_sni = tls_chlo_extractor.server_name();
+  *out_resumption_attempted = tls_chlo_extractor.resumption_attempted();
+  *out_early_data_attempted = tls_chlo_extractor.early_data_attempted();
   return true;
 }
 
diff --git a/quic/core/quic_buffered_packet_store.h b/quic/core/quic_buffered_packet_store.h
index b363802..3206b37 100644
--- a/quic/core/quic_buffered_packet_store.h
+++ b/quic/core/quic_buffered_packet_store.h
@@ -117,11 +117,16 @@
   // Returns whether we've now parsed a full multi-packet TLS CHLO.
   // When this returns true, |out_alpns| is populated with the list of ALPNs
   // extracted from the CHLO. |out_sni| is populated with the SNI tag in CHLO.
+  // |out_resumption_attempted| is populated if the CHLO has the
+  // 'pre_shared_key' TLS extension. |out_early_data_attempted| is populated if
+  // the CHLO has the 'early_data' TLS extension.
   bool IngestPacketForTlsChloExtraction(const QuicConnectionId& connection_id,
                                         const ParsedQuicVersion& version,
                                         const QuicReceivedPacket& packet,
                                         std::vector<std::string>* out_alpns,
-                                        std::string* out_sni);
+                                        std::string* out_sni,
+                                        bool* out_resumption_attempted,
+                                        bool* out_early_data_attempted);
 
   // Returns the list of buffered packets for |connection_id| and removes them
   // from the store. Returns an empty list if no early arrived packets for this
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 1484b74..ef6db5d 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -780,13 +780,15 @@
     bool has_full_tls_chlo = false;
     std::string sni;
     std::vector<std::string> alpns;
+    bool resumption_attempted = false, early_data_attempted = false;
     if (buffered_packets_.HasBufferedPackets(
             packet_info.destination_connection_id)) {
       // If we already have buffered packets for this connection ID,
       // use the associated TlsChloExtractor to parse this packet.
       has_full_tls_chlo = buffered_packets_.IngestPacketForTlsChloExtraction(
           packet_info.destination_connection_id, packet_info.version,
-          packet_info.packet, &alpns, &sni);
+          packet_info.packet, &alpns, &sni, &resumption_attempted,
+          &early_data_attempted);
     } else {
       // If we do not have a BufferedPacketList for this connection ID,
       // create a single-use one to check whether this packet contains a
@@ -798,6 +800,8 @@
         has_full_tls_chlo = true;
         alpns = tls_chlo_extractor.alpns();
         sni = tls_chlo_extractor.server_name();
+        resumption_attempted = tls_chlo_extractor.resumption_attempted();
+        early_data_attempted = tls_chlo_extractor.early_data_attempted();
       }
     }
     if (!has_full_tls_chlo) {
@@ -811,6 +815,11 @@
     ParsedClientHello parsed_chlo;
     parsed_chlo.sni = std::move(sni);
     parsed_chlo.alpns = std::move(alpns);
+    if (packet_info.retry_token.has_value()) {
+      parsed_chlo.retry_token = std::string(*packet_info.retry_token);
+    }
+    parsed_chlo.resumption_attempted = resumption_attempted;
+    parsed_chlo.early_data_attempted = early_data_attempted;
     return parsed_chlo;
   }
 
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc
index 1dfda6c..4d59c13 100644
--- a/quic/core/quic_dispatcher_test.cc
+++ b/quic/core/quic_dispatcher_test.cc
@@ -68,12 +68,8 @@
                             QuicConnection* connection,
                             const QuicCryptoServerConfig* crypto_config,
                             QuicCompressedCertsCache* compressed_certs_cache)
-      : QuicServerSessionBase(config,
-                              CurrentSupportedVersions(),
-                              connection,
-                              nullptr,
-                              nullptr,
-                              crypto_config,
+      : QuicServerSessionBase(config, CurrentSupportedVersions(), connection,
+                              nullptr, nullptr, crypto_config,
                               compressed_certs_cache) {
     Initialize();
   }
@@ -83,26 +79,17 @@
 
   ~TestQuicSpdyServerSession() override { DeleteConnection(); }
 
-  MOCK_METHOD(void,
-              OnConnectionClosed,
+  MOCK_METHOD(void, OnConnectionClosed,
               (const QuicConnectionCloseFrame& frame,
                ConnectionCloseSource source),
               (override));
-  MOCK_METHOD(QuicSpdyStream*,
-              CreateIncomingStream,
-              (QuicStreamId id),
+  MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id),
               (override));
-  MOCK_METHOD(QuicSpdyStream*,
-              CreateIncomingStream,
-              (PendingStream*),
+  MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*),
               (override));
-  MOCK_METHOD(QuicSpdyStream*,
-              CreateOutgoingBidirectionalStream,
-              (),
+  MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (),
               (override));
-  MOCK_METHOD(QuicSpdyStream*,
-              CreateOutgoingUnidirectionalStream,
-              (),
+  MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (),
               (override));
 
   std::unique_ptr<QuicCryptoServerStreamBase> CreateQuicCryptoServerStream(
@@ -138,10 +125,8 @@
                const ParsedClientHello& parsed_chlo),
               (override));
 
-  MOCK_METHOD(bool,
-              ShouldCreateOrBufferPacketForConnection,
-              (const ReceivedPacketInfo& packet_info),
-              (override));
+  MOCK_METHOD(bool, ShouldCreateOrBufferPacketForConnection,
+              (const ReceivedPacketInfo& packet_info), (override));
 
   struct TestQuicPerPacketContext : public QuicPerPacketContext {
     std::string custom_packet_context;
@@ -179,9 +164,7 @@
                        MockQuicConnectionHelper* helper,
                        MockAlarmFactory* alarm_factory,
                        QuicDispatcher* dispatcher)
-      : MockQuicConnection(connection_id,
-                           helper,
-                           alarm_factory,
+      : MockQuicConnection(connection_id, helper, alarm_factory,
                            Perspective::IS_SERVER),
         dispatcher_(dispatcher),
         active_connection_ids_({connection_id}) {}
@@ -225,15 +208,12 @@
       : version_(GetParam()),
         version_manager_(AllSupportedVersions()),
         crypto_config_(QuicCryptoServerConfig::TESTING,
-                       QuicRandom::GetInstance(),
-                       std::move(proof_source),
+                       QuicRandom::GetInstance(), std::move(proof_source),
                        KeyExchangeSource::Default()),
         server_address_(QuicIpAddress::Any4(), 5),
-        dispatcher_(
-            new NiceMock<TestDispatcher>(&config_,
-                                         &crypto_config_,
-                                         &version_manager_,
-                                         mock_helper_.GetRandomGenerator())),
+        dispatcher_(new NiceMock<TestDispatcher>(
+            &config_, &crypto_config_, &version_manager_,
+            mock_helper_.GetRandomGenerator())),
         time_wait_list_manager_(nullptr),
         session1_(nullptr),
         session2_(nullptr),
@@ -268,8 +248,7 @@
   // using the version under test.
   void ProcessPacket(QuicSocketAddress peer_address,
                      QuicConnectionId server_connection_id,
-                     bool has_version_flag,
-                     const std::string& data) {
+                     bool has_version_flag, const std::string& data) {
     ProcessPacket(peer_address, server_connection_id, has_version_flag, data,
                   CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER);
   }
@@ -278,8 +257,7 @@
   // using the version under test.
   void ProcessPacket(QuicSocketAddress peer_address,
                      QuicConnectionId server_connection_id,
-                     bool has_version_flag,
-                     const std::string& data,
+                     bool has_version_flag, const std::string& data,
                      QuicConnectionIdIncluded server_connection_id_included,
                      QuicPacketNumberLength packet_number_length) {
     ProcessPacket(peer_address, server_connection_id, has_version_flag, data,
@@ -289,8 +267,7 @@
   // Process a packet using the version under test.
   void ProcessPacket(QuicSocketAddress peer_address,
                      QuicConnectionId server_connection_id,
-                     bool has_version_flag,
-                     const std::string& data,
+                     bool has_version_flag, const std::string& data,
                      QuicConnectionIdIncluded server_connection_id_included,
                      QuicPacketNumberLength packet_number_length,
                      uint64_t packet_number) {
@@ -302,10 +279,8 @@
   // Processes a packet.
   void ProcessPacket(QuicSocketAddress peer_address,
                      QuicConnectionId server_connection_id,
-                     bool has_version_flag,
-                     ParsedQuicVersion version,
-                     const std::string& data,
-                     bool full_padding,
+                     bool has_version_flag, ParsedQuicVersion version,
+                     const std::string& data, bool full_padding,
                      QuicConnectionIdIncluded server_connection_id_included,
                      QuicPacketNumberLength packet_number_length,
                      uint64_t packet_number) {
@@ -319,10 +294,8 @@
   void ProcessPacket(QuicSocketAddress peer_address,
                      QuicConnectionId server_connection_id,
                      QuicConnectionId client_connection_id,
-                     bool has_version_flag,
-                     ParsedQuicVersion version,
-                     const std::string& data,
-                     bool full_padding,
+                     bool has_version_flag, ParsedQuicVersion version,
+                     const std::string& data, bool full_padding,
                      QuicConnectionIdIncluded server_connection_id_included,
                      QuicConnectionIdIncluded client_connection_id_included,
                      QuicPacketNumberLength packet_number_length,
@@ -340,8 +313,7 @@
 
   void ProcessReceivedPacket(
       std::unique_ptr<QuicReceivedPacket> received_packet,
-      const QuicSocketAddress& peer_address,
-      const ParsedQuicVersion& version,
+      const QuicSocketAddress& peer_address, const ParsedQuicVersion& version,
       const QuicConnectionId& server_connection_id) {
     if (version.UsesQuicCrypto() &&
         ChloExtractor::Extract(*received_packet, version, {}, nullptr,
@@ -367,12 +339,9 @@
   }
 
   std::unique_ptr<QuicSession> CreateSession(
-      TestDispatcher* dispatcher,
-      const QuicConfig& config,
-      QuicConnectionId connection_id,
-      const QuicSocketAddress& /*peer_address*/,
-      MockQuicConnectionHelper* helper,
-      MockAlarmFactory* alarm_factory,
+      TestDispatcher* dispatcher, const QuicConfig& config,
+      QuicConnectionId connection_id, const QuicSocketAddress& /*peer_address*/,
+      MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory,
       const QuicCryptoServerConfig* crypto_config,
       QuicCompressedCertsCache* compressed_certs_cache,
       TestQuicSpdyServerSession** session_ptr) {
@@ -414,8 +383,7 @@
   }
 
   void ProcessUndecryptableEarlyPacket(
-      const ParsedQuicVersion& version,
-      const QuicSocketAddress& peer_address,
+      const ParsedQuicVersion& version, const QuicSocketAddress& peer_address,
       const QuicConnectionId& server_connection_id) {
     std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
         GetUndecryptableEarlyPacket(version, server_connection_id);
@@ -441,15 +409,41 @@
                           const QuicSocketAddress& peer_address,
                           const QuicConnectionId& server_connection_id,
                           const QuicConnectionId& client_connection_id) {
+    ProcessFirstFlight(version, peer_address, server_connection_id,
+                       client_connection_id, TestClientCryptoConfig());
+  }
+
+  void ProcessFirstFlight(
+      const ParsedQuicVersion& version, const QuicSocketAddress& peer_address,
+      const QuicConnectionId& server_connection_id,
+      const QuicConnectionId& client_connection_id,
+      std::unique_ptr<QuicCryptoClientConfig> client_crypto_config) {
     std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
-        GetFirstFlightOfPackets(version, server_connection_id,
-                                client_connection_id);
+        GetFirstFlightOfPackets(version, DefaultQuicConfig(),
+                                server_connection_id, client_connection_id,
+                                std::move(client_crypto_config));
     for (auto&& packet : packets) {
       ProcessReceivedPacket(std::move(packet), peer_address, version,
                             server_connection_id);
     }
   }
 
+  std::unique_ptr<QuicCryptoClientConfig> TestClientCryptoConfig() {
+    auto client_crypto_config = std::make_unique<QuicCryptoClientConfig>(
+        crypto_test_utils::ProofVerifierForTesting());
+    if (address_token_.has_value()) {
+      client_crypto_config->LookupOrCreate(TestServerId())
+          ->set_source_address_token(*address_token_);
+    }
+    return client_crypto_config;
+  }
+
+  // If called, the first flight packets generated in |ProcessFirstFlight| will
+  // contain the given |address_token|.
+  void SetAddressToken(std::string address_token) {
+    address_token_ = std::move(address_token);
+  }
+
   std::string ExpectedAlpnForVersion(ParsedQuicVersion version) {
     return AlpnForVersion(version);
   }
@@ -460,6 +454,9 @@
     ParsedClientHello parsed_chlo;
     parsed_chlo.alpns = {ExpectedAlpn()};
     parsed_chlo.sni = TestHostname();
+    if (address_token_.has_value()) {
+      parsed_chlo.retry_token = *address_token_;
+    }
     return parsed_chlo;
   }
 
@@ -521,6 +518,7 @@
   std::map<QuicConnectionId, std::list<std::string>> data_connection_map_;
   QuicBufferedPacketStore* store_;
   uint64_t connection_id_;
+  absl::optional<std::string> address_token_;
 };
 
 class QuicDispatcherTestAllVersions : public QuicDispatcherTestBase {};
@@ -540,11 +538,14 @@
   if (version_.UsesQuicCrypto()) {
     return;
   }
+  SetAddressToken("hsdifghdsaifnasdpfjdsk");
+
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
 
-  EXPECT_CALL(*dispatcher_,
-              CreateQuicSession(TestConnectionId(1), _, client_address,
-                                Eq(ExpectedAlpn()), _, _))
+  EXPECT_CALL(
+      *dispatcher_,
+      CreateQuicSession(TestConnectionId(1), _, client_address,
+                        Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest())))
       .WillOnce(Return(ByMove(CreateSession(
           dispatcher_.get(), config_, TestConnectionId(1), client_address,
           &mock_helper_, &mock_alarm_factory_, &crypto_config_,
@@ -566,6 +567,8 @@
   if (!version_.UsesTls()) {
     return;
   }
+  SetAddressToken("857293462398");
+
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
   QuicConnectionId server_connection_id = TestConnectionId();
   QuicConfig client_config = DefaultQuicConfig();
@@ -576,7 +579,9 @@
   client_config.custom_transport_parameters_to_send()[kCustomParameterId] =
       kCustomParameterValue;
   std::vector<std::unique_ptr<QuicReceivedPacket>> packets =
-      GetFirstFlightOfPackets(version_, client_config, server_connection_id);
+      GetFirstFlightOfPackets(version_, client_config, server_connection_id,
+                              EmptyQuicConnectionId(),
+                              TestClientCryptoConfig());
   ASSERT_EQ(packets.size(), 2u);
   if (add_reordering) {
     std::swap(packets[0], packets[1]);
@@ -1533,8 +1538,7 @@
  public:
   bool IsWriteBlocked() const override { return false; }
 
-  WriteResult WritePacket(const char* buffer,
-                          size_t buf_len,
+  WriteResult WritePacket(const char* buffer, size_t buf_len,
                           const QuicIpAddress& /*self_client_address*/,
                           const QuicSocketAddress& /*peer_client_address*/,
                           PerPacketOptions* /*options*/) override {
@@ -1864,8 +1868,7 @@
   bool IsWriteBlocked() const override { return write_blocked_; }
   void SetWritable() override { write_blocked_ = false; }
 
-  WriteResult WritePacket(const char* /*buffer*/,
-                          size_t /*buf_len*/,
+  WriteResult WritePacket(const char* /*buffer*/, size_t /*buf_len*/,
                           const QuicIpAddress& /*self_client_address*/,
                           const QuicSocketAddress& /*peer_client_address*/,
                           PerPacketOptions* /*options*/) override {
@@ -2370,8 +2373,7 @@
   }
 
   void ProcessUndecryptableEarlyPacket(
-      const ParsedQuicVersion& version,
-      const QuicSocketAddress& peer_address,
+      const ParsedQuicVersion& version, const QuicSocketAddress& peer_address,
       const QuicConnectionId& server_connection_id) {
     QuicDispatcherTestBase::ProcessUndecryptableEarlyPacket(
         version, peer_address, server_connection_id);
@@ -2394,8 +2396,7 @@
   QuicSocketAddress client_addr_;
 };
 
-INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests,
-                         BufferedPacketStoreTest,
+INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests, BufferedPacketStoreTest,
                          ::testing::ValuesIn(CurrentSupportedVersions()),
                          ::testing::PrintToStringParamName());
 
diff --git a/quic/core/quic_types.cc b/quic/core/quic_types.cc
index 6bd58e9..70c9101 100644
--- a/quic/core/quic_types.cc
+++ b/quic/core/quic_types.cc
@@ -405,13 +405,17 @@
 bool operator==(const ParsedClientHello& a, const ParsedClientHello& b) {
   return a.sni == b.sni && a.uaid == b.uaid && a.alpns == b.alpns &&
          a.legacy_version_encapsulation_inner_packet ==
-             b.legacy_version_encapsulation_inner_packet;
+             b.legacy_version_encapsulation_inner_packet &&
+         a.retry_token == b.retry_token &&
+         a.resumption_attempted == b.resumption_attempted &&
+         a.early_data_attempted == b.early_data_attempted;
 }
 
 std::ostream& operator<<(std::ostream& os,
                          const ParsedClientHello& parsed_chlo) {
   os << "{ sni:" << parsed_chlo.sni << ", uaid:" << parsed_chlo.uaid
      << ", alpns:" << quiche::PrintElements(parsed_chlo.alpns)
+     << ", len(retry_token):" << parsed_chlo.retry_token.size()
      << ", len(inner_packet):"
      << parsed_chlo.legacy_version_encapsulation_inner_packet.size() << " }";
   return os;
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index b048121..200e843 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -872,6 +872,11 @@
   std::string uaid;                // QUIC crypto only.
   std::vector<std::string> alpns;  // QUIC crypto and TLS.
   std::string legacy_version_encapsulation_inner_packet;  // QUIC crypto only.
+  // The unvalidated retry token from the last received packet of a potentially
+  // multi-packet client hello. TLS only.
+  std::string retry_token;
+  bool resumption_attempted = false;  // TLS only.
+  bool early_data_attempted = false;  // TLS only.
 };
 
 QUIC_EXPORT_PRIVATE bool operator==(const ParsedClientHello& a,
diff --git a/quic/core/tls_chlo_extractor.cc b/quic/core/tls_chlo_extractor.cc
index ac1fb18..3a2df79 100644
--- a/quic/core/tls_chlo_extractor.cc
+++ b/quic/core/tls_chlo_extractor.cc
@@ -20,6 +20,16 @@
 
 namespace quic {
 
+namespace {
+bool HasExtension(const SSL_CLIENT_HELLO* client_hello, uint16_t extension) {
+  const uint8_t* unused_extension_bytes;
+  size_t unused_extension_len;
+  return 1 == SSL_early_callback_ctx_extension_get(client_hello, extension,
+                                                   &unused_extension_bytes,
+                                                   &unused_extension_len);
+}
+}  // namespace
+
 TlsChloExtractor::TlsChloExtractor()
     : crypto_stream_sequencer_(this),
       state_(State::kInitial),
@@ -280,6 +290,11 @@
   if (server_name) {
     server_name_ = std::string(server_name);
   }
+
+  resumption_attempted_ =
+      HasExtension(client_hello, TLSEXT_TYPE_pre_shared_key);
+  early_data_attempted_ = HasExtension(client_hello, TLSEXT_TYPE_early_data);
+
   const uint8_t* alpn_data;
   size_t alpn_len;
   int rv = SSL_early_callback_ctx_extension_get(
diff --git a/quic/core/tls_chlo_extractor.h b/quic/core/tls_chlo_extractor.h
index 4b1cf31..29296dd 100644
--- a/quic/core/tls_chlo_extractor.h
+++ b/quic/core/tls_chlo_extractor.h
@@ -45,6 +45,8 @@
   State state() const { return state_; }
   std::vector<std::string> alpns() const { return alpns_; }
   std::string server_name() const { return server_name_; }
+  bool resumption_attempted() const { return resumption_attempted_; }
+  bool early_data_attempted() const { return early_data_attempted_; }
 
   // Converts |state| to a human-readable string suitable for logging.
   static std::string StateToString(State state);
@@ -246,6 +248,12 @@
   std::vector<std::string> alpns_;
   // SNI parsed from the CHLO.
   std::string server_name_;
+  // Whether resumption is attempted from the CHLO, indicated by the
+  // 'pre_shared_key' TLS extension.
+  bool resumption_attempted_ = false;
+  // Whether early data is attempted from the CHLO, indicated by the
+  // 'early_data' TLS extension.
+  bool early_data_attempted_ = false;
 };
 
 // Convenience method to facilitate logging TlsChloExtractor::State.
diff --git a/quic/core/tls_chlo_extractor_test.cc b/quic/core/tls_chlo_extractor_test.cc
index 8b5ee42..8a2ce09 100644
--- a/quic/core/tls_chlo_extractor_test.cc
+++ b/quic/core/tls_chlo_extractor_test.cc
@@ -3,26 +3,87 @@
 // found in the LICENSE file.
 
 #include "quic/core/tls_chlo_extractor.h"
+
 #include <memory>
 
+#include "third_party/boringssl/src/include/openssl/ssl.h"
 #include "quic/core/http/quic_spdy_client_session.h"
 #include "quic/core/quic_connection.h"
 #include "quic/core/quic_packet_writer_wrapper.h"
+#include "quic/core/quic_types.h"
 #include "quic/core/quic_versions.h"
 #include "quic/platform/api/quic_test.h"
 #include "quic/test_tools/crypto_test_utils.h"
 #include "quic/test_tools/first_flight.h"
 #include "quic/test_tools/quic_test_utils.h"
+#include "quic/test_tools/simple_session_cache.h"
 
 namespace quic {
 namespace test {
 namespace {
 
+using testing::_;
+using testing::AnyNumber;
+
 class TlsChloExtractorTest : public QuicTestWithParam<ParsedQuicVersion> {
  protected:
-  TlsChloExtractorTest() : version_(GetParam()) {}
+  TlsChloExtractorTest() : version_(GetParam()), server_id_(TestServerId()) {}
 
   void Initialize() { packets_ = GetFirstFlightOfPackets(version_, config_); }
+  void Initialize(std::unique_ptr<QuicCryptoClientConfig> crypto_config) {
+    packets_ = GetFirstFlightOfPackets(version_, config_, TestConnectionId(),
+                                       EmptyQuicConnectionId(),
+                                       std::move(crypto_config));
+  }
+
+  // Perform a full handshake in order to insert a SSL_SESSION into
+  // crypto_config->session_cache(), which can be used by a TLS resumption.
+  void PerformFullHandshake(QuicCryptoClientConfig* crypto_config) const {
+    ASSERT_NE(crypto_config->session_cache(), nullptr);
+    MockQuicConnectionHelper client_helper, server_helper;
+    MockAlarmFactory alarm_factory;
+    ParsedQuicVersionVector supported_versions = {version_};
+    PacketSavingConnection* client_connection =
+        new PacketSavingConnection(&client_helper, &alarm_factory,
+                                   Perspective::IS_CLIENT, supported_versions);
+    // Advance the time, because timers do not like uninitialized times.
+    client_connection->AdvanceTime(QuicTime::Delta::FromSeconds(1));
+    QuicClientPushPromiseIndex push_promise_index;
+    QuicSpdyClientSession client_session(config_, supported_versions,
+                                         client_connection, server_id_,
+                                         crypto_config, &push_promise_index);
+    client_session.Initialize();
+
+    std::unique_ptr<QuicCryptoServerConfig> server_crypto_config =
+        crypto_test_utils::CryptoServerConfigForTesting();
+    QuicConfig server_config;
+
+    EXPECT_CALL(*client_connection, SendCryptoData(_, _, _)).Times(AnyNumber());
+    client_session.GetMutableCryptoStream()->CryptoConnect();
+
+    crypto_test_utils::HandshakeWithFakeServer(
+        &server_config, server_crypto_config.get(), &server_helper,
+        &alarm_factory, client_connection,
+        client_session.GetMutableCryptoStream(),
+        AlpnForVersion(client_connection->version()));
+
+    // For some reason, the test client can not receive the server settings and
+    // the SSL_SESSION will not be inserted to client's session_cache. We create
+    // a dummy settings and call SetServerApplicationStateForResumption manually
+    // to ensure the SSL_SESSION is cached.
+    // TODO(wub): Fix crypto_test_utils::HandshakeWithFakeServer to make sure a
+    // SSL_SESSION is cached at the client, and remove the rest of the function.
+    SettingsFrame server_settings;
+    server_settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] =
+        kDefaultQpackMaxDynamicTableCapacity;
+    std::unique_ptr<char[]> buffer;
+    uint64_t length =
+        HttpEncoder::SerializeSettingsFrame(server_settings, &buffer);
+    client_session.GetMutableCryptoStream()
+        ->SetServerApplicationStateForResumption(
+            std::make_unique<ApplicationState>(buffer.get(),
+                                               buffer.get() + length));
+  }
 
   void IngestPackets() {
     for (const std::unique_ptr<QuicReceivedPacket>& packet : packets_) {
@@ -62,6 +123,7 @@
   }
 
   ParsedQuicVersion version_;
+  QuicServerId server_id_;
   TlsChloExtractor tls_chlo_extractor_;
   QuicConfig config_;
   std::vector<std::unique_ptr<QuicReceivedPacket>> packets_;
@@ -79,6 +141,42 @@
   ValidateChloDetails();
   EXPECT_EQ(tls_chlo_extractor_.state(),
             TlsChloExtractor::State::kParsedFullSinglePacketChlo);
+  EXPECT_FALSE(tls_chlo_extractor_.resumption_attempted());
+  EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted());
+}
+
+TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ResumptionOnly) {
+  auto crypto_client_config = std::make_unique<QuicCryptoClientConfig>(
+      crypto_test_utils::ProofVerifierForTesting(),
+      std::make_unique<SimpleSessionCache>());
+  PerformFullHandshake(crypto_client_config.get());
+
+  SSL_CTX_set_early_data_enabled(crypto_client_config->ssl_ctx(), 0);
+  Initialize(std::move(crypto_client_config));
+  EXPECT_GE(packets_.size(), 1u);
+  IngestPackets();
+  ValidateChloDetails();
+  EXPECT_EQ(tls_chlo_extractor_.state(),
+            TlsChloExtractor::State::kParsedFullSinglePacketChlo);
+  EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted());
+  EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted());
+}
+
+TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ZeroRtt) {
+  auto crypto_client_config = std::make_unique<QuicCryptoClientConfig>(
+      crypto_test_utils::ProofVerifierForTesting(),
+      std::make_unique<SimpleSessionCache>());
+  PerformFullHandshake(crypto_client_config.get());
+
+  IncreaseSizeOfChlo();
+  Initialize(std::move(crypto_client_config));
+  EXPECT_GE(packets_.size(), 1u);
+  IngestPackets();
+  ValidateChloDetails();
+  EXPECT_EQ(tls_chlo_extractor_.state(),
+            TlsChloExtractor::State::kParsedFullMultiPacketChlo);
+  EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted());
+  EXPECT_TRUE(tls_chlo_extractor_.early_data_attempted());
 }
 
 TEST_P(TlsChloExtractorTest, MultiPacket) {
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc
index febd8d3..634c16f 100644
--- a/quic/test_tools/crypto_test_utils.cc
+++ b/quic/test_tools/crypto_test_utils.cc
@@ -228,7 +228,7 @@
                             MockQuicConnectionHelper* helper,
                             MockAlarmFactory* alarm_factory,
                             PacketSavingConnection* client_conn,
-                            QuicCryptoClientStream* client,
+                            QuicCryptoClientStreamBase* client,
                             std::string alpn) {
   auto* server_conn = new testing::NiceMock<PacketSavingConnection>(
       helper, alarm_factory, Perspective::IS_SERVER,
@@ -593,7 +593,7 @@
 
 }  // namespace
 
-void CompareClientAndServerKeys(QuicCryptoClientStream* client,
+void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client,
                                 QuicCryptoServerStreamBase* server) {
   QuicFramer* client_framer = QuicConnectionPeer::GetFramer(
       QuicStreamPeer::session(client)->connection());
diff --git a/quic/test_tools/crypto_test_utils.h b/quic/test_tools/crypto_test_utils.h
index d839c6c..87e354a 100644
--- a/quic/test_tools/crypto_test_utils.h
+++ b/quic/test_tools/crypto_test_utils.h
@@ -78,7 +78,7 @@
                             MockQuicConnectionHelper* helper,
                             MockAlarmFactory* alarm_factory,
                             PacketSavingConnection* client_conn,
-                            QuicCryptoClientStream* client,
+                            QuicCryptoClientStreamBase* client,
                             std::string alpn);
 
 // returns: the number of client hellos that the client sent.
@@ -195,7 +195,7 @@
     QuicCompressedCertsCache* compressed_certs_cache,
     CryptoHandshakeMessage* out);
 
-void CompareClientAndServerKeys(QuicCryptoClientStream* client,
+void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client,
                                 QuicCryptoServerStreamBase* server);
 
 // Return a CHLO nonce in hexadecimal.
diff --git a/quic/test_tools/first_flight.cc b/quic/test_tools/first_flight.cc
index 435bb83..ec0a5e3 100644
--- a/quic/test_tools/first_flight.cc
+++ b/quic/test_tools/first_flight.cc
@@ -32,18 +32,28 @@
   FirstFlightExtractor(const ParsedQuicVersion& version,
                        const QuicConfig& config,
                        const QuicConnectionId& server_connection_id,
-                       const QuicConnectionId& client_connection_id)
+                       const QuicConnectionId& client_connection_id,
+                       std::unique_ptr<QuicCryptoClientConfig> crypto_config)
       : version_(version),
         server_connection_id_(server_connection_id),
         client_connection_id_(client_connection_id),
         writer_(this),
         config_(config),
-        crypto_config_(crypto_test_utils::ProofVerifierForTesting()) {
+        crypto_config_(std::move(crypto_config)) {
     EXPECT_NE(version_, UnsupportedQuicVersion());
   }
 
+  FirstFlightExtractor(const ParsedQuicVersion& version,
+                       const QuicConfig& config,
+                       const QuicConnectionId& server_connection_id,
+                       const QuicConnectionId& client_connection_id)
+      : FirstFlightExtractor(
+            version, config, server_connection_id, client_connection_id,
+            std::make_unique<QuicCryptoClientConfig>(
+                crypto_test_utils::ProofVerifierForTesting())) {}
+
   void GenerateFirstFlight() {
-    crypto_config_.set_alpn(AlpnForVersion(version_));
+    crypto_config_->set_alpn(AlpnForVersion(version_));
     connection_ =
         new QuicConnection(server_connection_id_,
                            /*initial_self_address=*/QuicSocketAddress(),
@@ -55,7 +65,7 @@
     session_ = std::make_unique<QuicSpdyClientSession>(
         config_, ParsedQuicVersionVector{version_},
         connection_,  // session_ takes ownership of connection_ here.
-        TestServerId(), &crypto_config_, &push_promise_index_);
+        TestServerId(), crypto_config_.get(), &push_promise_index_);
     session_->Initialize();
     session_->CryptoConnect();
   }
@@ -84,7 +94,7 @@
   MockAlarmFactory alarm_factory_;
   DelegatedPacketWriter writer_;
   QuicConfig config_;
-  QuicCryptoClientConfig crypto_config_;
+  std::unique_ptr<QuicCryptoClientConfig> crypto_config_;
   QuicClientPushPromiseIndex push_promise_index_;
   QuicConnection* connection_;  // Owned by session_.
   std::unique_ptr<QuicSpdyClientSession> session_;
@@ -92,6 +102,18 @@
 };
 
 std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
+    const ParsedQuicVersion& version, const QuicConfig& config,
+    const QuicConnectionId& server_connection_id,
+    const QuicConnectionId& client_connection_id,
+    std::unique_ptr<QuicCryptoClientConfig> crypto_config) {
+  FirstFlightExtractor first_flight_extractor(
+      version, config, server_connection_id, client_connection_id,
+      std::move(crypto_config));
+  first_flight_extractor.GenerateFirstFlight();
+  return first_flight_extractor.ConsumePackets();
+}
+
+std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
     const ParsedQuicVersion& version,
     const QuicConfig& config,
     const QuicConnectionId& server_connection_id,
diff --git a/quic/test_tools/first_flight.h b/quic/test_tools/first_flight.h
index c18c879..448a189 100644
--- a/quic/test_tools/first_flight.h
+++ b/quic/test_tools/first_flight.h
@@ -8,6 +8,7 @@
 #include <memory>
 #include <vector>
 
+#include "quic/core/crypto/quic_crypto_client_config.h"
 #include "quic/core/quic_config.h"
 #include "quic/core/quic_connection_id.h"
 #include "quic/core/quic_packet_writer.h"
@@ -74,16 +75,23 @@
 // HTTP/3 connection. In most cases, this array will only contain one packet
 // that carries the CHLO.
 std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
-    const ParsedQuicVersion& version,
-    const QuicConfig& config,
+    const ParsedQuicVersion& version, const QuicConfig& config,
     const QuicConnectionId& server_connection_id,
-    const QuicConnectionId& client_connection_id);
+    const QuicConnectionId& client_connection_id,
+    std::unique_ptr<QuicCryptoClientConfig> crypto_config);
 
 // Below are various convenience overloads that use default values for the
 // omitted parameters:
 // |config| = DefaultQuicConfig(),
 // |server_connection_id| = TestConnectionId(),
 // |client_connection_id| = EmptyQuicConnectionId().
+// |crypto_config| =
+//     QuicCryptoClientConfig(crypto_test_utils::ProofVerifierForTesting())
+std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
+    const ParsedQuicVersion& version, const QuicConfig& config,
+    const QuicConnectionId& server_connection_id,
+    const QuicConnectionId& client_connection_id);
+
 std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets(
     const ParsedQuicVersion& version,
     const QuicConfig& config,
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc
index da095dc..1149721 100644
--- a/quic/test_tools/quic_test_utils.cc
+++ b/quic/test_tools/quic_test_utils.cc
@@ -88,9 +88,7 @@
                                   sizeof(kStatelessResetTokenDataForTest));
 }
 
-std::string TestHostname() {
-  return "test.example.org";
-}
+std::string TestHostname() { return "test.example.com"; }
 
 QuicServerId TestServerId() {
   return QuicServerId(TestHostname(), kTestPort);