gfe-relnote: In QUIC, use standard deviation to calculate PTO timeout. Protected by gfe2_reloadable_flag_quic_use_standard_deviation_for_pto.
PiperOrigin-RevId: 298487425
Change-Id: Ie29e4226796af161a9a5b777cd59486ea5cb5d3d
diff --git a/quic/core/congestion_control/rtt_stats.cc b/quic/core/congestion_control/rtt_stats.cc
index f4cbf38..d3853d7 100644
--- a/quic/core/congestion_control/rtt_stats.cc
+++ b/quic/core/congestion_control/rtt_stats.cc
@@ -27,6 +27,7 @@
smoothed_rtt_(QuicTime::Delta::Zero()),
previous_srtt_(QuicTime::Delta::Zero()),
mean_deviation_(QuicTime::Delta::Zero()),
+ calculate_standard_deviation_(false),
initial_rtt_(QuicTime::Delta::FromMilliseconds(kInitialRttMs)),
max_ack_delay_(QuicTime::Delta::Zero()),
last_update_time_(QuicTime::Zero()),
@@ -77,6 +78,9 @@
}
}
latest_rtt_ = rtt_sample;
+ if (calculate_standard_deviation_) {
+ standard_deviation_calculator_.OnNewRttSample(rtt_sample, smoothed_rtt_);
+ }
// First time call.
if (smoothed_rtt_.IsZero()) {
smoothed_rtt_ = rtt_sample;
@@ -101,4 +105,30 @@
max_ack_delay_ = QuicTime::Delta::Zero();
}
+QuicTime::Delta RttStats::GetStandardOrMeanDeviation() const {
+ DCHECK(calculate_standard_deviation_);
+ if (!standard_deviation_calculator_.has_valid_standard_deviation) {
+ return mean_deviation_;
+ }
+ return standard_deviation_calculator_.CalculateStandardDeviation();
+}
+
+void RttStats::StandardDeviationCaculator::OnNewRttSample(
+ QuicTime::Delta rtt_sample,
+ QuicTime::Delta smoothed_rtt) {
+ double new_value = rtt_sample.ToMicroseconds();
+ if (smoothed_rtt.IsZero()) {
+ return;
+ }
+ has_valid_standard_deviation = true;
+ const double delta = new_value - smoothed_rtt.ToMicroseconds();
+ m2 = kOneMinusBeta * m2 + kBeta * pow(delta, 2);
+}
+
+QuicTime::Delta
+RttStats::StandardDeviationCaculator::CalculateStandardDeviation() const {
+ DCHECK(has_valid_standard_deviation);
+ return QuicTime::Delta::FromMicroseconds(sqrt(m2));
+}
+
} // namespace quic
diff --git a/quic/core/congestion_control/rtt_stats.h b/quic/core/congestion_control/rtt_stats.h
index 1c0466a..9062c7a 100644
--- a/quic/core/congestion_control/rtt_stats.h
+++ b/quic/core/congestion_control/rtt_stats.h
@@ -23,6 +23,24 @@
class QUIC_EXPORT_PRIVATE RttStats {
public:
+ // Calculates running standard-deviation using Welford's algorithm:
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#
+ // Welford's_Online_algorithm.
+ struct QUIC_EXPORT_PRIVATE StandardDeviationCaculator {
+ StandardDeviationCaculator() {}
+
+ // Called when a new RTT sample is available.
+ void OnNewRttSample(QuicTime::Delta rtt_sample,
+ QuicTime::Delta smoothed_rtt);
+ // Calculates the standard deviation.
+ QuicTime::Delta CalculateStandardDeviation() const;
+
+ bool has_valid_standard_deviation = false;
+
+ private:
+ double m2 = 0;
+ };
+
RttStats();
RttStats(const RttStats&) = delete;
RttStats& operator=(const RttStats&) = delete;
@@ -73,6 +91,10 @@
QuicTime::Delta mean_deviation() const { return mean_deviation_; }
+ // Returns standard deviation if there is a valid one. Otherwise, returns
+ // mean_deviation_.
+ QuicTime::Delta GetStandardOrMeanDeviation() const;
+
QuicTime last_update_time() const { return last_update_time_; }
bool ignore_max_ack_delay() const { return ignore_max_ack_delay_; }
@@ -85,6 +107,10 @@
max_ack_delay_ = std::max(max_ack_delay_, initial_max_ack_delay);
}
+ void EnableStandardDeviationCalculation() {
+ calculate_standard_deviation_ = true;
+ }
+
private:
friend class test::RttStatsPeer;
@@ -96,6 +122,10 @@
// Approximation of standard deviation, the error is roughly 1.25 times
// larger than the standard deviation, for a normally distributed signal.
QuicTime::Delta mean_deviation_;
+ // Standard deviation calculator. Only used calculate_standard_deviation_ is
+ // true.
+ StandardDeviationCaculator standard_deviation_calculator_;
+ bool calculate_standard_deviation_;
QuicTime::Delta initial_rtt_;
// The maximum ack delay observed over the connection after excluding ack
// delays that were too large to be included in an RTT measurement.
diff --git a/quic/core/congestion_control/rtt_stats_test.cc b/quic/core/congestion_control/rtt_stats_test.cc
index be5f4d6..1b0e545 100644
--- a/quic/core/congestion_control/rtt_stats_test.cc
+++ b/quic/core/congestion_control/rtt_stats_test.cc
@@ -9,6 +9,7 @@
#include "net/third_party/quiche/src/quic/platform/api/quic_logging.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_mock_log.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
+#include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
#include "net/third_party/quiche/src/quic/test_tools/rtt_stats_peer.h"
using testing::_;
@@ -216,5 +217,54 @@
EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.min_rtt());
}
+TEST_F(RttStatsTest, StandardDeviationCaculatorTest1) {
+ // All samples are the same.
+ rtt_stats_.EnableStandardDeviationCalculation();
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ EXPECT_EQ(rtt_stats_.mean_deviation(),
+ rtt_stats_.GetStandardOrMeanDeviation());
+
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.GetStandardOrMeanDeviation());
+}
+
+TEST_F(RttStatsTest, StandardDeviationCaculatorTest2) {
+ // Small variance.
+ rtt_stats_.EnableStandardDeviationCalculation();
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(9),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(11),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ EXPECT_LT(QuicTime::Delta::FromMicroseconds(500),
+ rtt_stats_.GetStandardOrMeanDeviation());
+ EXPECT_GT(QuicTime::Delta::FromMilliseconds(1),
+ rtt_stats_.GetStandardOrMeanDeviation());
+}
+
+TEST_F(RttStatsTest, StandardDeviationCaculatorTest3) {
+ // Some variance.
+ rtt_stats_.EnableStandardDeviationCalculation();
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ EXPECT_APPROX_EQ(rtt_stats_.mean_deviation(),
+ rtt_stats_.GetStandardOrMeanDeviation(), 0.25f);
+}
+
} // namespace test
} // namespace quic
diff --git a/quic/core/crypto/crypto_protocol.h b/quic/core/crypto/crypto_protocol.h
index eb8d75c..9532b99 100644
--- a/quic/core/crypto/crypto_protocol.h
+++ b/quic/core/crypto/crypto_protocol.h
@@ -213,6 +213,8 @@
const QuicTag kPAG1 = TAG('P', 'A', 'G', '1'); // Make 1st PTO more aggressive
const QuicTag kPAG2 = TAG('P', 'A', 'G', '2'); // Make first 2 PTOs more
// aggressive
+const QuicTag kPSDA = TAG('P', 'S', 'D', 'A'); // Use standard deviation when
+ // calculating PTO timeout.
const QuicTag kPLE1 = TAG('P', 'L', 'E', '1'); // Arm the 1st PTO with
// earliest in flight sent time
// and at least 0.5*srtt from
diff --git a/quic/core/quic_sent_packet_manager.cc b/quic/core/quic_sent_packet_manager.cc
index b6eae54..c8fd641 100644
--- a/quic/core/quic_sent_packet_manager.cc
+++ b/quic/core/quic_sent_packet_manager.cc
@@ -106,7 +106,8 @@
num_tlp_timeout_ptos_(0),
one_rtt_packet_acked_(false),
one_rtt_packet_sent_(false),
- first_pto_srtt_multiplier_(0) {
+ first_pto_srtt_multiplier_(0),
+ use_standard_deviation_for_pto_(false) {
SetSendAlgorithm(congestion_control_type);
}
@@ -198,6 +199,12 @@
first_pto_srtt_multiplier_ = 1.5;
}
}
+ if (GetQuicReloadableFlag(quic_use_standard_deviation_for_pto) &&
+ config.HasClientSentConnectionOption(kPSDA, perspective)) {
+ QUIC_RELOADABLE_FLAG_COUNT(quic_use_standard_deviation_for_pto);
+ use_standard_deviation_for_pto_ = true;
+ rtt_stats_.EnableStandardDeviationCalculation();
+ }
}
// Configure congestion control.
@@ -1143,10 +1150,12 @@
}
return 2 * rtt_stats_.initial_rtt();
}
+ const QuicTime::Delta rtt_var = use_standard_deviation_for_pto_
+ ? rtt_stats_.GetStandardOrMeanDeviation()
+ : rtt_stats_.mean_deviation();
QuicTime::Delta pto_delay =
rtt_stats_.smoothed_rtt() +
- std::max(pto_rttvar_multiplier_ * rtt_stats_.mean_deviation(),
- kAlarmGranularity) +
+ std::max(pto_rttvar_multiplier_ * rtt_var, kAlarmGranularity) +
(ShouldAddMaxAckDelay() ? peer_max_ack_delay_ : QuicTime::Delta::Zero());
pto_delay =
pto_delay * (1 << (consecutive_pto_count_ -
diff --git a/quic/core/quic_sent_packet_manager.h b/quic/core/quic_sent_packet_manager.h
index 99ab378..fa0634b 100644
--- a/quic/core/quic_sent_packet_manager.h
+++ b/quic/core/quic_sent_packet_manager.h
@@ -653,6 +653,10 @@
// delay and multiplier * srtt from last in flight packet.
float first_pto_srtt_multiplier_;
+ // If true, use standard deviation (instead of mean deviation) when
+ // calculating PTO timeout.
+ bool use_standard_deviation_for_pto_;
+
const bool avoid_overestimate_bandwidth_with_aggregation_ =
GetQuicReloadableFlag(quic_avoid_overestimate_bandwidth_with_aggregation);
};
diff --git a/quic/core/quic_sent_packet_manager_test.cc b/quic/core/quic_sent_packet_manager_test.cc
index 36d476f..2a39f9b 100644
--- a/quic/core/quic_sent_packet_manager_test.cc
+++ b/quic/core/quic_sent_packet_manager_test.cc
@@ -3506,6 +3506,44 @@
manager_.GetRetransmissionTime());
}
+TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeoutUsingStandardDeviation) {
+ SetQuicReloadableFlag(quic_use_standard_deviation_for_pto, true);
+ EnablePto(k1PTO);
+ // Use PTOS and PSDA.
+ QuicConfig config;
+ QuicTagVector options;
+ options.push_back(kPTOS);
+ options.push_back(kPSDA);
+ QuicConfigPeer::SetReceivedConnectionOptions(&config, options);
+ EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+ EXPECT_CALL(*network_change_visitor_, OnCongestionChange());
+ manager_.SetFromConfig(config);
+ EXPECT_TRUE(manager_.skip_packet_number_for_pto());
+ EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true));
+ EXPECT_CALL(*send_algorithm_, PacingRate(_))
+ .WillRepeatedly(Return(QuicBandwidth::Zero()));
+ EXPECT_CALL(*send_algorithm_, GetCongestionWindow())
+ .WillRepeatedly(Return(10 * kDefaultTCPMSS));
+ RttStats* rtt_stats = const_cast<RttStats*>(manager_.GetRttStats());
+ rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(50),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(75),
+ QuicTime::Delta::Zero(), QuicTime::Zero());
+ QuicTime::Delta srtt = rtt_stats->smoothed_rtt();
+
+ SendDataPacket(1, ENCRYPTION_FORWARD_SECURE);
+ // Verify PTO is correctly set using standard deviation.
+ QuicTime::Delta expected_pto_delay =
+ srtt + 4 * rtt_stats->GetStandardOrMeanDeviation() +
+ QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
+ EXPECT_EQ(clock_.Now() + expected_pto_delay,
+ manager_.GetRetransmissionTime());
+}
+
TEST_F(QuicSentPacketManagerTest,
ComputingProbeTimeoutByLeftEdgeMultiplePacketNumberSpaces) {
SetQuicReloadableFlag(quic_arm_pto_with_earliest_sent_time, true);