gfe-relnote: If batch writer is used in a QuicConneciton, flush it right after a MTU probe is sent. Protected by --gfe2_reloadable_flag_quic_batch_writer_flush_after_mtu_probe.

PiperOrigin-RevId: 292547985
Change-Id: I14c3eab267c5d2a78c0fa742285910c7b1281161
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 09820f3..9d9f1fc 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -337,7 +337,9 @@
           GetQuicReloadableFlag(quic_use_handshaker_delegate2) ||
           version().handshake_protocol == PROTOCOL_TLS1_3),
       check_handshake_timeout_before_idle_timeout_(GetQuicReloadableFlag(
-          quic_check_handshake_timeout_before_idle_timeout)) {
+          quic_check_handshake_timeout_before_idle_timeout)),
+      batch_writer_flush_after_mtu_probe_(
+          GetQuicReloadableFlag(quic_batch_writer_flush_after_mtu_probe)) {
   QUIC_DLOG(INFO) << ENDPOINT << "Created connection with server connection ID "
                   << server_connection_id
                   << " and version: " << ParsedQuicVersionToString(version());
@@ -2192,8 +2194,9 @@
                     ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
     return true;
   }
-  SerializedPacketFate fate = DeterminePacketFate(
-      /*is_mtu_discovery=*/packet->encrypted_length > long_term_mtu_);
+  const bool looks_like_mtu_probe = packet->retransmittable_frames.empty() &&
+                                    packet->encrypted_length > long_term_mtu_;
+  SerializedPacketFate fate = DeterminePacketFate(looks_like_mtu_probe);
   // Termination packets are encrypted and saved, so don't exit early.
   const bool is_termination_packet = IsTerminationPacket(*packet);
   QuicPacketNumber packet_number = packet->packet_number;
@@ -2211,8 +2214,6 @@
         new QuicEncryptedPacket(buffer_copy, encrypted_length, true));
   }
 
-  const bool looks_like_mtu_probe = packet->retransmittable_frames.empty() &&
-                                    packet->encrypted_length > long_term_mtu_;
   DCHECK_LE(encrypted_length, kMaxOutgoingPacketSize);
   if (!looks_like_mtu_probe) {
     DCHECK_LE(encrypted_length, packet_creator_.max_packet_length());
@@ -2285,6 +2286,18 @@
       result = writer_->WritePacket(packet->encrypted_buffer, encrypted_length,
                                     self_address().host(), peer_address(),
                                     per_packet_options_);
+      // This is a work around for an issue with linux UDP GSO batch writers.
+      // When sending a GSO packet with 2 segments, if the first segment is
+      // larger than the path MTU, instead of EMSGSIZE, the linux kernel returns
+      // EINVAL, which translates to WRITE_STATUS_ERROR and causes conneciton to
+      // be closed. By manually flush the writer here, the MTU probe is sent in
+      // a normal(non-GSO) packet, so the kernel can return EMSGSIZE and we will
+      // not close the connection.
+      if (batch_writer_flush_after_mtu_probe_ && looks_like_mtu_probe &&
+          writer_->IsBatchMode()) {
+        QUIC_RELOADABLE_FLAG_COUNT(quic_batch_writer_flush_after_mtu_probe);
+        result = writer_->Flush();
+      }
       break;
     case FAILED_TO_WRITE_COALESCED_PACKET:
       // Failed to send existing coalesced packet when determining packet fate,
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index bcedee0..12a899f 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1520,6 +1520,9 @@
 
   // Latched value of quic_check_handshake_timeout_before_idle_timeout.
   const bool check_handshake_timeout_before_idle_timeout_;
+
+  // Latched value of quic_batch_writer_flush_after_mtu_probe.
+  const bool batch_writer_flush_after_mtu_probe_;
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 7809cc5..8f02bbb 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -340,24 +340,7 @@
   TestPacketWriter(ParsedQuicVersion version, MockClock* clock)
       : version_(version),
         framer_(SupportedVersions(version_), Perspective::IS_SERVER),
-        last_packet_size_(0),
-        write_blocked_(false),
-        write_should_fail_(false),
-        block_on_next_flush_(false),
-        block_on_next_write_(false),
-        next_packet_too_large_(false),
-        always_get_packet_too_large_(false),
-        is_write_blocked_data_buffered_(false),
-        is_batch_mode_(false),
-        final_bytes_of_last_packet_(0),
-        final_bytes_of_previous_packet_(0),
-        use_tagging_decrypter_(false),
-        packets_write_attempts_(0),
-        connection_close_packets_(0),
-        clock_(clock),
-        write_pause_time_delta_(QuicTime::Delta::Zero()),
-        max_packet_size_(kMaxOutgoingPacketSize),
-        supports_release_time_(false) {
+        clock_(clock) {
     QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
                                                         TestConnectionId());
     framer_.framer()->SetInitialObfuscators(TestConnectionId());
@@ -430,6 +413,10 @@
     if (!write_pause_time_delta_.IsZero()) {
       clock_->AdvanceTime(write_pause_time_delta_);
     }
+    if (is_batch_mode_) {
+      bytes_buffered_ += last_packet_size_;
+      return WriteResult(WRITE_STATUS_OK, 0);
+    }
     return WriteResult(WRITE_STATUS_OK, last_packet_size_);
   }
 
@@ -459,12 +446,15 @@
   }
 
   WriteResult Flush() override {
+    flush_attempts_++;
     if (block_on_next_flush_) {
       block_on_next_flush_ = false;
       SetWriteBlocked();
       return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
     }
-    return WriteResult(WRITE_STATUS_OK, 0);
+    int bytes_flushed = bytes_buffered_;
+    bytes_buffered_ = 0;
+    return WriteResult(WRITE_STATUS_OK, bytes_flushed);
   }
 
   void BlockOnNextFlush() { block_on_next_flush_ = true; }
@@ -572,7 +562,9 @@
 
   void use_tagging_decrypter() { use_tagging_decrypter_ = true; }
 
-  uint32_t packets_write_attempts() { return packets_write_attempts_; }
+  uint32_t packets_write_attempts() const { return packets_write_attempts_; }
+
+  uint32_t flush_attempts() const { return flush_attempts_; }
 
   uint32_t connection_close_packets() const {
     return connection_close_packets_;
@@ -597,27 +589,32 @@
  private:
   ParsedQuicVersion version_;
   SimpleQuicFramer framer_;
-  size_t last_packet_size_;
+  size_t last_packet_size_ = 0;
   QuicPacketHeader last_packet_header_;
-  bool write_blocked_;
-  bool write_should_fail_;
-  bool block_on_next_flush_;
-  bool block_on_next_write_;
-  bool next_packet_too_large_;
-  bool always_get_packet_too_large_;
-  bool is_write_blocked_data_buffered_;
-  bool is_batch_mode_;
-  uint32_t final_bytes_of_last_packet_;
-  uint32_t final_bytes_of_previous_packet_;
-  bool use_tagging_decrypter_;
-  uint32_t packets_write_attempts_;
-  uint32_t connection_close_packets_;
-  MockClock* clock_;
+  bool write_blocked_ = false;
+  bool write_should_fail_ = false;
+  bool block_on_next_flush_ = false;
+  bool block_on_next_write_ = false;
+  bool next_packet_too_large_ = false;
+  bool always_get_packet_too_large_ = false;
+  bool is_write_blocked_data_buffered_ = false;
+  bool is_batch_mode_ = false;
+  // Number of times Flush() was called.
+  uint32_t flush_attempts_ = 0;
+  // (Batch mode only) Number of bytes buffered in writer. It is used as the
+  // return value of a successful Flush().
+  uint32_t bytes_buffered_ = 0;
+  uint32_t final_bytes_of_last_packet_ = 0;
+  uint32_t final_bytes_of_previous_packet_ = 0;
+  bool use_tagging_decrypter_ = false;
+  uint32_t packets_write_attempts_ = 0;
+  uint32_t connection_close_packets_ = 0;
+  MockClock* clock_ = nullptr;
   // If non-zero, the clock will pause during WritePacket for this amount of
   // time.
-  QuicTime::Delta write_pause_time_delta_;
-  QuicByteCount max_packet_size_;
-  bool supports_release_time_;
+  QuicTime::Delta write_pause_time_delta_ = QuicTime::Delta::Zero();
+  QuicByteCount max_packet_size_ = kMaxOutgoingPacketSize;
+  bool supports_release_time_ = false;
 };
 
 class TestConnection : public QuicConnection {
@@ -4836,6 +4833,25 @@
   EXPECT_EQ(QuicPacketNumber(4u), creator_->packet_number());
 }
 
+// Verifies that when a MTU probe packet is sent and buffered in a batch writer,
+// the writer is flushed immediately.
+TEST_P(QuicConnectionTest, BatchWriterFlushedAfterMtuDiscoveryPacket) {
+  writer_->SetBatchMode(true);
+  MtuDiscoveryTestInit();
+
+  // Send an MTU probe.
+  const size_t target_mtu = kDefaultMaxPacketSize + 100;
+  QuicByteCount mtu_probe_size;
+  EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _))
+      .WillOnce(SaveArg<3>(&mtu_probe_size));
+  const uint32_t prior_flush_attempts = writer_->flush_attempts();
+  connection_.SendMtuDiscoveryPacket(target_mtu);
+  EXPECT_EQ(target_mtu, mtu_probe_size);
+  if (GetQuicReloadableFlag(quic_batch_writer_flush_after_mtu_probe)) {
+    EXPECT_EQ(writer_->flush_attempts(), prior_flush_attempts + 1);
+  }
+}
+
 // Tests whether MTU discovery does not happen when it is not explicitly enabled
 // by the connection options.
 TEST_P(QuicConnectionTest, MtuDiscoveryDisabled) {