gfe-relnote: In QUIC, use standard deviation to calculate PTO timeout. Protected by gfe2_reloadable_flag_quic_use_standard_deviation_for_pto.

PiperOrigin-RevId: 298487425
Change-Id: Ie29e4226796af161a9a5b777cd59486ea5cb5d3d
diff --git a/quic/core/congestion_control/rtt_stats.cc b/quic/core/congestion_control/rtt_stats.cc
index f4cbf38..d3853d7 100644
--- a/quic/core/congestion_control/rtt_stats.cc
+++ b/quic/core/congestion_control/rtt_stats.cc
@@ -27,6 +27,7 @@
       smoothed_rtt_(QuicTime::Delta::Zero()),
       previous_srtt_(QuicTime::Delta::Zero()),
       mean_deviation_(QuicTime::Delta::Zero()),
+      calculate_standard_deviation_(false),
       initial_rtt_(QuicTime::Delta::FromMilliseconds(kInitialRttMs)),
       max_ack_delay_(QuicTime::Delta::Zero()),
       last_update_time_(QuicTime::Zero()),
@@ -77,6 +78,9 @@
     }
   }
   latest_rtt_ = rtt_sample;
+  if (calculate_standard_deviation_) {
+    standard_deviation_calculator_.OnNewRttSample(rtt_sample, smoothed_rtt_);
+  }
   // First time call.
   if (smoothed_rtt_.IsZero()) {
     smoothed_rtt_ = rtt_sample;
@@ -101,4 +105,30 @@
   max_ack_delay_ = QuicTime::Delta::Zero();
 }
 
+QuicTime::Delta RttStats::GetStandardOrMeanDeviation() const {
+  DCHECK(calculate_standard_deviation_);
+  if (!standard_deviation_calculator_.has_valid_standard_deviation) {
+    return mean_deviation_;
+  }
+  return standard_deviation_calculator_.CalculateStandardDeviation();
+}
+
+void RttStats::StandardDeviationCaculator::OnNewRttSample(
+    QuicTime::Delta rtt_sample,
+    QuicTime::Delta smoothed_rtt) {
+  double new_value = rtt_sample.ToMicroseconds();
+  if (smoothed_rtt.IsZero()) {
+    return;
+  }
+  has_valid_standard_deviation = true;
+  const double delta = new_value - smoothed_rtt.ToMicroseconds();
+  m2 = kOneMinusBeta * m2 + kBeta * pow(delta, 2);
+}
+
+QuicTime::Delta
+RttStats::StandardDeviationCaculator::CalculateStandardDeviation() const {
+  DCHECK(has_valid_standard_deviation);
+  return QuicTime::Delta::FromMicroseconds(sqrt(m2));
+}
+
 }  // namespace quic
diff --git a/quic/core/congestion_control/rtt_stats.h b/quic/core/congestion_control/rtt_stats.h
index 1c0466a..9062c7a 100644
--- a/quic/core/congestion_control/rtt_stats.h
+++ b/quic/core/congestion_control/rtt_stats.h
@@ -23,6 +23,24 @@
 
 class QUIC_EXPORT_PRIVATE RttStats {
  public:
+  // Calculates running standard-deviation using Welford's algorithm:
+  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#
+  // Welford's_Online_algorithm.
+  struct QUIC_EXPORT_PRIVATE StandardDeviationCaculator {
+    StandardDeviationCaculator() {}
+
+    // Called when a new RTT sample is available.
+    void OnNewRttSample(QuicTime::Delta rtt_sample,
+                        QuicTime::Delta smoothed_rtt);
+    // Calculates the standard deviation.
+    QuicTime::Delta CalculateStandardDeviation() const;
+
+    bool has_valid_standard_deviation = false;
+
+   private:
+    double m2 = 0;
+  };
+
   RttStats();
   RttStats(const RttStats&) = delete;
   RttStats& operator=(const RttStats&) = delete;
@@ -73,6 +91,10 @@
 
   QuicTime::Delta mean_deviation() const { return mean_deviation_; }
 
+  // Returns standard deviation if there is a valid one. Otherwise, returns
+  // mean_deviation_.
+  QuicTime::Delta GetStandardOrMeanDeviation() const;
+
   QuicTime last_update_time() const { return last_update_time_; }
 
   bool ignore_max_ack_delay() const { return ignore_max_ack_delay_; }
@@ -85,6 +107,10 @@
     max_ack_delay_ = std::max(max_ack_delay_, initial_max_ack_delay);
   }
 
+  void EnableStandardDeviationCalculation() {
+    calculate_standard_deviation_ = true;
+  }
+
  private:
   friend class test::RttStatsPeer;
 
@@ -96,6 +122,10 @@
   // Approximation of standard deviation, the error is roughly 1.25 times
   // larger than the standard deviation, for a normally distributed signal.
   QuicTime::Delta mean_deviation_;
+  // Standard deviation calculator. Only used calculate_standard_deviation_ is
+  // true.
+  StandardDeviationCaculator standard_deviation_calculator_;
+  bool calculate_standard_deviation_;
   QuicTime::Delta initial_rtt_;
   // The maximum ack delay observed over the connection after excluding ack
   // delays that were too large to be included in an RTT measurement.
diff --git a/quic/core/congestion_control/rtt_stats_test.cc b/quic/core/congestion_control/rtt_stats_test.cc
index be5f4d6..1b0e545 100644
--- a/quic/core/congestion_control/rtt_stats_test.cc
+++ b/quic/core/congestion_control/rtt_stats_test.cc
@@ -9,6 +9,7 @@
 #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_mock_log.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
+#include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
 #include "net/third_party/quiche/src/quic/test_tools/rtt_stats_peer.h"
 
 using testing::_;
@@ -216,5 +217,54 @@
   EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.min_rtt());
 }
 
+TEST_F(RttStatsTest, StandardDeviationCaculatorTest1) {
+  // All samples are the same.
+  rtt_stats_.EnableStandardDeviationCalculation();
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  EXPECT_EQ(rtt_stats_.mean_deviation(),
+            rtt_stats_.GetStandardOrMeanDeviation());
+
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.GetStandardOrMeanDeviation());
+}
+
+TEST_F(RttStatsTest, StandardDeviationCaculatorTest2) {
+  // Small variance.
+  rtt_stats_.EnableStandardDeviationCalculation();
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(9),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(11),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  EXPECT_LT(QuicTime::Delta::FromMicroseconds(500),
+            rtt_stats_.GetStandardOrMeanDeviation());
+  EXPECT_GT(QuicTime::Delta::FromMilliseconds(1),
+            rtt_stats_.GetStandardOrMeanDeviation());
+}
+
+TEST_F(RttStatsTest, StandardDeviationCaculatorTest3) {
+  // Some variance.
+  rtt_stats_.EnableStandardDeviationCalculation();
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  EXPECT_APPROX_EQ(rtt_stats_.mean_deviation(),
+                   rtt_stats_.GetStandardOrMeanDeviation(), 0.25f);
+}
+
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/crypto/crypto_protocol.h b/quic/core/crypto/crypto_protocol.h
index eb8d75c..9532b99 100644
--- a/quic/core/crypto/crypto_protocol.h
+++ b/quic/core/crypto/crypto_protocol.h
@@ -213,6 +213,8 @@
 const QuicTag kPAG1 = TAG('P', 'A', 'G', '1');   // Make 1st PTO more aggressive
 const QuicTag kPAG2 = TAG('P', 'A', 'G', '2');   // Make first 2 PTOs more
                                                  // aggressive
+const QuicTag kPSDA = TAG('P', 'S', 'D', 'A');   // Use standard deviation when
+                                                 // calculating PTO timeout.
 const QuicTag kPLE1 = TAG('P', 'L', 'E', '1');   // Arm the 1st PTO with
                                                  // earliest in flight sent time
                                                  // and at least 0.5*srtt from
diff --git a/quic/core/quic_sent_packet_manager.cc b/quic/core/quic_sent_packet_manager.cc
index b6eae54..c8fd641 100644
--- a/quic/core/quic_sent_packet_manager.cc
+++ b/quic/core/quic_sent_packet_manager.cc
@@ -106,7 +106,8 @@
       num_tlp_timeout_ptos_(0),
       one_rtt_packet_acked_(false),
       one_rtt_packet_sent_(false),
-      first_pto_srtt_multiplier_(0) {
+      first_pto_srtt_multiplier_(0),
+      use_standard_deviation_for_pto_(false) {
   SetSendAlgorithm(congestion_control_type);
 }
 
@@ -198,6 +199,12 @@
         first_pto_srtt_multiplier_ = 1.5;
       }
     }
+    if (GetQuicReloadableFlag(quic_use_standard_deviation_for_pto) &&
+        config.HasClientSentConnectionOption(kPSDA, perspective)) {
+      QUIC_RELOADABLE_FLAG_COUNT(quic_use_standard_deviation_for_pto);
+      use_standard_deviation_for_pto_ = true;
+      rtt_stats_.EnableStandardDeviationCalculation();
+    }
   }
 
   // Configure congestion control.
@@ -1143,10 +1150,12 @@
     }
     return 2 * rtt_stats_.initial_rtt();
   }
+  const QuicTime::Delta rtt_var = use_standard_deviation_for_pto_
+                                      ? rtt_stats_.GetStandardOrMeanDeviation()
+                                      : rtt_stats_.mean_deviation();
   QuicTime::Delta pto_delay =
       rtt_stats_.smoothed_rtt() +
-      std::max(pto_rttvar_multiplier_ * rtt_stats_.mean_deviation(),
-               kAlarmGranularity) +
+      std::max(pto_rttvar_multiplier_ * rtt_var, kAlarmGranularity) +
       (ShouldAddMaxAckDelay() ? peer_max_ack_delay_ : QuicTime::Delta::Zero());
   pto_delay =
       pto_delay * (1 << (consecutive_pto_count_ -
diff --git a/quic/core/quic_sent_packet_manager.h b/quic/core/quic_sent_packet_manager.h
index 99ab378..fa0634b 100644
--- a/quic/core/quic_sent_packet_manager.h
+++ b/quic/core/quic_sent_packet_manager.h
@@ -653,6 +653,10 @@
   // delay and multiplier * srtt from last in flight packet.
   float first_pto_srtt_multiplier_;
 
+  // If true, use standard deviation (instead of mean deviation) when
+  // calculating PTO timeout.
+  bool use_standard_deviation_for_pto_;
+
   const bool avoid_overestimate_bandwidth_with_aggregation_ =
       GetQuicReloadableFlag(quic_avoid_overestimate_bandwidth_with_aggregation);
 };
diff --git a/quic/core/quic_sent_packet_manager_test.cc b/quic/core/quic_sent_packet_manager_test.cc
index 36d476f..2a39f9b 100644
--- a/quic/core/quic_sent_packet_manager_test.cc
+++ b/quic/core/quic_sent_packet_manager_test.cc
@@ -3506,6 +3506,44 @@
             manager_.GetRetransmissionTime());
 }
 
+TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeoutUsingStandardDeviation) {
+  SetQuicReloadableFlag(quic_use_standard_deviation_for_pto, true);
+  EnablePto(k1PTO);
+  // Use PTOS and PSDA.
+  QuicConfig config;
+  QuicTagVector options;
+  options.push_back(kPTOS);
+  options.push_back(kPSDA);
+  QuicConfigPeer::SetReceivedConnectionOptions(&config, options);
+  EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+  EXPECT_CALL(*network_change_visitor_, OnCongestionChange());
+  manager_.SetFromConfig(config);
+  EXPECT_TRUE(manager_.skip_packet_number_for_pto());
+  EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true));
+  EXPECT_CALL(*send_algorithm_, PacingRate(_))
+      .WillRepeatedly(Return(QuicBandwidth::Zero()));
+  EXPECT_CALL(*send_algorithm_, GetCongestionWindow())
+      .WillRepeatedly(Return(10 * kDefaultTCPMSS));
+  RttStats* rtt_stats = const_cast<RttStats*>(manager_.GetRttStats());
+  rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(75),
+                       QuicTime::Delta::Zero(), QuicTime::Zero());
+  QuicTime::Delta srtt = rtt_stats->smoothed_rtt();
+
+  SendDataPacket(1, ENCRYPTION_FORWARD_SECURE);
+  // Verify PTO is correctly set using standard deviation.
+  QuicTime::Delta expected_pto_delay =
+      srtt + 4 * rtt_stats->GetStandardOrMeanDeviation() +
+      QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+  EXPECT_EQ(clock_.Now() + expected_pto_delay,
+            manager_.GetRetransmissionTime());
+}
+
 TEST_F(QuicSentPacketManagerTest,
        ComputingProbeTimeoutByLeftEdgeMultiplePacketNumberSpaces) {
   SetQuicReloadableFlag(quic_arm_pto_with_earliest_sent_time, true);