Shorten PTO when a HANDSHAKE or 1-RTT packet is received prior to the INITIAL. Client only.

PiperOrigin-RevId: 338475279
Change-Id: Idde6ef503847b1c9e0a5fb4c3e1e336102d06fdb
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index cdb93a3..7f7089b 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -3744,6 +3744,9 @@
   }
   QUIC_DVLOG(1) << ENDPOINT << "Queueing undecryptable packet.";
   undecryptable_packets_.emplace_back(packet, decryption_level);
+  if (perspective_ == Perspective::IS_CLIENT) {
+    SetRetransmissionAlarm();
+  }
 }
 
 void QuicConnection::MaybeProcessUndecryptablePackets() {
@@ -3831,6 +3834,9 @@
     }
     undecryptable_packets_.clear();
   }
+  if (perspective_ == Perspective::IS_CLIENT) {
+    SetRetransmissionAlarm();
+  }
 }
 
 void QuicConnection::QueueCoalescedPacket(const QuicEncryptedPacket& packet) {
@@ -4123,8 +4129,7 @@
     return;
   }
 
-  retransmission_alarm_->Update(sent_packet_manager_.GetRetransmissionTime(),
-                                kAlarmGranularity);
+  retransmission_alarm_->Update(GetRetransmissionDeadline(), kAlarmGranularity);
 }
 
 void QuicConnection::MaybeSetMtuAlarm(QuicPacketNumber sent_packet_number) {
@@ -5385,6 +5390,20 @@
   return num_rtos_for_blackhole_detection_ > 0;
 }
 
+QuicTime QuicConnection::GetRetransmissionDeadline() const {
+  if (perspective_ == Perspective::IS_CLIENT &&
+      SupportsMultiplePacketNumberSpaces() && !IsHandshakeConfirmed() &&
+      stats_.pto_count == 0 &&
+      !framer_.HasDecrypterOfEncryptionLevel(ENCRYPTION_HANDSHAKE) &&
+      !undecryptable_packets_.empty()) {
+    // Retransmits ClientHello quickly when a Handshake or 1-RTT packet is
+    // received prior to having Handshake keys. Adding kAlarmGranulary will
+    // avoid spurious retransmissions in the case of small-scale reordering.
+    return clock_->ApproximateNow() + kAlarmGranularity;
+  }
+  return sent_packet_manager_.GetRetransmissionTime();
+}
+
 void QuicConnection::SendPathChallenge(QuicPathFrameBuffer* data_buffer,
                                        const QuicSocketAddress& self_address,
                                        const QuicSocketAddress& peer_address,
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 59fa80d..376c623 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1450,6 +1450,9 @@
   // Returns true if network blackhole should be detected.
   bool ShouldDetectBlackhole() const;
 
+  // Returns retransmission deadline.
+  QuicTime GetRetransmissionDeadline() const;
+
   // Validate connection IDs used during the handshake. Closes the connection
   // on validation failure.
   bool ValidateConfigConnectionIds(const QuicConfig& config);
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 7d736b3..0b94236 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -12861,6 +12861,67 @@
             QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs));
 }
 
+TEST_P(QuicConnectionTest, FastRecoveryOfLostServerHello) {
+  if (!connection_.SupportsMultiplePacketNumberSpaces()) {
+    return;
+  }
+  EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+  QuicConfig config;
+  connection_.SetFromConfig(config);
+
+  use_tagging_decrypter();
+  connection_.SetEncrypter(ENCRYPTION_INITIAL,
+                           std::make_unique<TaggingEncrypter>(0x01));
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL);
+  connection_.SendCryptoStreamData();
+  clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20));
+
+  // Assume ServerHello gets lost.
+  peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE,
+                            std::make_unique<TaggingEncrypter>(0x02));
+  ProcessCryptoPacketAtLevel(2, ENCRYPTION_HANDSHAKE);
+  ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet());
+  // Shorten PTO for fast recovery from lost ServerHello.
+  EXPECT_EQ(clock_.ApproximateNow() + kAlarmGranularity,
+            connection_.GetRetransmissionAlarm()->deadline());
+}
+
+TEST_P(QuicConnectionTest, ServerHelloGetsReordered) {
+  if (!connection_.SupportsMultiplePacketNumberSpaces()) {
+    return;
+  }
+  EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _));
+  QuicConfig config;
+  connection_.SetFromConfig(config);
+  EXPECT_CALL(visitor_, OnCryptoFrame(_))
+      .WillRepeatedly(Invoke([=](const QuicCryptoFrame& frame) {
+        if (frame.level == ENCRYPTION_INITIAL) {
+          // Install handshake read keys.
+          SetDecrypter(ENCRYPTION_HANDSHAKE,
+                       std::make_unique<StrictTaggingDecrypter>(0x02));
+          connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
+                                   std::make_unique<TaggingEncrypter>(0x02));
+          connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
+        }
+      }));
+
+  use_tagging_decrypter();
+  connection_.SetEncrypter(ENCRYPTION_INITIAL,
+                           std::make_unique<TaggingEncrypter>(0x01));
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL);
+  connection_.SendCryptoStreamData();
+  clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20));
+
+  // Assume ServerHello gets reordered.
+  peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE,
+                            std::make_unique<TaggingEncrypter>(0x02));
+  ProcessCryptoPacketAtLevel(2, ENCRYPTION_HANDSHAKE);
+  ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL);
+  // Verify fast recovery is not enabled.
+  EXPECT_EQ(connection_.sent_packet_manager().GetRetransmissionTime(),
+            connection_.GetRetransmissionAlarm()->deadline());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic