gfe-relnote: In QUIC with TLS, do not proceed in SetWriteSecret if connection has been closed. Protected by gfe2_reloadable_flag_quic_notify_handshaker_on_connection_close.

PiperOrigin-RevId: 309753183
Change-Id: I829b92d82ca84f85ab60aa09940a3b205641a34b
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index a692c94..af52b12 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -3522,9 +3522,8 @@
 // Regression test for b/116200989.
 TEST_P(EndToEndTest,
        SendStatelessResetIfServerConnectionClosedLocallyDuringHandshake) {
-  if (version_.UsesTls()) {
-    // TODO(b/155317149): Enable this test for TLS.
-    Initialize();
+  if (!GetQuicReloadableFlag(quic_notify_handshaker_on_connection_close)) {
+    ASSERT_TRUE(Initialize());
     return;
   }
   connect_to_server_on_initialize_ = false;
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc
index 8a4dfba..d400fc7 100644
--- a/quic/core/http/quic_spdy_stream_test.cc
+++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -125,6 +125,8 @@
   void OnPacketDecrypted(EncryptionLevel /*level*/) override {}
   void OnOneRttPacketAcknowledged() override {}
   void OnHandshakePacketSent() override {}
+  void OnConnectionClosed(QuicErrorCode /*error*/,
+                          ConnectionCloseSource /*source*/) override {}
   void OnHandshakeDoneReceived() override {}
 
   MOCK_METHOD(void, OnCanWrite, (), (override));
diff --git a/quic/core/quic_crypto_client_handshaker.h b/quic/core/quic_crypto_client_handshaker.h
index 40b490e..5ba93f2 100644
--- a/quic/core/quic_crypto_client_handshaker.h
+++ b/quic/core/quic_crypto_client_handshaker.h
@@ -52,6 +52,8 @@
   size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
   void OnOneRttPacketAcknowledged() override {}
   void OnHandshakePacketSent() override {}
+  void OnConnectionClosed(QuicErrorCode /*error*/,
+                          ConnectionCloseSource /*source*/) override {}
   void OnHandshakeDoneReceived() override;
   void OnApplicationState(
       std::unique_ptr<ApplicationState> /*application_state*/) override {
diff --git a/quic/core/quic_crypto_client_stream.cc b/quic/core/quic_crypto_client_stream.cc
index 532a982..36dc4cd 100644
--- a/quic/core/quic_crypto_client_stream.cc
+++ b/quic/core/quic_crypto_client_stream.cc
@@ -117,6 +117,11 @@
   handshaker_->OnHandshakePacketSent();
 }
 
+void QuicCryptoClientStream::OnConnectionClosed(QuicErrorCode error,
+                                                ConnectionCloseSource source) {
+  handshaker_->OnConnectionClosed(error, source);
+}
+
 void QuicCryptoClientStream::OnHandshakeDoneReceived() {
   handshaker_->OnHandshakeDoneReceived();
 }
diff --git a/quic/core/quic_crypto_client_stream.h b/quic/core/quic_crypto_client_stream.h
index 82b0ef9..23f83c7 100644
--- a/quic/core/quic_crypto_client_stream.h
+++ b/quic/core/quic_crypto_client_stream.h
@@ -159,6 +159,10 @@
     // Called when a packet of ENCRYPTION_HANDSHAKE gets sent.
     virtual void OnHandshakePacketSent() = 0;
 
+    // Called when connection gets closed.
+    virtual void OnConnectionClosed(QuicErrorCode error,
+                                    ConnectionCloseSource source) = 0;
+
     // Called when handshake done has been received.
     virtual void OnHandshakeDoneReceived() = 0;
 
@@ -215,6 +219,8 @@
   void OnPacketDecrypted(EncryptionLevel /*level*/) override {}
   void OnOneRttPacketAcknowledged() override;
   void OnHandshakePacketSent() override;
+  void OnConnectionClosed(QuicErrorCode error,
+                          ConnectionCloseSource source) override;
   void OnHandshakeDoneReceived() override;
   HandshakeState GetHandshakeState() const override;
   size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h
index 8eef07c..52d4874 100644
--- a/quic/core/quic_crypto_server_stream.h
+++ b/quic/core/quic_crypto_server_stream.h
@@ -43,6 +43,8 @@
   void OnPacketDecrypted(EncryptionLevel level) override;
   void OnOneRttPacketAcknowledged() override {}
   void OnHandshakePacketSent() override {}
+  void OnConnectionClosed(QuicErrorCode /*error*/,
+                          ConnectionCloseSource /*source*/) override {}
   void OnHandshakeDoneReceived() override;
   bool ShouldSendExpectCTHeader() const override;
 
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index 99d1ff4..29dfb93 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -398,6 +398,11 @@
     on_closed_frame_ = frame;
   }
 
+  if (GetQuicReloadableFlag(quic_notify_handshaker_on_connection_close)) {
+    QUIC_RELOADABLE_FLAG_COUNT(quic_notify_handshaker_on_connection_close);
+    GetMutableCryptoStream()->OnConnectionClosed(frame.quic_error_code, source);
+  }
+
   // Copy all non static streams in a new map for the ease of deleting.
   QuicSmallMap<QuicStreamId, QuicStream*, 10> non_static_streams;
   for (const auto& it : stream_map_) {
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index d6435ff..324c378 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -305,6 +305,11 @@
   initial_keys_dropped_ = true;
 }
 
+void TlsClientHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
+                                             ConnectionCloseSource /*source*/) {
+  state_ = STATE_CONNECTION_CLOSED;
+}
+
 void TlsClientHandshaker::OnHandshakeDoneReceived() {
   if (!one_rtt_keys_available_) {
     CloseConnection(QUIC_HANDSHAKE_FAILED,
@@ -318,6 +323,9 @@
     EncryptionLevel level,
     const SSL_CIPHER* cipher,
     const std::vector<uint8_t>& write_secret) {
+  if (state_ == STATE_CONNECTION_CLOSED) {
+    return;
+  }
   if (level == ENCRYPTION_FORWARD_SECURE) {
     encryption_established_ = true;
   }
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 91b8be8..fdd68c2 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -60,6 +60,8 @@
   size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
   void OnOneRttPacketAcknowledged() override;
   void OnHandshakePacketSent() override;
+  void OnConnectionClosed(QuicErrorCode error,
+                          ConnectionCloseSource source) override;
   void OnHandshakeDoneReceived() override;
   void SetWriteSecret(EncryptionLevel level,
                       const SSL_CIPHER* cipher,
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index baecf70..890abae 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -162,6 +162,11 @@
   return false;
 }
 
+void TlsServerHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
+                                             ConnectionCloseSource /*source*/) {
+  state_ = STATE_CONNECTION_CLOSED;
+}
+
 bool TlsServerHandshaker::encryption_established() const {
   return encryption_established_;
 }
@@ -326,6 +331,10 @@
     EncryptionLevel level,
     const SSL_CIPHER* cipher,
     const std::vector<uint8_t>& write_secret) {
+  if (GetQuicReloadableFlag(quic_notify_handshaker_on_connection_close) &&
+      state_ == STATE_CONNECTION_CLOSED) {
+    return;
+  }
   if (level == ENCRYPTION_FORWARD_SECURE) {
     encryption_established_ = true;
     // Fill crypto_negotiated_params_:
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 28fa121..c62dbbb 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -48,6 +48,8 @@
   void OnPacketDecrypted(EncryptionLevel level) override;
   void OnOneRttPacketAcknowledged() override {}
   void OnHandshakePacketSent() override {}
+  void OnConnectionClosed(QuicErrorCode error,
+                          ConnectionCloseSource source) override;
   void OnHandshakeDoneReceived() override;
   bool ShouldSendExpectCTHeader() const override;