Add retry_token, resumption_attempted and early_data_attempted to quic::ParsedClientHello. This is a follow up to cl/406887508. PiperOrigin-RevId: 409242611
diff --git a/quic/core/quic_buffered_packet_store.cc b/quic/core/quic_buffered_packet_store.cc index 291434f..f597ce0 100644 --- a/quic/core/quic_buffered_packet_store.cc +++ b/quic/core/quic_buffered_packet_store.cc
@@ -255,11 +255,10 @@ } bool QuicBufferedPacketStore::IngestPacketForTlsChloExtraction( - const QuicConnectionId& connection_id, - const ParsedQuicVersion& version, - const QuicReceivedPacket& packet, - std::vector<std::string>* out_alpns, - std::string* out_sni) { + const QuicConnectionId& connection_id, const ParsedQuicVersion& version, + const QuicReceivedPacket& packet, std::vector<std::string>* out_alpns, + std::string* out_sni, bool* out_resumption_attempted, + bool* out_early_data_attempted) { QUICHE_DCHECK_NE(out_alpns, nullptr); QUICHE_DCHECK_NE(out_sni, nullptr); QUICHE_DCHECK_EQ(version.handshake_protocol, PROTOCOL_TLS1_3); @@ -273,8 +272,11 @@ if (!it->second.tls_chlo_extractor.HasParsedFullChlo()) { return false; } - *out_alpns = it->second.tls_chlo_extractor.alpns(); - *out_sni = it->second.tls_chlo_extractor.server_name(); + const TlsChloExtractor& tls_chlo_extractor = it->second.tls_chlo_extractor; + *out_alpns = tls_chlo_extractor.alpns(); + *out_sni = tls_chlo_extractor.server_name(); + *out_resumption_attempted = tls_chlo_extractor.resumption_attempted(); + *out_early_data_attempted = tls_chlo_extractor.early_data_attempted(); return true; }
diff --git a/quic/core/quic_buffered_packet_store.h b/quic/core/quic_buffered_packet_store.h index b363802..3206b37 100644 --- a/quic/core/quic_buffered_packet_store.h +++ b/quic/core/quic_buffered_packet_store.h
@@ -117,11 +117,16 @@ // Returns whether we've now parsed a full multi-packet TLS CHLO. // When this returns true, |out_alpns| is populated with the list of ALPNs // extracted from the CHLO. |out_sni| is populated with the SNI tag in CHLO. + // |out_resumption_attempted| is populated if the CHLO has the + // 'pre_shared_key' TLS extension. |out_early_data_attempted| is populated if + // the CHLO has the 'early_data' TLS extension. bool IngestPacketForTlsChloExtraction(const QuicConnectionId& connection_id, const ParsedQuicVersion& version, const QuicReceivedPacket& packet, std::vector<std::string>* out_alpns, - std::string* out_sni); + std::string* out_sni, + bool* out_resumption_attempted, + bool* out_early_data_attempted); // Returns the list of buffered packets for |connection_id| and removes them // from the store. Returns an empty list if no early arrived packets for this
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc index 1484b74..ef6db5d 100644 --- a/quic/core/quic_dispatcher.cc +++ b/quic/core/quic_dispatcher.cc
@@ -780,13 +780,15 @@ bool has_full_tls_chlo = false; std::string sni; std::vector<std::string> alpns; + bool resumption_attempted = false, early_data_attempted = false; if (buffered_packets_.HasBufferedPackets( packet_info.destination_connection_id)) { // If we already have buffered packets for this connection ID, // use the associated TlsChloExtractor to parse this packet. has_full_tls_chlo = buffered_packets_.IngestPacketForTlsChloExtraction( packet_info.destination_connection_id, packet_info.version, - packet_info.packet, &alpns, &sni); + packet_info.packet, &alpns, &sni, &resumption_attempted, + &early_data_attempted); } else { // If we do not have a BufferedPacketList for this connection ID, // create a single-use one to check whether this packet contains a @@ -798,6 +800,8 @@ has_full_tls_chlo = true; alpns = tls_chlo_extractor.alpns(); sni = tls_chlo_extractor.server_name(); + resumption_attempted = tls_chlo_extractor.resumption_attempted(); + early_data_attempted = tls_chlo_extractor.early_data_attempted(); } } if (!has_full_tls_chlo) { @@ -811,6 +815,11 @@ ParsedClientHello parsed_chlo; parsed_chlo.sni = std::move(sni); parsed_chlo.alpns = std::move(alpns); + if (packet_info.retry_token.has_value()) { + parsed_chlo.retry_token = std::string(*packet_info.retry_token); + } + parsed_chlo.resumption_attempted = resumption_attempted; + parsed_chlo.early_data_attempted = early_data_attempted; return parsed_chlo; }
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc index 1dfda6c..4d59c13 100644 --- a/quic/core/quic_dispatcher_test.cc +++ b/quic/core/quic_dispatcher_test.cc
@@ -68,12 +68,8 @@ QuicConnection* connection, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache) - : QuicServerSessionBase(config, - CurrentSupportedVersions(), - connection, - nullptr, - nullptr, - crypto_config, + : QuicServerSessionBase(config, CurrentSupportedVersions(), connection, + nullptr, nullptr, crypto_config, compressed_certs_cache) { Initialize(); } @@ -83,26 +79,17 @@ ~TestQuicSpdyServerSession() override { DeleteConnection(); } - MOCK_METHOD(void, - OnConnectionClosed, + MOCK_METHOD(void, OnConnectionClosed, (const QuicConnectionCloseFrame& frame, ConnectionCloseSource source), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (QuicStreamId id), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (PendingStream*), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingBidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingUnidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), (override)); std::unique_ptr<QuicCryptoServerStreamBase> CreateQuicCryptoServerStream( @@ -138,10 +125,8 @@ const ParsedClientHello& parsed_chlo), (override)); - MOCK_METHOD(bool, - ShouldCreateOrBufferPacketForConnection, - (const ReceivedPacketInfo& packet_info), - (override)); + MOCK_METHOD(bool, ShouldCreateOrBufferPacketForConnection, + (const ReceivedPacketInfo& packet_info), (override)); struct TestQuicPerPacketContext : public QuicPerPacketContext { std::string custom_packet_context; @@ -179,9 +164,7 @@ MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, QuicDispatcher* dispatcher) - : MockQuicConnection(connection_id, - helper, - alarm_factory, + : MockQuicConnection(connection_id, helper, alarm_factory, Perspective::IS_SERVER), dispatcher_(dispatcher), active_connection_ids_({connection_id}) {} @@ -225,15 +208,12 @@ : version_(GetParam()), version_manager_(AllSupportedVersions()), crypto_config_(QuicCryptoServerConfig::TESTING, - QuicRandom::GetInstance(), - std::move(proof_source), + QuicRandom::GetInstance(), std::move(proof_source), KeyExchangeSource::Default()), server_address_(QuicIpAddress::Any4(), 5), - dispatcher_( - new NiceMock<TestDispatcher>(&config_, - &crypto_config_, - &version_manager_, - mock_helper_.GetRandomGenerator())), + dispatcher_(new NiceMock<TestDispatcher>( + &config_, &crypto_config_, &version_manager_, + mock_helper_.GetRandomGenerator())), time_wait_list_manager_(nullptr), session1_(nullptr), session2_(nullptr), @@ -268,8 +248,7 @@ // using the version under test. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - const std::string& data) { + bool has_version_flag, const std::string& data) { ProcessPacket(peer_address, server_connection_id, has_version_flag, data, CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER); } @@ -278,8 +257,7 @@ // using the version under test. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - const std::string& data, + bool has_version_flag, const std::string& data, QuicConnectionIdIncluded server_connection_id_included, QuicPacketNumberLength packet_number_length) { ProcessPacket(peer_address, server_connection_id, has_version_flag, data, @@ -289,8 +267,7 @@ // Process a packet using the version under test. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - const std::string& data, + bool has_version_flag, const std::string& data, QuicConnectionIdIncluded server_connection_id_included, QuicPacketNumberLength packet_number_length, uint64_t packet_number) { @@ -302,10 +279,8 @@ // Processes a packet. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - ParsedQuicVersion version, - const std::string& data, - bool full_padding, + bool has_version_flag, ParsedQuicVersion version, + const std::string& data, bool full_padding, QuicConnectionIdIncluded server_connection_id_included, QuicPacketNumberLength packet_number_length, uint64_t packet_number) { @@ -319,10 +294,8 @@ void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, QuicConnectionId client_connection_id, - bool has_version_flag, - ParsedQuicVersion version, - const std::string& data, - bool full_padding, + bool has_version_flag, ParsedQuicVersion version, + const std::string& data, bool full_padding, QuicConnectionIdIncluded server_connection_id_included, QuicConnectionIdIncluded client_connection_id_included, QuicPacketNumberLength packet_number_length, @@ -340,8 +313,7 @@ void ProcessReceivedPacket( std::unique_ptr<QuicReceivedPacket> received_packet, - const QuicSocketAddress& peer_address, - const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, const ParsedQuicVersion& version, const QuicConnectionId& server_connection_id) { if (version.UsesQuicCrypto() && ChloExtractor::Extract(*received_packet, version, {}, nullptr, @@ -367,12 +339,9 @@ } std::unique_ptr<QuicSession> CreateSession( - TestDispatcher* dispatcher, - const QuicConfig& config, - QuicConnectionId connection_id, - const QuicSocketAddress& /*peer_address*/, - MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, + TestDispatcher* dispatcher, const QuicConfig& config, + QuicConnectionId connection_id, const QuicSocketAddress& /*peer_address*/, + MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache, TestQuicSpdyServerSession** session_ptr) { @@ -414,8 +383,7 @@ } void ProcessUndecryptableEarlyPacket( - const ParsedQuicVersion& version, - const QuicSocketAddress& peer_address, + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, const QuicConnectionId& server_connection_id) { std::unique_ptr<QuicEncryptedPacket> encrypted_packet = GetUndecryptableEarlyPacket(version, server_connection_id); @@ -441,15 +409,41 @@ const QuicSocketAddress& peer_address, const QuicConnectionId& server_connection_id, const QuicConnectionId& client_connection_id) { + ProcessFirstFlight(version, peer_address, server_connection_id, + client_connection_id, TestClientCryptoConfig()); + } + + void ProcessFirstFlight( + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr<QuicCryptoClientConfig> client_crypto_config) { std::vector<std::unique_ptr<QuicReceivedPacket>> packets = - GetFirstFlightOfPackets(version, server_connection_id, - client_connection_id); + GetFirstFlightOfPackets(version, DefaultQuicConfig(), + server_connection_id, client_connection_id, + std::move(client_crypto_config)); for (auto&& packet : packets) { ProcessReceivedPacket(std::move(packet), peer_address, version, server_connection_id); } } + std::unique_ptr<QuicCryptoClientConfig> TestClientCryptoConfig() { + auto client_crypto_config = std::make_unique<QuicCryptoClientConfig>( + crypto_test_utils::ProofVerifierForTesting()); + if (address_token_.has_value()) { + client_crypto_config->LookupOrCreate(TestServerId()) + ->set_source_address_token(*address_token_); + } + return client_crypto_config; + } + + // If called, the first flight packets generated in |ProcessFirstFlight| will + // contain the given |address_token|. + void SetAddressToken(std::string address_token) { + address_token_ = std::move(address_token); + } + std::string ExpectedAlpnForVersion(ParsedQuicVersion version) { return AlpnForVersion(version); } @@ -460,6 +454,9 @@ ParsedClientHello parsed_chlo; parsed_chlo.alpns = {ExpectedAlpn()}; parsed_chlo.sni = TestHostname(); + if (address_token_.has_value()) { + parsed_chlo.retry_token = *address_token_; + } return parsed_chlo; } @@ -521,6 +518,7 @@ std::map<QuicConnectionId, std::list<std::string>> data_connection_map_; QuicBufferedPacketStore* store_; uint64_t connection_id_; + absl::optional<std::string> address_token_; }; class QuicDispatcherTestAllVersions : public QuicDispatcherTestBase {}; @@ -540,11 +538,14 @@ if (version_.UsesQuicCrypto()) { return; } + SetAddressToken("hsdifghdsaifnasdpfjdsk"); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); - EXPECT_CALL(*dispatcher_, - CreateQuicSession(TestConnectionId(1), _, client_address, - Eq(ExpectedAlpn()), _, _)) + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, TestConnectionId(1), client_address, &mock_helper_, &mock_alarm_factory_, &crypto_config_, @@ -566,6 +567,8 @@ if (!version_.UsesTls()) { return; } + SetAddressToken("857293462398"); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); QuicConnectionId server_connection_id = TestConnectionId(); QuicConfig client_config = DefaultQuicConfig(); @@ -576,7 +579,9 @@ client_config.custom_transport_parameters_to_send()[kCustomParameterId] = kCustomParameterValue; std::vector<std::unique_ptr<QuicReceivedPacket>> packets = - GetFirstFlightOfPackets(version_, client_config, server_connection_id); + GetFirstFlightOfPackets(version_, client_config, server_connection_id, + EmptyQuicConnectionId(), + TestClientCryptoConfig()); ASSERT_EQ(packets.size(), 2u); if (add_reordering) { std::swap(packets[0], packets[1]); @@ -1533,8 +1538,7 @@ public: bool IsWriteBlocked() const override { return false; } - WriteResult WritePacket(const char* buffer, - size_t buf_len, + WriteResult WritePacket(const char* buffer, size_t buf_len, const QuicIpAddress& /*self_client_address*/, const QuicSocketAddress& /*peer_client_address*/, PerPacketOptions* /*options*/) override { @@ -1864,8 +1868,7 @@ bool IsWriteBlocked() const override { return write_blocked_; } void SetWritable() override { write_blocked_ = false; } - WriteResult WritePacket(const char* /*buffer*/, - size_t /*buf_len*/, + WriteResult WritePacket(const char* /*buffer*/, size_t /*buf_len*/, const QuicIpAddress& /*self_client_address*/, const QuicSocketAddress& /*peer_client_address*/, PerPacketOptions* /*options*/) override { @@ -2370,8 +2373,7 @@ } void ProcessUndecryptableEarlyPacket( - const ParsedQuicVersion& version, - const QuicSocketAddress& peer_address, + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, const QuicConnectionId& server_connection_id) { QuicDispatcherTestBase::ProcessUndecryptableEarlyPacket( version, peer_address, server_connection_id); @@ -2394,8 +2396,7 @@ QuicSocketAddress client_addr_; }; -INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests, - BufferedPacketStoreTest, +INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests, BufferedPacketStoreTest, ::testing::ValuesIn(CurrentSupportedVersions()), ::testing::PrintToStringParamName());
diff --git a/quic/core/quic_types.cc b/quic/core/quic_types.cc index 6bd58e9..70c9101 100644 --- a/quic/core/quic_types.cc +++ b/quic/core/quic_types.cc
@@ -405,13 +405,17 @@ bool operator==(const ParsedClientHello& a, const ParsedClientHello& b) { return a.sni == b.sni && a.uaid == b.uaid && a.alpns == b.alpns && a.legacy_version_encapsulation_inner_packet == - b.legacy_version_encapsulation_inner_packet; + b.legacy_version_encapsulation_inner_packet && + a.retry_token == b.retry_token && + a.resumption_attempted == b.resumption_attempted && + a.early_data_attempted == b.early_data_attempted; } std::ostream& operator<<(std::ostream& os, const ParsedClientHello& parsed_chlo) { os << "{ sni:" << parsed_chlo.sni << ", uaid:" << parsed_chlo.uaid << ", alpns:" << quiche::PrintElements(parsed_chlo.alpns) + << ", len(retry_token):" << parsed_chlo.retry_token.size() << ", len(inner_packet):" << parsed_chlo.legacy_version_encapsulation_inner_packet.size() << " }"; return os;
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h index b048121..200e843 100644 --- a/quic/core/quic_types.h +++ b/quic/core/quic_types.h
@@ -872,6 +872,11 @@ std::string uaid; // QUIC crypto only. std::vector<std::string> alpns; // QUIC crypto and TLS. std::string legacy_version_encapsulation_inner_packet; // QUIC crypto only. + // The unvalidated retry token from the last received packet of a potentially + // multi-packet client hello. TLS only. + std::string retry_token; + bool resumption_attempted = false; // TLS only. + bool early_data_attempted = false; // TLS only. }; QUIC_EXPORT_PRIVATE bool operator==(const ParsedClientHello& a,
diff --git a/quic/core/tls_chlo_extractor.cc b/quic/core/tls_chlo_extractor.cc index ac1fb18..3a2df79 100644 --- a/quic/core/tls_chlo_extractor.cc +++ b/quic/core/tls_chlo_extractor.cc
@@ -20,6 +20,16 @@ namespace quic { +namespace { +bool HasExtension(const SSL_CLIENT_HELLO* client_hello, uint16_t extension) { + const uint8_t* unused_extension_bytes; + size_t unused_extension_len; + return 1 == SSL_early_callback_ctx_extension_get(client_hello, extension, + &unused_extension_bytes, + &unused_extension_len); +} +} // namespace + TlsChloExtractor::TlsChloExtractor() : crypto_stream_sequencer_(this), state_(State::kInitial), @@ -280,6 +290,11 @@ if (server_name) { server_name_ = std::string(server_name); } + + resumption_attempted_ = + HasExtension(client_hello, TLSEXT_TYPE_pre_shared_key); + early_data_attempted_ = HasExtension(client_hello, TLSEXT_TYPE_early_data); + const uint8_t* alpn_data; size_t alpn_len; int rv = SSL_early_callback_ctx_extension_get(
diff --git a/quic/core/tls_chlo_extractor.h b/quic/core/tls_chlo_extractor.h index 4b1cf31..29296dd 100644 --- a/quic/core/tls_chlo_extractor.h +++ b/quic/core/tls_chlo_extractor.h
@@ -45,6 +45,8 @@ State state() const { return state_; } std::vector<std::string> alpns() const { return alpns_; } std::string server_name() const { return server_name_; } + bool resumption_attempted() const { return resumption_attempted_; } + bool early_data_attempted() const { return early_data_attempted_; } // Converts |state| to a human-readable string suitable for logging. static std::string StateToString(State state); @@ -246,6 +248,12 @@ std::vector<std::string> alpns_; // SNI parsed from the CHLO. std::string server_name_; + // Whether resumption is attempted from the CHLO, indicated by the + // 'pre_shared_key' TLS extension. + bool resumption_attempted_ = false; + // Whether early data is attempted from the CHLO, indicated by the + // 'early_data' TLS extension. + bool early_data_attempted_ = false; }; // Convenience method to facilitate logging TlsChloExtractor::State.
diff --git a/quic/core/tls_chlo_extractor_test.cc b/quic/core/tls_chlo_extractor_test.cc index 8b5ee42..8a2ce09 100644 --- a/quic/core/tls_chlo_extractor_test.cc +++ b/quic/core/tls_chlo_extractor_test.cc
@@ -3,26 +3,87 @@ // found in the LICENSE file. #include "quic/core/tls_chlo_extractor.h" + #include <memory> +#include "third_party/boringssl/src/include/openssl/ssl.h" #include "quic/core/http/quic_spdy_client_session.h" #include "quic/core/quic_connection.h" #include "quic/core/quic_packet_writer_wrapper.h" +#include "quic/core/quic_types.h" #include "quic/core/quic_versions.h" #include "quic/platform/api/quic_test.h" #include "quic/test_tools/crypto_test_utils.h" #include "quic/test_tools/first_flight.h" #include "quic/test_tools/quic_test_utils.h" +#include "quic/test_tools/simple_session_cache.h" namespace quic { namespace test { namespace { +using testing::_; +using testing::AnyNumber; + class TlsChloExtractorTest : public QuicTestWithParam<ParsedQuicVersion> { protected: - TlsChloExtractorTest() : version_(GetParam()) {} + TlsChloExtractorTest() : version_(GetParam()), server_id_(TestServerId()) {} void Initialize() { packets_ = GetFirstFlightOfPackets(version_, config_); } + void Initialize(std::unique_ptr<QuicCryptoClientConfig> crypto_config) { + packets_ = GetFirstFlightOfPackets(version_, config_, TestConnectionId(), + EmptyQuicConnectionId(), + std::move(crypto_config)); + } + + // Perform a full handshake in order to insert a SSL_SESSION into + // crypto_config->session_cache(), which can be used by a TLS resumption. + void PerformFullHandshake(QuicCryptoClientConfig* crypto_config) const { + ASSERT_NE(crypto_config->session_cache(), nullptr); + MockQuicConnectionHelper client_helper, server_helper; + MockAlarmFactory alarm_factory; + ParsedQuicVersionVector supported_versions = {version_}; + PacketSavingConnection* client_connection = + new PacketSavingConnection(&client_helper, &alarm_factory, + Perspective::IS_CLIENT, supported_versions); + // Advance the time, because timers do not like uninitialized times. + client_connection->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + QuicClientPushPromiseIndex push_promise_index; + QuicSpdyClientSession client_session(config_, supported_versions, + client_connection, server_id_, + crypto_config, &push_promise_index); + client_session.Initialize(); + + std::unique_ptr<QuicCryptoServerConfig> server_crypto_config = + crypto_test_utils::CryptoServerConfigForTesting(); + QuicConfig server_config; + + EXPECT_CALL(*client_connection, SendCryptoData(_, _, _)).Times(AnyNumber()); + client_session.GetMutableCryptoStream()->CryptoConnect(); + + crypto_test_utils::HandshakeWithFakeServer( + &server_config, server_crypto_config.get(), &server_helper, + &alarm_factory, client_connection, + client_session.GetMutableCryptoStream(), + AlpnForVersion(client_connection->version())); + + // For some reason, the test client can not receive the server settings and + // the SSL_SESSION will not be inserted to client's session_cache. We create + // a dummy settings and call SetServerApplicationStateForResumption manually + // to ensure the SSL_SESSION is cached. + // TODO(wub): Fix crypto_test_utils::HandshakeWithFakeServer to make sure a + // SSL_SESSION is cached at the client, and remove the rest of the function. + SettingsFrame server_settings; + server_settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = + kDefaultQpackMaxDynamicTableCapacity; + std::unique_ptr<char[]> buffer; + uint64_t length = + HttpEncoder::SerializeSettingsFrame(server_settings, &buffer); + client_session.GetMutableCryptoStream() + ->SetServerApplicationStateForResumption( + std::make_unique<ApplicationState>(buffer.get(), + buffer.get() + length)); + } void IngestPackets() { for (const std::unique_ptr<QuicReceivedPacket>& packet : packets_) { @@ -62,6 +123,7 @@ } ParsedQuicVersion version_; + QuicServerId server_id_; TlsChloExtractor tls_chlo_extractor_; QuicConfig config_; std::vector<std::unique_ptr<QuicReceivedPacket>> packets_; @@ -79,6 +141,42 @@ ValidateChloDetails(); EXPECT_EQ(tls_chlo_extractor_.state(), TlsChloExtractor::State::kParsedFullSinglePacketChlo); + EXPECT_FALSE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ResumptionOnly) { + auto crypto_client_config = std::make_unique<QuicCryptoClientConfig>( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique<SimpleSessionCache>()); + PerformFullHandshake(crypto_client_config.get()); + + SSL_CTX_set_early_data_enabled(crypto_client_config->ssl_ctx(), 0); + Initialize(std::move(crypto_client_config)); + EXPECT_GE(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); + EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ZeroRtt) { + auto crypto_client_config = std::make_unique<QuicCryptoClientConfig>( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique<SimpleSessionCache>()); + PerformFullHandshake(crypto_client_config.get()); + + IncreaseSizeOfChlo(); + Initialize(std::move(crypto_client_config)); + EXPECT_GE(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullMultiPacketChlo); + EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_TRUE(tls_chlo_extractor_.early_data_attempted()); } TEST_P(TlsChloExtractorTest, MultiPacket) {
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc index febd8d3..634c16f 100644 --- a/quic/test_tools/crypto_test_utils.cc +++ b/quic/test_tools/crypto_test_utils.cc
@@ -228,7 +228,7 @@ MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, PacketSavingConnection* client_conn, - QuicCryptoClientStream* client, + QuicCryptoClientStreamBase* client, std::string alpn) { auto* server_conn = new testing::NiceMock<PacketSavingConnection>( helper, alarm_factory, Perspective::IS_SERVER, @@ -593,7 +593,7 @@ } // namespace -void CompareClientAndServerKeys(QuicCryptoClientStream* client, +void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client, QuicCryptoServerStreamBase* server) { QuicFramer* client_framer = QuicConnectionPeer::GetFramer( QuicStreamPeer::session(client)->connection());
diff --git a/quic/test_tools/crypto_test_utils.h b/quic/test_tools/crypto_test_utils.h index d839c6c..87e354a 100644 --- a/quic/test_tools/crypto_test_utils.h +++ b/quic/test_tools/crypto_test_utils.h
@@ -78,7 +78,7 @@ MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, PacketSavingConnection* client_conn, - QuicCryptoClientStream* client, + QuicCryptoClientStreamBase* client, std::string alpn); // returns: the number of client hellos that the client sent. @@ -195,7 +195,7 @@ QuicCompressedCertsCache* compressed_certs_cache, CryptoHandshakeMessage* out); -void CompareClientAndServerKeys(QuicCryptoClientStream* client, +void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client, QuicCryptoServerStreamBase* server); // Return a CHLO nonce in hexadecimal.
diff --git a/quic/test_tools/first_flight.cc b/quic/test_tools/first_flight.cc index 435bb83..ec0a5e3 100644 --- a/quic/test_tools/first_flight.cc +++ b/quic/test_tools/first_flight.cc
@@ -32,18 +32,28 @@ FirstFlightExtractor(const ParsedQuicVersion& version, const QuicConfig& config, const QuicConnectionId& server_connection_id, - const QuicConnectionId& client_connection_id) + const QuicConnectionId& client_connection_id, + std::unique_ptr<QuicCryptoClientConfig> crypto_config) : version_(version), server_connection_id_(server_connection_id), client_connection_id_(client_connection_id), writer_(this), config_(config), - crypto_config_(crypto_test_utils::ProofVerifierForTesting()) { + crypto_config_(std::move(crypto_config)) { EXPECT_NE(version_, UnsupportedQuicVersion()); } + FirstFlightExtractor(const ParsedQuicVersion& version, + const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) + : FirstFlightExtractor( + version, config, server_connection_id, client_connection_id, + std::make_unique<QuicCryptoClientConfig>( + crypto_test_utils::ProofVerifierForTesting())) {} + void GenerateFirstFlight() { - crypto_config_.set_alpn(AlpnForVersion(version_)); + crypto_config_->set_alpn(AlpnForVersion(version_)); connection_ = new QuicConnection(server_connection_id_, /*initial_self_address=*/QuicSocketAddress(), @@ -55,7 +65,7 @@ session_ = std::make_unique<QuicSpdyClientSession>( config_, ParsedQuicVersionVector{version_}, connection_, // session_ takes ownership of connection_ here. - TestServerId(), &crypto_config_, &push_promise_index_); + TestServerId(), crypto_config_.get(), &push_promise_index_); session_->Initialize(); session_->CryptoConnect(); } @@ -84,7 +94,7 @@ MockAlarmFactory alarm_factory_; DelegatedPacketWriter writer_; QuicConfig config_; - QuicCryptoClientConfig crypto_config_; + std::unique_ptr<QuicCryptoClientConfig> crypto_config_; QuicClientPushPromiseIndex push_promise_index_; QuicConnection* connection_; // Owned by session_. std::unique_ptr<QuicSpdyClientSession> session_; @@ -92,6 +102,18 @@ }; std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr<QuicCryptoClientConfig> crypto_config) { + FirstFlightExtractor first_flight_extractor( + version, config, server_connection_id, client_connection_id, + std::move(crypto_config)); + first_flight_extractor.GenerateFirstFlight(); + return first_flight_extractor.ConsumePackets(); +} + +std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets( const ParsedQuicVersion& version, const QuicConfig& config, const QuicConnectionId& server_connection_id,
diff --git a/quic/test_tools/first_flight.h b/quic/test_tools/first_flight.h index c18c879..448a189 100644 --- a/quic/test_tools/first_flight.h +++ b/quic/test_tools/first_flight.h
@@ -8,6 +8,7 @@ #include <memory> #include <vector> +#include "quic/core/crypto/quic_crypto_client_config.h" #include "quic/core/quic_config.h" #include "quic/core/quic_connection_id.h" #include "quic/core/quic_packet_writer.h" @@ -74,16 +75,23 @@ // HTTP/3 connection. In most cases, this array will only contain one packet // that carries the CHLO. std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets( - const ParsedQuicVersion& version, - const QuicConfig& config, + const ParsedQuicVersion& version, const QuicConfig& config, const QuicConnectionId& server_connection_id, - const QuicConnectionId& client_connection_id); + const QuicConnectionId& client_connection_id, + std::unique_ptr<QuicCryptoClientConfig> crypto_config); // Below are various convenience overloads that use default values for the // omitted parameters: // |config| = DefaultQuicConfig(), // |server_connection_id| = TestConnectionId(), // |client_connection_id| = EmptyQuicConnectionId(). +// |crypto_config| = +// QuicCryptoClientConfig(crypto_test_utils::ProofVerifierForTesting()) +std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id); + std::vector<std::unique_ptr<QuicReceivedPacket>> GetFirstFlightOfPackets( const ParsedQuicVersion& version, const QuicConfig& config,
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc index da095dc..1149721 100644 --- a/quic/test_tools/quic_test_utils.cc +++ b/quic/test_tools/quic_test_utils.cc
@@ -88,9 +88,7 @@ sizeof(kStatelessResetTokenDataForTest)); } -std::string TestHostname() { - return "test.example.org"; -} +std::string TestHostname() { return "test.example.com"; } QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort);