Close connection on packet serialization failures in SerializePacket.

Protected by FLAGS_quic_reloadable_flag_quic_close_connection_on_serialization_failure.

PiperOrigin-RevId: 329331250
Change-Id: I04f27188e8c3288a58697f2b9bbaf825c0e3c652
diff --git a/quic/core/quic_packet_creator.cc b/quic/core/quic_packet_creator.cc
index 0bf8bdc..c06d167 100644
--- a/quic/core/quic_packet_creator.cc
+++ b/quic/core/quic_packet_creator.cc
@@ -133,6 +133,9 @@
       fully_pad_crypto_handshake_packets_(true),
       latched_hard_max_packet_length_(0),
       max_datagram_frame_size_(0) {
+  if (close_connection_on_serialization_failure_) {
+    QUIC_RELOADABLE_FLAG_COUNT(quic_close_connection_on_serialization_failure);
+  }
   SetMaxPacketLength(kDefaultMaxPacketSize);
   if (!framer_->version().UsesTls()) {
     // QUIC+TLS negotiates the maximum datagram frame size via the
@@ -470,12 +473,18 @@
   }
 
   DCHECK_EQ(nullptr, packet_.encrypted_buffer);
-  SerializePacket(std::move(external_buffer), kMaxOutgoingPacketSize);
+  const bool success =
+      SerializePacket(std::move(external_buffer), kMaxOutgoingPacketSize);
+  if (close_connection_on_serialization_failure_ && !success) {
+    return;
+  }
   OnSerializedPacket();
 }
 
 void QuicPacketCreator::OnSerializedPacket() {
-  if (packet_.encrypted_buffer == nullptr) {
+  if (close_connection_on_serialization_failure_) {
+    QUIC_BUG_IF(packet_.encrypted_buffer == nullptr);
+  } else if (packet_.encrypted_buffer == nullptr) {
     const std::string error_details = "Failed to SerializePacket.";
     QUIC_BUG << error_details;
     delegate_->OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET,
@@ -543,8 +552,11 @@
       return 0;
     }
   }
-  SerializePacket(QuicOwnedPacketBuffer(buffer, nullptr), buffer_len);
-  // TODO(b/166255274): report unrecoverable error on serialization failures.
+  const bool success =
+      SerializePacket(QuicOwnedPacketBuffer(buffer, nullptr), buffer_len);
+  if (close_connection_on_serialization_failure_ && !success) {
+    return 0;
+  }
   const size_t encrypted_length = packet_.encrypted_length;
   // Clear frames in packet_. No need to DeleteFrames since frames are owned by
   // initial_packet.
@@ -562,6 +574,7 @@
     bool fin,
     TransmissionType transmission_type,
     size_t* num_bytes_consumed) {
+  // TODO(b/167222597): consider using ScopedSerializationFailureHandler.
   DCHECK(queued_frames_.empty());
   DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id));
   // Write out the packet header
@@ -738,11 +751,22 @@
   return false;
 }
 
-void QuicPacketCreator::SerializePacket(QuicOwnedPacketBuffer encrypted_buffer,
+bool QuicPacketCreator::SerializePacket(QuicOwnedPacketBuffer encrypted_buffer,
                                         size_t encrypted_buffer_len) {
-  const bool use_queued_frames_cleaner = GetQuicReloadableFlag(
-      quic_neuter_initial_packet_in_coalescer_with_initial_key_discarded);
-  ScopedQueuedFramesCleaner cleaner(use_queued_frames_cleaner ? this : nullptr);
+  if (close_connection_on_serialization_failure_ &&
+      packet_.encrypted_buffer != nullptr) {
+    const std::string error_details =
+        "Packet's encrypted buffer is not empty before serialization";
+    QUIC_BUG << error_details;
+    delegate_->OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET,
+                                    error_details);
+    return false;
+  }
+  const bool use_handler =
+      GetQuicReloadableFlag(
+          quic_neuter_initial_packet_in_coalescer_with_initial_key_discarded) ||
+      close_connection_on_serialization_failure_;
+  ScopedSerializationFailureHandler handler(use_handler ? this : nullptr);
 
   DCHECK_LT(0u, encrypted_buffer_len);
   QUIC_BUG_IF(queued_frames_.empty() && pending_padding_bytes_ == 0)
@@ -772,7 +796,7 @@
              << QuicFramesToString(queued_frames_)
              << " at missing encryption_level " << packet_.encryption_level
              << " using " << framer_->version();
-    return;
+    return false;
   }
 
   DCHECK_GE(max_plaintext_size_, packet_size_);
@@ -790,7 +814,7 @@
              << latched_hard_max_packet_length_
              << ", max_packet_length_: " << max_packet_length_
              << ", header: " << header;
-    return;
+    return false;
   }
 
   // ACK Frames will be truncated due to length only if they're the only frame
@@ -811,11 +835,11 @@
       encrypted_buffer_len, encrypted_buffer.buffer);
   if (encrypted_length == 0) {
     QUIC_BUG << "Failed to encrypt packet number " << packet_.packet_number;
-    return;
+    return false;
   }
 
   packet_size_ = 0;
-  if (!use_queued_frames_cleaner) {
+  if (!use_handler) {
     queued_frames_.clear();
   }
   packet_.encrypted_buffer = encrypted_buffer.buffer;
@@ -823,6 +847,7 @@
 
   encrypted_buffer.buffer = nullptr;
   packet_.release_encrypted_buffer = std::move(encrypted_buffer).release_buffer;
+  return true;
 }
 
 std::unique_ptr<QuicEncryptedPacket>
@@ -1983,15 +2008,25 @@
   creator_->SetDefaultPeerAddress(old_peer_address_);
 }
 
-QuicPacketCreator::ScopedQueuedFramesCleaner::ScopedQueuedFramesCleaner(
-    QuicPacketCreator* creator)
+QuicPacketCreator::ScopedSerializationFailureHandler::
+    ScopedSerializationFailureHandler(QuicPacketCreator* creator)
     : creator_(creator) {}
 
-QuicPacketCreator::ScopedQueuedFramesCleaner::~ScopedQueuedFramesCleaner() {
+QuicPacketCreator::ScopedSerializationFailureHandler::
+    ~ScopedSerializationFailureHandler() {
   if (creator_ == nullptr) {
     return;
   }
+  // Always clear queued_frames_.
   creator_->queued_frames_.clear();
+
+  if (creator_->close_connection_on_serialization_failure_ &&
+      creator_->packet_.encrypted_buffer == nullptr) {
+    const std::string error_details = "Failed to SerializePacket.";
+    QUIC_BUG << error_details;
+    creator_->delegate_->OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET,
+                                              error_details);
+  }
 }
 
 void QuicPacketCreator::set_encryption_level(EncryptionLevel level) {
diff --git a/quic/core/quic_packet_creator.h b/quic/core/quic_packet_creator.h
index 02bc9e3..a82cec4 100644
--- a/quic/core/quic_packet_creator.h
+++ b/quic/core/quic_packet_creator.h
@@ -27,6 +27,7 @@
 #include "net/third_party/quiche/src/quic/core/quic_packets.h"
 #include "net/third_party/quiche/src/quic/core/quic_types.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_export.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_macros.h"
 #include "net/third_party/quiche/src/common/platform/api/quiche_string_piece.h"
 
 namespace quic {
@@ -472,11 +473,12 @@
  private:
   friend class test::QuicPacketCreatorPeer;
 
-  // Used to clear queued_frames_ of creator upon exiting the scope.
-  class QUIC_EXPORT_PRIVATE ScopedQueuedFramesCleaner {
+  // Used to 1) clear queued_frames_, 2) report unrecoverable error (if
+  // serialization fails) upon exiting the scope.
+  class QUIC_EXPORT_PRIVATE ScopedSerializationFailureHandler {
    public:
-    explicit ScopedQueuedFramesCleaner(QuicPacketCreator* creator);
-    ~ScopedQueuedFramesCleaner();
+    explicit ScopedSerializationFailureHandler(QuicPacketCreator* creator);
+    ~ScopedSerializationFailureHandler();
 
    private:
     QuicPacketCreator* creator_;  // Unowned.
@@ -507,10 +509,11 @@
 
   // Serializes all frames which have been added and adds any which should be
   // retransmitted to packet_.retransmittable_frames. All frames must fit into
-  // a single packet.
-  // Fails if |encrypted_buffer_len| isn't long enough for the encrypted packet.
-  void SerializePacket(QuicOwnedPacketBuffer encrypted_buffer,
-                       size_t encrypted_buffer_len);
+  // a single packet. Returns true on success, otherwise, returns false.
+  // Fails if |encrypted_buffer| is not large enough for the encrypted packet.
+  QUIC_MUST_USE_RESULT bool SerializePacket(
+      QuicOwnedPacketBuffer encrypted_buffer,
+      size_t encrypted_buffer_len);
 
   // Called after a new SerialiedPacket is created to call the delegate's
   // OnSerializedPacket and reset state.
@@ -663,6 +666,9 @@
 
   const bool coalesced_packet_of_higher_space_ =
       GetQuicReloadableFlag(quic_coalesced_packet_of_higher_space2);
+
+  const bool close_connection_on_serialization_failure_ =
+      GetQuicReloadableFlag(quic_close_connection_on_serialization_failure);
 };
 
 }  // namespace quic
diff --git a/quic/test_tools/quic_packet_creator_peer.cc b/quic/test_tools/quic_packet_creator_peer.cc
index 151575a..f94620b 100644
--- a/quic/test_tools/quic_packet_creator_peer.cc
+++ b/quic/test_tools/quic_packet_creator_peer.cc
@@ -111,7 +111,9 @@
     bool success = creator->AddFrame(frame, NOT_RETRANSMISSION);
     DCHECK(success);
   }
-  creator->SerializePacket(QuicOwnedPacketBuffer(buffer, nullptr), buffer_len);
+  const bool success = creator->SerializePacket(
+      QuicOwnedPacketBuffer(buffer, nullptr), buffer_len);
+  DCHECK(success);
   SerializedPacket packet = std::move(creator->packet_);
   // The caller takes ownership of the QuicEncryptedPacket.
   creator->packet_.encrypted_buffer = nullptr;