QUIC Key Update support Handles key updates initiated remotely and also adds a QuicConnection method to initiate a key update, but this method is currently only called in tests. Protected by FLAGS_quic_reloadable_flag_quic_key_update_supported. PiperOrigin-RevId: 336385088 Change-Id: If74d032e1d34e5392312f4b619d28c9f93a95265
diff --git a/quic/core/chlo_extractor.cc b/quic/core/chlo_extractor.cc index 53c3246..33f6c0e 100644 --- a/quic/core/chlo_extractor.cc +++ b/quic/core/chlo_extractor.cc
@@ -82,6 +82,11 @@ bool IsValidStatelessResetToken(QuicUint128 token) const override; void OnAuthenticatedIetfStatelessResetPacket( const QuicIetfStatelessResetPacket& /*packet*/) override {} + void OnKeyUpdate() override; + void OnDecryptedFirstPacketInKeyPhase() override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; // CryptoFramerVisitorInterface implementation. void OnError(CryptoFramer* framer) override; @@ -310,6 +315,20 @@ return true; } +void ChloFramerVisitor::OnKeyUpdate() {} + +void ChloFramerVisitor::OnDecryptedFirstPacketInKeyPhase() {} + +std::unique_ptr<QuicDecrypter> +ChloFramerVisitor::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return nullptr; +} + +std::unique_ptr<QuicEncrypter> +ChloFramerVisitor::CreateCurrentOneRttEncrypter() { + return nullptr; +} + void ChloFramerVisitor::OnError(CryptoFramer* /*framer*/) {} void ChloFramerVisitor::OnHandshakeMessage(
diff --git a/quic/core/crypto/crypto_utils.cc b/quic/core/crypto/crypto_utils.cc index e6ec275..1868a3a 100644 --- a/quic/core/crypto/crypto_utils.cc +++ b/quic/core/crypto/crypto_utils.cc
@@ -128,6 +128,12 @@ return HkdfExpandLabel(prf, pp_secret, "quic hp", out_len); } +std::vector<uint8_t> CryptoUtils::GenerateNextKeyPhaseSecret( + const EVP_MD* prf, + const std::vector<uint8_t>& current_secret) { + return HkdfExpandLabel(prf, current_secret, "quic ku", current_secret.size()); +} + namespace { // Salt from https://tools.ietf.org/html/draft-ietf-quic-tls-27#section-5.2
diff --git a/quic/core/crypto/crypto_utils.h b/quic/core/crypto/crypto_utils.h index 543173a..2191ae1 100644 --- a/quic/core/crypto/crypto_utils.h +++ b/quic/core/crypto/crypto_utils.h
@@ -98,6 +98,11 @@ const std::vector<uint8_t>& pp_secret, size_t out_len); + // Given a secret for key phase n, return the secret for phase n+1. + static std::vector<uint8_t> GenerateNextKeyPhaseSecret( + const EVP_MD* prf, + const std::vector<uint8_t>& current_secret); + // IETF QUIC encrypts ENCRYPTION_INITIAL messages with a version-specific key // (to prevent network observers that are not aware of that QUIC version from // making decisions based on the TLS handshake). This packet protection secret
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc index c57fc6e..6145a79 100644 --- a/quic/core/http/end_to_end_test.cc +++ b/quic/core/http/end_to_end_test.cc
@@ -4821,6 +4821,176 @@ 0u); } +TEST_P(EndToEndTest, KeyUpdateInitiatedByClient) { + SetQuicReloadableFlag(quic_key_update_supported, true); + + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(0u, client_connection->GetStats().key_update_count); + + EXPECT_TRUE(client_connection->InitiateKeyUpdate()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + EXPECT_TRUE(client_connection->InitiateKeyUpdate()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(2u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ(2u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, KeyUpdateInitiatedByServer) { + SetQuicReloadableFlag(quic_key_update_supported, true); + + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(0u, client_connection->GetStats().key_update_count); + + // Use WaitUntil to ensure the server had executed the key update predicate + // before sending the Foo request, otherwise the test can be flaky if it + // receives the Foo request before executing the key update. + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_TRUE(server_connection->InitiateKeyUpdate()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_TRUE(server_connection->InitiateKeyUpdate()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(2u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ(2u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, KeyUpdateInitiatedByBoth) { + SetQuicReloadableFlag(quic_key_update_supported, true); + + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + + // Use WaitUntil to ensure the server had executed the key update predicate + // before the client sends the Foo request, otherwise the Foo request from + // the client could trigger the server key update before the server can + // initiate the key update locally. That would mean the test is no longer + // hitting the intended test state of both sides locally initiating a key + // update before receiving a packet in the new key phase from the other side. + // Additionally the test would fail since InitiateKeyUpdate() would not allow + // to do another key update yet and return false. + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_TRUE(server_connection->InitiateKeyUpdate()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_TRUE(client_connection->InitiateKeyUpdate()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_TRUE(server_connection->InitiateKeyUpdate()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + EXPECT_TRUE(client_connection->InitiateKeyUpdate()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(2u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ(2u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc index 8af0087..bb081b3 100644 --- a/quic/core/http/quic_spdy_session_test.cc +++ b/quic/core/http/quic_spdy_session_test.cc
@@ -145,6 +145,14 @@ } void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> /*application_state*/) override {} + bool KeyUpdateSupportedLocally() const override { return false; } + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } const QuicCryptoNegotiatedParameters& crypto_negotiated_params() const override { return *params_;
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc index 0abec1d..4adbaab 100644 --- a/quic/core/http/quic_spdy_stream_test.cc +++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -133,6 +133,14 @@ } void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> /*application_state*/) override {} + bool KeyUpdateSupportedLocally() const override { return true; } + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } const QuicCryptoNegotiatedParameters& crypto_negotiated_params() const override { return *params_;
diff --git a/quic/core/quic_config.cc b/quic/core/quic_config.cc index 5ea6b24..60569b3 100644 --- a/quic/core/quic_config.cc +++ b/quic/core/quic_config.cc
@@ -448,6 +448,8 @@ initial_session_flow_control_window_bytes_(kCFCW, PRESENCE_OPTIONAL), connection_migration_disabled_(kNCMR, PRESENCE_OPTIONAL), support_handshake_done_(0, PRESENCE_OPTIONAL), + key_update_supported_remotely_(false), + key_update_supported_locally_(false), alternate_server_address_ipv6_(kASAD, PRESENCE_OPTIONAL), alternate_server_address_ipv4_(kASAD, PRESENCE_OPTIONAL), stateless_reset_token_(kSRST, PRESENCE_OPTIONAL), @@ -863,6 +865,18 @@ return support_handshake_done_.HasReceivedValue(); } +void QuicConfig::SetKeyUpdateSupportedLocally() { + key_update_supported_locally_ = true; +} + +bool QuicConfig::KeyUpdateSupportedForConnection() const { + return key_update_supported_remotely_ && KeyUpdateSupportedLocally(); +} + +bool QuicConfig::KeyUpdateSupportedLocally() const { + return key_update_supported_locally_; +} + void QuicConfig::SetIPv6AlternateServerAddressToSend( const QuicSocketAddress& alternate_server_address_ipv6) { if (!alternate_server_address_ipv6.host().IsIPv6()) { @@ -1228,7 +1242,9 @@ params->google_connection_options = connection_options_.GetSendValues(); } - params->key_update_not_yet_supported = true; + if (!KeyUpdateSupportedLocally()) { + params->key_update_not_yet_supported = true; + } params->custom_parameters = custom_transport_parameters_to_send_; @@ -1336,6 +1352,9 @@ if (params.support_handshake_done) { support_handshake_done_.SetReceivedValue(1u); } + if (!is_resumption && !params.key_update_not_yet_supported) { + key_update_supported_remotely_ = true; + } active_connection_id_limit_.SetReceivedValue( params.active_connection_id_limit.value());
diff --git a/quic/core/quic_config.h b/quic/core/quic_config.h index 4132987..0ce577a 100644 --- a/quic/core/quic_config.h +++ b/quic/core/quic_config.h
@@ -387,6 +387,11 @@ bool HandshakeDoneSupported() const; bool PeerSupportsHandshakeDone() const; + // Key update support. + void SetKeyUpdateSupportedLocally(); + bool KeyUpdateSupportedForConnection() const; + bool KeyUpdateSupportedLocally() const; + // IPv6 alternate server address. void SetIPv6AlternateServerAddressToSend( const QuicSocketAddress& alternate_server_address_ipv6); @@ -580,6 +585,13 @@ // Uses the support_handshake_done transport parameter in IETF QUIC. QuicFixedUint32 support_handshake_done_; + // Whether key update is supported by the peer. Uses key_update_not_yet + // supported transport parameter in IETF QUIC. + bool key_update_supported_remotely_; + + // Whether key update is supported locally. + bool key_update_supported_locally_; + // Alternate server addresses the client could connect to. // Uses the preferred_address transport parameter in IETF QUIC. // Note that when QUIC_CRYPTO is in use, only one of the addresses is sent.
diff --git a/quic/core/quic_config_test.cc b/quic/core/quic_config_test.cc index 4622a6e..73e8d97 100644 --- a/quic/core/quic_config_test.cc +++ b/quic/core/quic_config_test.cc
@@ -55,6 +55,8 @@ EXPECT_FALSE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()); EXPECT_EQ(kMaxIncomingPacketSize, config_.GetMaxPacketSizeToSend()); EXPECT_FALSE(config_.HasReceivedMaxPacketSize()); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + EXPECT_FALSE(config_.KeyUpdateSupportedLocally()); } TEST_P(QuicConfigTest, AutoSetIetfFlowControl) { @@ -673,6 +675,79 @@ EXPECT_TRUE(config_.DisableConnectionMigration()); } +TEST_P(QuicConfigTest, KeyUpdateNotYetSupportedTransportParameterNorLocally) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + EXPECT_FALSE(config_.KeyUpdateSupportedLocally()); + TransportParameters params; + params.key_update_not_yet_supported = true; + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + EXPECT_FALSE(config_.KeyUpdateSupportedLocally()); +} + +TEST_P(QuicConfigTest, KeyUpdateNotYetSupportedTransportParameter) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + config_.SetKeyUpdateSupportedLocally(); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + EXPECT_TRUE(config_.KeyUpdateSupportedLocally()); + + TransportParameters params; + params.key_update_not_yet_supported = true; + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + EXPECT_TRUE(config_.KeyUpdateSupportedLocally()); +} + +TEST_P(QuicConfigTest, KeyUpdateSupportedRemotelyButNotLocally) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + EXPECT_FALSE(config_.KeyUpdateSupportedLocally()); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + + TransportParameters params; + params.key_update_not_yet_supported = false; + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + EXPECT_FALSE(config_.KeyUpdateSupportedLocally()); +} + +TEST_P(QuicConfigTest, KeyUpdateSupported) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + config_.SetKeyUpdateSupportedLocally(); + EXPECT_TRUE(config_.KeyUpdateSupportedLocally()); + EXPECT_FALSE(config_.KeyUpdateSupportedForConnection()); + + TransportParameters params; + params.key_update_not_yet_supported = false; + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_TRUE(config_.KeyUpdateSupportedForConnection()); + EXPECT_TRUE(config_.KeyUpdateSupportedLocally()); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index 00bfa6b..77fd2c5 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -170,6 +170,24 @@ QuicConnection* connection_; }; +class DiscardPreviousOneRttKeysAlarmDelegate : public QuicAlarm::Delegate { + public: + explicit DiscardPreviousOneRttKeysAlarmDelegate(QuicConnection* connection) + : connection_(connection) {} + DiscardPreviousOneRttKeysAlarmDelegate( + const DiscardPreviousOneRttKeysAlarmDelegate&) = delete; + DiscardPreviousOneRttKeysAlarmDelegate& operator=( + const DiscardPreviousOneRttKeysAlarmDelegate&) = delete; + + void OnAlarm() override { + DCHECK(connection_->connected()); + connection_->DiscardPreviousOneRttKeys(); + } + + private: + QuicConnection* connection_; +}; + // When the clearer goes out of scope, the coalesced packet gets cleared. class ScopedCoalescedPacketClearer { public: @@ -245,6 +263,7 @@ peer_address_(initial_peer_address), direct_peer_address_(initial_peer_address), active_effective_peer_migration_type_(NO_CHANGE), + support_key_update_for_connection_(false), last_packet_decrypted_(false), last_size_(0), current_packet_data_(nullptr), @@ -280,6 +299,9 @@ process_undecryptable_packets_alarm_(alarm_factory_->CreateAlarm( arena_.New<ProcessUndecryptablePacketsAlarmDelegate>(this), &arena_)), + discard_previous_one_rtt_keys_alarm_(alarm_factory_->CreateAlarm( + arena_.New<DiscardPreviousOneRttKeysAlarmDelegate>(this), + &arena_)), visitor_(nullptr), debug_visitor_(nullptr), packet_creator_(server_connection_id_, &framer_, random_generator_, this), @@ -554,6 +576,10 @@ if (!ValidateConfigConnectionIds(config)) { return; } + support_key_update_for_connection_ = + config.KeyUpdateSupportedForConnection(); + framer_.SetKeyUpdateSupportForConnection( + support_key_update_for_connection_); } else { SetNetworkTimeouts(config.max_time_before_crypto_handshake(), config.max_idle_time_before_crypto_handshake()); @@ -1889,6 +1915,43 @@ ConnectionCloseSource::FROM_PEER); } +void QuicConnection::OnKeyUpdate() { + DCHECK(support_key_update_for_connection_); + QUIC_DLOG(INFO) << ENDPOINT << "Key phase updated"; + + lowest_packet_sent_in_current_key_phase_.Clear(); + stats_.key_update_count++; + + // If another key update triggers while the previous + // discard_previous_one_rtt_keys_alarm_ hasn't fired yet, cancel it since the + // old keys would already be discarded. + discard_previous_one_rtt_keys_alarm_->Cancel(); +} + +void QuicConnection::OnDecryptedFirstPacketInKeyPhase() { + QUIC_DLOG(INFO) << ENDPOINT << "OnDecryptedFirstPacketInKeyPhase"; + // An endpoint SHOULD retain old read keys for no more than three times the + // PTO after having received a packet protected using the new keys. After this + // period, old read keys and their corresponding secrets SHOULD be discarded. + // + // Note that this will cause an unnecessary + // discard_previous_one_rtt_keys_alarm_ on the first packet in the 1RTT + // encryption level, but this is harmless. + discard_previous_one_rtt_keys_alarm_->Set( + clock_->ApproximateNow() + sent_packet_manager_.GetPtoDelay() * 3); +} + +std::unique_ptr<QuicDecrypter> +QuicConnection::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + QUIC_DLOG(INFO) << ENDPOINT << "AdvanceKeysAndCreateCurrentOneRttDecrypter"; + return visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr<QuicEncrypter> QuicConnection::CreateCurrentOneRttEncrypter() { + QUIC_DLOG(INFO) << ENDPOINT << "CreateCurrentOneRttEncrypter"; + return visitor_->CreateCurrentOneRttEncrypter(); +} + void QuicConnection::ClearLastFrames() { should_last_packet_instigate_acks_ = false; } @@ -2984,6 +3047,13 @@ handshake_packet_sent_ = true; } + if (packet->encryption_level == ENCRYPTION_FORWARD_SECURE && + !lowest_packet_sent_in_current_key_phase_.IsInitialized()) { + QUIC_DLOG(INFO) << ENDPOINT << "lowest_packet_sent_in_current_key_phase_ = " + << packet_number; + lowest_packet_sent_in_current_key_phase_ = packet_number; + } + if (in_flight || !retransmission_alarm_->IsSet()) { SetRetransmissionAlarm(); } @@ -3451,6 +3521,26 @@ framer_.RemoveDecrypter(level); } +void QuicConnection::DiscardPreviousOneRttKeys() { + framer_.DiscardPreviousOneRttKeys(); +} + +bool QuicConnection::IsKeyUpdateAllowed() const { + return support_key_update_for_connection_ && + GetLargestAckedPacket().IsInitialized() && + lowest_packet_sent_in_current_key_phase_.IsInitialized() && + GetLargestAckedPacket() >= lowest_packet_sent_in_current_key_phase_; +} + +bool QuicConnection::InitiateKeyUpdate() { + QUIC_DLOG(INFO) << ENDPOINT << "InitiateKeyUpdate"; + if (!IsKeyUpdateAllowed()) { + QUIC_BUG << "key update not allowed"; + return false; + } + return framer_.DoKeyUpdate(); +} + const QuicDecrypter* QuicConnection::decrypter() const { return framer_.decrypter(); } @@ -3735,6 +3825,7 @@ send_alarm_->Cancel(); mtu_discovery_alarm_->Cancel(); process_undecryptable_packets_alarm_->Cancel(); + discard_previous_one_rtt_keys_alarm_->Cancel(); blackhole_detector_.StopDetection(); idle_network_detector_.StopDetection(); }
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h index 240138a..f6496c7 100644 --- a/quic/core/quic_connection.h +++ b/quic/core/quic_connection.h
@@ -191,6 +191,15 @@ // Called when a packet of ENCRYPTION_HANDSHAKE gets sent. virtual void OnHandshakePacketSent() = 0; + + // Called to generate a decrypter for the next key phase. Each call should + // generate the key for phase n+1. + virtual std::unique_ptr<QuicDecrypter> + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called to generate an encrypter for the same key phase of the last + // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0; }; // Interface which gets callbacks from the QuicConnection at interesting @@ -628,6 +637,11 @@ bool IsValidStatelessResetToken(QuicUint128 token) const override; void OnAuthenticatedIetfStatelessResetPacket( const QuicIetfStatelessResetPacket& packet) override; + void OnKeyUpdate() override; + void OnDecryptedFirstPacketInKeyPhase() override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; // QuicPacketCreator::DelegateInterface bool ShouldGeneratePacket(HasRetransmittableData retransmittable, @@ -788,6 +802,16 @@ std::unique_ptr<QuicDecrypter> decrypter); void RemoveDecrypter(EncryptionLevel level); + // Discard keys for the previous key phase. + void DiscardPreviousOneRttKeys(); + + // Returns true if it is currently allowed to initiate a key update. + bool IsKeyUpdateAllowed() const; + + // Increment the key phase. It is a bug to call this when IsKeyUpdateAllowed() + // is false. Returns false on error. + bool InitiateKeyUpdate(); + const QuicDecrypter* decrypter() const; const QuicDecrypter* alternative_decrypter() const; @@ -1493,6 +1517,14 @@ // started. QuicPacketNumber highest_packet_sent_before_effective_peer_migration_; + // True if Key Update is supported on this connection. + bool support_key_update_for_connection_; + + // Tracks the lowest packet sent in the current key phase. Will be + // uninitialized before the first one-RTT packet has been sent or after a + // key update but before the first packet has been sent. + QuicPacketNumber lowest_packet_sent_in_current_key_phase_; + // True if the last packet has gotten far enough in the framer to be // decrypted. bool last_packet_decrypted_; @@ -1587,6 +1619,9 @@ // An alarm that fires to process undecryptable packets when new decyrption // keys are available. QuicArenaScopedPtr<QuicAlarm> process_undecryptable_packets_alarm_; + // An alarm that fires to discard keys for the previous key phase some time + // after a key update has completed. + QuicArenaScopedPtr<QuicAlarm> discard_previous_one_rtt_keys_alarm_; // Neither visitor is owned by this class. QuicConnectionVisitorInterface* visitor_; QuicConnectionDebugVisitor* debug_visitor_;
diff --git a/quic/core/quic_connection_stats.cc b/quic/core/quic_connection_stats.cc index 191918c..0cc6a73 100644 --- a/quic/core/quic_connection_stats.cc +++ b/quic/core/quic_connection_stats.cc
@@ -55,6 +55,7 @@ os << " num_ack_aggregation_epochs: " << s.num_ack_aggregation_epochs; os << " sent_legacy_version_encapsulated_packets: " << s.sent_legacy_version_encapsulated_packets; + os << " key_update_count: " << s.key_update_count; os << " }"; return os;
diff --git a/quic/core/quic_connection_stats.h b/quic/core/quic_connection_stats.h index 5a568e1..87158b3 100644 --- a/quic/core/quic_connection_stats.h +++ b/quic/core/quic_connection_stats.h
@@ -169,6 +169,11 @@ // Number of times when the connection tries to send data but gets throttled // by amplification factor. size_t num_amplification_throttling = 0; + + // Number of key phase updates that have occurred. In the case of a locally + // initiated key update, this is incremented when the keys are updated, before + // the peer has acknowledged the key update. + uint32_t key_update_count = 0; }; } // namespace quic
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc index f480f88..e40fa67 100644 --- a/quic/core/quic_connection_test.cc +++ b/quic/core/quic_connection_test.cc
@@ -421,6 +421,11 @@ QuicConnectionPeer::GetProcessUndecryptablePacketsAlarm(this)); } + TestAlarmFactory::TestAlarm* GetDiscardPreviousOneRttKeysAlarm() { + return reinterpret_cast<TestAlarmFactory::TestAlarm*>( + QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm(this)); + } + TestAlarmFactory::TestAlarm* GetBlackholeDetectorAlarm() { return reinterpret_cast<TestAlarmFactory::TestAlarm*>( QuicConnectionPeer::GetBlackholeDetectorAlarm(this)); @@ -11801,6 +11806,150 @@ ProcessCoalescedPacket({{4, frames4, ENCRYPTION_FORWARD_SECURE}}); } +TEST_P(QuicConnectionTest, InitiateKeyUpdate) { + if (!connection_.version().UsesTls()) { + return; + } + + TransportParameters params; + params.key_update_not_yet_supported = false; + QuicConfig config; + std::string error_details; + EXPECT_THAT(config.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + config.SetKeyUpdateSupportedLocally(); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().AuthenticatesHandshakeConnectionIds()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + MockFramerVisitor peer_framer_visitor_; + peer_framer_.set_visitor(&peer_framer_visitor_); + + use_tagging_decrypter(); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<TaggingEncrypter>(0x01)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(0x01)); + connection_.OnHandshakeComplete(); + + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<TaggingEncrypter>(0x01)); + + // Key update should still not be allowed, since no packet has been acked + // from the current key phase. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + // Send packet 1. + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + + // Key update should still not be allowed, even though a packet was sent in + // the current key phase it hasn't been acked yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + // Receive ack for packet 1. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _)); + QuicAckFrame frame1 = InitAckFrame(1); + ProcessAckPacket(&frame1); + + // OnDecryptedFirstPacketInKeyPhase is called even on the first key phase, + // so discard_previous_keys_alarm_ should be set now. + EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + + // Key update should now be allowed. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce( + []() { return std::make_unique<StrictTaggingDecrypter>(0x02); }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { + return std::make_unique<TaggingEncrypter>(0x02); + }); + EXPECT_TRUE(connection_.InitiateKeyUpdate()); + // discard_previous_keys_alarm_ should not be set until a packet from the new + // key phase has been received. (The alarm that was set above should be + // cleared if it hasn't fired before the next key update happened.) + EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + + // Pretend that peer accepts the key update. + EXPECT_CALL(peer_framer_visitor_, + AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce( + []() { return std::make_unique<StrictTaggingDecrypter>(0x02); }); + EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x02); }); + peer_framer_.SetKeyUpdateSupportForConnection(true); + peer_framer_.DoKeyUpdate(); + + // Another key update should not be allowed yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + // Send packet 2. + SendStreamDataToPeer(2, "bar", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(2u), last_packet); + // Receive ack for packet 2. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _)); + QuicAckFrame frame2 = InitAckFrame(2); + ProcessAckPacket(&frame2); + EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + + // Key update should be allowed again now that a packet has been acked from + // the current key phase. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce( + []() { return std::make_unique<StrictTaggingDecrypter>(0x03); }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { + return std::make_unique<TaggingEncrypter>(0x03); + }); + EXPECT_TRUE(connection_.InitiateKeyUpdate()); + + // Pretend that peer accepts the key update. + EXPECT_CALL(peer_framer_visitor_, + AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce( + []() { return std::make_unique<StrictTaggingDecrypter>(0x03); }); + EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x03); }); + peer_framer_.DoKeyUpdate(); + + // Another key update should not be allowed yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + // Send packet 3. + SendStreamDataToPeer(3, "baz", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(3u), last_packet); + + // Another key update should not be allowed yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + // Receive ack for packet 3. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _)); + QuicAckFrame frame3 = InitAckFrame(3); + ProcessAckPacket(&frame3); + EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + + // Key update should be allowed now. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce( + []() { return std::make_unique<StrictTaggingDecrypter>(0x04); }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { + return std::make_unique<TaggingEncrypter>(0x04); + }); + EXPECT_TRUE(connection_.InitiateKeyUpdate()); + EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/core/quic_crypto_client_handshaker.cc b/quic/core/quic_crypto_client_handshaker.cc index b60b026..784a451 100644 --- a/quic/core/quic_crypto_client_handshaker.cc +++ b/quic/core/quic_crypto_client_handshaker.cc
@@ -175,6 +175,24 @@ return QuicCryptoHandshaker::BufferSizeLimitForLevel(level); } +bool QuicCryptoClientHandshaker::KeyUpdateSupportedLocally() const { + return false; +} + +std::unique_ptr<QuicDecrypter> +QuicCryptoClientHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + // Key update is only defined in QUIC+TLS. + DCHECK(false); + return nullptr; +} + +std::unique_ptr<QuicEncrypter> +QuicCryptoClientHandshaker::CreateCurrentOneRttEncrypter() { + // Key update is only defined in QUIC+TLS. + DCHECK(false); + return nullptr; +} + void QuicCryptoClientHandshaker::OnConnectionClosed( QuicErrorCode /*error*/, ConnectionCloseSource /*source*/) {
diff --git a/quic/core/quic_crypto_client_handshaker.h b/quic/core/quic_crypto_client_handshaker.h index 605318e..49fa405 100644 --- a/quic/core/quic_crypto_client_handshaker.h +++ b/quic/core/quic_crypto_client_handshaker.h
@@ -51,6 +51,10 @@ CryptoMessageParser* crypto_message_parser() override; HandshakeState GetHandshakeState() const override; size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + bool KeyUpdateSupportedLocally() const override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; void OnOneRttPacketAcknowledged() override {} void OnHandshakePacketSent() override {} void OnConnectionClosed(QuicErrorCode /*error*/,
diff --git a/quic/core/quic_crypto_client_stream.cc b/quic/core/quic_crypto_client_stream.cc index 67a9a11..94a8948 100644 --- a/quic/core/quic_crypto_client_stream.cc +++ b/quic/core/quic_crypto_client_stream.cc
@@ -109,6 +109,20 @@ return handshaker_->BufferSizeLimitForLevel(level); } +bool QuicCryptoClientStream::KeyUpdateSupportedLocally() const { + return handshaker_->KeyUpdateSupportedLocally(); +} + +std::unique_ptr<QuicDecrypter> +QuicCryptoClientStream::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return handshaker_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr<QuicEncrypter> +QuicCryptoClientStream::CreateCurrentOneRttEncrypter() { + return handshaker_->CreateCurrentOneRttEncrypter(); +} + std::string QuicCryptoClientStream::chlo_hash() const { return handshaker_->chlo_hash(); }
diff --git a/quic/core/quic_crypto_client_stream.h b/quic/core/quic_crypto_client_stream.h index 1d9b04b..123bdfe 100644 --- a/quic/core/quic_crypto_client_stream.h +++ b/quic/core/quic_crypto_client_stream.h
@@ -151,6 +151,18 @@ // buffered at each encryption level. virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const = 0; + // Returns whether the implementation supports key update. + virtual bool KeyUpdateSupportedLocally() const = 0; + + // Called to generate a decrypter for the next key phase. Each call should + // generate the key for phase n+1. + virtual std::unique_ptr<QuicDecrypter> + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called to generate an encrypter for the same key phase of the last + // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0; + // Returns current handshake state. virtual HandshakeState GetHandshakeState() const = 0; @@ -228,6 +240,10 @@ void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> application_state) override; size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + bool KeyUpdateSupportedLocally() const override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; std::string chlo_hash() const;
diff --git a/quic/core/quic_crypto_server_stream.cc b/quic/core/quic_crypto_server_stream.cc index e8f99eb..9d7597c 100644 --- a/quic/core/quic_crypto_server_stream.cc +++ b/quic/core/quic_crypto_server_stream.cc
@@ -402,6 +402,24 @@ return QuicCryptoHandshaker::BufferSizeLimitForLevel(level); } +bool QuicCryptoServerStream::KeyUpdateSupportedLocally() const { + return false; +} + +std::unique_ptr<QuicDecrypter> +QuicCryptoServerStream::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + // Key update is only defined in QUIC+TLS. + DCHECK(false); + return nullptr; +} + +std::unique_ptr<QuicEncrypter> +QuicCryptoServerStream::CreateCurrentOneRttEncrypter() { + // Key update is only defined in QUIC+TLS. + DCHECK(false); + return nullptr; +} + void QuicCryptoServerStream::ProcessClientHello( QuicReferenceCountedPointer<ValidateClientHelloResultCallback::Result> result,
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h index 29d680e..c99a136 100644 --- a/quic/core/quic_crypto_server_stream.h +++ b/quic/core/quic_crypto_server_stream.h
@@ -61,6 +61,10 @@ void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> state) override; size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + bool KeyUpdateSupportedLocally() const override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; // From QuicCryptoHandshaker void OnHandshakeMessage(const CryptoHandshakeMessage& message) override;
diff --git a/quic/core/quic_crypto_stream.h b/quic/core/quic_crypto_stream.h index 21bdde6..ea15ded 100644 --- a/quic/core/quic_crypto_stream.h +++ b/quic/core/quic_crypto_stream.h
@@ -126,6 +126,18 @@ // encryption level |level|. virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const; + // Returns whether the implementation supports key update. + virtual bool KeyUpdateSupportedLocally() const = 0; + + // Called to generate a decrypter for the next key phase. Each call should + // generate the key for phase n+1. + virtual std::unique_ptr<QuicDecrypter> + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called to generate an encrypter for the same key phase of the last + // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0; + // Called to cancel retransmission of unencrypted crypto stream data. void NeuterUnencryptedStreamData();
diff --git a/quic/core/quic_crypto_stream_test.cc b/quic/core/quic_crypto_stream_test.cc index f7e2483..4fc9ac3 100644 --- a/quic/core/quic_crypto_stream_test.cc +++ b/quic/core/quic_crypto_stream_test.cc
@@ -66,6 +66,14 @@ HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; } void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> /*application_state*/) override {} + bool KeyUpdateSupportedLocally() const override { return false; } + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } private: QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters> params_;
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc index 6b9f0c1..6842211 100644 --- a/quic/core/quic_framer.cc +++ b/quic/core/quic_framer.cc
@@ -412,6 +412,8 @@ process_timestamps_(false), creation_time_(creation_time), last_timestamp_(QuicTime::Delta::Zero()), + support_key_update_for_connection_(false), + current_key_phase_bit_(false), first_sending_packet_number_(FirstSendingPacketNumber()), data_producer_(nullptr), infer_packet_header_type_from_version_(perspective == @@ -2136,7 +2138,7 @@ PacketNumberLengthToOnWireValue(header.packet_number_length)); } else { type = static_cast<uint8_t>( - FLAGS_FIXED_BIT | + FLAGS_FIXED_BIT | (current_key_phase_bit_ ? FLAGS_KEY_PHASE_BIT : 0) | PacketNumberLengthToOnWireValue(header.packet_number_length)); } return writer->WriteUInt8(type); @@ -4235,6 +4237,41 @@ decrypter_[level] = nullptr; } +void QuicFramer::SetKeyUpdateSupportForConnection(bool enabled) { + QUIC_DVLOG(1) << ENDPOINT << "SetKeyUpdateSupportForConnection: " << enabled; + support_key_update_for_connection_ = enabled; +} + +void QuicFramer::DiscardPreviousOneRttKeys() { + DCHECK(support_key_update_for_connection_); + QUIC_DVLOG(1) << ENDPOINT << "Discarding previous set of 1-RTT keys"; + previous_decrypter_ = nullptr; +} + +bool QuicFramer::DoKeyUpdate() { + DCHECK(support_key_update_for_connection_); + if (!next_decrypter_) { + // If key update is locally initiated, next decrypter might not be created + // yet. + next_decrypter_ = visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); + } + std::unique_ptr<QuicEncrypter> next_encrypter = + visitor_->CreateCurrentOneRttEncrypter(); + if (!next_decrypter_ || !next_encrypter) { + QUIC_BUG << "Failed to create next crypters"; + return false; + } + current_key_phase_bit_ = !current_key_phase_bit_; + QUIC_DLOG(INFO) << ENDPOINT << "DoKeyUpdate: new current_key_phase_bit_=" + << current_key_phase_bit_; + current_key_phase_first_received_packet_number_.Clear(); + previous_decrypter_ = std::move(decrypter_[ENCRYPTION_FORWARD_SECURE]); + decrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_decrypter_); + encrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_encrypter); + visitor_->OnKeyUpdate(); + return true; +} + const QuicDecrypter* QuicFramer::GetDecrypter(EncryptionLevel level) const { DCHECK(version_.KnowsWhichDecrypterToUse()); return decrypter_[level].get(); @@ -4616,6 +4653,9 @@ EncryptionLevel level = decrypter_level_; QuicDecrypter* decrypter = decrypter_[level].get(); QuicDecrypter* alternative_decrypter = nullptr; + bool key_phase_parsed = false; + bool key_phase; + bool attempt_key_update = false; if (version().KnowsWhichDecrypterToUse()) { if (header.form == GOOGLE_QUIC_PACKET) { QUIC_BUG << "Attempted to decrypt GOOGLE_QUIC_PACKET with a version that " @@ -4635,6 +4675,45 @@ perspective_ == Perspective::IS_CLIENT && header.nonce != nullptr) { decrypter->SetDiversificationNonce(*header.nonce); } + if (support_key_update_for_connection_ && + header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + DCHECK(version().UsesTls()); + DCHECK_EQ(level, ENCRYPTION_FORWARD_SECURE); + key_phase = (header.type_byte & FLAGS_KEY_PHASE_BIT) != 0; + key_phase_parsed = true; + QUIC_DVLOG(1) << ENDPOINT << "packet " << header.packet_number + << " received key_phase=" << key_phase + << " current_key_phase_bit_=" << current_key_phase_bit_; + if (key_phase != current_key_phase_bit_) { + if (current_key_phase_first_received_packet_number_.IsInitialized() && + header.packet_number > + current_key_phase_first_received_packet_number_) { + if (!next_decrypter_) { + next_decrypter_ = + visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); + if (!next_decrypter_) { + QUIC_BUG << "Failed to create next_decrypter"; + return false; + } + } + QUIC_DVLOG(1) << ENDPOINT << "packet " << header.packet_number + << " attempt_key_update=true"; + attempt_key_update = true; + decrypter = next_decrypter_.get(); + } else { + if (previous_decrypter_) { + QUIC_DVLOG(1) << ENDPOINT + << "trying previous_decrypter_ for packet " + << header.packet_number; + decrypter = previous_decrypter_.get(); + } else { + QUIC_DVLOG(1) << ENDPOINT << "dropping packet " + << header.packet_number << " with old key phase"; + return false; + } + } + } + } } else if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) { if (!EncryptionLevelIsValid(alternative_decrypter_level_)) { QUIC_BUG << "Attempted to decrypt with bad alternative_decrypter_level_"; @@ -4655,6 +4734,27 @@ if (success) { visitor_->OnDecryptedPacket(level); *decrypted_level = level; + if (attempt_key_update) { + if (!DoKeyUpdate()) { + set_detailed_error("Key update failed due to internal error"); + return RaiseError(QUIC_INTERNAL_ERROR); + } + DCHECK_EQ(current_key_phase_bit_, key_phase); + } + if (key_phase_parsed && + !current_key_phase_first_received_packet_number_.IsInitialized() && + key_phase == current_key_phase_bit_) { + // Set packet number for current key phase if it hasn't been initialized + // yet. This is set outside of attempt_key_update since the key update + // may have been initiated locally, and in that case we don't know yet + // which packet number from the remote side to use until we receive a + // packet with that phase. + QUIC_DVLOG(1) << ENDPOINT + << "current_key_phase_first_received_packet_number_ = " + << header.packet_number; + current_key_phase_first_received_packet_number_ = header.packet_number; + visitor_->OnDecryptedFirstPacketInKeyPhase(); + } } else if (alternative_decrypter != nullptr) { if (header.nonce != nullptr) { DCHECK_EQ(perspective_, Perspective::IS_CLIENT);
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h index e57e252..48a7877 100644 --- a/quic/core/quic_framer.h +++ b/quic/core/quic_framer.h
@@ -237,6 +237,25 @@ // Called when an IETF StreamsBlocked frame has been parsed. virtual bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) = 0; + + // Called when a Key Phase Update has been initiated. This is called for both + // locally and peer initiated key updates. If the key update was locally + // initiated, this does not indicate the peer has received the key update yet. + virtual void OnKeyUpdate() = 0; + + // Called on the first decrypted packet in each key phase (including the + // first key phase.) + virtual void OnDecryptedFirstPacketInKeyPhase() = 0; + + // Called when the framer needs to generate a decrypter for the next key + // phase. Each call should generate the key for phase n+1. + virtual std::unique_ptr<QuicDecrypter> + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called when the framer needs to generate an encrypter. The key corresponds + // to the key phase of the last decrypter returned by + // AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0; }; // Class for parsing and constructing QUIC packets. It has a @@ -519,6 +538,13 @@ std::unique_ptr<QuicDecrypter> decrypter); void RemoveDecrypter(EncryptionLevel level); + // Enables key update support. + void SetKeyUpdateSupportForConnection(bool enabled); + // Discard the decrypter for the previous key phase. + void DiscardPreviousOneRttKeys(); + // Update the key phase. + bool DoKeyUpdate(); + const QuicDecrypter* GetDecrypter(EncryptionLevel level) const; const QuicDecrypter* decrypter() const; const QuicDecrypter* alternative_decrypter() const; @@ -1059,6 +1085,23 @@ // The last timestamp received if process_timestamps_ is true. QuicTime::Delta last_timestamp_; + // Whether IETF QUIC Key Update is supported on this connection. + bool support_key_update_for_connection_; + // The value of the current key phase bit, which is toggled when the keys are + // changed. + bool current_key_phase_bit_; + // Tracks the first packet received in the current key phase. Will be + // uninitialized before the first one-RTT packet has been received or after a + // locally initiated key update but before the first packet from the peer in + // the new key phase is received. + QuicPacketNumber current_key_phase_first_received_packet_number_; + // Decrypter for the previous key phase. Will be null if in the first key + // phase or previous keys have been discarded. + std::unique_ptr<QuicDecrypter> previous_decrypter_; + // Decrypter for the next key phase. May be null if next keys haven't been + // generated yet. + std::unique_ptr<QuicDecrypter> next_decrypter_; + // If this is a framer of a connection, this is the packet number of first // sending packet. If this is a framer of a framer of dispatcher, this is the // packet number of sent packets (for those which have packet number).
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc index f35613a..77e03eb 100644 --- a/quic/core/quic_framer_test.cc +++ b/quic/core/quic_framer_test.cc
@@ -185,6 +185,31 @@ std::string ciphertext_; }; +std::unique_ptr<QuicEncryptedPacket> EncryptPacketWithTagAndPhase( + const QuicPacket& packet, + uint8_t tag, + bool phase) { + std::string packet_data = std::string(packet.AsStringPiece()); + if (phase) { + packet_data[0] |= FLAGS_KEY_PHASE_BIT; + } else { + packet_data[0] &= ~FLAGS_KEY_PHASE_BIT; + } + + TaggingEncrypter crypter(tag); + const size_t packet_size = crypter.GetCiphertextSize(packet_data.size()); + char* buffer = new char[packet_size]; + size_t buf_len = 0; + if (!crypter.EncryptPacket(0, quiche::QuicheStringPiece(), packet_data, + buffer, &buf_len, packet_size)) { + delete[] buffer; + return nullptr; + } + + return std::make_unique<QuicEncryptedPacket>(buffer, buf_len, + /*owns_buffer=*/true); +} + class TestQuicVisitor : public QuicFramerVisitorInterface { public: TestQuicVisitor() @@ -193,6 +218,9 @@ packet_count_(0), frame_count_(0), complete_packets_(0), + key_update_count_(0), + derive_next_key_count_(0), + decrypted_first_packet_in_key_phase_count_(0), accept_packet_(true), accept_public_header_(true) {} @@ -547,6 +575,21 @@ EXPECT_EQ(0u, framer_->current_received_frame_type()); } + void OnKeyUpdate() override { key_update_count_++; } + + void OnDecryptedFirstPacketInKeyPhase() override { + decrypted_first_packet_in_key_phase_count_++; + } + + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + derive_next_key_count_++; + return std::make_unique<StrictTaggingDecrypter>(derive_next_key_count_); + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return std::make_unique<TaggingEncrypter>(derive_next_key_count_); + } + void set_framer(QuicFramer* framer) { framer_ = framer; transport_version_ = framer->transport_version(); @@ -558,6 +601,9 @@ int packet_count_; int frame_count_; int complete_packets_; + int key_update_count_; + int derive_next_key_count_; + int decrypted_first_packet_in_key_phase_count_; bool accept_packet_; bool accept_public_header_; @@ -14759,6 +14805,536 @@ visitor_.ack_frames_[0]->ack_delay_time); } +TEST_P(QuicFramerTest, KeyUpdate) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=0, key=1: no key update. + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=1, key=2: key update should have + // occurred. + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed another valid packet with phase=1, key=2: no key update. + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process another key update. + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 2, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(2, visitor_.key_update_count_); + EXPECT_EQ(2, visitor_.derive_next_key_count_); + EXPECT_EQ(3, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateOldPacketAfterUpdate) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 1. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 0. (Receiving packet from previous phase + // after packet from new phase was received.) + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateOldPacketAfterDiscardPreviousOneRttKeys) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 1. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Discard keys for previous key phase. + framer_.DiscardPreviousOneRttKeys(); + + // Process packet N+1 with phase 0. (Receiving packet from previous phase + // after packet from new phase was received.) + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should not decrypt and key update count should not change. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdatePacketsOutOfOrder) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 1. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 1. (Receiving packet from new phase out of + // order.) + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateWrongKey) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=0, key=1: no key update. + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 2, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=1 but key=3, should not process and should not cause key + // update, but next decrypter key should have been created to attempt to + // decode it. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=1 but key=1, should not process and should not cause key + // update. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=0 but key=2, should not process and should not cause key + // update. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateReceivedWhenNotEnabled) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 1, true)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Received a packet with key phase updated even though framer hasn't had key + // update enabled (SetNextOneRttCrypters never called). Should fail to + // process. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0, visitor_.key_update_count_); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateLocallyInitiated) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + EXPECT_TRUE(framer_.DoKeyUpdate()); + // Key update count should be updated, but haven't received packet from peer + // with new key phase. + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 1. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted( + EncryptPacketWithTagAndPhase(*data, 1, true)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change and + // OnDecryptedFirstPacketInKeyPhase should have been called. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N-1 with phase 0. (Receiving packet from previous phase + // after packet from new phase was received.) + header.packet_number = kPacketNumber - 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 0 and key 1. This should not decrypt even + // though it's using the previous key, since the packet number is higher than + // a packet number received using the current key. + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should not decrypt and key update count should not change. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(2, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateLocallyInitiatedReceivedOldPacket) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<StrictTaggingDecrypter>(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + EXPECT_TRUE(framer_.DoKeyUpdate()); + // Key update count should be updated, but haven't received packet + // from peer with new key phase. + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. (Receiving packet from previous phase + // after locally initiated key update, but before any packet from new phase + // was received.) + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr<QuicPacket> data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + std::unique_ptr<QuicEncryptedPacket> encrypted = + EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change and + // OnDecryptedFirstPacketInKeyPhase should not have been called since the + // packet was from the previous key phase. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 1. + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change, but + // OnDecryptedFirstPacketInKeyPhase should have been called. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 0 and key 1. This should not decrypt even + // though it's using the previous key, since the packet number is higher than + // a packet number received using the current key. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should not decrypt and key update count should not change. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1, visitor_.key_update_count_); + EXPECT_EQ(2, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc index e28db59..a6c6e6c 100644 --- a/quic/core/quic_session.cc +++ b/quic/core/quic_session.cc
@@ -137,6 +137,11 @@ connection_->OnSuccessfulVersionNegotiation(); } + if (GetQuicReloadableFlag(quic_key_update_supported) && + GetMutableCryptoStream()->KeyUpdateSupportedLocally()) { + config_.SetKeyUpdateSupportedLocally(); + } + if (QuicVersionUsesCryptoFrames(transport_version())) { return; } @@ -276,6 +281,15 @@ GetMutableCryptoStream()->OnHandshakePacketSent(); } +std::unique_ptr<QuicDecrypter> +QuicSession::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return GetMutableCryptoStream()->AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr<QuicEncrypter> QuicSession::CreateCurrentOneRttEncrypter() { + return GetMutableCryptoStream()->CreateCurrentOneRttEncrypter(); +} + void QuicSession::PendingStreamOnRstStream(const QuicRstStreamFrame& frame) { DCHECK(VersionUsesHttp3(transport_version())); QuicStreamId stream_id = frame.stream_id;
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h index a3a982b..42f1b04 100644 --- a/quic/core/quic_session.h +++ b/quic/core/quic_session.h
@@ -134,6 +134,9 @@ void OnPacketDecrypted(EncryptionLevel level) override; void OnOneRttPacketAcknowledged() override; void OnHandshakePacketSent() override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; // QuicStreamFrameDataProducer WriteStreamDataResult WriteStreamData(QuicStreamId id,
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc index b2c5b1f..a11e6f0 100644 --- a/quic/core/quic_session_test.cc +++ b/quic/core/quic_session_test.cc
@@ -140,6 +140,15 @@ } void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> /*application_state*/) override {} + MOCK_METHOD(bool, KeyUpdateSupportedLocally, (), (const, override)); + MOCK_METHOD(std::unique_ptr<QuicDecrypter>, + AdvanceKeysAndCreateCurrentOneRttDecrypter, + (), + (override)); + MOCK_METHOD(std::unique_ptr<QuicEncrypter>, + CreateCurrentOneRttEncrypter, + (), + (override)); MOCK_METHOD(void, OnCanWrite, (), (override)); bool HasPendingCryptoRetransmission() const override { return false; } @@ -198,6 +207,8 @@ writev_consumes_all_data_(false), uses_pending_streams_(false), num_incoming_streams_created_(0) { + EXPECT_CALL(*GetMutableCryptoStream(), KeyUpdateSupportedLocally()) + .WillRepeatedly(Return(false)); Initialize(); this->connection()->SetEncrypter( ENCRYPTION_FORWARD_SECURE, @@ -2229,6 +2240,30 @@ ASSERT_TRUE(session_.connection()->can_receive_ack_frequency_frame()); } +TEST_P(QuicSessionTestClient, KeyUpdateFlagNotSet) { + SetQuicReloadableFlag(quic_key_update_supported, false); + EXPECT_CALL(*session_.GetMutableCryptoStream(), KeyUpdateSupportedLocally()) + .Times(0); + session_.Initialize(); + EXPECT_FALSE(session_.config()->KeyUpdateSupportedLocally()); +} + +TEST_P(QuicSessionTestClient, KeyUpdateNotSupportedLocallyAndFlagSet) { + SetQuicReloadableFlag(quic_key_update_supported, true); + EXPECT_CALL(*session_.GetMutableCryptoStream(), KeyUpdateSupportedLocally()) + .WillOnce(Return(false)); + session_.Initialize(); + EXPECT_FALSE(session_.config()->KeyUpdateSupportedLocally()); +} + +TEST_P(QuicSessionTestClient, KeyUpdateSupportedLocallyAndFlagSet) { + SetQuicReloadableFlag(quic_key_update_supported, true); + EXPECT_CALL(*session_.GetMutableCryptoStream(), KeyUpdateSupportedLocally()) + .WillOnce(Return(true)); + session_.Initialize(); + EXPECT_TRUE(session_.config()->KeyUpdateSupportedLocally()); +} + TEST_P(QuicSessionTestClient, FailedToCreateStreamIfTooCloseToIdleTimeout) { connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); EXPECT_TRUE(session_.CanOpenNextOutgoingBidirectionalStream());
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h index 13307bb..4a896a0 100644 --- a/quic/core/quic_types.h +++ b/quic/core/quic_types.h
@@ -602,8 +602,8 @@ QuicLongHeaderType type); enum QuicPacketHeaderTypeFlags : uint8_t { - // Bit 2: Reserved for experimentation for short header. - FLAGS_EXPERIMENTATION_BIT = 1 << 2, + // Bit 2: Key phase bit for IETF QUIC short header packets. + FLAGS_KEY_PHASE_BIT = 1 << 2, // Bit 3: Google QUIC Demultiplexing bit, the short header always sets this // bit to 0, allowing to distinguish Google QUIC packets from short header // packets.
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc index 745957f..145cdcd 100644 --- a/quic/core/quic_versions.cc +++ b/quic/core/quic_versions.cc
@@ -663,6 +663,7 @@ void QuicVersionInitializeSupportForIetfDraft() { // Enable necessary flags. SetQuicRestartFlag(quic_enable_zero_rtt_for_tls_v2, true); + SetQuicReloadableFlag(quic_key_update_supported, true); } void QuicEnableVersion(const ParsedQuicVersion& version) {
diff --git a/quic/core/tls_chlo_extractor.h b/quic/core/tls_chlo_extractor.h index 53e06b4..6d2f254 100644 --- a/quic/core/tls_chlo_extractor.h +++ b/quic/core/tls_chlo_extractor.h
@@ -160,6 +160,15 @@ } void OnAuthenticatedIetfStatelessResetPacket( const QuicIetfStatelessResetPacket& /*packet*/) override {} + void OnKeyUpdate() override {} + void OnDecryptedFirstPacketInKeyPhase() override {} + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } // Methods from QuicStreamSequencer::StreamInterface. void OnDataAvailable() override;
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc index c6e601e..234b4d1 100644 --- a/quic/core/tls_client_handshaker.cc +++ b/quic/core/tls_client_handshaker.cc
@@ -349,6 +349,20 @@ return TlsHandshaker::BufferSizeLimitForLevel(level); } +bool TlsClientHandshaker::KeyUpdateSupportedLocally() const { + return true; +} + +std::unique_ptr<QuicDecrypter> +TlsClientHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr<QuicEncrypter> +TlsClientHandshaker::CreateCurrentOneRttEncrypter() { + return TlsHandshaker::CreateCurrentOneRttEncrypter(); +} + void TlsClientHandshaker::OnOneRttPacketAcknowledged() { OnHandshakeConfirmed(); }
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h index 89388cc..f5a349f 100644 --- a/quic/core/tls_client_handshaker.h +++ b/quic/core/tls_client_handshaker.h
@@ -59,6 +59,10 @@ CryptoMessageParser* crypto_message_parser() override; HandshakeState GetHandshakeState() const override; size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + bool KeyUpdateSupportedLocally() const override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; void OnOneRttPacketAcknowledged() override; void OnHandshakePacketSent() override; void OnConnectionClosed(QuicErrorCode error,
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc index 362a1e8..a53dae4 100644 --- a/quic/core/tls_handshaker.cc +++ b/quic/core/tls_handshaker.cc
@@ -68,10 +68,22 @@ void TlsHandshaker::SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, const std::vector<uint8_t>& write_secret) { + QUIC_DVLOG(1) << "SetWriteSecret level=" << level; std::unique_ptr<QuicEncrypter> encrypter = QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); - CryptoUtils::InitializeCrypterSecrets(Prf(cipher), write_secret, - encrypter.get()); + const EVP_MD* prf = Prf(cipher); + CryptoUtils::SetKeyAndIV(prf, write_secret, encrypter.get()); + std::vector<uint8_t> header_protection_key = + CryptoUtils::GenerateHeaderProtectionKey(prf, write_secret, + encrypter->GetKeySize()); + encrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece( + reinterpret_cast<char*>(header_protection_key.data()), + header_protection_key.size())); + if (level == ENCRYPTION_FORWARD_SECURE) { + DCHECK(latest_write_secret_.empty()); + latest_write_secret_ = write_secret; + one_rtt_write_header_protection_key_ = header_protection_key; + } handshaker_delegate_->OnNewEncryptionKeyAvailable(level, std::move(encrypter)); } @@ -79,16 +91,73 @@ bool TlsHandshaker::SetReadSecret(EncryptionLevel level, const SSL_CIPHER* cipher, const std::vector<uint8_t>& read_secret) { + QUIC_DVLOG(1) << "SetReadSecret level=" << level; std::unique_ptr<QuicDecrypter> decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); - CryptoUtils::InitializeCrypterSecrets(Prf(cipher), read_secret, - decrypter.get()); + const EVP_MD* prf = Prf(cipher); + CryptoUtils::SetKeyAndIV(prf, read_secret, decrypter.get()); + std::vector<uint8_t> header_protection_key = + CryptoUtils::GenerateHeaderProtectionKey(prf, read_secret, + decrypter->GetKeySize()); + decrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece( + reinterpret_cast<char*>(header_protection_key.data()), + header_protection_key.size())); + if (level == ENCRYPTION_FORWARD_SECURE) { + DCHECK(latest_read_secret_.empty()); + latest_read_secret_ = read_secret; + one_rtt_read_header_protection_key_ = header_protection_key; + } return handshaker_delegate_->OnNewDecryptionKeyAvailable( level, std::move(decrypter), /*set_alternative_decrypter=*/false, /*latch_once_used=*/false); } +std::unique_ptr<QuicDecrypter> +TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + if (latest_read_secret_.empty() || latest_write_secret_.empty() || + one_rtt_read_header_protection_key_.empty() || + one_rtt_write_header_protection_key_.empty()) { + std::string error_details = "1-RTT secret(s) not set yet."; + QUIC_BUG << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details); + return nullptr; + } + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + const EVP_MD* prf = Prf(cipher); + latest_read_secret_ = + CryptoUtils::GenerateNextKeyPhaseSecret(prf, latest_read_secret_); + latest_write_secret_ = + CryptoUtils::GenerateNextKeyPhaseSecret(prf, latest_write_secret_); + + std::unique_ptr<QuicDecrypter> decrypter = + QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + CryptoUtils::SetKeyAndIV(prf, latest_read_secret_, decrypter.get()); + decrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece( + reinterpret_cast<char*>(one_rtt_read_header_protection_key_.data()), + one_rtt_read_header_protection_key_.size())); + + return decrypter; +} + +std::unique_ptr<QuicEncrypter> TlsHandshaker::CreateCurrentOneRttEncrypter() { + if (latest_write_secret_.empty() || + one_rtt_write_header_protection_key_.empty()) { + std::string error_details = "1-RTT write secret not set yet."; + QUIC_BUG << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details); + return nullptr; + } + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + std::unique_ptr<QuicEncrypter> encrypter = + QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_, encrypter.get()); + encrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece( + reinterpret_cast<char*>(one_rtt_write_header_protection_key_.data()), + one_rtt_write_header_protection_key_.size())); + return encrypter; +} + void TlsHandshaker::WriteMessage(EncryptionLevel level, absl::string_view data) { stream_->WriteCryptoData(level, data);
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h index ef3ade1..077e373 100644 --- a/quic/core/tls_handshaker.h +++ b/quic/core/tls_handshaker.h
@@ -48,6 +48,8 @@ CryptoMessageParser* crypto_message_parser() { return this; } size_t BufferSizeLimitForLevel(EncryptionLevel level) const; ssl_early_data_reason_t EarlyDataReason() const; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter(); + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter(); protected: virtual void AdvanceHandshake() = 0; @@ -104,6 +106,14 @@ QuicErrorCode parser_error_ = QUIC_NO_ERROR; std::string parser_error_detail_; + + // The most recently derived 1-RTT read and write secrets, which are updated + // on each key update. + std::vector<uint8_t> latest_read_secret_; + std::vector<uint8_t> latest_write_secret_; + // 1-RTT header protection keys, which are not changed during key update. + std::vector<uint8_t> one_rtt_read_header_protection_key_; + std::vector<uint8_t> one_rtt_write_header_protection_key_; }; } // namespace quic
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index f108a32..405e109 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -216,6 +216,20 @@ return TlsHandshaker::BufferSizeLimitForLevel(level); } +bool TlsServerHandshaker::KeyUpdateSupportedLocally() const { + return true; +} + +std::unique_ptr<QuicDecrypter> +TlsServerHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr<QuicEncrypter> +TlsServerHandshaker::CreateCurrentOneRttEncrypter() { + return TlsHandshaker::CreateCurrentOneRttEncrypter(); +} + void TlsServerHandshaker::OverrideQuicConfigDefaults(QuicConfig* /*config*/) {} void TlsServerHandshaker::AdvanceHandshake() {
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h index 79d19cb..7b28825 100644 --- a/quic/core/tls_server_handshaker.h +++ b/quic/core/tls_server_handshaker.h
@@ -66,6 +66,10 @@ void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> state) override; size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + bool KeyUpdateSupportedLocally() const override; + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override; void SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, const std::vector<uint8_t>& write_secret) override;
diff --git a/quic/test_tools/quic_connection_peer.cc b/quic/test_tools/quic_connection_peer.cc index 01da28c..1f4e0c8 100644 --- a/quic/test_tools/quic_connection_peer.cc +++ b/quic/test_tools/quic_connection_peer.cc
@@ -151,6 +151,12 @@ } // static +QuicAlarm* QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm( + QuicConnection* connection) { + return connection->discard_previous_one_rtt_keys_alarm_.get(); +} + +// static QuicPacketWriter* QuicConnectionPeer::GetWriter(QuicConnection* connection) { return connection->writer_; }
diff --git a/quic/test_tools/quic_connection_peer.h b/quic/test_tools/quic_connection_peer.h index 1375cbe..783b055 100644 --- a/quic/test_tools/quic_connection_peer.h +++ b/quic/test_tools/quic_connection_peer.h
@@ -82,6 +82,8 @@ static QuicAlarm* GetMtuDiscoveryAlarm(QuicConnection* connection); static QuicAlarm* GetProcessUndecryptablePacketsAlarm( QuicConnection* connection); + static QuicAlarm* GetDiscardPreviousOneRttKeysAlarm( + QuicConnection* connection); static QuicPacketWriter* GetWriter(QuicConnection* connection); // If |owns_writer| is true, takes ownership of |writer|.
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index 294b7db..c6a37f0 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -424,6 +424,16 @@ OnAuthenticatedIetfStatelessResetPacket, (const QuicIetfStatelessResetPacket&), (override)); + MOCK_METHOD(void, OnKeyUpdate, (), (override)); + MOCK_METHOD(void, OnDecryptedFirstPacketInKeyPhase, (), (override)); + MOCK_METHOD(std::unique_ptr<QuicDecrypter>, + AdvanceKeysAndCreateCurrentOneRttDecrypter, + (), + (override)); + MOCK_METHOD(std::unique_ptr<QuicEncrypter>, + CreateCurrentOneRttEncrypter, + (), + (override)); }; class NoOpFramerVisitor : public QuicFramerVisitorInterface { @@ -483,6 +493,15 @@ bool IsValidStatelessResetToken(QuicUint128 token) const override; void OnAuthenticatedIetfStatelessResetPacket( const QuicIetfStatelessResetPacket& /*packet*/) override {} + void OnKeyUpdate() override {} + void OnDecryptedFirstPacketInKeyPhase() override {} + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } }; class MockQuicConnectionVisitor : public QuicConnectionVisitorInterface { @@ -558,6 +577,14 @@ MOCK_METHOD(void, OnPacketDecrypted, (EncryptionLevel), (override)); MOCK_METHOD(void, OnOneRttPacketAcknowledged, (), (override)); MOCK_METHOD(void, OnHandshakePacketSent, (), (override)); + MOCK_METHOD(std::unique_ptr<QuicDecrypter>, + AdvanceKeysAndCreateCurrentOneRttDecrypter, + (), + (override)); + MOCK_METHOD(std::unique_ptr<QuicEncrypter>, + CreateCurrentOneRttEncrypter, + (), + (override)); }; class MockQuicConnectionHelper : public QuicConnectionHelperInterface { @@ -865,6 +892,14 @@ HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; } void SetServerApplicationStateForResumption( std::unique_ptr<ApplicationState> /*application_state*/) override {} + bool KeyUpdateSupportedLocally() const override { return false; } + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } private: QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters> params_;
diff --git a/quic/test_tools/simple_quic_framer.cc b/quic/test_tools/simple_quic_framer.cc index 2b974ee..4a77331 100644 --- a/quic/test_tools/simple_quic_framer.cc +++ b/quic/test_tools/simple_quic_framer.cc
@@ -220,6 +220,16 @@ std::make_unique<QuicIetfStatelessResetPacket>(packet); } + void OnKeyUpdate() override {} + void OnDecryptedFirstPacketInKeyPhase() override {} + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } + const QuicPacketHeader& header() const { return header_; } const std::vector<QuicAckFrame>& ack_frames() const { return ack_frames_; } const std::vector<QuicConnectionCloseFrame>& connection_close_frames() const {
diff --git a/quic/test_tools/simulator/quic_endpoint.h b/quic/test_tools/simulator/quic_endpoint.h index 39c4855..faf2bfd 100644 --- a/quic/test_tools/simulator/quic_endpoint.h +++ b/quic/test_tools/simulator/quic_endpoint.h
@@ -93,6 +93,13 @@ void OnPacketDecrypted(EncryptionLevel /*level*/) override {} void OnOneRttPacketAcknowledged() override {} void OnHandshakePacketSent() override {} + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + return nullptr; + } // End QuicConnectionVisitorInterface implementation.
diff --git a/quic/tools/quic_client_interop_test_bin.cc b/quic/tools/quic_client_interop_test_bin.cc index 80dfd3b..15871d8 100644 --- a/quic/tools/quic_client_interop_test_bin.cc +++ b/quic/tools/quic_client_interop_test_bin.cc
@@ -52,6 +52,8 @@ // Second row of features (anything else protocol-related) // We switched to a different port and the server migrated to it. kRebinding, + // One endpoint can update keys and its peer responds correctly. + kKeyUpdate, // Third row of features (H3 tests) // An H3 transaction succeeded. @@ -81,6 +83,8 @@ return 'Q'; case Feature::kRebinding: return 'B'; + case Feature::kKeyUpdate: + return 'U'; case Feature::kHttp3: return '3'; case Feature::kDynamicEntryReferenced: @@ -107,7 +111,8 @@ ParsedQuicVersion version, bool test_version_negotiation, bool attempt_rebind, - bool attempt_multi_packet_chlo); + bool attempt_multi_packet_chlo, + bool attempt_key_update); // Constructs a SpdyHeaderBlock containing the pseudo-headers needed to make a // GET request to "/" on the hostname |authority|. @@ -191,7 +196,8 @@ ParsedQuicVersion version, bool test_version_negotiation, bool attempt_rebind, - bool attempt_multi_packet_chlo) { + bool attempt_multi_packet_chlo, + bool attempt_key_update) { ParsedQuicVersionVector versions = {version}; if (test_version_negotiation) { versions.insert(versions.begin(), QuicVersionReservedForNegotiation()); @@ -238,7 +244,7 @@ // Failed to negotiate version, retry without version negotiation. AttemptRequest(addr, authority, server_id, version, /*test_version_negotiation=*/false, attempt_rebind, - attempt_multi_packet_chlo); + attempt_multi_packet_chlo, attempt_key_update); return; } if (!client->session()->OneRttKeysAvailable()) { @@ -246,7 +252,7 @@ // Failed to handshake with multi-packet client hello, retry without it. AttemptRequest(addr, authority, server_id, version, test_version_negotiation, attempt_rebind, - /*attempt_multi_packet_chlo=*/false); + /*attempt_multi_packet_chlo=*/false, attempt_key_update); return; } return; @@ -278,7 +284,7 @@ // Rebinding does not work, retry without attempting it. AttemptRequest(addr, authority, server_id, version, test_version_negotiation, /*attempt_rebind=*/false, - attempt_multi_packet_chlo); + attempt_multi_packet_chlo, attempt_key_update); return; } InsertFeature(Feature::kRebinding); @@ -290,6 +296,27 @@ QUIC_LOG(ERROR) << "Failed to change ephemeral port"; } } + + if (attempt_key_update) { + if (connection->IsKeyUpdateAllowed()) { + if (connection->InitiateKeyUpdate()) { + client->SendRequestAndWaitForResponse(header_block, "", /*fin=*/true); + if (!client->connected()) { + // Key update does not work, retry without attempting it. + AttemptRequest(addr, authority, server_id, version, + test_version_negotiation, attempt_rebind, + attempt_multi_packet_chlo, + /*attempt_key_update=*/false); + return; + } + InsertFeature(Feature::kKeyUpdate); + } else { + QUIC_LOG(ERROR) << "Failed to initiate key update"; + } + } else { + QUIC_LOG(ERROR) << "Key update not allowed"; + } + } } if (connection->connected()) { @@ -368,7 +395,8 @@ runner.AttemptRequest(addr, authority, server_id, version, /*test_version_negotiation=*/true, /*attempt_rebind=*/true, - /*attempt_multi_packet_chlo=*/true); + /*attempt_multi_packet_chlo=*/true, + /*attempt_key_update=*/true); return runner.features(); }
diff --git a/quic/tools/quic_packet_printer_bin.cc b/quic/tools/quic_packet_printer_bin.cc index a19646f..e64e1ac 100644 --- a/quic/tools/quic_packet_printer_bin.cc +++ b/quic/tools/quic_packet_printer_bin.cc
@@ -220,6 +220,19 @@ const QuicIetfStatelessResetPacket& /*packet*/) override { std::cerr << "OnAuthenticatedIetfStatelessResetPacket\n"; } + void OnKeyUpdate() override { std::cerr << "OnKeyUpdate\n"; } + void OnDecryptedFirstPacketInKeyPhase() override { + std::cerr << "OnDecryptedFirstPacketInKeyPhase\n"; + } + std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + std::cerr << "AdvanceKeysAndCreateCurrentOneRttDecrypter\n"; + return nullptr; + } + std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override { + std::cerr << "CreateCurrentOneRttEncrypter\n"; + return nullptr; + } private: QuicFramer* framer_; // Unowned.