Suppress sending in the middle of packet processing while the connection still has initial keys.

Protected by FLAGS_quic_reloadable_flag_quic_suppress_write_mid_packet_processing.

PiperOrigin-RevId: 393207715
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index ce1d5a5..b77bb1b 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -3310,6 +3310,20 @@
     return false;
   }
 
+  if (GetQuicReloadableFlag(quic_suppress_write_mid_packet_processing) &&
+      version().CanSendCoalescedPackets() &&
+      framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL) &&
+      framer_.is_processing_packet()) {
+    QUIC_RELOADABLE_FLAG_COUNT(quic_suppress_write_mid_packet_processing);
+    // While we still have initial keys, suppress sending in mid of packet
+    // processing.
+    // TODO(fayang): always suppress sending while in the mid of packet
+    // processing.
+    QUIC_DVLOG(1) << ENDPOINT
+                  << "Suppress sending in the mid of packet processing";
+    return false;
+  }
+
   if (fill_coalesced_packet_) {
     // Try to coalesce packet, only allow to write when creator is on soft max
     // packet length. Given the next created packet is going to fill current
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 8f5ff90..871e5e8 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -1451,6 +1451,9 @@
     }
     connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
     peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE);
+    // Discard INITIAL key.
+    connection_.RemoveEncrypter(ENCRYPTION_INITIAL);
+    connection_.NeuterUnencryptedPackets();
     // Prevent packets from being coalesced.
     EXPECT_CALL(visitor_, GetHandshakeState())
         .WillRepeatedly(Return(HANDSHAKE_CONFIRMED));
@@ -10094,12 +10097,13 @@
   EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet());
 
   // Receives packet 1.
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL);
 
   const size_t anti_amplification_factor =
       GetQuicFlag(FLAGS_quic_anti_amplification_factor);
   // Verify now packets can be sent.
-  for (size_t i = 0; i < anti_amplification_factor; ++i) {
+  for (size_t i = 1; i < anti_amplification_factor; ++i) {
     EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
     connection_.SendCryptoDataWithString("foo", i * 3);
     // Verify retransmission alarm is not set if throttled by anti-amplification
@@ -10112,10 +10116,11 @@
   connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3);
 
   // Receives packet 2.
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL);
   // Verify more packets can be sent.
-  for (size_t i = anti_amplification_factor; i < anti_amplification_factor * 2;
-       ++i) {
+  for (size_t i = anti_amplification_factor + 1;
+       i < anti_amplification_factor * 2; ++i) {
     EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
     connection_.SendCryptoDataWithString("foo", i * 3);
   }
@@ -10124,6 +10129,7 @@
   connection_.SendCryptoDataWithString("foo",
                                        2 * anti_amplification_factor * 3);
 
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessPacket(3);
   // Verify anti-amplification limit is gone after address validation.
   for (size_t i = 0; i < 100; ++i) {
@@ -10160,11 +10166,12 @@
   EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet());
 
   // Receives packet 1.
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL);
 
   const size_t anti_amplification_factor = 3;
   // Verify now packets can be sent.
-  for (size_t i = 0; i < anti_amplification_factor; ++i) {
+  for (size_t i = 1; i < anti_amplification_factor; ++i) {
     EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
     connection_.SendCryptoDataWithString("foo", i * 3);
     // Verify retransmission alarm is not set if throttled by anti-amplification
@@ -10177,10 +10184,11 @@
   connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3);
 
   // Receives packet 2.
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL);
   // Verify more packets can be sent.
-  for (size_t i = anti_amplification_factor; i < anti_amplification_factor * 2;
-       ++i) {
+  for (size_t i = anti_amplification_factor + 1;
+       i < anti_amplification_factor * 2; ++i) {
     EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
     connection_.SendCryptoDataWithString("foo", i * 3);
   }
@@ -10189,6 +10197,7 @@
   connection_.SendCryptoDataWithString("foo",
                                        2 * anti_amplification_factor * 3);
 
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessPacket(3);
   // Verify anti-amplification limit is gone after address validation.
   for (size_t i = 0; i < 100; ++i) {
@@ -10225,11 +10234,12 @@
   EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet());
 
   // Receives packet 1.
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL);
 
   const size_t anti_amplification_factor = 10;
   // Verify now packets can be sent.
-  for (size_t i = 0; i < anti_amplification_factor; ++i) {
+  for (size_t i = 1; i < anti_amplification_factor; ++i) {
     EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
     connection_.SendCryptoDataWithString("foo", i * 3);
     // Verify retransmission alarm is not set if throttled by anti-amplification
@@ -10242,10 +10252,11 @@
   connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3);
 
   // Receives packet 2.
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL);
   // Verify more packets can be sent.
-  for (size_t i = anti_amplification_factor; i < anti_amplification_factor * 2;
-       ++i) {
+  for (size_t i = anti_amplification_factor + 1;
+       i < anti_amplification_factor * 2; ++i) {
     EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
     connection_.SendCryptoDataWithString("foo", i * 3);
   }
@@ -10254,6 +10265,7 @@
   connection_.SendCryptoDataWithString("foo",
                                        2 * anti_amplification_factor * 3);
 
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessPacket(3);
   // Verify anti-amplification limit is gone after address validation.
   for (size_t i = 0; i < 100; ++i) {
@@ -11009,6 +11021,11 @@
   EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_));
   EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+  // Discard INITIAL key.
+  connection_.RemoveEncrypter(ENCRYPTION_INITIAL);
+  connection_.NeuterUnencryptedPackets();
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
 
   ProcessPacket(2);
   ProcessPacket(3);
@@ -11064,7 +11081,7 @@
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
 
   EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([this]() {
-    connection_.SendControlFrame(QuicFrame(new QuicWindowUpdateFrame(1, 0, 0)));
+    notifier_.WriteOrBufferWindowUpate(0, 0);
   }));
   EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessDataPacket(1);
@@ -12786,7 +12803,7 @@
           connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
                                    std::make_unique<TaggingEncrypter>(0x03));
           connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
-          connection_.SendCryptoStreamData();
+          connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE);
           connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE,
                                    std::make_unique<TaggingEncrypter>(0x04));
           connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
@@ -12801,7 +12818,7 @@
   use_tagging_decrypter();
   connection_.SetEncrypter(ENCRYPTION_INITIAL,
                            std::make_unique<TaggingEncrypter>(0x01));
-  connection_.SendCryptoStreamData();
+  connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL);
   // Send 0-RTT packet.
   connection_.SetEncrypter(ENCRYPTION_ZERO_RTT,
                            std::make_unique<TaggingEncrypter>(0x02));
@@ -13770,6 +13787,8 @@
   connection_.RemoveEncrypter(ENCRYPTION_INITIAL);
   connection_.NeuterUnencryptedPackets();
   connection_.OnHandshakeComplete();
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
 
   EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([=]() {
     connection_.SendStreamData3();
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index eba1302..b2940bb 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -105,6 +105,8 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_bursts, false)
 // If true, stop resetting ideal_next_packet_send_time_ in pacing sender.
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_reset_ideal_next_packet_send_time, false)
+// If true, suppress crypto data write in mid of packet processing.
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_suppress_write_mid_packet_processing, false)
 // If true, time_wait_list can support multiple connection IDs.
 QUIC_FLAG(FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid_v2, true)
 // If true, update ACK timeout for NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames.
diff --git a/quic/test_tools/simple_session_notifier.cc b/quic/test_tools/simple_session_notifier.cc
index a1383b0..ce498ca 100644
--- a/quic/test_tools/simple_session_notifier.cc
+++ b/quic/test_tools/simple_session_notifier.cc
@@ -111,6 +111,21 @@
   WriteBufferedControlFrames();
 }
 
+void SimpleSessionNotifier::WriteOrBufferWindowUpate(
+    QuicStreamId id, QuicStreamOffset byte_offset) {
+  QUIC_DVLOG(1) << "Writing WINDOW_UPDATE";
+  const bool had_buffered_data =
+      HasBufferedStreamData() || HasBufferedControlFrames();
+  QuicControlFrameId control_frame_id = ++last_control_frame_id_;
+  control_frames_.emplace_back((
+      QuicFrame(new QuicWindowUpdateFrame(control_frame_id, id, byte_offset))));
+  if (had_buffered_data) {
+    QUIC_DLOG(WARNING) << "Connection is write blocked";
+    return;
+  }
+  WriteBufferedControlFrames();
+}
+
 void SimpleSessionNotifier::WriteOrBufferPing() {
   QUIC_DVLOG(1) << "Writing PING_FRAME";
   const bool had_buffered_data =
@@ -175,8 +190,7 @@
       !RetransmitLostStreamData()) {
     return;
   }
-  // Write buffered control frames.
-  if (!WriteBufferedControlFrames()) {
+  if (!WriteBufferedCryptoData() || !WriteBufferedControlFrames()) {
     return;
   }
   // Write new data.
@@ -666,6 +680,26 @@
   return !HasLostStreamData();
 }
 
+bool SimpleSessionNotifier::WriteBufferedCryptoData() {
+  for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) {
+    const StreamState& state = crypto_state_[i];
+    QuicIntervalSet<QuicStreamOffset> buffered_crypto_data(0,
+                                                           state.bytes_total);
+    buffered_crypto_data.Difference(crypto_bytes_transferred_[i]);
+    for (const auto& interval : buffered_crypto_data) {
+      size_t bytes_written = connection_->SendCryptoData(
+          static_cast<EncryptionLevel>(i), interval.Length(), interval.min());
+      crypto_state_[i].bytes_sent += bytes_written;
+      crypto_bytes_transferred_[i].Add(interval.min(),
+                                       interval.min() + bytes_written);
+      if (bytes_written < interval.Length()) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 bool SimpleSessionNotifier::WriteBufferedControlFrames() {
   while (HasBufferedControlFrames()) {
     QuicFrame frame_to_send =
diff --git a/quic/test_tools/simple_session_notifier.h b/quic/test_tools/simple_session_notifier.h
index 8bdadb2..5712884 100644
--- a/quic/test_tools/simple_session_notifier.h
+++ b/quic/test_tools/simple_session_notifier.h
@@ -34,6 +34,10 @@
   void WriteOrBufferRstStream(QuicStreamId id,
                               QuicRstStreamErrorCode error,
                               QuicStreamOffset bytes_written);
+
+  // Tries to write WINDOW_UPDATE.
+  void WriteOrBufferWindowUpate(QuicStreamId id, QuicStreamOffset byte_offset);
+
   // Tries to write PING.
   void WriteOrBufferPing();
 
@@ -127,6 +131,8 @@
 
   bool WriteBufferedControlFrames();
 
+  bool WriteBufferedCryptoData();
+
   bool IsControlFrameOutstanding(const QuicFrame& frame) const;
 
   bool HasBufferedControlFrames() const;