Allow QUIC Key Update on first received 1-RTT packet

As per the specification, the first key update is allowed even if we haven't decrypted any packets: <<An endpoint MUST NOT initiate a key update prior to having confirmed the handshake (Section 4.1.2). An endpoint MUST NOT initiate a subsequent key update unless it has received an acknowledgment for a packet that was sent protected with keys from the current key phase.>> This issue was found during the IETF 110 hackathon, as another implementation was performing a key update immediately post handshake.

Protected by FLAGS_quic_reloadable_flag_quic_fix_key_update_on_first_packet.

PiperOrigin-RevId: 360781943
Change-Id: I678cd85ffdd4d014184fa6db86e8c42c64c421cd
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 0e9fa10..83b2a65 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -43,6 +43,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_version_rfcv1, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_encrypted_control_frames, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_encrypted_goaway, true)
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_key_update_on_first_packet, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_on_stream_reset, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_willing_and_able_to_write2, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_h3_datagram, false)
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index 24034fd..51fa3ea 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -4335,6 +4335,7 @@
     QUIC_BUG << "Failed to create next crypters";
     return false;
   }
+  key_update_performed_ = true;
   current_key_phase_bit_ = !current_key_phase_bit_;
   QUIC_DLOG(INFO) << ENDPOINT << "DoKeyUpdate: new current_key_phase_bit_="
                   << current_key_phase_bit_;
@@ -4772,9 +4773,13 @@
                     << " received key_phase=" << key_phase
                     << " current_key_phase_bit_=" << current_key_phase_bit_;
       if (key_phase != current_key_phase_bit_) {
-        if (current_key_phase_first_received_packet_number_.IsInitialized() &&
-            header.packet_number >
-                current_key_phase_first_received_packet_number_) {
+        if ((current_key_phase_first_received_packet_number_.IsInitialized() &&
+             header.packet_number >
+                 current_key_phase_first_received_packet_number_) ||
+            (GetQuicReloadableFlag(quic_fix_key_update_on_first_packet) &&
+             !current_key_phase_first_received_packet_number_.IsInitialized() &&
+             !key_update_performed_)) {
+          QUIC_RELOADABLE_FLAG_COUNT(quic_fix_key_update_on_first_packet);
           if (!next_decrypter_) {
             next_decrypter_ =
                 visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter();
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h
index 0307594..fe395ce 100644
--- a/quic/core/quic_framer.h
+++ b/quic/core/quic_framer.h
@@ -1117,6 +1117,8 @@
   // The value of the current key phase bit, which is toggled when the keys are
   // changed.
   bool current_key_phase_bit_;
+  // Whether we have performed a key update at least once.
+  bool key_update_performed_ = false;
   // Tracks the first packet received in the current key phase. Will be
   // uninitialized before the first one-RTT packet has been received or after a
   // locally initiated key update but before the first packet from the peer in
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc
index dafe1f7..9f829e5 100644
--- a/quic/core/quic_framer_test.cc
+++ b/quic/core/quic_framer_test.cc
@@ -15164,6 +15164,42 @@
   EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
 }
 
+TEST_P(QuicFramerTest, KeyUpdateOnFirstReceivedPacket) {
+  if (!framer_.version().UsesTls()) {
+    // Key update is only used in QUIC+TLS.
+    return;
+  }
+  SetQuicReloadableFlag(quic_fix_key_update_on_first_packet, true);
+  ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse());
+  // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter
+  // instead of TestDecrypter.
+  framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+                           std::make_unique<StrictTaggingDecrypter>(/*key=*/0));
+  framer_.SetKeyUpdateSupportForConnection(true);
+
+  QuicPacketHeader header;
+  header.destination_connection_id = FramerTestConnectionId();
+  header.reset_flag = false;
+  header.version_flag = false;
+  header.packet_number = QuicPacketNumber(123);
+
+  QuicFrames frames = {QuicFrame(QuicPaddingFrame())};
+
+  QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT);
+  std::unique_ptr<QuicPacket> data(BuildDataPacket(header, frames));
+  ASSERT_TRUE(data != nullptr);
+  std::unique_ptr<QuicEncryptedPacket> encrypted(
+      EncryptPacketWithTagAndPhase(*data, /*tag=*/1, /*phase=*/true));
+  ASSERT_TRUE(encrypted);
+
+  QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER);
+  EXPECT_TRUE(framer_.ProcessPacket(*encrypted));
+  // Processed valid packet with phase=1, key=1: do key update.
+  EXPECT_EQ(1u, visitor_.key_update_count());
+  EXPECT_EQ(1, visitor_.derive_next_key_count_);
+  EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
+}
+
 TEST_P(QuicFramerTest, ErrorWhenUnexpectedFrameTypeEncountered) {
   if (!GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) ||
       !VersionHasIetfQuicFrames(framer_.transport_version()) ||
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc
index 97e2194..081791d 100644
--- a/quic/core/quic_versions.cc
+++ b/quic/core/quic_versions.cc
@@ -593,6 +593,7 @@
 
 void QuicVersionInitializeSupportForIetfDraft() {
   // Enable necessary flags.
+  SetQuicReloadableFlag(quic_fix_key_update_on_first_packet, true);
 }
 
 void QuicEnableVersion(const ParsedQuicVersion& version) {