Add QuicConnectionVisitorInterface::OnKeyUpdate upcall. PiperOrigin-RevId: 337897975 Change-Id: I550a9fd8f290f24f99cf80ee2ed3ecbc86f68879
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index 0aa7a4d..a19eae0 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -1943,6 +1943,8 @@ // discard_previous_one_rtt_keys_alarm_ hasn't fired yet, cancel it since the // old keys would already be discarded. discard_previous_one_rtt_keys_alarm_->Cancel(); + + visitor_->OnKeyUpdate(reason); } void QuicConnection::OnDecryptedFirstPacketInKeyPhase() {
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h index 2fc8c8a..73c6706 100644 --- a/quic/core/quic_connection.h +++ b/quic/core/quic_connection.h
@@ -196,6 +196,9 @@ // Called when a packet of ENCRYPTION_HANDSHAKE gets sent. virtual void OnHandshakePacketSent() = 0; + // Called when a key update has occurred. + virtual void OnKeyUpdate(KeyUpdateReason reason) = 0; + // Called to generate a decrypter for the next key phase. Each call should // generate the key for phase n+1. virtual std::unique_ptr<QuicDecrypter>
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc index 483ddb1..fd84b30 100644 --- a/quic/core/quic_connection_test.cc +++ b/quic/core/quic_connection_test.cc
@@ -12096,6 +12096,7 @@ EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x02); }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); // discard_previous_keys_alarm_ should not be set until a packet from the new // key phase has been received. (The alarm that was set above should be @@ -12132,6 +12133,7 @@ EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x03); }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); // Pretend that peer accepts the key update. @@ -12166,6 +12168,7 @@ EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { return std::make_unique<TaggingEncrypter>(0x04); }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); } @@ -12230,6 +12233,8 @@ .WillOnce([current_tag]() { return std::make_unique<TaggingEncrypter>(current_tag); }); + EXPECT_CALL(visitor_, + OnKeyUpdate(KeyUpdateReason::kLocalKeyUpdateLimitOverride)); } // Send packet. QuicPacketNumber last_packet;
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h index 8fca9bd..d162d59 100644 --- a/quic/core/quic_session.h +++ b/quic/core/quic_session.h
@@ -137,6 +137,7 @@ void OnPacketDecrypted(EncryptionLevel level) override; void OnOneRttPacketAcknowledged() override; void OnHandshakePacketSent() override; + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() override; std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter() override;
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index e577174..cff9d1a 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -581,6 +581,7 @@ MOCK_METHOD(void, OnPacketDecrypted, (EncryptionLevel), (override)); MOCK_METHOD(void, OnOneRttPacketAcknowledged, (), (override)); MOCK_METHOD(void, OnHandshakePacketSent, (), (override)); + MOCK_METHOD(void, OnKeyUpdate, (KeyUpdateReason), (override)); MOCK_METHOD(std::unique_ptr<QuicDecrypter>, AdvanceKeysAndCreateCurrentOneRttDecrypter, (),
diff --git a/quic/test_tools/simulator/quic_endpoint.h b/quic/test_tools/simulator/quic_endpoint.h index 7f5bef2..681cd33 100644 --- a/quic/test_tools/simulator/quic_endpoint.h +++ b/quic/test_tools/simulator/quic_endpoint.h
@@ -94,6 +94,7 @@ void OnPacketDecrypted(EncryptionLevel /*level*/) override {} void OnOneRttPacketAcknowledged() override {} void OnHandshakePacketSent() override {} + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter() override { return nullptr;