In quic, do not include max_ack_delay when calculating pto timeout for initial and handshake packet number spaces. protected by enabled gfe2_reloadable_flag_quic_fix_pto_timeout.

PiperOrigin-RevId: 318495662
Change-Id: Idce3e16c7f7761bf2519aadf1bb239f69be5494d
diff --git a/quic/core/quic_sent_packet_manager.cc b/quic/core/quic_sent_packet_manager.cc
index 4373013..2b7b11e 100644
--- a/quic/core/quic_sent_packet_manager.cc
+++ b/quic/core/quic_sent_packet_manager.cc
@@ -493,8 +493,15 @@
   }
 }
 
-bool QuicSentPacketManager::ShouldAddMaxAckDelay() const {
+bool QuicSentPacketManager::ShouldAddMaxAckDelay(
+    PacketNumberSpace space) const {
   DCHECK(pto_enabled_);
+  if (fix_pto_timeout_ && supports_multiple_packet_number_spaces() &&
+      space != APPLICATION_DATA) {
+    // When the PTO is armed for Initial or Handshake packet number spaces,
+    // the max_ack_delay is 0.
+    return false;
+  }
   if (always_include_max_ack_delay_for_pto_timeout_) {
     return true;
   }
@@ -1109,7 +1116,7 @@
               clock_->ApproximateNow(),
               std::max(unacked_packets_.GetFirstInFlightTransmissionInfo()
                                ->sent_time +
-                           GetProbeTimeoutDelay(),
+                           GetProbeTimeoutDelay(NUM_PACKET_NUMBER_SPACES),
                        unacked_packets_.GetLastInFlightPacketSentTime() +
                            first_pto_srtt_multiplier_ *
                                rtt_stats_.SmoothedOrInitialRtt()));
@@ -1117,7 +1124,7 @@
         // Ensure PTO never gets set to a time in the past.
         return std::max(clock_->ApproximateNow(),
                         unacked_packets_.GetLastInFlightPacketSentTime() +
-                            GetProbeTimeoutDelay());
+                            GetProbeTimeoutDelay(NUM_PACKET_NUMBER_SPACES));
       }
 
       PacketNumberSpace packet_number_space = NUM_PACKET_NUMBER_SPACES;
@@ -1142,13 +1149,15 @@
           return std::max(
               clock_->ApproximateNow(),
               std::max(
-                  first_application_info->sent_time + GetProbeTimeoutDelay(),
+                  first_application_info->sent_time +
+                      GetProbeTimeoutDelay(packet_number_space),
                   earliest_right_edge + first_pto_srtt_multiplier_ *
                                             rtt_stats_.SmoothedOrInitialRtt()));
         }
       }
-      return std::max(clock_->ApproximateNow(),
-                      earliest_right_edge + GetProbeTimeoutDelay());
+      return std::max(
+          clock_->ApproximateNow(),
+          earliest_right_edge + GetProbeTimeoutDelay(packet_number_space));
     }
   }
   DCHECK(false);
@@ -1227,7 +1236,8 @@
   return retransmission_delay;
 }
 
-const QuicTime::Delta QuicSentPacketManager::GetProbeTimeoutDelay() const {
+const QuicTime::Delta QuicSentPacketManager::GetProbeTimeoutDelay(
+    PacketNumberSpace space) const {
   DCHECK(pto_enabled_);
   if (rtt_stats_.smoothed_rtt().IsZero()) {
     // Respect kMinHandshakeTimeoutMs to avoid a potential amplification attack.
@@ -1243,7 +1253,8 @@
   QuicTime::Delta pto_delay =
       rtt_stats_.smoothed_rtt() +
       std::max(pto_rttvar_multiplier_ * rtt_var, kAlarmGranularity) +
-      (ShouldAddMaxAckDelay() ? peer_max_ack_delay_ : QuicTime::Delta::Zero());
+      (ShouldAddMaxAckDelay(space) ? peer_max_ack_delay_
+                                   : QuicTime::Delta::Zero());
   pto_delay =
       pto_delay * (1 << (consecutive_pto_count_ -
                          std::min(consecutive_pto_count_,
diff --git a/quic/core/quic_sent_packet_manager.h b/quic/core/quic_sent_packet_manager.h
index 72307b3..b69bf5c 100644
--- a/quic/core/quic_sent_packet_manager.h
+++ b/quic/core/quic_sent_packet_manager.h
@@ -439,7 +439,7 @@
   const QuicTime::Delta GetRetransmissionDelay() const;
 
   // Returns the probe timeout.
-  const QuicTime::Delta GetProbeTimeoutDelay() const;
+  const QuicTime::Delta GetProbeTimeoutDelay(PacketNumberSpace space) const;
 
   // Update the RTT if the ack is for the largest acked packet number.
   // Returns true if the rtt was updated.
@@ -505,7 +505,7 @@
 
   // Indicates whether including peer_max_ack_delay_ when calculating PTO
   // timeout.
-  bool ShouldAddMaxAckDelay() const;
+  bool ShouldAddMaxAckDelay(PacketNumberSpace space) const;
 
   // Gets the earliest in flight packet sent time to calculate PTO. Also
   // updates |packet_number_space| if a PTO timer should be armed.
@@ -662,6 +662,8 @@
   // The multiplier for caculating PTO timeout before any RTT sample is
   // available.
   float pto_multiplier_without_rtt_samples_;
+
+  const bool fix_pto_timeout_ = GetQuicReloadableFlag(quic_fix_pto_timeout);
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_sent_packet_manager_test.cc b/quic/core/quic_sent_packet_manager_test.cc
index 97d83f5..7af3063 100644
--- a/quic/core/quic_sent_packet_manager_test.cc
+++ b/quic/core/quic_sent_packet_manager_test.cc
@@ -3256,7 +3256,9 @@
       GetQuicReloadableFlag(quic_default_on_pto) ? 2 : 4;
   QuicTime::Delta expected_pto_delay =
       srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
-      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+      (GetQuicReloadableFlag(quic_fix_pto_timeout)
+           ? QuicTime::Delta::Zero()
+           : QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
   EXPECT_EQ(clock_.Now() + expected_pto_delay,
             manager_.GetRetransmissionTime());
 
@@ -3317,6 +3319,9 @@
   clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
   SendDataPacket(7, ENCRYPTION_HANDSHAKE);
   // Verify PTO timeout is now based on packet 6.
+  expected_pto_delay =
+      srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
+      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
   EXPECT_EQ(packet6_sent_time + expected_pto_delay * 2,
             manager_.GetRetransmissionTime());
 
@@ -3348,7 +3353,9 @@
       GetQuicReloadableFlag(quic_default_on_pto) ? 2 : 4;
   QuicTime::Delta expected_pto_delay =
       srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
-      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+      (GetQuicReloadableFlag(quic_fix_pto_timeout)
+           ? QuicTime::Delta::Zero()
+           : QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
   EXPECT_EQ(packet1_sent_time + expected_pto_delay,
             manager_.GetRetransmissionTime());
 
@@ -3381,6 +3388,9 @@
 
   // Discard handshake keys.
   manager_.SetHandshakeConfirmed();
+  expected_pto_delay =
+      srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
+      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
   // Verify PTO timeout is now based on packet 3 as handshake is
   // complete/confirmed.
   EXPECT_EQ(packet3_sent_time + expected_pto_delay,
@@ -3618,7 +3628,9 @@
       GetQuicReloadableFlag(quic_default_on_pto) ? 2 : 4;
   QuicTime::Delta expected_pto_delay =
       srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
-      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+      (GetQuicReloadableFlag(quic_fix_pto_timeout)
+           ? QuicTime::Delta::Zero()
+           : QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
   EXPECT_EQ(packet1_sent_time + expected_pto_delay,
             manager_.GetRetransmissionTime());
 
@@ -3653,6 +3665,9 @@
   manager_.SetHandshakeConfirmed();
   // Verify PTO timeout is now based on packet 3 as handshake is
   // complete/confirmed.
+  expected_pto_delay =
+      srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
+      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
   EXPECT_EQ(packet3_sent_time + expected_pto_delay,
             manager_.GetRetransmissionTime());
 
@@ -3695,7 +3710,9 @@
       GetQuicReloadableFlag(quic_default_on_pto) ? 2 : 4;
   QuicTime::Delta expected_pto_delay =
       srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
-      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+      (GetQuicReloadableFlag(quic_fix_pto_timeout)
+           ? QuicTime::Delta::Zero()
+           : QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
   EXPECT_EQ(packet1_sent_time + expected_pto_delay,
             manager_.GetRetransmissionTime());
 
@@ -3730,6 +3747,9 @@
   manager_.SetHandshakeConfirmed();
   // Verify PTO timeout is now based on packet 3 as handshake is
   // complete/confirmed.
+  expected_pto_delay =
+      srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
+      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
   EXPECT_EQ(packet3_sent_time + expected_pto_delay,
             manager_.GetRetransmissionTime());
 
@@ -3964,7 +3984,9 @@
   const QuicTime::Delta pto_delay =
       rtt_stats->smoothed_rtt() +
       pto_rttvar_multiplier * rtt_stats->mean_deviation() +
-      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+      (GetQuicReloadableFlag(quic_fix_pto_timeout)
+           ? QuicTime::Delta::Zero()
+           : QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
   if (GetQuicReloadableFlag(quic_fix_last_inflight_packets_sent_time)) {
     // Verify PTO is armed based on handshake data.
     EXPECT_EQ(packet2_sent_time + pto_delay, manager_.GetRetransmissionTime());
@@ -4069,7 +4091,9 @@
       GetQuicReloadableFlag(quic_default_on_pto) ? 2 : 4;
   QuicTime::Delta expected_pto_delay =
       srtt + pto_rttvar_multiplier * rtt_stats->mean_deviation() +
-      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+      (GetQuicReloadableFlag(quic_fix_pto_timeout)
+           ? QuicTime::Delta::Zero()
+           : QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
   EXPECT_EQ(packet1_sent_time + expected_pto_delay,
             manager_.GetRetransmissionTime());