Properly handle TLS/TCP being write-blocked in MasqueH2Connection

PiperOrigin-RevId: 865522892
diff --git a/quiche/quic/masque/masque_connection_pool.cc b/quiche/quic/masque/masque_connection_pool.cc
index 50fd485..b6bb8ef 100644
--- a/quiche/quic/masque/masque_connection_pool.cc
+++ b/quiche/quic/masque/masque_connection_pool.cc
@@ -498,6 +498,7 @@
 
   SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION);
   SSL_CTX_set_max_proto_version(ctx.get(), TLS1_3_VERSION);
+  SSL_CTX_set_mode(ctx.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
 
   return ctx;
 }
@@ -531,6 +532,8 @@
 
   SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION);
   SSL_CTX_set_max_proto_version(ctx.get(), TLS1_3_VERSION);
+  SSL_CTX_set_mode(ctx.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
+
   return ctx;
 }
 
diff --git a/quiche/quic/masque/masque_h2_connection.cc b/quiche/quic/masque/masque_h2_connection.cc
index ba287f3..ba38888 100644
--- a/quiche/quic/masque/masque_h2_connection.cc
+++ b/quiche/quic/masque/masque_h2_connection.cc
@@ -169,34 +169,60 @@
   return AttemptToSend();
 }
 
-int MasqueH2Connection::WriteDataToTls(absl::string_view data) {
+bool MasqueH2Connection::WriteDataToTls(absl::string_view data) {
   QUICHE_DVLOG(2) << ENDPOINT << "Writing " << data.size()
                   << " app bytes to TLS:" << std::endl
                   << quiche::QuicheTextUtils::HexDump(data);
-  int ssl_write_ret = SSL_write(ssl_, data.data(), data.size());
+  const bool buffered = !tls_write_buffer_.empty();
+  const char* buffer_to_write;
+  size_t size_to_write;
+  if (buffered) {
+    absl::StrAppend(&tls_write_buffer_, data);
+    buffer_to_write = tls_write_buffer_.data();
+    size_to_write = tls_write_buffer_.size();
+  } else {
+    buffer_to_write = data.data();
+    size_to_write = data.size();
+  }
+  const int ssl_write_ret = SSL_write(ssl_, buffer_to_write, size_to_write);
   if (ssl_write_ret <= 0) {
     int ssl_err = SSL_get_error(ssl_, ssl_write_ret);
-    QUICHE_LOG(ERROR) << FormatSslError("Error while writing data to TLS",
-                                        ssl_err, ssl_write_ret);
-    return -1;
+    if (ssl_err == SSL_ERROR_WANT_WRITE) {
+      if (!buffered) {
+        tls_write_buffer_ = std::string(data);
+      }
+      QUICHE_DVLOG(1) << ENDPOINT << "SSL_write will require another write, "
+                      << "buffered " << tls_write_buffer_.size() << " bytes";
+      return true;
+    }
+    Abort(SslErrorStatus("Error while writing data to TLS", ssl_err,
+                         ssl_write_ret));
+    return false;
   }
-  if (ssl_write_ret == static_cast<int>(data.size())) {
-    QUICHE_DVLOG(1) << ENDPOINT << "Wrote " << data.size() << " bytes to TLS";
+  if (ssl_write_ret == static_cast<int>(size_to_write)) {
+    QUICHE_DVLOG(1) << ENDPOINT << "Wrote " << size_to_write << " bytes to TLS";
+    if (buffered) {
+      tls_write_buffer_.clear();
+    }
   } else {
     QUICHE_DVLOG(1) << ENDPOINT << "Wrote " << ssl_write_ret << " / "
-                    << data.size() << "bytes to TLS";
+                    << size_to_write << " bytes to TLS and buffered the rest";
+    if (buffered) {
+      tls_write_buffer_.erase(0, ssl_write_ret);
+    } else {
+      tls_write_buffer_ = std::string(data.substr(ssl_write_ret));
+    }
   }
-  return ssl_write_ret;
+  return true;
 }
 
 int64_t MasqueH2Connection::OnReadyToSend(absl::string_view serialized) {
   QUICHE_DVLOG(1) << ENDPOINT << "Writing " << serialized.size()
                   << " bytes of h2 data to TLS";
-  int write_res = WriteDataToTls(serialized);
-  if (write_res < 0) {
+  if (!WriteDataToTls(serialized)) {
     return kSendError;
   }
-  return write_res;
+  return serialized.size();
 }
 
 MasqueH2Connection::DataFrameHeaderInfo
@@ -218,21 +244,19 @@
 bool MasqueH2Connection::SendDataFrame(Http2StreamId stream_id,
                                        absl::string_view frame_header,
                                        size_t payload_bytes) {
-  if (WriteDataToTls(frame_header) < 0) {
+  if (!WriteDataToTls(frame_header)) {
     return false;
   }
   MasqueH2Stream* stream = GetOrCreateH2Stream(stream_id);
   size_t length_to_write = std::min(payload_bytes, stream->body_to_send.size());
-  int length_written =
-      WriteDataToTls(stream->body_to_send.substr(0, length_to_write));
-  if (length_written < 0) {
+  if (!WriteDataToTls(stream->body_to_send.substr(0, length_to_write))) {
     return false;
   }
-  if (length_written == static_cast<int>(stream->body_to_send.size())) {
+  if (length_to_write == stream->body_to_send.size()) {
     stream->body_to_send.clear();
   } else {
     // Remove the written bytes from the start of `body_to_send`.
-    stream->body_to_send = stream->body_to_send.substr(length_written);
+    stream->body_to_send = stream->body_to_send.substr(length_to_write);
   }
   return true;
 }
@@ -268,6 +292,13 @@
 }
 
 bool MasqueH2Connection::AttemptToSend() {
+  if (!tls_write_buffer_.empty()) {
+    QUICHE_DVLOG(1) << ENDPOINT << "Attempting to write "
+                    << tls_write_buffer_.size() << " buffered bytes to TLS";
+    if (!WriteDataToTls("")) {
+      return false;
+    }
+  }
   if (!h2_adapter_) {
     return false;
   }
diff --git a/quiche/quic/masque/masque_h2_connection.h b/quiche/quic/masque/masque_h2_connection.h
index 9139857..bd16403 100644
--- a/quiche/quic/masque/masque_h2_connection.h
+++ b/quiche/quic/masque/masque_h2_connection.h
@@ -83,7 +83,7 @@
   std::vector<http2::adapter::Header> ConvertHeaders(
       const quiche::HttpHeaderBlock& headers);
 
-  int WriteDataToTls(absl::string_view data);
+  bool WriteDataToTls(absl::string_view data);
 
   // From http2::adapter::Http2VisitorInterface.
   int64_t OnReadyToSend(absl::string_view serialized) override;
@@ -142,6 +142,7 @@
   absl::flat_hash_map<Http2StreamId, std::unique_ptr<MasqueH2Stream>>
       h2_streams_;
   Visitor* visitor_;
+  std::string tls_write_buffer_;
 };
 
 // Formats an SSL error that was provided by BoringSSL.
diff --git a/quiche/quic/masque/masque_tcp_client_bin.cc b/quiche/quic/masque/masque_tcp_client_bin.cc
index d9ad722..8ecf710 100644
--- a/quiche/quic/masque/masque_tcp_client_bin.cc
+++ b/quiche/quic/masque/masque_tcp_client_bin.cc
@@ -98,6 +98,7 @@
 
   SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION);
   SSL_CTX_set_max_proto_version(ctx.get(), TLS1_3_VERSION);
+  SSL_CTX_set_mode(ctx.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
 
   return ctx;
 }
@@ -437,19 +438,47 @@
     QUICHE_DVLOG(2) << "Writing " << data.size()
                     << " app bytes to TLS:" << std::endl
                     << quiche::QuicheTextUtils::HexDump(data);
-    int ssl_write_ret = SSL_write(ssl_.get(), data.data(), data.size());
+    const bool buffered = !tls_write_buffer_.empty();
+    const char* buffer_to_write;
+    size_t size_to_write;
+    if (buffered) {
+      absl::StrAppend(&tls_write_buffer_, data);
+      buffer_to_write = tls_write_buffer_.data();
+      size_to_write = tls_write_buffer_.size();
+    } else {
+      buffer_to_write = data.data();
+      size_to_write = data.size();
+    }
+    const int ssl_write_ret =
+        SSL_write(ssl_.get(), buffer_to_write, size_to_write);
     if (ssl_write_ret <= 0) {
       int ssl_err = SSL_get_error(ssl_.get(), ssl_write_ret);
+      if (ssl_err == SSL_ERROR_WANT_WRITE) {
+        QUICHE_DVLOG(1) << "SSL_write will require another write, "
+                        << "buffered " << tls_write_buffer_.size() << " bytes";
+        if (!buffered) {
+          tls_write_buffer_ = std::string(data);
+        }
+        return data.size();
+      }
       QUICHE_LOG(ERROR) << FormatSslError("Error while writing request to TLS",
                                           ssl_err, ssl_write_ret);
       done_ = true;
       return -1;
     }
-    if (ssl_write_ret == static_cast<int>(data.size())) {
-      QUICHE_DVLOG(1) << "Wrote " << data.size() << " bytes to TLS";
+    if (ssl_write_ret == static_cast<int>(size_to_write)) {
+      QUICHE_DVLOG(1) << "Wrote " << size_to_write << " bytes to TLS";
+      if (buffered) {
+        tls_write_buffer_.clear();
+      }
     } else {
-      QUICHE_DVLOG(1) << "Wrote " << ssl_write_ret << " / " << data.size()
-                      << "bytes to TLS";
+      QUICHE_DVLOG(1) << "Wrote " << ssl_write_ret << " / " << size_to_write
+                      << " bytes to TLS and buffered the rest";
+      if (buffered) {
+        tls_write_buffer_.erase(0, ssl_write_ret);
+      } else {
+        tls_write_buffer_ = std::string(data.substr(ssl_write_ret));
+      }
     }
     SendToTransport();
     return ssl_write_ret;
@@ -514,6 +543,7 @@
   bool done_ = false;
   int32_t stream_id_ = -1;
   std::unique_ptr<MasqueH2Connection> h2_connection_;
+  std::string tls_write_buffer_;
 };
 
 int RunMasqueTcpClient(int argc, char* argv[]) {
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc
index 94f894c..9307e17 100644
--- a/quiche/quic/masque/masque_tcp_server_bin.cc
+++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -463,6 +463,7 @@
 
     SSL_CTX_set_min_proto_version(ctx_.get(), TLS1_2_VERSION);
     SSL_CTX_set_max_proto_version(ctx_.get(), TLS1_3_VERSION);
+    SSL_CTX_set_mode(ctx_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
 
     return true;
   }