In quic, determine a serialized packet fate before it gets serialized, and use the fate to determine whether add full padding later. protected by gfe2_reloadable_flag_quic_determine_serialized_packet_fate_early.

PiperOrigin-RevId: 319272370
Change-Id: I43b3b823d50cd5f9c1b8c7d946eb1319706ee41e
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 53dc7d9..8c395a5 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -2506,7 +2506,8 @@
 }
 
 bool QuicConnection::WritePacket(SerializedPacket* packet) {
-  if (ShouldDiscardPacket(*packet)) {
+  if (!packet_creator_.determine_serialized_packet_fate_early() &&
+      ShouldDiscardPacket(packet->encryption_level)) {
     ++stats_.packets_discarded;
     return true;
   }
@@ -2518,11 +2519,12 @@
                     ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
     return true;
   }
-
   const bool is_mtu_discovery = QuicUtils::ContainsFrameType(
       packet->nonretransmittable_frames, MTU_DISCOVERY_FRAME);
-
-  SerializedPacketFate fate = DeterminePacketFate(is_mtu_discovery);
+  const SerializedPacketFate fate =
+      packet_creator_.determine_serialized_packet_fate_early()
+          ? packet->fate
+          : GetSerializedPacketFate(is_mtu_discovery, packet->encryption_level);
   // Termination packets are encrypted and saved, so don't exit early.
   const bool is_termination_packet = IsTerminationPacket(*packet);
   QuicPacketNumber packet_number = packet->packet_number;
@@ -2565,6 +2567,10 @@
   QuicTime packet_send_time = CalculatePacketSentTime();
   WriteResult result(WRITE_STATUS_OK, encrypted_length);
   switch (fate) {
+    case DISCARD:
+      DCHECK(packet_creator_.determine_serialized_packet_fate_early());
+      ++stats_.packets_discarded;
+      return true;
     case COALESCE:
       QUIC_BUG_IF(!version().CanSendCoalescedPackets());
       if (!coalesced_packet_.MaybeCoalescePacket(
@@ -2835,21 +2841,20 @@
          (IsWriteError(result.status) && result.error_code == QUIC_EMSGSIZE);
 }
 
-bool QuicConnection::ShouldDiscardPacket(const SerializedPacket& packet) {
+bool QuicConnection::ShouldDiscardPacket(EncryptionLevel encryption_level) {
   if (!connected_) {
     QUIC_DLOG(INFO) << ENDPOINT
                     << "Not sending packet as connection is disconnected.";
     return true;
   }
 
-  QuicPacketNumber packet_number = packet.packet_number;
   if (encryption_level_ == ENCRYPTION_FORWARD_SECURE &&
-      packet.encryption_level == ENCRYPTION_INITIAL) {
+      encryption_level == ENCRYPTION_INITIAL) {
     // Drop packets that are NULL encrypted since the peer won't accept them
     // anymore.
     QUIC_DLOG(INFO) << ENDPOINT
-                    << "Dropping NULL encrypted packet: " << packet_number
-                    << " since the connection is forward secure.";
+                    << "Dropping NULL encrypted packet since the connection is "
+                       "forward secure.";
     return true;
   }
 
@@ -4490,8 +4495,14 @@
                  bytes_received_before_address_validation_;
 }
 
-SerializedPacketFate QuicConnection::DeterminePacketFate(
-    bool is_mtu_discovery) {
+SerializedPacketFate QuicConnection::GetSerializedPacketFate(
+    bool is_mtu_discovery,
+    EncryptionLevel encryption_level) {
+  if (packet_creator_.determine_serialized_packet_fate_early()) {
+    if (ShouldDiscardPacket(encryption_level)) {
+      return DISCARD;
+    }
+  }
   if (legacy_version_encapsulation_in_progress_) {
     DCHECK(!is_mtu_discovery);
     return LEGACY_VERSION_ENCAPSULATE;
@@ -4505,7 +4516,8 @@
       return COALESCE;
     }
     // Packet cannot be coalesced, flush existing coalesced packet.
-    if (!FlushCoalescedPacket()) {
+    if (!packet_creator_.determine_serialized_packet_fate_early() &&
+        !FlushCoalescedPacket()) {
       return FAILED_TO_WRITE_COALESCED_PACKET;
     }
   }
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index b3e44b3..b6bbc38 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -598,6 +598,9 @@
   void OnSerializedPacket(SerializedPacket packet) override;
   void OnUnrecoverableError(QuicErrorCode error,
                             const std::string& error_details) override;
+  SerializedPacketFate GetSerializedPacketFate(
+      bool is_mtu_discovery,
+      EncryptionLevel encryption_level) override;
 
   // QuicSentPacketManager::NetworkChangeVisitor
   void OnCongestionChange() override;
@@ -1039,7 +1042,7 @@
                                          const std::string& details);
 
   // Returns true if the packet should be discarded and not sent.
-  virtual bool ShouldDiscardPacket(const SerializedPacket& packet);
+  virtual bool ShouldDiscardPacket(EncryptionLevel encryption_level);
 
   // Retransmits packets continuously until blocked by the congestion control.
   // If there are no packets to retransmit, does not do anything.
@@ -1244,9 +1247,6 @@
   // Called to update ACK timeout when an retransmittable frame has been parsed.
   void MaybeUpdateAckTimeout();
 
-  // Returns packet fate when trying to write a packet via WritePacket().
-  SerializedPacketFate DeterminePacketFate(bool is_mtu_discovery);
-
   // Serialize and send coalesced_packet. Returns false if serialization fails
   // or the write causes errors, otherwise, returns true.
   bool FlushCoalescedPacket();
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 8b676fd..ace423a 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -3534,6 +3534,8 @@
 
 TEST_P(QuicConnectionTest, LargeSendWithPendingAck) {
   connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+  EXPECT_CALL(visitor_, GetHandshakeState())
+      .WillRepeatedly(Return(HANDSHAKE_CONFIRMED));
   // Set the ack alarm by processing a ping frame.
   EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_));
 
@@ -7232,6 +7234,11 @@
                               ConnectionCloseBehavior::SILENT_CLOSE);
   EXPECT_FALSE(connection_.connected());
   EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA));
+  if (GetQuicReloadableFlag(quic_determine_serialized_packet_fate_early)) {
+    EXPECT_EQ(DISCARD, connection_.GetSerializedPacketFate(
+                           /*is_mtu_discovery=*/false, ENCRYPTION_INITIAL));
+    return;
+  }
   std::unique_ptr<QuicPacket> packet =
       ConstructDataPacket(1, !kHasStopWaiting, ENCRYPTION_INITIAL);
   EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(1), _, _))
@@ -8339,7 +8346,7 @@
   ASSERT_EQ(0u, connection_.GetStats().packets_sent);
   connection_.set_fill_up_link_during_probing(true);
   EXPECT_CALL(visitor_, GetHandshakeState())
-      .WillRepeatedly(Return(HANDSHAKE_COMPLETE));
+      .WillRepeatedly(Return(HANDSHAKE_CONFIRMED));
   connection_.OnHandshakeComplete();
   connection_.SendStreamData3();
 
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 9fccaab..7cf9f51 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -87,6 +87,12 @@
     return {};
   }
 
+  SerializedPacketFate GetSerializedPacketFate(
+      bool /*is_mtu_discovery*/,
+      EncryptionLevel /*encryption_level*/) override {
+    return SEND_TO_WRITER;
+  }
+
   // QuicStreamFrameDataProducer
   WriteStreamDataResult WriteStreamData(QuicStreamId /*id*/,
                                         QuicStreamOffset offset,
diff --git a/quic/core/quic_legacy_version_encapsulator.cc b/quic/core/quic_legacy_version_encapsulator.cc
index 816b8c7..4afd342 100644
--- a/quic/core/quic_legacy_version_encapsulator.cc
+++ b/quic/core/quic_legacy_version_encapsulator.cc
@@ -77,6 +77,12 @@
   return QuicFrames();
 }
 
+SerializedPacketFate QuicLegacyVersionEncapsulator::GetSerializedPacketFate(
+    bool /*is_mtu_discovery*/,
+    EncryptionLevel /*encryption_level*/) {
+  return SEND_TO_WRITER;
+}
+
 // static
 QuicPacketLength QuicLegacyVersionEncapsulator::Encapsulate(
     quiche::QuicheStringPiece sni,
diff --git a/quic/core/quic_legacy_version_encapsulator.h b/quic/core/quic_legacy_version_encapsulator.h
index 3d9c156..413829b 100644
--- a/quic/core/quic_legacy_version_encapsulator.h
+++ b/quic/core/quic_legacy_version_encapsulator.h
@@ -47,6 +47,9 @@
   bool ShouldGeneratePacket(HasRetransmittableData retransmittable,
                             IsHandshake handshake) override;
   const QuicFrames MaybeBundleAckOpportunistically() override;
+  SerializedPacketFate GetSerializedPacketFate(
+      bool is_mtu_discovery,
+      EncryptionLevel encryption_level) override;
 
   ~QuicLegacyVersionEncapsulator() override;
 
diff --git a/quic/core/quic_packet_creator.cc b/quic/core/quic_packet_creator.cc
index e6f01ea..cf5540e 100644
--- a/quic/core/quic_packet_creator.cc
+++ b/quic/core/quic_packet_creator.cc
@@ -496,6 +496,7 @@
   packet_.transmission_type = NOT_RETRANSMISSION;
   packet_.encrypted_buffer = nullptr;
   packet_.encrypted_length = 0;
+  packet_.fate = SEND_TO_WRITER;
   if (avoid_leak_writer_buffer_) {
     QUIC_RELOADABLE_FLAG_COUNT_N(quic_avoid_leak_writer_buffer, 2, 3);
     QUIC_BUG_IF(packet_.release_encrypted_buffer != nullptr)
@@ -568,6 +569,14 @@
   // Write out the packet header
   QuicPacketHeader header;
   FillPacketHeader(&header);
+  if (determine_serialized_packet_fate_early_) {
+    packet_.fate = delegate_->GetSerializedPacketFate(
+        /*is_mtu_discovery=*/false, packet_.encryption_level);
+    QUIC_DVLOG(1) << ENDPOINT << "fate of packet " << packet_.packet_number
+                  << ": " << SerializedPacketFateToString(packet_.fate)
+                  << " of "
+                  << EncryptionLevelToString(packet_.encryption_level);
+  }
 
   QUIC_CACHELINE_ALIGNED char stack_buffer[kMaxOutgoingPacketSize];
   QuicOwnedPacketBuffer packet_buffer(delegate_->GetPacketBuffer());
@@ -758,6 +767,17 @@
   QuicPacketHeader header;
   // FillPacketHeader increments packet_number_.
   FillPacketHeader(&header);
+  if (determine_serialized_packet_fate_early_ && delegate_ != nullptr) {
+    QUIC_RELOADABLE_FLAG_COUNT(quic_determine_serialized_packet_fate_early);
+    packet_.fate = delegate_->GetSerializedPacketFate(
+        /*is_mtu_discovery=*/QuicUtils::ContainsFrameType(queued_frames_,
+                                                          MTU_DISCOVERY_FRAME),
+        packet_.encryption_level);
+    QUIC_DVLOG(1) << ENDPOINT << "fate of packet " << packet_.packet_number
+                  << ": " << SerializedPacketFateToString(packet_.fate)
+                  << " of "
+                  << EncryptionLevelToString(packet_.encryption_level);
+  }
 
   MaybeAddPadding();
 
@@ -1758,27 +1778,36 @@
     needs_full_padding_ = true;
   }
 
-  // Packet coalescer pads INITIAL packets, so the creator should not.
-  if (framer_->version().CanSendCoalescedPackets() &&
-      (packet_.encryption_level == ENCRYPTION_INITIAL ||
-       packet_.encryption_level == ENCRYPTION_HANDSHAKE)) {
-    // TODO(fayang): MTU discovery packets should not ever be sent as
-    // ENCRYPTION_INITIAL or ENCRYPTION_HANDSHAKE.
-    bool is_mtu_discovery = false;
-    for (const auto& frame : packet_.nonretransmittable_frames) {
-      if (frame.type == MTU_DISCOVERY_FRAME) {
-        is_mtu_discovery = true;
-        break;
-      }
-    }
-    if (!is_mtu_discovery) {
-      // Do not add full padding if connection tries to coalesce packet.
+  if (determine_serialized_packet_fate_early_) {
+    if (packet_.fate == COALESCE ||
+        packet_.fate == LEGACY_VERSION_ENCAPSULATE) {
+      // Do not add full padding if the packet is going to be coalesced or
+      // encapsulated.
       needs_full_padding_ = false;
     }
-  }
+  } else {
+    // Packet coalescer pads INITIAL packets, so the creator should not.
+    if (framer_->version().CanSendCoalescedPackets() &&
+        (packet_.encryption_level == ENCRYPTION_INITIAL ||
+         packet_.encryption_level == ENCRYPTION_HANDSHAKE)) {
+      // TODO(fayang): MTU discovery packets should not ever be sent as
+      // ENCRYPTION_INITIAL or ENCRYPTION_HANDSHAKE.
+      bool is_mtu_discovery = false;
+      for (const auto& frame : packet_.nonretransmittable_frames) {
+        if (frame.type == MTU_DISCOVERY_FRAME) {
+          is_mtu_discovery = true;
+          break;
+        }
+      }
+      if (!is_mtu_discovery) {
+        // Do not add full padding if connection tries to coalesce packet.
+        needs_full_padding_ = false;
+      }
+    }
 
-  if (disable_padding_override_) {
-    needs_full_padding_ = false;
+    if (disable_padding_override_) {
+      needs_full_padding_ = false;
+    }
   }
 
   // Header protection requires a minimum plaintext packet size.
diff --git a/quic/core/quic_packet_creator.h b/quic/core/quic_packet_creator.h
index c579e42..e5a73c3 100644
--- a/quic/core/quic_packet_creator.h
+++ b/quic/core/quic_packet_creator.h
@@ -58,6 +58,13 @@
     // Called when there is data to be sent. Retrieves updated ACK frame from
     // the delegate.
     virtual const QuicFrames MaybeBundleAckOpportunistically() = 0;
+
+    // Returns the packet fate for serialized packets which will be handed over
+    // to delegate via OnSerializedPacket(). Called when a packet is about to be
+    // serialized.
+    virtual SerializedPacketFate GetSerializedPacketFate(
+        bool is_mtu_discovery,
+        EncryptionLevel encryption_level) = 0;
   };
 
   // Interface which gets callbacks from the QuicPacketCreator at interesting
@@ -442,6 +449,10 @@
     disable_padding_override_ = should_disable_padding;
   }
 
+  bool determine_serialized_packet_fate_early() const {
+    return determine_serialized_packet_fate_early_;
+  }
+
  private:
   friend class test::QuicPacketCreatorPeer;
 
@@ -628,6 +639,8 @@
       GetQuicReloadableFlag(quic_fix_min_crypto_frame_size);
 
   // When true, this will override the padding generation code to disable it.
+  // TODO(fayang): remove this when deprecating
+  // quic_determine_serialized_packet_fate_early.
   bool disable_padding_override_ = false;
 
   const bool update_packet_size_ =
@@ -635,6 +648,9 @@
 
   const bool fix_extra_padding_bytes_ =
       GetQuicReloadableFlag(quic_fix_extra_padding_bytes);
+
+  const bool determine_serialized_packet_fate_early_ =
+      GetQuicReloadableFlag(quic_determine_serialized_packet_fate_early);
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_packet_creator_test.cc b/quic/core/quic_packet_creator_test.cc
index 771a34a..65377ef 100644
--- a/quic/core/quic_packet_creator_test.cc
+++ b/quic/core/quic_packet_creator_test.cc
@@ -164,6 +164,10 @@
         creator_(connection_id_, &client_framer_, &delegate_, &producer_) {
     EXPECT_CALL(delegate_, GetPacketBuffer())
         .WillRepeatedly(Return(QuicPacketBuffer()));
+    if (GetQuicReloadableFlag(quic_determine_serialized_packet_fate_early)) {
+      EXPECT_CALL(delegate_, GetSerializedPacketFate(_, _))
+          .WillRepeatedly(Return(SEND_TO_WRITER));
+    }
     creator_.SetEncrypter(ENCRYPTION_INITIAL, std::make_unique<NullEncrypter>(
                                                   Perspective::IS_CLIENT));
     creator_.SetEncrypter(ENCRYPTION_HANDSHAKE, std::make_unique<NullEncrypter>(
@@ -506,6 +510,11 @@
     EXPECT_CALL(delegate_, OnSerializedPacket(_))
         .WillRepeatedly(
             Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket));
+    if (GetQuicReloadableFlag(quic_determine_serialized_packet_fate_early) &&
+        client_framer_.version().CanSendCoalescedPackets()) {
+      EXPECT_CALL(delegate_, GetSerializedPacketFate(_, _))
+          .WillRepeatedly(Return(COALESCE));
+    }
     if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) {
       ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket(
           QuicUtils::GetCryptoStreamId(client_framer_.transport_version()),
@@ -2292,6 +2301,10 @@
               OnUnrecoverableError,
               (QuicErrorCode, const std::string&),
               (override));
+  MOCK_METHOD(SerializedPacketFate,
+              GetSerializedPacketFate,
+              (bool, EncryptionLevel),
+              (override));
 
   void SetCanWriteAnything() {
     EXPECT_CALL(*this, ShouldGeneratePacket(_, _)).WillRepeatedly(Return(true));
@@ -2443,6 +2456,10 @@
         ack_frame_(InitAckFrame(1)) {
     EXPECT_CALL(delegate_, GetPacketBuffer())
         .WillRepeatedly(Return(QuicPacketBuffer()));
+    if (GetQuicReloadableFlag(quic_determine_serialized_packet_fate_early)) {
+      EXPECT_CALL(delegate_, GetSerializedPacketFate(_, _))
+          .WillRepeatedly(Return(SEND_TO_WRITER));
+    }
     creator_.SetEncrypter(
         ENCRYPTION_FORWARD_SECURE,
         std::make_unique<NullEncrypter>(Perspective::IS_CLIENT));
diff --git a/quic/core/quic_packets.cc b/quic/core/quic_packets.cc
index 83b9773..2098bfe 100644
--- a/quic/core/quic_packets.cc
+++ b/quic/core/quic_packets.cc
@@ -462,7 +462,8 @@
       has_ack(has_ack),
       has_stop_waiting(has_stop_waiting),
       transmission_type(NOT_RETRANSMISSION),
-      has_ack_frame_copy(false) {}
+      has_ack_frame_copy(false),
+      fate(SEND_TO_WRITER) {}
 
 SerializedPacket::SerializedPacket(SerializedPacket&& other)
     : has_crypto_handshake(other.has_crypto_handshake),
@@ -473,7 +474,8 @@
       has_stop_waiting(other.has_stop_waiting),
       transmission_type(other.transmission_type),
       largest_acked(other.largest_acked),
-      has_ack_frame_copy(other.has_ack_frame_copy) {
+      has_ack_frame_copy(other.has_ack_frame_copy),
+      fate(other.fate) {
   if (this != &other) {
     if (release_encrypted_buffer && encrypted_buffer != nullptr) {
       release_encrypted_buffer(encrypted_buffer);
@@ -516,6 +518,7 @@
   copy->encryption_level = serialized.encryption_level;
   copy->transmission_type = serialized.transmission_type;
   copy->largest_acked = serialized.largest_acked;
+  copy->fate = serialized.fate;
 
   if (copy_buffer) {
     copy->encrypted_buffer = CopyBuffer(serialized);
diff --git a/quic/core/quic_packets.h b/quic/core/quic_packets.h
index 51105ad..450b842 100644
--- a/quic/core/quic_packets.h
+++ b/quic/core/quic_packets.h
@@ -403,6 +403,7 @@
   // Indicates whether this packet has a copy of ack frame in
   // nonretransmittable_frames.
   bool has_ack_frame_copy;
+  SerializedPacketFate fate;
 };
 
 // Make a copy of |serialized| (including the underlying frames). |copy_buffer|
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index b8e900d..57d0d7e 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -698,9 +698,12 @@
 
 // Indicates the fate of a serialized packet in WritePacket().
 enum SerializedPacketFate : uint8_t {
-  COALESCE,                          // Try to coalesce packet.
-  BUFFER,                            // Buffer packet in buffered_packets_.
-  SEND_TO_WRITER,                    // Send packet to writer.
+  DISCARD,         // Discard the packet.
+  COALESCE,        // Try to coalesce packet.
+  BUFFER,          // Buffer packet in buffered_packets_.
+  SEND_TO_WRITER,  // Send packet to writer.
+  // TODO(fayang): remove FAILED_TO_WRITE_COALESCED_PACKET when deprecating
+  // quic_determine_serialized_packet_fate_early.
   FAILED_TO_WRITE_COALESCED_PACKET,  // Packet cannot be coalesced, error occurs
                                      // when sending existing coalesced packet.
   LEGACY_VERSION_ENCAPSULATE,  // Perform Legacy Version Encapsulation on this
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 5ad5575..650213e 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -1493,6 +1493,10 @@
               MaybeBundleAckOpportunistically,
               (),
               (override));
+  MOCK_METHOD(SerializedPacketFate,
+              GetSerializedPacketFate,
+              (bool, EncryptionLevel),
+              (override));
 };
 
 class MockSessionNotifier : public SessionNotifierInterface {