gfe-relnote: In QUIC with TLS, do not proceed in SetWriteSecret if connection has been closed. Protected by gfe2_reloadable_flag_quic_notify_handshaker_on_connection_close.
PiperOrigin-RevId: 309753183
Change-Id: I829b92d82ca84f85ab60aa09940a3b205641a34b
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index a692c94..af52b12 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -3522,9 +3522,8 @@
// Regression test for b/116200989.
TEST_P(EndToEndTest,
SendStatelessResetIfServerConnectionClosedLocallyDuringHandshake) {
- if (version_.UsesTls()) {
- // TODO(b/155317149): Enable this test for TLS.
- Initialize();
+ if (!GetQuicReloadableFlag(quic_notify_handshaker_on_connection_close)) {
+ ASSERT_TRUE(Initialize());
return;
}
connect_to_server_on_initialize_ = false;
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc
index 8a4dfba..d400fc7 100644
--- a/quic/core/http/quic_spdy_stream_test.cc
+++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -125,6 +125,8 @@
void OnPacketDecrypted(EncryptionLevel /*level*/) override {}
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
+ void OnConnectionClosed(QuicErrorCode /*error*/,
+ ConnectionCloseSource /*source*/) override {}
void OnHandshakeDoneReceived() override {}
MOCK_METHOD(void, OnCanWrite, (), (override));
diff --git a/quic/core/quic_crypto_client_handshaker.h b/quic/core/quic_crypto_client_handshaker.h
index 40b490e..5ba93f2 100644
--- a/quic/core/quic_crypto_client_handshaker.h
+++ b/quic/core/quic_crypto_client_handshaker.h
@@ -52,6 +52,8 @@
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
+ void OnConnectionClosed(QuicErrorCode /*error*/,
+ ConnectionCloseSource /*source*/) override {}
void OnHandshakeDoneReceived() override;
void OnApplicationState(
std::unique_ptr<ApplicationState> /*application_state*/) override {
diff --git a/quic/core/quic_crypto_client_stream.cc b/quic/core/quic_crypto_client_stream.cc
index 532a982..36dc4cd 100644
--- a/quic/core/quic_crypto_client_stream.cc
+++ b/quic/core/quic_crypto_client_stream.cc
@@ -117,6 +117,11 @@
handshaker_->OnHandshakePacketSent();
}
+void QuicCryptoClientStream::OnConnectionClosed(QuicErrorCode error,
+ ConnectionCloseSource source) {
+ handshaker_->OnConnectionClosed(error, source);
+}
+
void QuicCryptoClientStream::OnHandshakeDoneReceived() {
handshaker_->OnHandshakeDoneReceived();
}
diff --git a/quic/core/quic_crypto_client_stream.h b/quic/core/quic_crypto_client_stream.h
index 82b0ef9..23f83c7 100644
--- a/quic/core/quic_crypto_client_stream.h
+++ b/quic/core/quic_crypto_client_stream.h
@@ -159,6 +159,10 @@
// Called when a packet of ENCRYPTION_HANDSHAKE gets sent.
virtual void OnHandshakePacketSent() = 0;
+ // Called when connection gets closed.
+ virtual void OnConnectionClosed(QuicErrorCode error,
+ ConnectionCloseSource source) = 0;
+
// Called when handshake done has been received.
virtual void OnHandshakeDoneReceived() = 0;
@@ -215,6 +219,8 @@
void OnPacketDecrypted(EncryptionLevel /*level*/) override {}
void OnOneRttPacketAcknowledged() override;
void OnHandshakePacketSent() override;
+ void OnConnectionClosed(QuicErrorCode error,
+ ConnectionCloseSource source) override;
void OnHandshakeDoneReceived() override;
HandshakeState GetHandshakeState() const override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h
index 8eef07c..52d4874 100644
--- a/quic/core/quic_crypto_server_stream.h
+++ b/quic/core/quic_crypto_server_stream.h
@@ -43,6 +43,8 @@
void OnPacketDecrypted(EncryptionLevel level) override;
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
+ void OnConnectionClosed(QuicErrorCode /*error*/,
+ ConnectionCloseSource /*source*/) override {}
void OnHandshakeDoneReceived() override;
bool ShouldSendExpectCTHeader() const override;
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index 99d1ff4..29dfb93 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -398,6 +398,11 @@
on_closed_frame_ = frame;
}
+ if (GetQuicReloadableFlag(quic_notify_handshaker_on_connection_close)) {
+ QUIC_RELOADABLE_FLAG_COUNT(quic_notify_handshaker_on_connection_close);
+ GetMutableCryptoStream()->OnConnectionClosed(frame.quic_error_code, source);
+ }
+
// Copy all non static streams in a new map for the ease of deleting.
QuicSmallMap<QuicStreamId, QuicStream*, 10> non_static_streams;
for (const auto& it : stream_map_) {
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index d6435ff..324c378 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -305,6 +305,11 @@
initial_keys_dropped_ = true;
}
+void TlsClientHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
+ ConnectionCloseSource /*source*/) {
+ state_ = STATE_CONNECTION_CLOSED;
+}
+
void TlsClientHandshaker::OnHandshakeDoneReceived() {
if (!one_rtt_keys_available_) {
CloseConnection(QUIC_HANDSHAKE_FAILED,
@@ -318,6 +323,9 @@
EncryptionLevel level,
const SSL_CIPHER* cipher,
const std::vector<uint8_t>& write_secret) {
+ if (state_ == STATE_CONNECTION_CLOSED) {
+ return;
+ }
if (level == ENCRYPTION_FORWARD_SECURE) {
encryption_established_ = true;
}
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 91b8be8..fdd68c2 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -60,6 +60,8 @@
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
void OnOneRttPacketAcknowledged() override;
void OnHandshakePacketSent() override;
+ void OnConnectionClosed(QuicErrorCode error,
+ ConnectionCloseSource source) override;
void OnHandshakeDoneReceived() override;
void SetWriteSecret(EncryptionLevel level,
const SSL_CIPHER* cipher,
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index baecf70..890abae 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -162,6 +162,11 @@
return false;
}
+void TlsServerHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
+ ConnectionCloseSource /*source*/) {
+ state_ = STATE_CONNECTION_CLOSED;
+}
+
bool TlsServerHandshaker::encryption_established() const {
return encryption_established_;
}
@@ -326,6 +331,10 @@
EncryptionLevel level,
const SSL_CIPHER* cipher,
const std::vector<uint8_t>& write_secret) {
+ if (GetQuicReloadableFlag(quic_notify_handshaker_on_connection_close) &&
+ state_ == STATE_CONNECTION_CLOSED) {
+ return;
+ }
if (level == ENCRYPTION_FORWARD_SECURE) {
encryption_established_ = true;
// Fill crypto_negotiated_params_:
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 28fa121..c62dbbb 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -48,6 +48,8 @@
void OnPacketDecrypted(EncryptionLevel level) override;
void OnOneRttPacketAcknowledged() override {}
void OnHandshakePacketSent() override {}
+ void OnConnectionClosed(QuicErrorCode error,
+ ConnectionCloseSource source) override;
void OnHandshakeDoneReceived() override;
bool ShouldSendExpectCTHeader() const override;