Retransmit initial data immediately upon receiving RETRY.

This change is IETF QUIC only and client only.

Protected by quic_reloadable_flag_quic_retransmit_after_receiving_retry.

PiperOrigin-RevId: 351231997
Change-Id: I89aa480eed3a123cfac062a39a9c578750f3d709
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 0a093f1..5f7b67e 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -957,6 +957,10 @@
 
   // Reinstall initial crypters because the connection ID changed.
   InstallInitialCrypters(server_connection_id_);
+
+  if (GetQuicReloadableFlag(quic_retransmit_after_receiving_retry)) {
+    sent_packet_manager_.MarkInitialPacketsForRetransmission();
+  }
 }
 
 bool QuicConnection::HasIncomingConnectionId(QuicConnectionId connection_id) {
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index dc434a5..6582b02 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -9976,6 +9976,9 @@
   // Make sure the connection uses the connection ID from the test vectors,
   QuicConnectionPeer::SetServerConnectionId(&connection_,
                                             original_connection_id);
+  // Make sure our fake framer has the new post-retry INITIAL keys so that any
+  // retransmission triggered by retry can be decrypted.
+  writer_->framer()->framer()->SetInitialObfuscators(new_connection_id);
 
   // Process the RETRY packet.
   connection_.ProcessUdpPacket(
@@ -9998,8 +10001,6 @@
   EXPECT_EQ(QuicPacketCreatorPeer::GetRetryToken(
                 QuicConnectionPeer::GetPacketCreator(&connection_)),
             retry_token);
-  // Make sure our fake framer has the new post-retry INITIAL keys.
-  writer_->framer()->framer()->SetInitialObfuscators(new_connection_id);
 
   // Test validating the original_connection_id from the config.
   QuicConfig received_config;
@@ -10096,6 +10097,50 @@
                           /*wrong_retry_id_in_config=*/true);
 }
 
+TEST_P(QuicConnectionTest, ClientRetransmitsInitialPacketsOnRetry) {
+  SetQuicReloadableFlag(quic_retransmit_after_receiving_retry, true);
+  if (!connection_.version().HasIetfQuicFrames()) {
+    // TestClientRetryHandling() currently only supports IETF draft versions.
+    return;
+  }
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL);
+
+  connection_.SendCryptoStreamData();
+
+  EXPECT_EQ(1u, writer_->packets_write_attempts());
+  TestClientRetryHandling(/*invalid_retry_tag=*/false,
+                          /*missing_original_id_in_config=*/false,
+                          /*wrong_original_id_in_config=*/false,
+                          /*missing_retry_id_in_config=*/false,
+                          /*wrong_retry_id_in_config=*/false);
+
+  // Verify that initial data is retransmitted immediately after receiving
+  // RETRY.
+  if (GetParam().ack_response == AckResponse::kImmediate) {
+    EXPECT_EQ(2u, writer_->packets_write_attempts());
+    EXPECT_EQ(1u, writer_->framer()->crypto_frames().size());
+  }
+}
+
+TEST_P(QuicConnectionTest, NoInitialPacketsRetransmissionOnInvalidRetry) {
+  SetQuicReloadableFlag(quic_retransmit_after_receiving_retry, true);
+  if (!connection_.version().HasIetfQuicFrames()) {
+    return;
+  }
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL);
+
+  connection_.SendCryptoStreamData();
+
+  EXPECT_EQ(1u, writer_->packets_write_attempts());
+  TestClientRetryHandling(/*invalid_retry_tag=*/true,
+                          /*missing_original_id_in_config=*/false,
+                          /*wrong_original_id_in_config=*/false,
+                          /*missing_retry_id_in_config=*/false,
+                          /*wrong_retry_id_in_config=*/false);
+
+  EXPECT_EQ(1u, writer_->packets_write_attempts());
+}
+
 TEST_P(QuicConnectionTest, ClientReceivesOriginalConnectionIdWithoutRetry) {
   if (!connection_.version().UsesTls()) {
     // QUIC+TLS is required to transmit connection ID transport parameters.
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 8a1a48a..4ec0f4d 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -51,6 +51,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_parse_accept_ch_frame, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_pass_path_response_to_validator, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_require_handshake_confirmation, false)
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_retransmit_after_receiving_retry, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_round_up_tiny_bandwidth, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_goaway_with_connection_close, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_path_response, false)
diff --git a/quic/core/quic_sent_packet_manager.cc b/quic/core/quic_sent_packet_manager.cc
index 08b1371..475edab 100644
--- a/quic/core/quic_sent_packet_manager.cc
+++ b/quic/core/quic_sent_packet_manager.cc
@@ -493,6 +493,42 @@
   }
 }
 
+void QuicSentPacketManager::MarkInitialPacketsForRetransmission() {
+  if (unacked_packets_.use_circular_deque()) {
+    if (unacked_packets_.empty()) {
+      return;
+    }
+    QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked();
+    QuicPacketNumber largest_sent_packet =
+        unacked_packets_.largest_sent_packet();
+    for (; packet_number <= largest_sent_packet; ++packet_number) {
+      QuicTransmissionInfo* transmission_info =
+          unacked_packets_.GetMutableTransmissionInfo(packet_number);
+      if (transmission_info->encryption_level == ENCRYPTION_INITIAL) {
+        if (transmission_info->in_flight) {
+          unacked_packets_.RemoveFromInFlight(transmission_info);
+        }
+        if (unacked_packets_.HasRetransmittableFrames(*transmission_info)) {
+          MarkForRetransmission(packet_number, ALL_INITIAL_RETRANSMISSION);
+        }
+      }
+    }
+  } else {
+    QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked();
+    for (QuicUnackedPacketMap::iterator it = unacked_packets_.begin();
+         it != unacked_packets_.end(); ++it, ++packet_number) {
+      if (it->encryption_level == ENCRYPTION_INITIAL) {
+        if (it->in_flight) {
+          unacked_packets_.RemoveFromInFlight(&*it);
+        }
+        if (unacked_packets_.HasRetransmittableFrames(*it)) {
+          MarkForRetransmission(packet_number, ALL_INITIAL_RETRANSMISSION);
+        }
+      }
+    }
+  }
+}
+
 void QuicSentPacketManager::MarkZeroRttPacketsForRetransmission() {
   if (unacked_packets_.use_circular_deque()) {
     if (unacked_packets_.empty()) {
diff --git a/quic/core/quic_sent_packet_manager.h b/quic/core/quic_sent_packet_manager.h
index a32742f..96010d1 100644
--- a/quic/core/quic_sent_packet_manager.h
+++ b/quic/core/quic_sent_packet_manager.h
@@ -155,6 +155,9 @@
   // data needs to be encrypted with a new key.
   void MarkZeroRttPacketsForRetransmission();
 
+  // Request retransmission of all unacked INITIAL packets.
+  void MarkInitialPacketsForRetransmission();
+
   // Notify the sent packet manager of an external network measurement or
   // prediction for either |bandwidth| or |rtt|; either can be empty.
   void AdjustNetworkParameters(
diff --git a/quic/core/quic_sent_packet_manager_test.cc b/quic/core/quic_sent_packet_manager_test.cc
index 968465c..52c0fe6 100644
--- a/quic/core/quic_sent_packet_manager_test.cc
+++ b/quic/core/quic_sent_packet_manager_test.cc
@@ -3792,6 +3792,15 @@
   manager_.NeuterUnencryptedPackets();
 }
 
+TEST_F(QuicSentPacketManagerTest, MarkInitialPacketsForRetransmission) {
+  SetQuicReloadableFlag(quic_retransmit_after_receiving_retry, true);
+  SendCryptoPacket(1);
+  SendPingPacket(2, ENCRYPTION_HANDSHAKE);
+  // Only the INITIAL packet will be retransmitted.
+  EXPECT_CALL(notifier_, OnFrameLost(_)).Times(1);
+  manager_.MarkInitialPacketsForRetransmission();
+}
+
 TEST_F(QuicSentPacketManagerTest, NoPacketThresholdDetectionForRuntPackets) {
   EXPECT_TRUE(
       QuicSentPacketManagerPeer::UsePacketThresholdForRuntPackets(&manager_));
diff --git a/quic/core/quic_types.cc b/quic/core/quic_types.cc
index 372f953..1b33d21 100644
--- a/quic/core/quic_types.cc
+++ b/quic/core/quic_types.cc
@@ -214,6 +214,7 @@
     RETURN_STRING_LITERAL(TLP_RETRANSMISSION);
     RETURN_STRING_LITERAL(PTO_RETRANSMISSION);
     RETURN_STRING_LITERAL(PROBING_RETRANSMISSION);
+    RETURN_STRING_LITERAL(ALL_INITIAL_RETRANSMISSION);
     default:
       // Some varz rely on this behavior for statistic collection.
       if (transmission_type == LAST_TRANSMISSION_TYPE + 1) {
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index 2d89ce6..57075d6 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -173,6 +173,8 @@
   TLP_RETRANSMISSION,           // Tail loss probes.
   PTO_RETRANSMISSION,           // Retransmission due to probe timeout.
   PROBING_RETRANSMISSION,       // Retransmission in order to probe bandwidth.
+  ALL_INITIAL_RETRANSMISSION,   // Retransmit all packets encrypted with INITIAL
+                                // key.
   LAST_TRANSMISSION_TYPE = PROBING_RETRANSMISSION,
 };
 
diff --git a/quic/core/quic_utils.cc b/quic/core/quic_utils.cc
index 5d4d512..dc9c879 100644
--- a/quic/core/quic_utils.cc
+++ b/quic/core/quic_utils.cc
@@ -351,6 +351,8 @@
       return PTO_RETRANSMITTED;
     case PROBING_RETRANSMISSION:
       return PROBE_RETRANSMITTED;
+    case ALL_INITIAL_RETRANSMISSION:
+      return UNACKABLE;
     default:
       QUIC_BUG << retransmission_type << " is not a retransmission_type";
       return UNACKABLE;