QUIC Key Update support
Handles key updates initiated remotely and also adds a QuicConnection method to initiate a key update, but this method is currently only called in tests.
Protected by FLAGS_quic_reloadable_flag_quic_key_update_supported.
PiperOrigin-RevId: 336385088
Change-Id: If74d032e1d34e5392312f4b619d28c9f93a95265
diff --git a/quic/core/chlo_extractor.cc b/quic/core/chlo_extractor.cc
index 53c3246..33f6c0e 100644
--- a/quic/core/chlo_extractor.cc
+++ b/quic/core/chlo_extractor.cc
@@ -82,6 +82,11 @@
bool IsValidStatelessResetToken(QuicUint128 token) const override;
void OnAuthenticatedIetfStatelessResetPacket(
const QuicIetfStatelessResetPacket& /*packet*/) override {}
+ void OnKeyUpdate() override;
+ void OnDecryptedFirstPacketInKeyPhase() override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
// CryptoFramerVisitorInterface implementation.
void OnError(CryptoFramer* framer) override;
@@ -310,6 +315,20 @@
return true;
}
+void ChloFramerVisitor::OnKeyUpdate() {}
+
+void ChloFramerVisitor::OnDecryptedFirstPacketInKeyPhase() {}
+
+std::unique_ptr<QuicDecrypter>
+ChloFramerVisitor::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ return nullptr;
+}
+
+std::unique_ptr<QuicEncrypter>
+ChloFramerVisitor::CreateCurrentOneRttEncrypter() {
+ return nullptr;
+}
+
void ChloFramerVisitor::OnError(CryptoFramer* /*framer*/) {}
void ChloFramerVisitor::OnHandshakeMessage(
diff --git a/quic/core/crypto/crypto_utils.cc b/quic/core/crypto/crypto_utils.cc
index e6ec275..1868a3a 100644
--- a/quic/core/crypto/crypto_utils.cc
+++ b/quic/core/crypto/crypto_utils.cc
@@ -128,6 +128,12 @@
return HkdfExpandLabel(prf, pp_secret, "quic hp", out_len);
}
+std::vector<uint8_t> CryptoUtils::GenerateNextKeyPhaseSecret(
+ const EVP_MD* prf,
+ const std::vector<uint8_t>& current_secret) {
+ return HkdfExpandLabel(prf, current_secret, "quic ku", current_secret.size());
+}
+
namespace {
// Salt from https://tools.ietf.org/html/draft-ietf-quic-tls-27#section-5.2
diff --git a/quic/core/crypto/crypto_utils.h b/quic/core/crypto/crypto_utils.h
index 543173a..2191ae1 100644
--- a/quic/core/crypto/crypto_utils.h
+++ b/quic/core/crypto/crypto_utils.h
@@ -98,6 +98,11 @@
const std::vector<uint8_t>& pp_secret,
size_t out_len);
+ // Given a secret for key phase n, return the secret for phase n+1.
+ static std::vector<uint8_t> GenerateNextKeyPhaseSecret(
+ const EVP_MD* prf,
+ const std::vector<uint8_t>& current_secret);
+
// IETF QUIC encrypts ENCRYPTION_INITIAL messages with a version-specific key
// (to prevent network observers that are not aware of that QUIC version from
// making decisions based on the TLS handshake). This packet protection secret
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index c57fc6e..6145a79 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -4821,6 +4821,176 @@
0u);
}
+TEST_P(EndToEndTest, KeyUpdateInitiatedByClient) {
+ SetQuicReloadableFlag(quic_key_update_supported, true);
+
+ if (!version_.UsesTls()) {
+ // Key Update is only supported in TLS handshake.
+ ASSERT_TRUE(Initialize());
+ return;
+ }
+
+ ASSERT_TRUE(Initialize());
+
+ SendSynchronousFooRequestAndCheckResponse();
+ QuicConnection* client_connection = GetClientConnection();
+ ASSERT_TRUE(client_connection);
+ EXPECT_EQ(0u, client_connection->GetStats().key_update_count);
+
+ EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+ EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(2u, client_connection->GetStats().key_update_count);
+
+ server_thread_->Pause();
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection) {
+ QuicConnectionStats server_stats = server_connection->GetStats();
+ EXPECT_EQ(2u, server_stats.key_update_count);
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ server_thread_->Resume();
+}
+
+TEST_P(EndToEndTest, KeyUpdateInitiatedByServer) {
+ SetQuicReloadableFlag(quic_key_update_supported, true);
+
+ if (!version_.UsesTls()) {
+ // Key Update is only supported in TLS handshake.
+ ASSERT_TRUE(Initialize());
+ return;
+ }
+
+ ASSERT_TRUE(Initialize());
+
+ SendSynchronousFooRequestAndCheckResponse();
+ QuicConnection* client_connection = GetClientConnection();
+ ASSERT_TRUE(client_connection);
+ EXPECT_EQ(0u, client_connection->GetStats().key_update_count);
+
+ // Use WaitUntil to ensure the server had executed the key update predicate
+ // before sending the Foo request, otherwise the test can be flaky if it
+ // receives the Foo request before executing the key update.
+ server_thread_->WaitUntil(
+ [this]() {
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection != nullptr) {
+ EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ return true;
+ },
+ QuicTime::Delta::FromSeconds(5));
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+ server_thread_->WaitUntil(
+ [this]() {
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection != nullptr) {
+ EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ return true;
+ },
+ QuicTime::Delta::FromSeconds(5));
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(2u, client_connection->GetStats().key_update_count);
+
+ server_thread_->Pause();
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection) {
+ QuicConnectionStats server_stats = server_connection->GetStats();
+ EXPECT_EQ(2u, server_stats.key_update_count);
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ server_thread_->Resume();
+}
+
+TEST_P(EndToEndTest, KeyUpdateInitiatedByBoth) {
+ SetQuicReloadableFlag(quic_key_update_supported, true);
+
+ if (!version_.UsesTls()) {
+ // Key Update is only supported in TLS handshake.
+ ASSERT_TRUE(Initialize());
+ return;
+ }
+
+ ASSERT_TRUE(Initialize());
+
+ SendSynchronousFooRequestAndCheckResponse();
+
+ // Use WaitUntil to ensure the server had executed the key update predicate
+ // before the client sends the Foo request, otherwise the Foo request from
+ // the client could trigger the server key update before the server can
+ // initiate the key update locally. That would mean the test is no longer
+ // hitting the intended test state of both sides locally initiating a key
+ // update before receiving a packet in the new key phase from the other side.
+ // Additionally the test would fail since InitiateKeyUpdate() would not allow
+ // to do another key update yet and return false.
+ server_thread_->WaitUntil(
+ [this]() {
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection != nullptr) {
+ EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ return true;
+ },
+ QuicTime::Delta::FromSeconds(5));
+ QuicConnection* client_connection = GetClientConnection();
+ ASSERT_TRUE(client_connection);
+ EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(1u, client_connection->GetStats().key_update_count);
+
+ server_thread_->WaitUntil(
+ [this]() {
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection != nullptr) {
+ EXPECT_TRUE(server_connection->InitiateKeyUpdate());
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ return true;
+ },
+ QuicTime::Delta::FromSeconds(5));
+ EXPECT_TRUE(client_connection->InitiateKeyUpdate());
+
+ SendSynchronousFooRequestAndCheckResponse();
+ EXPECT_EQ(2u, client_connection->GetStats().key_update_count);
+
+ server_thread_->Pause();
+ QuicConnection* server_connection = GetServerConnection();
+ if (server_connection) {
+ QuicConnectionStats server_stats = server_connection->GetStats();
+ EXPECT_EQ(2u, server_stats.key_update_count);
+ } else {
+ ADD_FAILURE() << "Missing server connection";
+ }
+ server_thread_->Resume();
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 8af0087..bb081b3 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -145,6 +145,14 @@
}
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> /*application_state*/) override {}
+ bool KeyUpdateSupportedLocally() const override { return false; }
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
const override {
return *params_;
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc
index 0abec1d..4adbaab 100644
--- a/quic/core/http/quic_spdy_stream_test.cc
+++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -133,6 +133,14 @@
}
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> /*application_state*/) override {}
+ bool KeyUpdateSupportedLocally() const override { return true; }
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
const override {
return *params_;
diff --git a/quic/core/quic_config.cc b/quic/core/quic_config.cc
index 5ea6b24..60569b3 100644
--- a/quic/core/quic_config.cc
+++ b/quic/core/quic_config.cc
@@ -448,6 +448,8 @@
initial_session_flow_control_window_bytes_(kCFCW, PRESENCE_OPTIONAL),
connection_migration_disabled_(kNCMR, PRESENCE_OPTIONAL),
support_handshake_done_(0, PRESENCE_OPTIONAL),
+ key_update_supported_remotely_(false),
+ key_update_supported_locally_(false),
alternate_server_address_ipv6_(kASAD, PRESENCE_OPTIONAL),
alternate_server_address_ipv4_(kASAD, PRESENCE_OPTIONAL),
stateless_reset_token_(kSRST, PRESENCE_OPTIONAL),
@@ -863,6 +865,18 @@
return support_handshake_done_.HasReceivedValue();
}
+void QuicConfig::SetKeyUpdateSupportedLocally() {
+ key_update_supported_locally_ = true;
+}
+
+bool QuicConfig::KeyUpdateSupportedForConnection() const {
+ return key_update_supported_remotely_ && KeyUpdateSupportedLocally();
+}
+
+bool QuicConfig::KeyUpdateSupportedLocally() const {
+ return key_update_supported_locally_;
+}
+
void QuicConfig::SetIPv6AlternateServerAddressToSend(
const QuicSocketAddress& alternate_server_address_ipv6) {
if (!alternate_server_address_ipv6.host().IsIPv6()) {
@@ -1228,7 +1242,9 @@
params->google_connection_options = connection_options_.GetSendValues();
}
- params->key_update_not_yet_supported = true;
+ if (!KeyUpdateSupportedLocally()) {
+ params->key_update_not_yet_supported = true;
+ }
params->custom_parameters = custom_transport_parameters_to_send_;
@@ -1336,6 +1352,9 @@
if (params.support_handshake_done) {
support_handshake_done_.SetReceivedValue(1u);
}
+ if (!is_resumption && !params.key_update_not_yet_supported) {
+ key_update_supported_remotely_ = true;
+ }
active_connection_id_limit_.SetReceivedValue(
params.active_connection_id_limit.value());
diff --git a/quic/core/quic_config.h b/quic/core/quic_config.h
index 4132987..0ce577a 100644
--- a/quic/core/quic_config.h
+++ b/quic/core/quic_config.h
@@ -387,6 +387,11 @@
bool HandshakeDoneSupported() const;
bool PeerSupportsHandshakeDone() const;
+ // Key update support.
+ void SetKeyUpdateSupportedLocally();
+ bool KeyUpdateSupportedForConnection() const;
+ bool KeyUpdateSupportedLocally() const;
+
// IPv6 alternate server address.
void SetIPv6AlternateServerAddressToSend(
const QuicSocketAddress& alternate_server_address_ipv6);
@@ -580,6 +585,13 @@
// Uses the support_handshake_done transport parameter in IETF QUIC.
QuicFixedUint32 support_handshake_done_;
+ // Whether key update is supported by the peer. Uses key_update_not_yet
+ // supported transport parameter in IETF QUIC.
+ bool key_update_supported_remotely_;
+
+ // Whether key update is supported locally.
+ bool key_update_supported_locally_;
+
// Alternate server addresses the client could connect to.
// Uses the preferred_address transport parameter in IETF QUIC.
// Note that when QUIC_CRYPTO is in use, only one of the addresses is sent.
diff --git a/quic/core/quic_config_test.cc b/quic/core/quic_config_test.cc
index 4622a6e..73e8d97 100644
--- a/quic/core/quic_config_test.cc
+++ b/quic/core/quic_config_test.cc
@@ -55,6 +55,8 @@
EXPECT_FALSE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional());
EXPECT_EQ(kMaxIncomingPacketSize, config_.GetMaxPacketSizeToSend());
EXPECT_FALSE(config_.HasReceivedMaxPacketSize());
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_FALSE(config_.KeyUpdateSupportedLocally());
}
TEST_P(QuicConfigTest, AutoSetIetfFlowControl) {
@@ -673,6 +675,79 @@
EXPECT_TRUE(config_.DisableConnectionMigration());
}
+TEST_P(QuicConfigTest, KeyUpdateNotYetSupportedTransportParameterNorLocally) {
+ if (!version_.UsesTls()) {
+ // TransportParameters are only used for QUIC+TLS.
+ return;
+ }
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_FALSE(config_.KeyUpdateSupportedLocally());
+ TransportParameters params;
+ params.key_update_not_yet_supported = true;
+ std::string error_details;
+ EXPECT_THAT(config_.ProcessTransportParameters(
+ params, /* is_resumption = */ false, &error_details),
+ IsQuicNoError());
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_FALSE(config_.KeyUpdateSupportedLocally());
+}
+
+TEST_P(QuicConfigTest, KeyUpdateNotYetSupportedTransportParameter) {
+ if (!version_.UsesTls()) {
+ // TransportParameters are only used for QUIC+TLS.
+ return;
+ }
+ config_.SetKeyUpdateSupportedLocally();
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_TRUE(config_.KeyUpdateSupportedLocally());
+
+ TransportParameters params;
+ params.key_update_not_yet_supported = true;
+ std::string error_details;
+ EXPECT_THAT(config_.ProcessTransportParameters(
+ params, /* is_resumption = */ false, &error_details),
+ IsQuicNoError());
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_TRUE(config_.KeyUpdateSupportedLocally());
+}
+
+TEST_P(QuicConfigTest, KeyUpdateSupportedRemotelyButNotLocally) {
+ if (!version_.UsesTls()) {
+ // TransportParameters are only used for QUIC+TLS.
+ return;
+ }
+ EXPECT_FALSE(config_.KeyUpdateSupportedLocally());
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+
+ TransportParameters params;
+ params.key_update_not_yet_supported = false;
+ std::string error_details;
+ EXPECT_THAT(config_.ProcessTransportParameters(
+ params, /* is_resumption = */ false, &error_details),
+ IsQuicNoError());
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_FALSE(config_.KeyUpdateSupportedLocally());
+}
+
+TEST_P(QuicConfigTest, KeyUpdateSupported) {
+ if (!version_.UsesTls()) {
+ // TransportParameters are only used for QUIC+TLS.
+ return;
+ }
+ config_.SetKeyUpdateSupportedLocally();
+ EXPECT_TRUE(config_.KeyUpdateSupportedLocally());
+ EXPECT_FALSE(config_.KeyUpdateSupportedForConnection());
+
+ TransportParameters params;
+ params.key_update_not_yet_supported = false;
+ std::string error_details;
+ EXPECT_THAT(config_.ProcessTransportParameters(
+ params, /* is_resumption = */ false, &error_details),
+ IsQuicNoError());
+ EXPECT_TRUE(config_.KeyUpdateSupportedForConnection());
+ EXPECT_TRUE(config_.KeyUpdateSupportedLocally());
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 00bfa6b..77fd2c5 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -170,6 +170,24 @@
QuicConnection* connection_;
};
+class DiscardPreviousOneRttKeysAlarmDelegate : public QuicAlarm::Delegate {
+ public:
+ explicit DiscardPreviousOneRttKeysAlarmDelegate(QuicConnection* connection)
+ : connection_(connection) {}
+ DiscardPreviousOneRttKeysAlarmDelegate(
+ const DiscardPreviousOneRttKeysAlarmDelegate&) = delete;
+ DiscardPreviousOneRttKeysAlarmDelegate& operator=(
+ const DiscardPreviousOneRttKeysAlarmDelegate&) = delete;
+
+ void OnAlarm() override {
+ DCHECK(connection_->connected());
+ connection_->DiscardPreviousOneRttKeys();
+ }
+
+ private:
+ QuicConnection* connection_;
+};
+
// When the clearer goes out of scope, the coalesced packet gets cleared.
class ScopedCoalescedPacketClearer {
public:
@@ -245,6 +263,7 @@
peer_address_(initial_peer_address),
direct_peer_address_(initial_peer_address),
active_effective_peer_migration_type_(NO_CHANGE),
+ support_key_update_for_connection_(false),
last_packet_decrypted_(false),
last_size_(0),
current_packet_data_(nullptr),
@@ -280,6 +299,9 @@
process_undecryptable_packets_alarm_(alarm_factory_->CreateAlarm(
arena_.New<ProcessUndecryptablePacketsAlarmDelegate>(this),
&arena_)),
+ discard_previous_one_rtt_keys_alarm_(alarm_factory_->CreateAlarm(
+ arena_.New<DiscardPreviousOneRttKeysAlarmDelegate>(this),
+ &arena_)),
visitor_(nullptr),
debug_visitor_(nullptr),
packet_creator_(server_connection_id_, &framer_, random_generator_, this),
@@ -554,6 +576,10 @@
if (!ValidateConfigConnectionIds(config)) {
return;
}
+ support_key_update_for_connection_ =
+ config.KeyUpdateSupportedForConnection();
+ framer_.SetKeyUpdateSupportForConnection(
+ support_key_update_for_connection_);
} else {
SetNetworkTimeouts(config.max_time_before_crypto_handshake(),
config.max_idle_time_before_crypto_handshake());
@@ -1889,6 +1915,43 @@
ConnectionCloseSource::FROM_PEER);
}
+void QuicConnection::OnKeyUpdate() {
+ DCHECK(support_key_update_for_connection_);
+ QUIC_DLOG(INFO) << ENDPOINT << "Key phase updated";
+
+ lowest_packet_sent_in_current_key_phase_.Clear();
+ stats_.key_update_count++;
+
+ // If another key update triggers while the previous
+ // discard_previous_one_rtt_keys_alarm_ hasn't fired yet, cancel it since the
+ // old keys would already be discarded.
+ discard_previous_one_rtt_keys_alarm_->Cancel();
+}
+
+void QuicConnection::OnDecryptedFirstPacketInKeyPhase() {
+ QUIC_DLOG(INFO) << ENDPOINT << "OnDecryptedFirstPacketInKeyPhase";
+ // An endpoint SHOULD retain old read keys for no more than three times the
+ // PTO after having received a packet protected using the new keys. After this
+ // period, old read keys and their corresponding secrets SHOULD be discarded.
+ //
+ // Note that this will cause an unnecessary
+ // discard_previous_one_rtt_keys_alarm_ on the first packet in the 1RTT
+ // encryption level, but this is harmless.
+ discard_previous_one_rtt_keys_alarm_->Set(
+ clock_->ApproximateNow() + sent_packet_manager_.GetPtoDelay() * 3);
+}
+
+std::unique_ptr<QuicDecrypter>
+QuicConnection::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ QUIC_DLOG(INFO) << ENDPOINT << "AdvanceKeysAndCreateCurrentOneRttDecrypter";
+ return visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter();
+}
+
+std::unique_ptr<QuicEncrypter> QuicConnection::CreateCurrentOneRttEncrypter() {
+ QUIC_DLOG(INFO) << ENDPOINT << "CreateCurrentOneRttEncrypter";
+ return visitor_->CreateCurrentOneRttEncrypter();
+}
+
void QuicConnection::ClearLastFrames() {
should_last_packet_instigate_acks_ = false;
}
@@ -2984,6 +3047,13 @@
handshake_packet_sent_ = true;
}
+ if (packet->encryption_level == ENCRYPTION_FORWARD_SECURE &&
+ !lowest_packet_sent_in_current_key_phase_.IsInitialized()) {
+ QUIC_DLOG(INFO) << ENDPOINT << "lowest_packet_sent_in_current_key_phase_ = "
+ << packet_number;
+ lowest_packet_sent_in_current_key_phase_ = packet_number;
+ }
+
if (in_flight || !retransmission_alarm_->IsSet()) {
SetRetransmissionAlarm();
}
@@ -3451,6 +3521,26 @@
framer_.RemoveDecrypter(level);
}
+void QuicConnection::DiscardPreviousOneRttKeys() {
+ framer_.DiscardPreviousOneRttKeys();
+}
+
+bool QuicConnection::IsKeyUpdateAllowed() const {
+ return support_key_update_for_connection_ &&
+ GetLargestAckedPacket().IsInitialized() &&
+ lowest_packet_sent_in_current_key_phase_.IsInitialized() &&
+ GetLargestAckedPacket() >= lowest_packet_sent_in_current_key_phase_;
+}
+
+bool QuicConnection::InitiateKeyUpdate() {
+ QUIC_DLOG(INFO) << ENDPOINT << "InitiateKeyUpdate";
+ if (!IsKeyUpdateAllowed()) {
+ QUIC_BUG << "key update not allowed";
+ return false;
+ }
+ return framer_.DoKeyUpdate();
+}
+
const QuicDecrypter* QuicConnection::decrypter() const {
return framer_.decrypter();
}
@@ -3735,6 +3825,7 @@
send_alarm_->Cancel();
mtu_discovery_alarm_->Cancel();
process_undecryptable_packets_alarm_->Cancel();
+ discard_previous_one_rtt_keys_alarm_->Cancel();
blackhole_detector_.StopDetection();
idle_network_detector_.StopDetection();
}
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 240138a..f6496c7 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -191,6 +191,15 @@
// Called when a packet of ENCRYPTION_HANDSHAKE gets sent.
virtual void OnHandshakePacketSent() = 0;
+
+ // Called to generate a decrypter for the next key phase. Each call should
+ // generate the key for phase n+1.
+ virtual std::unique_ptr<QuicDecrypter>
+ AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0;
+
+ // Called to generate an encrypter for the same key phase of the last
+ // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter().
+ virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0;
};
// Interface which gets callbacks from the QuicConnection at interesting
@@ -628,6 +637,11 @@
bool IsValidStatelessResetToken(QuicUint128 token) const override;
void OnAuthenticatedIetfStatelessResetPacket(
const QuicIetfStatelessResetPacket& packet) override;
+ void OnKeyUpdate() override;
+ void OnDecryptedFirstPacketInKeyPhase() override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
// QuicPacketCreator::DelegateInterface
bool ShouldGeneratePacket(HasRetransmittableData retransmittable,
@@ -788,6 +802,16 @@
std::unique_ptr<QuicDecrypter> decrypter);
void RemoveDecrypter(EncryptionLevel level);
+ // Discard keys for the previous key phase.
+ void DiscardPreviousOneRttKeys();
+
+ // Returns true if it is currently allowed to initiate a key update.
+ bool IsKeyUpdateAllowed() const;
+
+ // Increment the key phase. It is a bug to call this when IsKeyUpdateAllowed()
+ // is false. Returns false on error.
+ bool InitiateKeyUpdate();
+
const QuicDecrypter* decrypter() const;
const QuicDecrypter* alternative_decrypter() const;
@@ -1493,6 +1517,14 @@
// started.
QuicPacketNumber highest_packet_sent_before_effective_peer_migration_;
+ // True if Key Update is supported on this connection.
+ bool support_key_update_for_connection_;
+
+ // Tracks the lowest packet sent in the current key phase. Will be
+ // uninitialized before the first one-RTT packet has been sent or after a
+ // key update but before the first packet has been sent.
+ QuicPacketNumber lowest_packet_sent_in_current_key_phase_;
+
// True if the last packet has gotten far enough in the framer to be
// decrypted.
bool last_packet_decrypted_;
@@ -1587,6 +1619,9 @@
// An alarm that fires to process undecryptable packets when new decyrption
// keys are available.
QuicArenaScopedPtr<QuicAlarm> process_undecryptable_packets_alarm_;
+ // An alarm that fires to discard keys for the previous key phase some time
+ // after a key update has completed.
+ QuicArenaScopedPtr<QuicAlarm> discard_previous_one_rtt_keys_alarm_;
// Neither visitor is owned by this class.
QuicConnectionVisitorInterface* visitor_;
QuicConnectionDebugVisitor* debug_visitor_;
diff --git a/quic/core/quic_connection_stats.cc b/quic/core/quic_connection_stats.cc
index 191918c..0cc6a73 100644
--- a/quic/core/quic_connection_stats.cc
+++ b/quic/core/quic_connection_stats.cc
@@ -55,6 +55,7 @@
os << " num_ack_aggregation_epochs: " << s.num_ack_aggregation_epochs;
os << " sent_legacy_version_encapsulated_packets: "
<< s.sent_legacy_version_encapsulated_packets;
+ os << " key_update_count: " << s.key_update_count;
os << " }";
return os;
diff --git a/quic/core/quic_connection_stats.h b/quic/core/quic_connection_stats.h
index 5a568e1..87158b3 100644
--- a/quic/core/quic_connection_stats.h
+++ b/quic/core/quic_connection_stats.h
@@ -169,6 +169,11 @@
// Number of times when the connection tries to send data but gets throttled
// by amplification factor.
size_t num_amplification_throttling = 0;
+
+ // Number of key phase updates that have occurred. In the case of a locally
+ // initiated key update, this is incremented when the keys are updated, before
+ // the peer has acknowledged the key update.
+ uint32_t key_update_count = 0;
};
} // namespace quic
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index f480f88..e40fa67 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -421,6 +421,11 @@
QuicConnectionPeer::GetProcessUndecryptablePacketsAlarm(this));
}
+ TestAlarmFactory::TestAlarm* GetDiscardPreviousOneRttKeysAlarm() {
+ return reinterpret_cast<TestAlarmFactory::TestAlarm*>(
+ QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm(this));
+ }
+
TestAlarmFactory::TestAlarm* GetBlackholeDetectorAlarm() {
return reinterpret_cast<TestAlarmFactory::TestAlarm*>(
QuicConnectionPeer::GetBlackholeDetectorAlarm(this));
@@ -11801,6 +11806,150 @@
ProcessCoalescedPacket({{4, frames4, ENCRYPTION_FORWARD_SECURE}});
}
+TEST_P(QuicConnectionTest, InitiateKeyUpdate) {
+ if (!connection_.version().UsesTls()) {
+ return;
+ }
+
+ TransportParameters params;
+ params.key_update_not_yet_supported = false;
+ QuicConfig config;
+ std::string error_details;
+ EXPECT_THAT(config.ProcessTransportParameters(
+ params, /* is_resumption = */ false, &error_details),
+ IsQuicNoError());
+ config.SetKeyUpdateSupportedLocally();
+ QuicConfigPeer::SetNegotiated(&config, true);
+ if (connection_.version().AuthenticatesHandshakeConnectionIds()) {
+ QuicConfigPeer::SetReceivedOriginalConnectionId(
+ &config, connection_.connection_id());
+ QuicConfigPeer::SetReceivedInitialSourceConnectionId(
+ &config, connection_.connection_id());
+ }
+ EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+ connection_.SetFromConfig(config);
+
+ EXPECT_FALSE(connection_.IsKeyUpdateAllowed());
+
+ MockFramerVisitor peer_framer_visitor_;
+ peer_framer_.set_visitor(&peer_framer_visitor_);
+
+ use_tagging_decrypter();
+
+ connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+ connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<TaggingEncrypter>(0x01));
+ SetDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(0x01));
+ connection_.OnHandshakeComplete();
+
+ peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<TaggingEncrypter>(0x01));
+
+ // Key update should still not be allowed, since no packet has been acked
+ // from the current key phase.
+ EXPECT_FALSE(connection_.IsKeyUpdateAllowed());
+
+ // Send packet 1.
+ QuicPacketNumber last_packet;
+ SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet);
+ EXPECT_EQ(QuicPacketNumber(1u), last_packet);
+
+ // Key update should still not be allowed, even though a packet was sent in
+ // the current key phase it hasn't been acked yet.
+ EXPECT_FALSE(connection_.IsKeyUpdateAllowed());
+
+ EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
+ // Receive ack for packet 1.
+ EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _));
+ QuicAckFrame frame1 = InitAckFrame(1);
+ ProcessAckPacket(&frame1);
+
+ // OnDecryptedFirstPacketInKeyPhase is called even on the first key phase,
+ // so discard_previous_keys_alarm_ should be set now.
+ EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
+
+ // Key update should now be allowed.
+ EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter())
+ .WillOnce(
+ []() { return std::make_unique<StrictTaggingDecrypter>(0x02); });
+ EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() {
+ return std::make_unique<TaggingEncrypter>(0x02);
+ });
+ EXPECT_TRUE(connection_.InitiateKeyUpdate());
+ // discard_previous_keys_alarm_ should not be set until a packet from the new
+ // key phase has been received. (The alarm that was set above should be
+ // cleared if it hasn't fired before the next key update happened.)
+ EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
+
+ // Pretend that peer accepts the key update.
+ EXPECT_CALL(peer_framer_visitor_,
+ AdvanceKeysAndCreateCurrentOneRttDecrypter())
+ .WillOnce(
+ []() { return std::make_unique<StrictTaggingDecrypter>(0x02); });
+ EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter())
+ .WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x02); });
+ peer_framer_.SetKeyUpdateSupportForConnection(true);
+ peer_framer_.DoKeyUpdate();
+
+ // Another key update should not be allowed yet.
+ EXPECT_FALSE(connection_.IsKeyUpdateAllowed());
+
+ // Send packet 2.
+ SendStreamDataToPeer(2, "bar", 0, NO_FIN, &last_packet);
+ EXPECT_EQ(QuicPacketNumber(2u), last_packet);
+ // Receive ack for packet 2.
+ EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _));
+ QuicAckFrame frame2 = InitAckFrame(2);
+ ProcessAckPacket(&frame2);
+ EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
+
+ // Key update should be allowed again now that a packet has been acked from
+ // the current key phase.
+ EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter())
+ .WillOnce(
+ []() { return std::make_unique<StrictTaggingDecrypter>(0x03); });
+ EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() {
+ return std::make_unique<TaggingEncrypter>(0x03);
+ });
+ EXPECT_TRUE(connection_.InitiateKeyUpdate());
+
+ // Pretend that peer accepts the key update.
+ EXPECT_CALL(peer_framer_visitor_,
+ AdvanceKeysAndCreateCurrentOneRttDecrypter())
+ .WillOnce(
+ []() { return std::make_unique<StrictTaggingDecrypter>(0x03); });
+ EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter())
+ .WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x03); });
+ peer_framer_.DoKeyUpdate();
+
+ // Another key update should not be allowed yet.
+ EXPECT_FALSE(connection_.IsKeyUpdateAllowed());
+
+ // Send packet 3.
+ SendStreamDataToPeer(3, "baz", 0, NO_FIN, &last_packet);
+ EXPECT_EQ(QuicPacketNumber(3u), last_packet);
+
+ // Another key update should not be allowed yet.
+ EXPECT_FALSE(connection_.IsKeyUpdateAllowed());
+
+ // Receive ack for packet 3.
+ EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _));
+ QuicAckFrame frame3 = InitAckFrame(3);
+ ProcessAckPacket(&frame3);
+ EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
+
+ // Key update should be allowed now.
+ EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter())
+ .WillOnce(
+ []() { return std::make_unique<StrictTaggingDecrypter>(0x04); });
+ EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() {
+ return std::make_unique<TaggingEncrypter>(0x04);
+ });
+ EXPECT_TRUE(connection_.InitiateKeyUpdate());
+ EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quic/core/quic_crypto_client_handshaker.cc b/quic/core/quic_crypto_client_handshaker.cc
index b60b026..784a451 100644
--- a/quic/core/quic_crypto_client_handshaker.cc
+++ b/quic/core/quic_crypto_client_handshaker.cc
@@ -175,6 +175,24 @@
return QuicCryptoHandshaker::BufferSizeLimitForLevel(level);
}
+bool QuicCryptoClientHandshaker::KeyUpdateSupportedLocally() const {
+ return false;
+}
+
+std::unique_ptr<QuicDecrypter>
+QuicCryptoClientHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ // Key update is only defined in QUIC+TLS.
+ DCHECK(false);
+ return nullptr;
+}
+
+std::unique_ptr<QuicEncrypter>
+QuicCryptoClientHandshaker::CreateCurrentOneRttEncrypter() {
+ // Key update is only defined in QUIC+TLS.
+ DCHECK(false);
+ return nullptr;
+}
+
void QuicCryptoClientHandshaker::OnConnectionClosed(
QuicErrorCode /*error*/,
ConnectionCloseSource /*source*/) {
diff --git a/quic/core/quic_crypto_client_handshaker.h b/quic/core/quic_crypto_client_handshaker.h
index 605318e..49fa405 100644
--- a/quic/core/quic_crypto_client_handshaker.h
+++ b/quic/core/quic_crypto_client_handshaker.h
@@ -51,6 +51,10 @@
CryptoMessageParser* crypto_message_parser() override;
HandshakeState GetHandshakeState() const override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
+ bool KeyUpdateSupportedLocally() const override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
void OnConnectionClosed(QuicErrorCode /*error*/,
diff --git a/quic/core/quic_crypto_client_stream.cc b/quic/core/quic_crypto_client_stream.cc
index 67a9a11..94a8948 100644
--- a/quic/core/quic_crypto_client_stream.cc
+++ b/quic/core/quic_crypto_client_stream.cc
@@ -109,6 +109,20 @@
return handshaker_->BufferSizeLimitForLevel(level);
}
+bool QuicCryptoClientStream::KeyUpdateSupportedLocally() const {
+ return handshaker_->KeyUpdateSupportedLocally();
+}
+
+std::unique_ptr<QuicDecrypter>
+QuicCryptoClientStream::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ return handshaker_->AdvanceKeysAndCreateCurrentOneRttDecrypter();
+}
+
+std::unique_ptr<QuicEncrypter>
+QuicCryptoClientStream::CreateCurrentOneRttEncrypter() {
+ return handshaker_->CreateCurrentOneRttEncrypter();
+}
+
std::string QuicCryptoClientStream::chlo_hash() const {
return handshaker_->chlo_hash();
}
diff --git a/quic/core/quic_crypto_client_stream.h b/quic/core/quic_crypto_client_stream.h
index 1d9b04b..123bdfe 100644
--- a/quic/core/quic_crypto_client_stream.h
+++ b/quic/core/quic_crypto_client_stream.h
@@ -151,6 +151,18 @@
// buffered at each encryption level.
virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const = 0;
+ // Returns whether the implementation supports key update.
+ virtual bool KeyUpdateSupportedLocally() const = 0;
+
+ // Called to generate a decrypter for the next key phase. Each call should
+ // generate the key for phase n+1.
+ virtual std::unique_ptr<QuicDecrypter>
+ AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0;
+
+ // Called to generate an encrypter for the same key phase of the last
+ // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter().
+ virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0;
+
// Returns current handshake state.
virtual HandshakeState GetHandshakeState() const = 0;
@@ -228,6 +240,10 @@
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> application_state) override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
+ bool KeyUpdateSupportedLocally() const override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
std::string chlo_hash() const;
diff --git a/quic/core/quic_crypto_server_stream.cc b/quic/core/quic_crypto_server_stream.cc
index e8f99eb..9d7597c 100644
--- a/quic/core/quic_crypto_server_stream.cc
+++ b/quic/core/quic_crypto_server_stream.cc
@@ -402,6 +402,24 @@
return QuicCryptoHandshaker::BufferSizeLimitForLevel(level);
}
+bool QuicCryptoServerStream::KeyUpdateSupportedLocally() const {
+ return false;
+}
+
+std::unique_ptr<QuicDecrypter>
+QuicCryptoServerStream::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ // Key update is only defined in QUIC+TLS.
+ DCHECK(false);
+ return nullptr;
+}
+
+std::unique_ptr<QuicEncrypter>
+QuicCryptoServerStream::CreateCurrentOneRttEncrypter() {
+ // Key update is only defined in QUIC+TLS.
+ DCHECK(false);
+ return nullptr;
+}
+
void QuicCryptoServerStream::ProcessClientHello(
QuicReferenceCountedPointer<ValidateClientHelloResultCallback::Result>
result,
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h
index 29d680e..c99a136 100644
--- a/quic/core/quic_crypto_server_stream.h
+++ b/quic/core/quic_crypto_server_stream.h
@@ -61,6 +61,10 @@
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> state) override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
+ bool KeyUpdateSupportedLocally() const override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
// From QuicCryptoHandshaker
void OnHandshakeMessage(const CryptoHandshakeMessage& message) override;
diff --git a/quic/core/quic_crypto_stream.h b/quic/core/quic_crypto_stream.h
index 21bdde6..ea15ded 100644
--- a/quic/core/quic_crypto_stream.h
+++ b/quic/core/quic_crypto_stream.h
@@ -126,6 +126,18 @@
// encryption level |level|.
virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const;
+ // Returns whether the implementation supports key update.
+ virtual bool KeyUpdateSupportedLocally() const = 0;
+
+ // Called to generate a decrypter for the next key phase. Each call should
+ // generate the key for phase n+1.
+ virtual std::unique_ptr<QuicDecrypter>
+ AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0;
+
+ // Called to generate an encrypter for the same key phase of the last
+ // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter().
+ virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0;
+
// Called to cancel retransmission of unencrypted crypto stream data.
void NeuterUnencryptedStreamData();
diff --git a/quic/core/quic_crypto_stream_test.cc b/quic/core/quic_crypto_stream_test.cc
index f7e2483..4fc9ac3 100644
--- a/quic/core/quic_crypto_stream_test.cc
+++ b/quic/core/quic_crypto_stream_test.cc
@@ -66,6 +66,14 @@
HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; }
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> /*application_state*/) override {}
+ bool KeyUpdateSupportedLocally() const override { return false; }
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
private:
QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters> params_;
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index 6b9f0c1..6842211 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -412,6 +412,8 @@
process_timestamps_(false),
creation_time_(creation_time),
last_timestamp_(QuicTime::Delta::Zero()),
+ support_key_update_for_connection_(false),
+ current_key_phase_bit_(false),
first_sending_packet_number_(FirstSendingPacketNumber()),
data_producer_(nullptr),
infer_packet_header_type_from_version_(perspective ==
@@ -2136,7 +2138,7 @@
PacketNumberLengthToOnWireValue(header.packet_number_length));
} else {
type = static_cast<uint8_t>(
- FLAGS_FIXED_BIT |
+ FLAGS_FIXED_BIT | (current_key_phase_bit_ ? FLAGS_KEY_PHASE_BIT : 0) |
PacketNumberLengthToOnWireValue(header.packet_number_length));
}
return writer->WriteUInt8(type);
@@ -4235,6 +4237,41 @@
decrypter_[level] = nullptr;
}
+void QuicFramer::SetKeyUpdateSupportForConnection(bool enabled) {
+ QUIC_DVLOG(1) << ENDPOINT << "SetKeyUpdateSupportForConnection: " << enabled;
+ support_key_update_for_connection_ = enabled;
+}
+
+void QuicFramer::DiscardPreviousOneRttKeys() {
+ DCHECK(support_key_update_for_connection_);
+ QUIC_DVLOG(1) << ENDPOINT << "Discarding previous set of 1-RTT keys";
+ previous_decrypter_ = nullptr;
+}
+
+bool QuicFramer::DoKeyUpdate() {
+ DCHECK(support_key_update_for_connection_);
+ if (!next_decrypter_) {
+ // If key update is locally initiated, next decrypter might not be created
+ // yet.
+ next_decrypter_ = visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter();
+ }
+ std::unique_ptr<QuicEncrypter> next_encrypter =
+ visitor_->CreateCurrentOneRttEncrypter();
+ if (!next_decrypter_ || !next_encrypter) {
+ QUIC_BUG << "Failed to create next crypters";
+ return false;
+ }
+ current_key_phase_bit_ = !current_key_phase_bit_;
+ QUIC_DLOG(INFO) << ENDPOINT << "DoKeyUpdate: new current_key_phase_bit_="
+ << current_key_phase_bit_;
+ current_key_phase_first_received_packet_number_.Clear();
+ previous_decrypter_ = std::move(decrypter_[ENCRYPTION_FORWARD_SECURE]);
+ decrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_decrypter_);
+ encrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_encrypter);
+ visitor_->OnKeyUpdate();
+ return true;
+}
+
const QuicDecrypter* QuicFramer::GetDecrypter(EncryptionLevel level) const {
DCHECK(version_.KnowsWhichDecrypterToUse());
return decrypter_[level].get();
@@ -4616,6 +4653,9 @@
EncryptionLevel level = decrypter_level_;
QuicDecrypter* decrypter = decrypter_[level].get();
QuicDecrypter* alternative_decrypter = nullptr;
+ bool key_phase_parsed = false;
+ bool key_phase;
+ bool attempt_key_update = false;
if (version().KnowsWhichDecrypterToUse()) {
if (header.form == GOOGLE_QUIC_PACKET) {
QUIC_BUG << "Attempted to decrypt GOOGLE_QUIC_PACKET with a version that "
@@ -4635,6 +4675,45 @@
perspective_ == Perspective::IS_CLIENT && header.nonce != nullptr) {
decrypter->SetDiversificationNonce(*header.nonce);
}
+ if (support_key_update_for_connection_ &&
+ header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
+ DCHECK(version().UsesTls());
+ DCHECK_EQ(level, ENCRYPTION_FORWARD_SECURE);
+ key_phase = (header.type_byte & FLAGS_KEY_PHASE_BIT) != 0;
+ key_phase_parsed = true;
+ QUIC_DVLOG(1) << ENDPOINT << "packet " << header.packet_number
+ << " received key_phase=" << key_phase
+ << " current_key_phase_bit_=" << current_key_phase_bit_;
+ if (key_phase != current_key_phase_bit_) {
+ if (current_key_phase_first_received_packet_number_.IsInitialized() &&
+ header.packet_number >
+ current_key_phase_first_received_packet_number_) {
+ if (!next_decrypter_) {
+ next_decrypter_ =
+ visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter();
+ if (!next_decrypter_) {
+ QUIC_BUG << "Failed to create next_decrypter";
+ return false;
+ }
+ }
+ QUIC_DVLOG(1) << ENDPOINT << "packet " << header.packet_number
+ << " attempt_key_update=true";
+ attempt_key_update = true;
+ decrypter = next_decrypter_.get();
+ } else {
+ if (previous_decrypter_) {
+ QUIC_DVLOG(1) << ENDPOINT
+ << "trying previous_decrypter_ for packet "
+ << header.packet_number;
+ decrypter = previous_decrypter_.get();
+ } else {
+ QUIC_DVLOG(1) << ENDPOINT << "dropping packet "
+ << header.packet_number << " with old key phase";
+ return false;
+ }
+ }
+ }
+ }
} else if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) {
if (!EncryptionLevelIsValid(alternative_decrypter_level_)) {
QUIC_BUG << "Attempted to decrypt with bad alternative_decrypter_level_";
@@ -4655,6 +4734,27 @@
if (success) {
visitor_->OnDecryptedPacket(level);
*decrypted_level = level;
+ if (attempt_key_update) {
+ if (!DoKeyUpdate()) {
+ set_detailed_error("Key update failed due to internal error");
+ return RaiseError(QUIC_INTERNAL_ERROR);
+ }
+ DCHECK_EQ(current_key_phase_bit_, key_phase);
+ }
+ if (key_phase_parsed &&
+ !current_key_phase_first_received_packet_number_.IsInitialized() &&
+ key_phase == current_key_phase_bit_) {
+ // Set packet number for current key phase if it hasn't been initialized
+ // yet. This is set outside of attempt_key_update since the key update
+ // may have been initiated locally, and in that case we don't know yet
+ // which packet number from the remote side to use until we receive a
+ // packet with that phase.
+ QUIC_DVLOG(1) << ENDPOINT
+ << "current_key_phase_first_received_packet_number_ = "
+ << header.packet_number;
+ current_key_phase_first_received_packet_number_ = header.packet_number;
+ visitor_->OnDecryptedFirstPacketInKeyPhase();
+ }
} else if (alternative_decrypter != nullptr) {
if (header.nonce != nullptr) {
DCHECK_EQ(perspective_, Perspective::IS_CLIENT);
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h
index e57e252..48a7877 100644
--- a/quic/core/quic_framer.h
+++ b/quic/core/quic_framer.h
@@ -237,6 +237,25 @@
// Called when an IETF StreamsBlocked frame has been parsed.
virtual bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) = 0;
+
+ // Called when a Key Phase Update has been initiated. This is called for both
+ // locally and peer initiated key updates. If the key update was locally
+ // initiated, this does not indicate the peer has received the key update yet.
+ virtual void OnKeyUpdate() = 0;
+
+ // Called on the first decrypted packet in each key phase (including the
+ // first key phase.)
+ virtual void OnDecryptedFirstPacketInKeyPhase() = 0;
+
+ // Called when the framer needs to generate a decrypter for the next key
+ // phase. Each call should generate the key for phase n+1.
+ virtual std::unique_ptr<QuicDecrypter>
+ AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0;
+
+ // Called when the framer needs to generate an encrypter. The key corresponds
+ // to the key phase of the last decrypter returned by
+ // AdvanceKeysAndCreateCurrentOneRttDecrypter().
+ virtual std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() = 0;
};
// Class for parsing and constructing QUIC packets. It has a
@@ -519,6 +538,13 @@
std::unique_ptr<QuicDecrypter> decrypter);
void RemoveDecrypter(EncryptionLevel level);
+ // Enables key update support.
+ void SetKeyUpdateSupportForConnection(bool enabled);
+ // Discard the decrypter for the previous key phase.
+ void DiscardPreviousOneRttKeys();
+ // Update the key phase.
+ bool DoKeyUpdate();
+
const QuicDecrypter* GetDecrypter(EncryptionLevel level) const;
const QuicDecrypter* decrypter() const;
const QuicDecrypter* alternative_decrypter() const;
@@ -1059,6 +1085,23 @@
// The last timestamp received if process_timestamps_ is true.
QuicTime::Delta last_timestamp_;
+ // Whether IETF QUIC Key Update is supported on this connection.
+ bool support_key_update_for_connection_;
+ // The value of the current key phase bit, which is toggled when the keys are
+ // changed.
+ bool current_key_phase_bit_;
+ // Tracks the first packet received in the current key phase. Will be
+ // uninitialized before the first one-RTT packet has been received or after a
+ // locally initiated key update but before the first packet from the peer in
+ // the new key phase is received.
+ QuicPacketNumber current_key_phase_first_received_packet_number_;
+ // Decrypter for the previous key phase. Will be null if in the first key
+ // phase or previous keys have been discarded.
+ std::unique_ptr<QuicDecrypter> previous_decrypter_;
+ // Decrypter for the next key phase. May be null if next keys haven't been
+ // generated yet.
+ std::unique_ptr<QuicDecrypter> next_decrypter_;
+
// If this is a framer of a connection, this is the packet number of first
// sending packet. If this is a framer of a framer of dispatcher, this is the
// packet number of sent packets (for those which have packet number).
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc
index f35613a..77e03eb 100644
--- a/quic/core/quic_framer_test.cc
+++ b/quic/core/quic_framer_test.cc
@@ -185,6 +185,31 @@
std::string ciphertext_;
};
+std::unique_ptr<QuicEncryptedPacket> EncryptPacketWithTagAndPhase(
+ const QuicPacket& packet,
+ uint8_t tag,
+ bool phase) {
+ std::string packet_data = std::string(packet.AsStringPiece());
+ if (phase) {
+ packet_data[0] |= FLAGS_KEY_PHASE_BIT;
+ } else {
+ packet_data[0] &= ~FLAGS_KEY_PHASE_BIT;
+ }
+
+ TaggingEncrypter crypter(tag);
+ const size_t packet_size = crypter.GetCiphertextSize(packet_data.size());
+ char* buffer = new char[packet_size];
+ size_t buf_len = 0;
+ if (!crypter.EncryptPacket(0, quiche::QuicheStringPiece(), packet_data,
+ buffer, &buf_len, packet_size)) {
+ delete[] buffer;
+ return nullptr;
+ }
+
+ return std::make_unique<QuicEncryptedPacket>(buffer, buf_len,
+ /*owns_buffer=*/true);
+}
+
class TestQuicVisitor : public QuicFramerVisitorInterface {
public:
TestQuicVisitor()
@@ -193,6 +218,9 @@
packet_count_(0),
frame_count_(0),
complete_packets_(0),
+ key_update_count_(0),
+ derive_next_key_count_(0),
+ decrypted_first_packet_in_key_phase_count_(0),
accept_packet_(true),
accept_public_header_(true) {}
@@ -547,6 +575,21 @@
EXPECT_EQ(0u, framer_->current_received_frame_type());
}
+ void OnKeyUpdate() override { key_update_count_++; }
+
+ void OnDecryptedFirstPacketInKeyPhase() override {
+ decrypted_first_packet_in_key_phase_count_++;
+ }
+
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ derive_next_key_count_++;
+ return std::make_unique<StrictTaggingDecrypter>(derive_next_key_count_);
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return std::make_unique<TaggingEncrypter>(derive_next_key_count_);
+ }
+
void set_framer(QuicFramer* framer) {
framer_ = framer;
transport_version_ = framer->transport_version();
@@ -558,6 +601,9 @@
int packet_count_;
int frame_count_;
int complete_packets_;
+ int key_update_count_;
+ int derive_next_key_count_;
+ int decrypted_first_packet_in_key_phase_count_;
bool accept_packet_;
bool accept_public_header_;
@@ -14759,6 +14805,536 @@
visitor_.ack_frames_[0]->ack_delay_time);
}
+TEST_P(QuicFramerTest, KeyUpdate) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 0, false));
+ ASSERT_TRUE(encrypted);
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ // Processed valid packet with phase=0, key=1: no key update.
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(0, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ header.packet_number += 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ // Processed valid packet with phase=1, key=2: key update should have
+ // occurred.
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ header.packet_number += 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ // Processed another valid packet with phase=1, key=2: no key update.
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process another key update.
+ header.packet_number += 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 2, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(2, visitor_.key_update_count_);
+ EXPECT_EQ(2, visitor_.derive_next_key_count_);
+ EXPECT_EQ(3, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdateOldPacketAfterUpdate) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ // Process packet N with phase 0.
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 0, false));
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(0, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+2 with phase 1.
+ header.packet_number = kPacketNumber + 2;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+1 with phase 0. (Receiving packet from previous phase
+ // after packet from new phase was received.)
+ header.packet_number = kPacketNumber + 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 0, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should decrypt and key update count should not change.
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdateOldPacketAfterDiscardPreviousOneRttKeys) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ // Process packet N with phase 0.
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 0, false));
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(0, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+2 with phase 1.
+ header.packet_number = kPacketNumber + 2;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Discard keys for previous key phase.
+ framer_.DiscardPreviousOneRttKeys();
+
+ // Process packet N+1 with phase 0. (Receiving packet from previous phase
+ // after packet from new phase was received.)
+ header.packet_number = kPacketNumber + 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 0, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should not decrypt and key update count should not change.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdatePacketsOutOfOrder) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ // Process packet N with phase 0.
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 0, false));
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(0, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+2 with phase 1.
+ header.packet_number = kPacketNumber + 2;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+1 with phase 1. (Receiving packet from new phase out of
+ // order.)
+ header.packet_number = kPacketNumber + 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should decrypt and key update count should not change.
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdateWrongKey) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 0, false));
+ ASSERT_TRUE(encrypted);
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ // Processed valid packet with phase=0, key=1: no key update.
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(0, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ header.packet_number += 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 2, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet with phase=1 but key=3, should not process and should not cause key
+ // update, but next decrypter key should have been created to attempt to
+ // decode it.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ header.packet_number += 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 0, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet with phase=1 but key=1, should not process and should not cause key
+ // update.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ header.packet_number += 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet with phase=0 but key=2, should not process and should not cause key
+ // update.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdateReceivedWhenNotEnabled) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 1, true));
+ ASSERT_TRUE(encrypted);
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Received a packet with key phase updated even though framer hasn't had key
+ // update enabled (SetNextOneRttCrypters never called). Should fail to
+ // process.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(0, visitor_.key_update_count_);
+ EXPECT_EQ(0, visitor_.derive_next_key_count_);
+ EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdateLocallyInitiated) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ EXPECT_TRUE(framer_.DoKeyUpdate());
+ // Key update count should be updated, but haven't received packet from peer
+ // with new key phase.
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ // Process packet N with phase 1.
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted(
+ EncryptPacketWithTagAndPhase(*data, 1, true));
+ ASSERT_TRUE(encrypted);
+
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should decrypt and key update count should not change and
+ // OnDecryptedFirstPacketInKeyPhase should have been called.
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N-1 with phase 0. (Receiving packet from previous phase
+ // after packet from new phase was received.)
+ header.packet_number = kPacketNumber - 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 0, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should decrypt and key update count should not change.
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+1 with phase 0 and key 1. This should not decrypt even
+ // though it's using the previous key, since the packet number is higher than
+ // a packet number received using the current key.
+ header.packet_number = kPacketNumber + 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 0, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should not decrypt and key update count should not change.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(2, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
+TEST_P(QuicFramerTest, KeyUpdateLocallyInitiatedReceivedOldPacket) {
+ if (!framer_.version().UsesTls()) {
+ // Key update is only used in QUIC+TLS.
+ return;
+ }
+ ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+ // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+ // instead of TestDecrypter.
+ framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+ framer_.SetKeyUpdateSupportForConnection(true);
+
+ EXPECT_TRUE(framer_.DoKeyUpdate());
+ // Key update count should be updated, but haven't received packet
+ // from peer with new key phase.
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ QuicPacketHeader header;
+ header.destination_connection_id = FramerTestConnectionId();
+ header.reset_flag = false;
+ header.version_flag = false;
+ header.packet_number = kPacketNumber;
+
+ QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+ // Process packet N with phase 0. (Receiving packet from previous phase
+ // after locally initiated key update, but before any packet from new phase
+ // was received.)
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ std::unique_ptr<QuicPacket> data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ std::unique_ptr<QuicEncryptedPacket> encrypted =
+ EncryptPacketWithTagAndPhase(*data, 0, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should decrypt and key update count should not change and
+ // OnDecryptedFirstPacketInKeyPhase should not have been called since the
+ // packet was from the previous key phase.
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+1 with phase 1.
+ header.packet_number = kPacketNumber + 1;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 1, true);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should decrypt and key update count should not change, but
+ // OnDecryptedFirstPacketInKeyPhase should have been called.
+ EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(1, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+
+ // Process packet N+2 with phase 0 and key 1. This should not decrypt even
+ // though it's using the previous key, since the packet number is higher than
+ // a packet number received using the current key.
+ header.packet_number = kPacketNumber + 2;
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+ data = BuildDataPacket(header, frames);
+ ASSERT_TRUE(data != nullptr);
+ encrypted = EncryptPacketWithTagAndPhase(*data, 0, false);
+ ASSERT_TRUE(encrypted);
+ QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+ // Packet should not decrypt and key update count should not change.
+ EXPECT_FALSE(framer_.ProcessPacket(*encrypted));
+ EXPECT_EQ(1, visitor_.key_update_count_);
+ EXPECT_EQ(2, visitor_.derive_next_key_count_);
+ EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index e28db59..a6c6e6c 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -137,6 +137,11 @@
connection_->OnSuccessfulVersionNegotiation();
}
+ if (GetQuicReloadableFlag(quic_key_update_supported) &&
+ GetMutableCryptoStream()->KeyUpdateSupportedLocally()) {
+ config_.SetKeyUpdateSupportedLocally();
+ }
+
if (QuicVersionUsesCryptoFrames(transport_version())) {
return;
}
@@ -276,6 +281,15 @@
GetMutableCryptoStream()->OnHandshakePacketSent();
}
+std::unique_ptr<QuicDecrypter>
+QuicSession::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ return GetMutableCryptoStream()->AdvanceKeysAndCreateCurrentOneRttDecrypter();
+}
+
+std::unique_ptr<QuicEncrypter> QuicSession::CreateCurrentOneRttEncrypter() {
+ return GetMutableCryptoStream()->CreateCurrentOneRttEncrypter();
+}
+
void QuicSession::PendingStreamOnRstStream(const QuicRstStreamFrame& frame) {
DCHECK(VersionUsesHttp3(transport_version()));
QuicStreamId stream_id = frame.stream_id;
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index a3a982b..42f1b04 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -134,6 +134,9 @@
void OnPacketDecrypted(EncryptionLevel level) override;
void OnOneRttPacketAcknowledged() override;
void OnHandshakePacketSent() override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
// QuicStreamFrameDataProducer
WriteStreamDataResult WriteStreamData(QuicStreamId id,
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc
index b2c5b1f..a11e6f0 100644
--- a/quic/core/quic_session_test.cc
+++ b/quic/core/quic_session_test.cc
@@ -140,6 +140,15 @@
}
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> /*application_state*/) override {}
+ MOCK_METHOD(bool, KeyUpdateSupportedLocally, (), (const, override));
+ MOCK_METHOD(std::unique_ptr<QuicDecrypter>,
+ AdvanceKeysAndCreateCurrentOneRttDecrypter,
+ (),
+ (override));
+ MOCK_METHOD(std::unique_ptr<QuicEncrypter>,
+ CreateCurrentOneRttEncrypter,
+ (),
+ (override));
MOCK_METHOD(void, OnCanWrite, (), (override));
bool HasPendingCryptoRetransmission() const override { return false; }
@@ -198,6 +207,8 @@
writev_consumes_all_data_(false),
uses_pending_streams_(false),
num_incoming_streams_created_(0) {
+ EXPECT_CALL(*GetMutableCryptoStream(), KeyUpdateSupportedLocally())
+ .WillRepeatedly(Return(false));
Initialize();
this->connection()->SetEncrypter(
ENCRYPTION_FORWARD_SECURE,
@@ -2229,6 +2240,30 @@
ASSERT_TRUE(session_.connection()->can_receive_ack_frequency_frame());
}
+TEST_P(QuicSessionTestClient, KeyUpdateFlagNotSet) {
+ SetQuicReloadableFlag(quic_key_update_supported, false);
+ EXPECT_CALL(*session_.GetMutableCryptoStream(), KeyUpdateSupportedLocally())
+ .Times(0);
+ session_.Initialize();
+ EXPECT_FALSE(session_.config()->KeyUpdateSupportedLocally());
+}
+
+TEST_P(QuicSessionTestClient, KeyUpdateNotSupportedLocallyAndFlagSet) {
+ SetQuicReloadableFlag(quic_key_update_supported, true);
+ EXPECT_CALL(*session_.GetMutableCryptoStream(), KeyUpdateSupportedLocally())
+ .WillOnce(Return(false));
+ session_.Initialize();
+ EXPECT_FALSE(session_.config()->KeyUpdateSupportedLocally());
+}
+
+TEST_P(QuicSessionTestClient, KeyUpdateSupportedLocallyAndFlagSet) {
+ SetQuicReloadableFlag(quic_key_update_supported, true);
+ EXPECT_CALL(*session_.GetMutableCryptoStream(), KeyUpdateSupportedLocally())
+ .WillOnce(Return(true));
+ session_.Initialize();
+ EXPECT_TRUE(session_.config()->KeyUpdateSupportedLocally());
+}
+
TEST_P(QuicSessionTestClient, FailedToCreateStreamIfTooCloseToIdleTimeout) {
connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
EXPECT_TRUE(session_.CanOpenNextOutgoingBidirectionalStream());
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index 13307bb..4a896a0 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -602,8 +602,8 @@
QuicLongHeaderType type);
enum QuicPacketHeaderTypeFlags : uint8_t {
- // Bit 2: Reserved for experimentation for short header.
- FLAGS_EXPERIMENTATION_BIT = 1 << 2,
+ // Bit 2: Key phase bit for IETF QUIC short header packets.
+ FLAGS_KEY_PHASE_BIT = 1 << 2,
// Bit 3: Google QUIC Demultiplexing bit, the short header always sets this
// bit to 0, allowing to distinguish Google QUIC packets from short header
// packets.
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc
index 745957f..145cdcd 100644
--- a/quic/core/quic_versions.cc
+++ b/quic/core/quic_versions.cc
@@ -663,6 +663,7 @@
void QuicVersionInitializeSupportForIetfDraft() {
// Enable necessary flags.
SetQuicRestartFlag(quic_enable_zero_rtt_for_tls_v2, true);
+ SetQuicReloadableFlag(quic_key_update_supported, true);
}
void QuicEnableVersion(const ParsedQuicVersion& version) {
diff --git a/quic/core/tls_chlo_extractor.h b/quic/core/tls_chlo_extractor.h
index 53e06b4..6d2f254 100644
--- a/quic/core/tls_chlo_extractor.h
+++ b/quic/core/tls_chlo_extractor.h
@@ -160,6 +160,15 @@
}
void OnAuthenticatedIetfStatelessResetPacket(
const QuicIetfStatelessResetPacket& /*packet*/) override {}
+ void OnKeyUpdate() override {}
+ void OnDecryptedFirstPacketInKeyPhase() override {}
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
// Methods from QuicStreamSequencer::StreamInterface.
void OnDataAvailable() override;
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index c6e601e..234b4d1 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -349,6 +349,20 @@
return TlsHandshaker::BufferSizeLimitForLevel(level);
}
+bool TlsClientHandshaker::KeyUpdateSupportedLocally() const {
+ return true;
+}
+
+std::unique_ptr<QuicDecrypter>
+TlsClientHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ return TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter();
+}
+
+std::unique_ptr<QuicEncrypter>
+TlsClientHandshaker::CreateCurrentOneRttEncrypter() {
+ return TlsHandshaker::CreateCurrentOneRttEncrypter();
+}
+
void TlsClientHandshaker::OnOneRttPacketAcknowledged() {
OnHandshakeConfirmed();
}
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 89388cc..f5a349f 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -59,6 +59,10 @@
CryptoMessageParser* crypto_message_parser() override;
HandshakeState GetHandshakeState() const override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
+ bool KeyUpdateSupportedLocally() const override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
void OnOneRttPacketAcknowledged() override;
void OnHandshakePacketSent() override;
void OnConnectionClosed(QuicErrorCode error,
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index 362a1e8..a53dae4 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -68,10 +68,22 @@
void TlsHandshaker::SetWriteSecret(EncryptionLevel level,
const SSL_CIPHER* cipher,
const std::vector<uint8_t>& write_secret) {
+ QUIC_DVLOG(1) << "SetWriteSecret level=" << level;
std::unique_ptr<QuicEncrypter> encrypter =
QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
- CryptoUtils::InitializeCrypterSecrets(Prf(cipher), write_secret,
- encrypter.get());
+ const EVP_MD* prf = Prf(cipher);
+ CryptoUtils::SetKeyAndIV(prf, write_secret, encrypter.get());
+ std::vector<uint8_t> header_protection_key =
+ CryptoUtils::GenerateHeaderProtectionKey(prf, write_secret,
+ encrypter->GetKeySize());
+ encrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece(
+ reinterpret_cast<char*>(header_protection_key.data()),
+ header_protection_key.size()));
+ if (level == ENCRYPTION_FORWARD_SECURE) {
+ DCHECK(latest_write_secret_.empty());
+ latest_write_secret_ = write_secret;
+ one_rtt_write_header_protection_key_ = header_protection_key;
+ }
handshaker_delegate_->OnNewEncryptionKeyAvailable(level,
std::move(encrypter));
}
@@ -79,16 +91,73 @@
bool TlsHandshaker::SetReadSecret(EncryptionLevel level,
const SSL_CIPHER* cipher,
const std::vector<uint8_t>& read_secret) {
+ QUIC_DVLOG(1) << "SetReadSecret level=" << level;
std::unique_ptr<QuicDecrypter> decrypter =
QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
- CryptoUtils::InitializeCrypterSecrets(Prf(cipher), read_secret,
- decrypter.get());
+ const EVP_MD* prf = Prf(cipher);
+ CryptoUtils::SetKeyAndIV(prf, read_secret, decrypter.get());
+ std::vector<uint8_t> header_protection_key =
+ CryptoUtils::GenerateHeaderProtectionKey(prf, read_secret,
+ decrypter->GetKeySize());
+ decrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece(
+ reinterpret_cast<char*>(header_protection_key.data()),
+ header_protection_key.size()));
+ if (level == ENCRYPTION_FORWARD_SECURE) {
+ DCHECK(latest_read_secret_.empty());
+ latest_read_secret_ = read_secret;
+ one_rtt_read_header_protection_key_ = header_protection_key;
+ }
return handshaker_delegate_->OnNewDecryptionKeyAvailable(
level, std::move(decrypter),
/*set_alternative_decrypter=*/false,
/*latch_once_used=*/false);
}
+std::unique_ptr<QuicDecrypter>
+TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ if (latest_read_secret_.empty() || latest_write_secret_.empty() ||
+ one_rtt_read_header_protection_key_.empty() ||
+ one_rtt_write_header_protection_key_.empty()) {
+ std::string error_details = "1-RTT secret(s) not set yet.";
+ QUIC_BUG << error_details;
+ CloseConnection(QUIC_INTERNAL_ERROR, error_details);
+ return nullptr;
+ }
+ const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
+ const EVP_MD* prf = Prf(cipher);
+ latest_read_secret_ =
+ CryptoUtils::GenerateNextKeyPhaseSecret(prf, latest_read_secret_);
+ latest_write_secret_ =
+ CryptoUtils::GenerateNextKeyPhaseSecret(prf, latest_write_secret_);
+
+ std::unique_ptr<QuicDecrypter> decrypter =
+ QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
+ CryptoUtils::SetKeyAndIV(prf, latest_read_secret_, decrypter.get());
+ decrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece(
+ reinterpret_cast<char*>(one_rtt_read_header_protection_key_.data()),
+ one_rtt_read_header_protection_key_.size()));
+
+ return decrypter;
+}
+
+std::unique_ptr<QuicEncrypter> TlsHandshaker::CreateCurrentOneRttEncrypter() {
+ if (latest_write_secret_.empty() ||
+ one_rtt_write_header_protection_key_.empty()) {
+ std::string error_details = "1-RTT write secret not set yet.";
+ QUIC_BUG << error_details;
+ CloseConnection(QUIC_INTERNAL_ERROR, error_details);
+ return nullptr;
+ }
+ const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
+ std::unique_ptr<QuicEncrypter> encrypter =
+ QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
+ CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_, encrypter.get());
+ encrypter->SetHeaderProtectionKey(quiche::QuicheStringPiece(
+ reinterpret_cast<char*>(one_rtt_write_header_protection_key_.data()),
+ one_rtt_write_header_protection_key_.size()));
+ return encrypter;
+}
+
void TlsHandshaker::WriteMessage(EncryptionLevel level,
absl::string_view data) {
stream_->WriteCryptoData(level, data);
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h
index ef3ade1..077e373 100644
--- a/quic/core/tls_handshaker.h
+++ b/quic/core/tls_handshaker.h
@@ -48,6 +48,8 @@
CryptoMessageParser* crypto_message_parser() { return this; }
size_t BufferSizeLimitForLevel(EncryptionLevel level) const;
ssl_early_data_reason_t EarlyDataReason() const;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter();
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter();
protected:
virtual void AdvanceHandshake() = 0;
@@ -104,6 +106,14 @@
QuicErrorCode parser_error_ = QUIC_NO_ERROR;
std::string parser_error_detail_;
+
+ // The most recently derived 1-RTT read and write secrets, which are updated
+ // on each key update.
+ std::vector<uint8_t> latest_read_secret_;
+ std::vector<uint8_t> latest_write_secret_;
+ // 1-RTT header protection keys, which are not changed during key update.
+ std::vector<uint8_t> one_rtt_read_header_protection_key_;
+ std::vector<uint8_t> one_rtt_write_header_protection_key_;
};
} // namespace quic
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index f108a32..405e109 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -216,6 +216,20 @@
return TlsHandshaker::BufferSizeLimitForLevel(level);
}
+bool TlsServerHandshaker::KeyUpdateSupportedLocally() const {
+ return true;
+}
+
+std::unique_ptr<QuicDecrypter>
+TlsServerHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() {
+ return TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter();
+}
+
+std::unique_ptr<QuicEncrypter>
+TlsServerHandshaker::CreateCurrentOneRttEncrypter() {
+ return TlsHandshaker::CreateCurrentOneRttEncrypter();
+}
+
void TlsServerHandshaker::OverrideQuicConfigDefaults(QuicConfig* /*config*/) {}
void TlsServerHandshaker::AdvanceHandshake() {
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 79d19cb..7b28825 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -66,6 +66,10 @@
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> state) override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
+ bool KeyUpdateSupportedLocally() const override;
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override;
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
void SetWriteSecret(EncryptionLevel level,
const SSL_CIPHER* cipher,
const std::vector<uint8_t>& write_secret) override;
diff --git a/quic/test_tools/quic_connection_peer.cc b/quic/test_tools/quic_connection_peer.cc
index 01da28c..1f4e0c8 100644
--- a/quic/test_tools/quic_connection_peer.cc
+++ b/quic/test_tools/quic_connection_peer.cc
@@ -151,6 +151,12 @@
}
// static
+QuicAlarm* QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm(
+ QuicConnection* connection) {
+ return connection->discard_previous_one_rtt_keys_alarm_.get();
+}
+
+// static
QuicPacketWriter* QuicConnectionPeer::GetWriter(QuicConnection* connection) {
return connection->writer_;
}
diff --git a/quic/test_tools/quic_connection_peer.h b/quic/test_tools/quic_connection_peer.h
index 1375cbe..783b055 100644
--- a/quic/test_tools/quic_connection_peer.h
+++ b/quic/test_tools/quic_connection_peer.h
@@ -82,6 +82,8 @@
static QuicAlarm* GetMtuDiscoveryAlarm(QuicConnection* connection);
static QuicAlarm* GetProcessUndecryptablePacketsAlarm(
QuicConnection* connection);
+ static QuicAlarm* GetDiscardPreviousOneRttKeysAlarm(
+ QuicConnection* connection);
static QuicPacketWriter* GetWriter(QuicConnection* connection);
// If |owns_writer| is true, takes ownership of |writer|.
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 294b7db..c6a37f0 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -424,6 +424,16 @@
OnAuthenticatedIetfStatelessResetPacket,
(const QuicIetfStatelessResetPacket&),
(override));
+ MOCK_METHOD(void, OnKeyUpdate, (), (override));
+ MOCK_METHOD(void, OnDecryptedFirstPacketInKeyPhase, (), (override));
+ MOCK_METHOD(std::unique_ptr<QuicDecrypter>,
+ AdvanceKeysAndCreateCurrentOneRttDecrypter,
+ (),
+ (override));
+ MOCK_METHOD(std::unique_ptr<QuicEncrypter>,
+ CreateCurrentOneRttEncrypter,
+ (),
+ (override));
};
class NoOpFramerVisitor : public QuicFramerVisitorInterface {
@@ -483,6 +493,15 @@
bool IsValidStatelessResetToken(QuicUint128 token) const override;
void OnAuthenticatedIetfStatelessResetPacket(
const QuicIetfStatelessResetPacket& /*packet*/) override {}
+ void OnKeyUpdate() override {}
+ void OnDecryptedFirstPacketInKeyPhase() override {}
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
};
class MockQuicConnectionVisitor : public QuicConnectionVisitorInterface {
@@ -558,6 +577,14 @@
MOCK_METHOD(void, OnPacketDecrypted, (EncryptionLevel), (override));
MOCK_METHOD(void, OnOneRttPacketAcknowledged, (), (override));
MOCK_METHOD(void, OnHandshakePacketSent, (), (override));
+ MOCK_METHOD(std::unique_ptr<QuicDecrypter>,
+ AdvanceKeysAndCreateCurrentOneRttDecrypter,
+ (),
+ (override));
+ MOCK_METHOD(std::unique_ptr<QuicEncrypter>,
+ CreateCurrentOneRttEncrypter,
+ (),
+ (override));
};
class MockQuicConnectionHelper : public QuicConnectionHelperInterface {
@@ -865,6 +892,14 @@
HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; }
void SetServerApplicationStateForResumption(
std::unique_ptr<ApplicationState> /*application_state*/) override {}
+ bool KeyUpdateSupportedLocally() const override { return false; }
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
private:
QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters> params_;
diff --git a/quic/test_tools/simple_quic_framer.cc b/quic/test_tools/simple_quic_framer.cc
index 2b974ee..4a77331 100644
--- a/quic/test_tools/simple_quic_framer.cc
+++ b/quic/test_tools/simple_quic_framer.cc
@@ -220,6 +220,16 @@
std::make_unique<QuicIetfStatelessResetPacket>(packet);
}
+ void OnKeyUpdate() override {}
+ void OnDecryptedFirstPacketInKeyPhase() override {}
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
+
const QuicPacketHeader& header() const { return header_; }
const std::vector<QuicAckFrame>& ack_frames() const { return ack_frames_; }
const std::vector<QuicConnectionCloseFrame>& connection_close_frames() const {
diff --git a/quic/test_tools/simulator/quic_endpoint.h b/quic/test_tools/simulator/quic_endpoint.h
index 39c4855..faf2bfd 100644
--- a/quic/test_tools/simulator/quic_endpoint.h
+++ b/quic/test_tools/simulator/quic_endpoint.h
@@ -93,6 +93,13 @@
void OnPacketDecrypted(EncryptionLevel /*level*/) override {}
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ return nullptr;
+ }
// End QuicConnectionVisitorInterface implementation.
diff --git a/quic/tools/quic_client_interop_test_bin.cc b/quic/tools/quic_client_interop_test_bin.cc
index 80dfd3b..15871d8 100644
--- a/quic/tools/quic_client_interop_test_bin.cc
+++ b/quic/tools/quic_client_interop_test_bin.cc
@@ -52,6 +52,8 @@
// Second row of features (anything else protocol-related)
// We switched to a different port and the server migrated to it.
kRebinding,
+ // One endpoint can update keys and its peer responds correctly.
+ kKeyUpdate,
// Third row of features (H3 tests)
// An H3 transaction succeeded.
@@ -81,6 +83,8 @@
return 'Q';
case Feature::kRebinding:
return 'B';
+ case Feature::kKeyUpdate:
+ return 'U';
case Feature::kHttp3:
return '3';
case Feature::kDynamicEntryReferenced:
@@ -107,7 +111,8 @@
ParsedQuicVersion version,
bool test_version_negotiation,
bool attempt_rebind,
- bool attempt_multi_packet_chlo);
+ bool attempt_multi_packet_chlo,
+ bool attempt_key_update);
// Constructs a SpdyHeaderBlock containing the pseudo-headers needed to make a
// GET request to "/" on the hostname |authority|.
@@ -191,7 +196,8 @@
ParsedQuicVersion version,
bool test_version_negotiation,
bool attempt_rebind,
- bool attempt_multi_packet_chlo) {
+ bool attempt_multi_packet_chlo,
+ bool attempt_key_update) {
ParsedQuicVersionVector versions = {version};
if (test_version_negotiation) {
versions.insert(versions.begin(), QuicVersionReservedForNegotiation());
@@ -238,7 +244,7 @@
// Failed to negotiate version, retry without version negotiation.
AttemptRequest(addr, authority, server_id, version,
/*test_version_negotiation=*/false, attempt_rebind,
- attempt_multi_packet_chlo);
+ attempt_multi_packet_chlo, attempt_key_update);
return;
}
if (!client->session()->OneRttKeysAvailable()) {
@@ -246,7 +252,7 @@
// Failed to handshake with multi-packet client hello, retry without it.
AttemptRequest(addr, authority, server_id, version,
test_version_negotiation, attempt_rebind,
- /*attempt_multi_packet_chlo=*/false);
+ /*attempt_multi_packet_chlo=*/false, attempt_key_update);
return;
}
return;
@@ -278,7 +284,7 @@
// Rebinding does not work, retry without attempting it.
AttemptRequest(addr, authority, server_id, version,
test_version_negotiation, /*attempt_rebind=*/false,
- attempt_multi_packet_chlo);
+ attempt_multi_packet_chlo, attempt_key_update);
return;
}
InsertFeature(Feature::kRebinding);
@@ -290,6 +296,27 @@
QUIC_LOG(ERROR) << "Failed to change ephemeral port";
}
}
+
+ if (attempt_key_update) {
+ if (connection->IsKeyUpdateAllowed()) {
+ if (connection->InitiateKeyUpdate()) {
+ client->SendRequestAndWaitForResponse(header_block, "", /*fin=*/true);
+ if (!client->connected()) {
+ // Key update does not work, retry without attempting it.
+ AttemptRequest(addr, authority, server_id, version,
+ test_version_negotiation, attempt_rebind,
+ attempt_multi_packet_chlo,
+ /*attempt_key_update=*/false);
+ return;
+ }
+ InsertFeature(Feature::kKeyUpdate);
+ } else {
+ QUIC_LOG(ERROR) << "Failed to initiate key update";
+ }
+ } else {
+ QUIC_LOG(ERROR) << "Key update not allowed";
+ }
+ }
}
if (connection->connected()) {
@@ -368,7 +395,8 @@
runner.AttemptRequest(addr, authority, server_id, version,
/*test_version_negotiation=*/true,
/*attempt_rebind=*/true,
- /*attempt_multi_packet_chlo=*/true);
+ /*attempt_multi_packet_chlo=*/true,
+ /*attempt_key_update=*/true);
return runner.features();
}
diff --git a/quic/tools/quic_packet_printer_bin.cc b/quic/tools/quic_packet_printer_bin.cc
index a19646f..e64e1ac 100644
--- a/quic/tools/quic_packet_printer_bin.cc
+++ b/quic/tools/quic_packet_printer_bin.cc
@@ -220,6 +220,19 @@
const QuicIetfStatelessResetPacket& /*packet*/) override {
std::cerr << "OnAuthenticatedIetfStatelessResetPacket\n";
}
+ void OnKeyUpdate() override { std::cerr << "OnKeyUpdate\n"; }
+ void OnDecryptedFirstPacketInKeyPhase() override {
+ std::cerr << "OnDecryptedFirstPacketInKeyPhase\n";
+ }
+ std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
+ override {
+ std::cerr << "AdvanceKeysAndCreateCurrentOneRttDecrypter\n";
+ return nullptr;
+ }
+ std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override {
+ std::cerr << "CreateCurrentOneRttEncrypter\n";
+ return nullptr;
+ }
private:
QuicFramer* framer_; // Unowned.