Dropping initial keys at the end of writing to avoid potential missing write keys in the middle of writing.

Protected by FLAGS_quic_reloadable_flag_quic_fix_missing_initial_keys2.

PiperOrigin-RevId: 334170250
Change-Id: I24e118e69c4cba3ddbed57df67340596446fc6d8
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index ab32a38..2c791f2 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -10276,7 +10276,9 @@
                                ? QuicPacketNumber(3)
                                : QuicPacketNumber(4),
                            _, _));
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  if (!GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  }
   connection_.GetRetransmissionAlarm()->Fire();
   // Verify 1-RTT packet gets coalesced with handshake retransmission.
   EXPECT_EQ(0x01010101u, writer_->final_bytes_of_last_packet());
@@ -10298,7 +10300,9 @@
               OnPacketSent(_, _, handshake_retransmission + 1, _, _));
   EXPECT_CALL(*send_algorithm_,
               OnPacketSent(_, _, handshake_retransmission, _, _));
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  if (!GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  }
   connection_.GetRetransmissionAlarm()->Fire();
   // Verify 1-RTT packet gets coalesced with handshake retransmission.
   EXPECT_EQ(0x01010101u, writer_->final_bytes_of_last_packet());
@@ -11078,7 +11082,11 @@
   connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
                            std::make_unique<TaggingEncrypter>(0x02));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  if (GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  } else {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  }
   connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE);
   // Verify PTO time does not change.
   EXPECT_EQ(expected_pto_time,
@@ -11117,8 +11125,11 @@
   // Receives packet 1000 in handshake data.
   ProcessCryptoPacketAtLevel(1000, ENCRYPTION_HANDSHAKE);
   EXPECT_TRUE(connection_.HasPendingAcks());
-
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  if (GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  } else {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  }
   connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE);
 
   // Receives packet 1001 in handshake data.
@@ -11181,7 +11192,11 @@
   connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
                            std::make_unique<TaggingEncrypter>(0x02));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(3);
+  if (GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  } else {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(3);
+  }
   // Send HANDSHAKE 2 and 3.
   connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE);
   connection_.SendCryptoDataWithString("bar", 3, ENCRYPTION_HANDSHAKE);
@@ -11247,7 +11262,11 @@
   connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
                            std::make_unique<TaggingEncrypter>(0x02));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  if (GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  } else {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  }
   std::string handshake_crypto_data(1024, 'a');
   connection_.SendCryptoDataWithString(handshake_crypto_data, 0,
                                        ENCRYPTION_HANDSHAKE);
@@ -11319,7 +11338,11 @@
                            std::make_unique<TaggingEncrypter>(0x02));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
   // Verify HANDSHAKE packet is coalesced with INITIAL retransmission.
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  if (GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  } else {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(2);
+  }
   std::string handshake_crypto_data(1024, 'a');
   connection_.SendCryptoDataWithString(handshake_crypto_data, 0,
                                        ENCRYPTION_HANDSHAKE);
@@ -11610,7 +11633,7 @@
   SetQuicReloadableFlag(
       quic_neuter_initial_packet_in_coalescer_with_initial_key_discarded, true);
   if (!connection_.version().CanSendCoalescedPackets() ||
-      !GetQuicReloadableFlag(quic_fix_missing_initial_keys)) {
+      !GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
     // Cannot set quic_fix_missing_initial_keys in the test since connection_ is
     // created since the setup.
     return;
@@ -11625,15 +11648,17 @@
   connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
                            std::make_unique<TaggingEncrypter>(0x02));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
-  {
-    QuicConnection::ScopedPacketFlusher flusher(&connection_);
-    connection_.GetAckAlarm()->Fire();
-    // Verify this ACK packet is on hold.
-    EXPECT_EQ(0u, writer_->packets_write_attempts());
-
-    // Discard INITIAL key while there is an INITIAL packet in the coalescer.
+  EXPECT_CALL(visitor_, OnHandshakePacketSent()).WillOnce(Invoke([this]() {
     connection_.RemoveEncrypter(ENCRYPTION_INITIAL);
     connection_.NeuterUnencryptedPackets();
+  }));
+  {
+    QuicConnection::ScopedPacketFlusher flusher(&connection_);
+    connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE);
+    // Verify the packet is on hold.
+    EXPECT_EQ(0u, writer_->packets_write_attempts());
+    // Flush pending ACKs.
+    connection_.GetAckAlarm()->Fire();
   }
   // If not setting
   // quic_neuter_initial_packet_in_coalescer_with_initial_key_discarded, there
@@ -12042,7 +12067,9 @@
   connection_.GetSendAlarm()->Set(clock_.ApproximateNow());
 
   // Fire ACK alarm.
-  EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  if (!GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+  }
   connection_.GetAckAlarm()->Fire();
   if (GetQuicReloadableFlag(quic_fix_pto_pending_timer_count)) {
     // Verify 1-RTT packet is coalesced with handshake packet.
@@ -12057,7 +12084,9 @@
 
   ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet());
   if (GetQuicReloadableFlag(quic_fix_pto_pending_timer_count)) {
-    EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+    if (!GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
+      EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1);
+    }
   } else {
     EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(0);
     EXPECT_CALL(visitor_, SendPing()).WillOnce(Invoke([this]() {
@@ -12080,12 +12109,11 @@
 // Regression test for b/168294218.
 TEST_P(QuicConnectionTest, CoalescerHandlesInitialKeyDiscard) {
   if (!connection_.version().CanSendCoalescedPackets() ||
-      !GetQuicReloadableFlag(quic_fix_missing_initial_keys)) {
+      !GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
     return;
   }
   SetQuicReloadableFlag(quic_discard_initial_packet_with_key_dropped, true);
-  // Verify only one HANDSHAKE packet gets sent.
-  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2);
   EXPECT_CALL(visitor_, OnHandshakePacketSent()).WillOnce(Invoke([this]() {
     connection_.RemoveEncrypter(ENCRYPTION_INITIAL);
     connection_.NeuterUnencryptedPackets();
@@ -12113,7 +12141,7 @@
 // Regresstion test for b/168294218
 TEST_P(QuicConnectionTest, ZeroRttRejectionAndMissingInitialKeys) {
   if (!connection_.SupportsMultiplePacketNumberSpaces() ||
-      !GetQuicReloadableFlag(quic_fix_missing_initial_keys)) {
+      !GetQuicReloadableFlag(quic_fix_missing_initial_keys2)) {
     return;
   }
   // Not defer send in response to packet.