Add QuicConnectionVisitorInterface::OnKeyUpdate upcall.
PiperOrigin-RevId: 337897975
Change-Id: I550a9fd8f290f24f99cf80ee2ed3ecbc86f68879
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 0aa7a4d..a19eae0 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -1943,6 +1943,8 @@
// discard_previous_one_rtt_keys_alarm_ hasn't fired yet, cancel it since the
// old keys would already be discarded.
discard_previous_one_rtt_keys_alarm_->Cancel();
+
+ visitor_->OnKeyUpdate(reason);
}
void QuicConnection::OnDecryptedFirstPacketInKeyPhase() {
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 2fc8c8a..73c6706 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -196,6 +196,9 @@
// Called when a packet of ENCRYPTION_HANDSHAKE gets sent.
virtual void OnHandshakePacketSent() = 0;
+ // Called when a key update has occurred.
+ virtual void OnKeyUpdate(KeyUpdateReason reason) = 0;
+
// Called to generate a decrypter for the next key phase. Each call should
// generate the key for phase n+1.
virtual std::unique_ptr<QuicDecrypter>
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 483ddb1..fd84b30 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -12096,6 +12096,7 @@
EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() {
return std::make_unique<TaggingEncrypter>(0x02);
});
+ EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests));
EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests));
// discard_previous_keys_alarm_ should not be set until a packet from the new
// key phase has been received. (The alarm that was set above should be
@@ -12132,6 +12133,7 @@
EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() {
return std::make_unique<TaggingEncrypter>(0x03);
});
+ EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests));
EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests));
// Pretend that peer accepts the key update.
@@ -12166,6 +12168,7 @@
EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() {
return std::make_unique<TaggingEncrypter>(0x04);
});
+ EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests));
EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests));
EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet());
}
@@ -12230,6 +12233,8 @@
.WillOnce([current_tag]() {
return std::make_unique<TaggingEncrypter>(current_tag);
});
+ EXPECT_CALL(visitor_,
+ OnKeyUpdate(KeyUpdateReason::kLocalKeyUpdateLimitOverride));
}
// Send packet.
QuicPacketNumber last_packet;
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index 8fca9bd..d162d59 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -137,6 +137,7 @@
void OnPacketDecrypted(EncryptionLevel level) override;
void OnOneRttPacketAcknowledged() override;
void OnHandshakePacketSent() override;
+ void OnKeyUpdate(KeyUpdateReason /*reason*/) override {}
std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
override;
std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index e577174..cff9d1a 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -581,6 +581,7 @@
MOCK_METHOD(void, OnPacketDecrypted, (EncryptionLevel), (override));
MOCK_METHOD(void, OnOneRttPacketAcknowledged, (), (override));
MOCK_METHOD(void, OnHandshakePacketSent, (), (override));
+ MOCK_METHOD(void, OnKeyUpdate, (KeyUpdateReason), (override));
MOCK_METHOD(std::unique_ptr<QuicDecrypter>,
AdvanceKeysAndCreateCurrentOneRttDecrypter,
(),
diff --git a/quic/test_tools/simulator/quic_endpoint.h b/quic/test_tools/simulator/quic_endpoint.h
index 7f5bef2..681cd33 100644
--- a/quic/test_tools/simulator/quic_endpoint.h
+++ b/quic/test_tools/simulator/quic_endpoint.h
@@ -94,6 +94,7 @@
void OnPacketDecrypted(EncryptionLevel /*level*/) override {}
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
+ void OnKeyUpdate(KeyUpdateReason /*reason*/) override {}
std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter()
override {
return nullptr;