Close connection when an IETF frame of unexpected type is received at the corresponding encryption level.

Protected by FLAGS_quic_reloadable_flag_quic_reject_unexpected_ietf_frame_types.

PiperOrigin-RevId: 360703267
Change-Id: Iba465aee2e8b709757f2c77d9b8bfd860ef89bf4
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 60eccb8..4407411 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -1174,9 +1174,16 @@
                                                   EncryptionLevel level) {
     QuicPacketHeader header = ConstructPacketHeader(number, level);
     QuicFrames frames;
-    frames.push_back(QuicFrame(frame1_));
-    if (has_stop_waiting) {
-      frames.push_back(QuicFrame(stop_waiting_));
+    if (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) &&
+        VersionHasIetfQuicFrames(version().transport_version) &&
+        (level == ENCRYPTION_INITIAL || level == ENCRYPTION_HANDSHAKE)) {
+      frames.push_back(QuicFrame(QuicPingFrame()));
+      frames.push_back(QuicFrame(QuicPaddingFrame(100)));
+    } else {
+      frames.push_back(QuicFrame(frame1_));
+      if (has_stop_waiting) {
+        frames.push_back(QuicFrame(stop_waiting_));
+      }
     }
     return ConstructPacket(header, frames);
   }
@@ -2785,7 +2792,9 @@
 
 TEST_P(QuicConnectionTest, RejectUnencryptedStreamData) {
   // EXPECT_QUIC_BUG tests are expensive so only run one instance of them.
-  if (!IsDefaultTestConfiguration()) {
+  if (!IsDefaultTestConfiguration() ||
+      (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) &&
+       VersionHasIetfQuicFrames(version().transport_version))) {
     return;
   }
 
@@ -2939,28 +2948,6 @@
   }
 }
 
-TEST_P(QuicConnectionTest,
-       AckFrequencyFrameOutsideApplicationDataNumberSpaceIsIgnored) {
-  if (!GetParam().version.HasIetfQuicFrames()) {
-    return;
-  }
-  connection_.set_can_receive_ack_frequency_frame();
-
-  QuicAckFrequencyFrame ack_frequency_frame;
-  ack_frequency_frame.packet_tolerance = 3;
-  ProcessFramePacketAtLevel(1, QuicFrame(&ack_frequency_frame),
-                            ENCRYPTION_HANDSHAKE);
-
-  // Expect 30 acks, every 2nd (instead of 3rd) packet including the first
-  // packet with AckFrequencyFrame.
-  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(30);
-  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(60);
-  // Receives packets 2 - 61.
-  for (size_t i = 2; i <= 61; ++i) {
-    ProcessDataPacket(i);
-  }
-}
-
 TEST_P(QuicConnectionTest, AckDecimationReducesAcks) {
   const size_t kMinRttMs = 40;
   RttStats* rtt_stats = const_cast<RttStats*>(manager_->GetRttStats());
@@ -7757,7 +7744,13 @@
   EXPECT_CALL(visitor_,
               OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF));
   ForceProcessFramePacket(QuicFrame(frame1_));
-  TestConnectionCloseQuicErrorCode(QUIC_MAYBE_CORRUPTED_MEMORY);
+  if (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) &&
+      VersionHasIetfQuicFrames(version().transport_version)) {
+    // INITIAL packet should not contain STREAM frame.
+    TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION);
+  } else {
+    TestConnectionCloseQuicErrorCode(QUIC_MAYBE_CORRUPTED_MEMORY);
+  }
 }
 
 TEST_P(QuicConnectionTest, ClientReceivesRejOnNonCryptoStream) {
@@ -7774,7 +7767,13 @@
   EXPECT_CALL(visitor_,
               OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF));
   ForceProcessFramePacket(QuicFrame(frame1_));
-  TestConnectionCloseQuicErrorCode(QUIC_MAYBE_CORRUPTED_MEMORY);
+  if (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) &&
+      VersionHasIetfQuicFrames(version().transport_version)) {
+    // INITIAL packet should not contain STREAM frame.
+    TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION);
+  } else {
+    TestConnectionCloseQuicErrorCode(QUIC_MAYBE_CORRUPTED_MEMORY);
+  }
 }
 
 TEST_P(QuicConnectionTest, CloseConnectionOnPacketTooLarge) {
@@ -9174,23 +9173,23 @@
   // Receives packet 1000 in initial data.
   ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL);
   EXPECT_TRUE(connection_.HasPendingAcks());
-  peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT,
+  peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE,
                             std::make_unique<TaggingEncrypter>(0x02));
-  SetDecrypter(ENCRYPTION_ZERO_RTT,
+  SetDecrypter(ENCRYPTION_FORWARD_SECURE,
                std::make_unique<StrictTaggingDecrypter>(0x02));
   connection_.SetEncrypter(ENCRYPTION_INITIAL,
                            std::make_unique<TaggingEncrypter>(0x02));
   // Receives packet 1000 in application data.
-  ProcessDataPacketAtLevel(1000, false, ENCRYPTION_ZERO_RTT);
+  ProcessDataPacketAtLevel(1000, false, ENCRYPTION_FORWARD_SECURE);
   EXPECT_TRUE(connection_.HasPendingAcks());
-  connection_.SendApplicationDataAtLevel(ENCRYPTION_ZERO_RTT, 5, "data", 0,
-                                         NO_FIN);
+  connection_.SendApplicationDataAtLevel(ENCRYPTION_FORWARD_SECURE, 5, "data",
+                                         0, NO_FIN);
   // Verify application data ACK gets bundled with outgoing data.
   EXPECT_EQ(2u, writer_->frame_count());
   // Make sure ACK alarm is still set because initial data is not ACKed.
   EXPECT_TRUE(connection_.HasPendingAcks());
   // Receive packet 1001 in application data.
-  ProcessDataPacketAtLevel(1001, false, ENCRYPTION_ZERO_RTT);
+  ProcessDataPacketAtLevel(1001, false, ENCRYPTION_FORWARD_SECURE);
   clock_.AdvanceTime(DefaultRetransmissionTime());
   // Simulates ACK alarm fires and verify two ACKs are flushed.
   EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2);
@@ -9199,13 +9198,9 @@
   connection_.GetAckAlarm()->Fire();
   EXPECT_FALSE(connection_.HasPendingAcks());
   // Receives more packets in application data.
-  ProcessDataPacketAtLevel(1002, false, ENCRYPTION_ZERO_RTT);
+  ProcessDataPacketAtLevel(1002, false, ENCRYPTION_FORWARD_SECURE);
   EXPECT_TRUE(connection_.HasPendingAcks());
 
-  peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE,
-                            std::make_unique<TaggingEncrypter>(0x02));
-  SetDecrypter(ENCRYPTION_FORWARD_SECURE,
-               std::make_unique<StrictTaggingDecrypter>(0x02));
   // Verify zero rtt and forward secure packets get acked in the same packet.
   EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1);
   ProcessDataPacket(1003);
@@ -11049,7 +11044,10 @@
                            std::make_unique<TaggingEncrypter>(0x01));
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
   // Verify all ENCRYPTION_HANDSHAKE packets get processed.
-  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(6);
+  if (!GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) ||
+      !VersionHasIetfQuicFrames(version().transport_version)) {
+    EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(6);
+  }
   connection_.GetProcessUndecryptablePacketsAlarm()->Fire();
   EXPECT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_));
 
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 708ae5c..0e9fa10 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -48,6 +48,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_h3_datagram, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_pass_path_response_to_validator, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_preempt_stream_data_with_handshake_packet, false)
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_reject_unexpected_ietf_frame_types, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_require_handshake_confirmation, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_path_response, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_timestamps, false)
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index fdccef9..24034fd 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -1899,7 +1899,7 @@
   // Handle the payload.
   if (VersionHasIetfQuicFrames(version_.transport_version)) {
     current_received_frame_type_ = 0;
-    if (!ProcessIetfFrameData(&reader, *header)) {
+    if (!ProcessIetfFrameData(&reader, *header, decrypted_level)) {
       current_received_frame_type_ = 0;
       QUICHE_DCHECK_NE(QUIC_NO_ERROR,
                        error_);  // ProcessIetfFrameData sets the error.
@@ -3098,8 +3098,33 @@
   return true;
 }
 
+// static
+bool QuicFramer::IsIetfFrameTypeExpectedForEncryptionLevel(
+    uint64_t frame_type,
+    EncryptionLevel level) {
+  switch (level) {
+    case ENCRYPTION_INITIAL:
+    case ENCRYPTION_HANDSHAKE:
+      return frame_type == IETF_CRYPTO || frame_type == IETF_ACK ||
+             frame_type == IETF_PING || frame_type == IETF_PADDING ||
+             frame_type == IETF_CONNECTION_CLOSE;
+    case ENCRYPTION_ZERO_RTT:
+      return !(frame_type == IETF_ACK || frame_type == IETF_CRYPTO ||
+               frame_type == IETF_HANDSHAKE_DONE ||
+               frame_type == IETF_NEW_TOKEN ||
+               frame_type == IETF_PATH_RESPONSE ||
+               frame_type == IETF_RETIRE_CONNECTION_ID);
+    case ENCRYPTION_FORWARD_SECURE:
+      return true;
+    default:
+      QUIC_BUG << "Unknown encryption level: " << level;
+  }
+  return false;
+}
+
 bool QuicFramer::ProcessIetfFrameData(QuicDataReader* reader,
-                                      const QuicPacketHeader& header) {
+                                      const QuicPacketHeader& header,
+                                      EncryptionLevel decrypted_level) {
   QUICHE_DCHECK(VersionHasIetfQuicFrames(version_.transport_version))
       << "Attempt to process frames as IETF frames but version ("
       << version_.transport_version << ") does not support IETF Framing.";
@@ -3118,6 +3143,19 @@
       set_detailed_error("Unable to read frame type.");
       return RaiseError(QUIC_INVALID_FRAME_DATA);
     }
+    if (reject_unexpected_ietf_frame_types_) {
+      QUIC_RELOADABLE_FLAG_COUNT_N(quic_reject_unexpected_ietf_frame_types, 1,
+                                   2);
+      if (!IsIetfFrameTypeExpectedForEncryptionLevel(frame_type,
+                                                     decrypted_level)) {
+        QUIC_RELOADABLE_FLAG_COUNT_N(quic_reject_unexpected_ietf_frame_types, 2,
+                                     2);
+        set_detailed_error(absl::StrCat("IETF frame type ", frame_type,
+                                        " is unexpected at encryption level ",
+                                        decrypted_level));
+        return RaiseError(IETF_QUIC_PROTOCOL_VIOLATION);
+      }
+    }
     current_received_frame_type_ = frame_type;
 
     // Is now the number of bytes into which the frame type was encoded.
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h
index 571278d..0307594 100644
--- a/quic/core/quic_framer.h
+++ b/quic/core/quic_framer.h
@@ -826,8 +826,13 @@
       QuicPacketNumber base_packet_number,
       uint64_t* packet_number);
   bool ProcessFrameData(QuicDataReader* reader, const QuicPacketHeader& header);
+
+  static bool IsIetfFrameTypeExpectedForEncryptionLevel(uint64_t frame_type,
+                                                        EncryptionLevel level);
+
   bool ProcessIetfFrameData(QuicDataReader* reader,
-                            const QuicPacketHeader& header);
+                            const QuicPacketHeader& header,
+                            EncryptionLevel decrypted_level);
   bool ProcessStreamFrame(QuicDataReader* reader,
                           uint8_t frame_type,
                           QuicStreamFrame* frame);
@@ -1158,6 +1163,9 @@
   // Indicates whether received RETRY packets should be dropped.
   bool drop_incoming_retry_packets_ = false;
 
+  bool reject_unexpected_ietf_frame_types_ =
+      GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types);
+
   // The length in bytes of the last packet number written to an IETF-framed
   // packet.
   size_t last_written_packet_number_length_;
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc
index 12b779a..dafe1f7 100644
--- a/quic/core/quic_framer_test.cc
+++ b/quic/core/quic_framer_test.cc
@@ -15164,6 +15164,47 @@
   EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_);
 }
 
+TEST_P(QuicFramerTest, ErrorWhenUnexpectedFrameTypeEncountered) {
+  if (!GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) ||
+      !VersionHasIetfQuicFrames(framer_.transport_version()) ||
+      !QuicVersionHasLongHeaderLengths(framer_.transport_version()) ||
+      !framer_.version().HasLongHeaderLengths()) {
+    return;
+  }
+  SetDecrypterLevel(ENCRYPTION_ZERO_RTT);
+  // clang-format off
+  unsigned char packet[] = {
+    // public flags (long header with packet type ZERO_RTT_PROTECTED and
+    // 4-byte packet number)
+    0xD3,
+    // version
+    QUIC_VERSION_BYTES,
+    // destination connection ID length
+    0x08,
+    // destination connection ID
+    0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10,
+    // source connection ID length
+    0x08,
+    // source connection ID
+    0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11,
+    // long header packet length
+    0x05,
+    // packet number
+    0x12, 0x34, 0x56, 0x00,
+    // unexpected ietf ack frame type in 0-RTT packet
+    0x02,
+  };
+  // clang-format on
+
+  QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false);
+
+  EXPECT_FALSE(framer_.ProcessPacket(encrypted));
+
+  EXPECT_THAT(framer_.error(), IsError(IETF_QUIC_PROTOCOL_VIOLATION));
+  EXPECT_EQ("IETF frame type 2 is unexpected at encryption level 2",
+            framer_.detailed_error());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/quic_packet_creator_test.cc b/quic/core/quic_packet_creator_test.cc
index 8ff44c1..11c9800 100644
--- a/quic/core/quic_packet_creator_test.cc
+++ b/quic/core/quic_packet_creator_test.cc
@@ -299,7 +299,9 @@
   for (int i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) {
     EncryptionLevel level = static_cast<EncryptionLevel>(i);
     creator_.set_encryption_level(level);
-    frames_.push_back(QuicFrame(new QuicAckFrame(InitAckFrame(1))));
+    if (level != ENCRYPTION_ZERO_RTT) {
+      frames_.push_back(QuicFrame(new QuicAckFrame(InitAckFrame(1))));
+    }
     QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId(
         client_framer_.transport_version(), Perspective::IS_CLIENT);
     if (level != ENCRYPTION_INITIAL && level != ENCRYPTION_HANDSHAKE) {
@@ -308,7 +310,9 @@
     }
     SerializedPacket serialized = SerializeAllFrames(frames_);
     EXPECT_EQ(level, serialized.encryption_level);
-    delete frames_[0].ack_frame;
+    if (level != ENCRYPTION_ZERO_RTT) {
+      delete frames_[0].ack_frame;
+    }
     frames_.clear();
 
     {
@@ -318,13 +322,15 @@
       EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_));
       EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _));
       EXPECT_CALL(framer_visitor_, OnPacketHeader(_));
-      EXPECT_CALL(framer_visitor_, OnAckFrameStart(_, _))
-          .WillOnce(Return(true));
-      EXPECT_CALL(framer_visitor_,
-                  OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)))
-          .WillOnce(Return(true));
-      EXPECT_CALL(framer_visitor_, OnAckFrameEnd(QuicPacketNumber(1)))
-          .WillOnce(Return(true));
+      if (level != ENCRYPTION_ZERO_RTT) {
+        EXPECT_CALL(framer_visitor_, OnAckFrameStart(_, _))
+            .WillOnce(Return(true));
+        EXPECT_CALL(framer_visitor_,
+                    OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)))
+            .WillOnce(Return(true));
+        EXPECT_CALL(framer_visitor_, OnAckFrameEnd(QuicPacketNumber(1)))
+            .WillOnce(Return(true));
+      }
       if (level != ENCRYPTION_INITIAL && level != ENCRYPTION_HANDSHAKE) {
         EXPECT_CALL(framer_visitor_, OnStreamFrame(_));
       }
@@ -2118,7 +2124,9 @@
     EncryptionLevel level = static_cast<EncryptionLevel>(i);
     creator_.set_encryption_level(level);
     QuicAckFrame ack_frame(InitAckFrame(1));
-    frames_.push_back(QuicFrame(&ack_frame));
+    if (level != ENCRYPTION_ZERO_RTT) {
+      frames_.push_back(QuicFrame(&ack_frame));
+    }
     if (level != ENCRYPTION_INITIAL && level != ENCRYPTION_HANDSHAKE) {
       frames_.push_back(
           QuicFrame(QuicStreamFrame(1, false, 0u, absl::string_view())));
@@ -2156,20 +2164,26 @@
     EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_));
     EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _));
     EXPECT_CALL(framer_visitor_, OnPacketHeader(_));
-    EXPECT_CALL(framer_visitor_, OnAckFrameStart(_, _)).WillOnce(Return(true));
-    EXPECT_CALL(framer_visitor_,
-                OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)))
-        .WillOnce(Return(true));
-    EXPECT_CALL(framer_visitor_, OnAckFrameEnd(_)).WillOnce(Return(true));
+    if (i != ENCRYPTION_ZERO_RTT) {
+      EXPECT_CALL(framer_visitor_, OnAckFrameStart(_, _))
+          .WillOnce(Return(true));
+      EXPECT_CALL(framer_visitor_,
+                  OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)))
+          .WillOnce(Return(true));
+      EXPECT_CALL(framer_visitor_, OnAckFrameEnd(_)).WillOnce(Return(true));
+    }
     if (i == ENCRYPTION_INITIAL) {
       // Verify padding is added.
       EXPECT_CALL(framer_visitor_, OnPaddingFrame(_));
-    } else {
+    } else if (i != ENCRYPTION_ZERO_RTT) {
       EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)).Times(testing::AtMost(1));
     }
     if (i != ENCRYPTION_INITIAL && i != ENCRYPTION_HANDSHAKE) {
       EXPECT_CALL(framer_visitor_, OnStreamFrame(_));
     }
+    if (i == ENCRYPTION_ZERO_RTT) {
+      EXPECT_CALL(framer_visitor_, OnPaddingFrame(_));
+    }
     EXPECT_CALL(framer_visitor_, OnPacketComplete());
 
     server_framer_.ProcessPacket(*packets[i]);