gfe-relnote: In QUIC, use standard deviation to calculate PTO timeout. Protected by gfe2_reloadable_flag_quic_use_standard_deviation_for_pto. PiperOrigin-RevId: 298487425 Change-Id: Ie29e4226796af161a9a5b777cd59486ea5cb5d3d
diff --git a/quic/core/congestion_control/rtt_stats.cc b/quic/core/congestion_control/rtt_stats.cc index f4cbf38..d3853d7 100644 --- a/quic/core/congestion_control/rtt_stats.cc +++ b/quic/core/congestion_control/rtt_stats.cc
@@ -27,6 +27,7 @@ smoothed_rtt_(QuicTime::Delta::Zero()), previous_srtt_(QuicTime::Delta::Zero()), mean_deviation_(QuicTime::Delta::Zero()), + calculate_standard_deviation_(false), initial_rtt_(QuicTime::Delta::FromMilliseconds(kInitialRttMs)), max_ack_delay_(QuicTime::Delta::Zero()), last_update_time_(QuicTime::Zero()), @@ -77,6 +78,9 @@ } } latest_rtt_ = rtt_sample; + if (calculate_standard_deviation_) { + standard_deviation_calculator_.OnNewRttSample(rtt_sample, smoothed_rtt_); + } // First time call. if (smoothed_rtt_.IsZero()) { smoothed_rtt_ = rtt_sample; @@ -101,4 +105,30 @@ max_ack_delay_ = QuicTime::Delta::Zero(); } +QuicTime::Delta RttStats::GetStandardOrMeanDeviation() const { + DCHECK(calculate_standard_deviation_); + if (!standard_deviation_calculator_.has_valid_standard_deviation) { + return mean_deviation_; + } + return standard_deviation_calculator_.CalculateStandardDeviation(); +} + +void RttStats::StandardDeviationCaculator::OnNewRttSample( + QuicTime::Delta rtt_sample, + QuicTime::Delta smoothed_rtt) { + double new_value = rtt_sample.ToMicroseconds(); + if (smoothed_rtt.IsZero()) { + return; + } + has_valid_standard_deviation = true; + const double delta = new_value - smoothed_rtt.ToMicroseconds(); + m2 = kOneMinusBeta * m2 + kBeta * pow(delta, 2); +} + +QuicTime::Delta +RttStats::StandardDeviationCaculator::CalculateStandardDeviation() const { + DCHECK(has_valid_standard_deviation); + return QuicTime::Delta::FromMicroseconds(sqrt(m2)); +} + } // namespace quic
diff --git a/quic/core/congestion_control/rtt_stats.h b/quic/core/congestion_control/rtt_stats.h index 1c0466a..9062c7a 100644 --- a/quic/core/congestion_control/rtt_stats.h +++ b/quic/core/congestion_control/rtt_stats.h
@@ -23,6 +23,24 @@ class QUIC_EXPORT_PRIVATE RttStats { public: + // Calculates running standard-deviation using Welford's algorithm: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance# + // Welford's_Online_algorithm. + struct QUIC_EXPORT_PRIVATE StandardDeviationCaculator { + StandardDeviationCaculator() {} + + // Called when a new RTT sample is available. + void OnNewRttSample(QuicTime::Delta rtt_sample, + QuicTime::Delta smoothed_rtt); + // Calculates the standard deviation. + QuicTime::Delta CalculateStandardDeviation() const; + + bool has_valid_standard_deviation = false; + + private: + double m2 = 0; + }; + RttStats(); RttStats(const RttStats&) = delete; RttStats& operator=(const RttStats&) = delete; @@ -73,6 +91,10 @@ QuicTime::Delta mean_deviation() const { return mean_deviation_; } + // Returns standard deviation if there is a valid one. Otherwise, returns + // mean_deviation_. + QuicTime::Delta GetStandardOrMeanDeviation() const; + QuicTime last_update_time() const { return last_update_time_; } bool ignore_max_ack_delay() const { return ignore_max_ack_delay_; } @@ -85,6 +107,10 @@ max_ack_delay_ = std::max(max_ack_delay_, initial_max_ack_delay); } + void EnableStandardDeviationCalculation() { + calculate_standard_deviation_ = true; + } + private: friend class test::RttStatsPeer; @@ -96,6 +122,10 @@ // Approximation of standard deviation, the error is roughly 1.25 times // larger than the standard deviation, for a normally distributed signal. QuicTime::Delta mean_deviation_; + // Standard deviation calculator. Only used calculate_standard_deviation_ is + // true. + StandardDeviationCaculator standard_deviation_calculator_; + bool calculate_standard_deviation_; QuicTime::Delta initial_rtt_; // The maximum ack delay observed over the connection after excluding ack // delays that were too large to be included in an RTT measurement.
diff --git a/quic/core/congestion_control/rtt_stats_test.cc b/quic/core/congestion_control/rtt_stats_test.cc index be5f4d6..1b0e545 100644 --- a/quic/core/congestion_control/rtt_stats_test.cc +++ b/quic/core/congestion_control/rtt_stats_test.cc
@@ -9,6 +9,7 @@ #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h" #include "net/third_party/quiche/src/quic/platform/api/quic_mock_log.h" #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" +#include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h" #include "net/third_party/quiche/src/quic/test_tools/rtt_stats_peer.h" using testing::_; @@ -216,5 +217,54 @@ EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.min_rtt()); } +TEST_F(RttStatsTest, StandardDeviationCaculatorTest1) { + // All samples are the same. + rtt_stats_.EnableStandardDeviationCalculation(); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(rtt_stats_.mean_deviation(), + rtt_stats_.GetStandardOrMeanDeviation()); + + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.GetStandardOrMeanDeviation()); +} + +TEST_F(RttStatsTest, StandardDeviationCaculatorTest2) { + // Small variance. + rtt_stats_.EnableStandardDeviationCalculation(); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(9), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(11), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_LT(QuicTime::Delta::FromMicroseconds(500), + rtt_stats_.GetStandardOrMeanDeviation()); + EXPECT_GT(QuicTime::Delta::FromMilliseconds(1), + rtt_stats_.GetStandardOrMeanDeviation()); +} + +TEST_F(RttStatsTest, StandardDeviationCaculatorTest3) { + // Some variance. + rtt_stats_.EnableStandardDeviationCalculation(); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_APPROX_EQ(rtt_stats_.mean_deviation(), + rtt_stats_.GetStandardOrMeanDeviation(), 0.25f); +} + } // namespace test } // namespace quic
diff --git a/quic/core/crypto/crypto_protocol.h b/quic/core/crypto/crypto_protocol.h index eb8d75c..9532b99 100644 --- a/quic/core/crypto/crypto_protocol.h +++ b/quic/core/crypto/crypto_protocol.h
@@ -213,6 +213,8 @@ const QuicTag kPAG1 = TAG('P', 'A', 'G', '1'); // Make 1st PTO more aggressive const QuicTag kPAG2 = TAG('P', 'A', 'G', '2'); // Make first 2 PTOs more // aggressive +const QuicTag kPSDA = TAG('P', 'S', 'D', 'A'); // Use standard deviation when + // calculating PTO timeout. const QuicTag kPLE1 = TAG('P', 'L', 'E', '1'); // Arm the 1st PTO with // earliest in flight sent time // and at least 0.5*srtt from
diff --git a/quic/core/quic_sent_packet_manager.cc b/quic/core/quic_sent_packet_manager.cc index b6eae54..c8fd641 100644 --- a/quic/core/quic_sent_packet_manager.cc +++ b/quic/core/quic_sent_packet_manager.cc
@@ -106,7 +106,8 @@ num_tlp_timeout_ptos_(0), one_rtt_packet_acked_(false), one_rtt_packet_sent_(false), - first_pto_srtt_multiplier_(0) { + first_pto_srtt_multiplier_(0), + use_standard_deviation_for_pto_(false) { SetSendAlgorithm(congestion_control_type); } @@ -198,6 +199,12 @@ first_pto_srtt_multiplier_ = 1.5; } } + if (GetQuicReloadableFlag(quic_use_standard_deviation_for_pto) && + config.HasClientSentConnectionOption(kPSDA, perspective)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_use_standard_deviation_for_pto); + use_standard_deviation_for_pto_ = true; + rtt_stats_.EnableStandardDeviationCalculation(); + } } // Configure congestion control. @@ -1143,10 +1150,12 @@ } return 2 * rtt_stats_.initial_rtt(); } + const QuicTime::Delta rtt_var = use_standard_deviation_for_pto_ + ? rtt_stats_.GetStandardOrMeanDeviation() + : rtt_stats_.mean_deviation(); QuicTime::Delta pto_delay = rtt_stats_.smoothed_rtt() + - std::max(pto_rttvar_multiplier_ * rtt_stats_.mean_deviation(), - kAlarmGranularity) + + std::max(pto_rttvar_multiplier_ * rtt_var, kAlarmGranularity) + (ShouldAddMaxAckDelay() ? peer_max_ack_delay_ : QuicTime::Delta::Zero()); pto_delay = pto_delay * (1 << (consecutive_pto_count_ -
diff --git a/quic/core/quic_sent_packet_manager.h b/quic/core/quic_sent_packet_manager.h index 99ab378..fa0634b 100644 --- a/quic/core/quic_sent_packet_manager.h +++ b/quic/core/quic_sent_packet_manager.h
@@ -653,6 +653,10 @@ // delay and multiplier * srtt from last in flight packet. float first_pto_srtt_multiplier_; + // If true, use standard deviation (instead of mean deviation) when + // calculating PTO timeout. + bool use_standard_deviation_for_pto_; + const bool avoid_overestimate_bandwidth_with_aggregation_ = GetQuicReloadableFlag(quic_avoid_overestimate_bandwidth_with_aggregation); };
diff --git a/quic/core/quic_sent_packet_manager_test.cc b/quic/core/quic_sent_packet_manager_test.cc index 36d476f..2a39f9b 100644 --- a/quic/core/quic_sent_packet_manager_test.cc +++ b/quic/core/quic_sent_packet_manager_test.cc
@@ -3506,6 +3506,44 @@ manager_.GetRetransmissionTime()); } +TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeoutUsingStandardDeviation) { + SetQuicReloadableFlag(quic_use_standard_deviation_for_pto, true); + EnablePto(k1PTO); + // Use PTOS and PSDA. + QuicConfig config; + QuicTagVector options; + options.push_back(kPTOS); + options.push_back(kPSDA); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_TRUE(manager_.skip_packet_number_for_pto()); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast<RttStats*>(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(50), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(50), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(75), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is correctly set using standard deviation. + QuicTime::Delta expected_pto_delay = + srtt + 4 * rtt_stats->GetStandardOrMeanDeviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeoutByLeftEdgeMultiplePacketNumberSpaces) { SetQuicReloadableFlag(quic_arm_pto_with_earliest_sent_time, true);