Properly handle TLS/TCP being write-blocked in MasqueH2Connection PiperOrigin-RevId: 865522892
diff --git a/quiche/quic/masque/masque_connection_pool.cc b/quiche/quic/masque/masque_connection_pool.cc index 50fd485..b6bb8ef 100644 --- a/quiche/quic/masque/masque_connection_pool.cc +++ b/quiche/quic/masque/masque_connection_pool.cc
@@ -498,6 +498,7 @@ SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION); SSL_CTX_set_max_proto_version(ctx.get(), TLS1_3_VERSION); + SSL_CTX_set_mode(ctx.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); return ctx; } @@ -531,6 +532,8 @@ SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION); SSL_CTX_set_max_proto_version(ctx.get(), TLS1_3_VERSION); + SSL_CTX_set_mode(ctx.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + return ctx; }
diff --git a/quiche/quic/masque/masque_h2_connection.cc b/quiche/quic/masque/masque_h2_connection.cc index ba287f3..ba38888 100644 --- a/quiche/quic/masque/masque_h2_connection.cc +++ b/quiche/quic/masque/masque_h2_connection.cc
@@ -169,34 +169,60 @@ return AttemptToSend(); } -int MasqueH2Connection::WriteDataToTls(absl::string_view data) { +bool MasqueH2Connection::WriteDataToTls(absl::string_view data) { QUICHE_DVLOG(2) << ENDPOINT << "Writing " << data.size() << " app bytes to TLS:" << std::endl << quiche::QuicheTextUtils::HexDump(data); - int ssl_write_ret = SSL_write(ssl_, data.data(), data.size()); + const bool buffered = !tls_write_buffer_.empty(); + const char* buffer_to_write; + size_t size_to_write; + if (buffered) { + absl::StrAppend(&tls_write_buffer_, data); + buffer_to_write = tls_write_buffer_.data(); + size_to_write = tls_write_buffer_.size(); + } else { + buffer_to_write = data.data(); + size_to_write = data.size(); + } + const int ssl_write_ret = SSL_write(ssl_, buffer_to_write, size_to_write); if (ssl_write_ret <= 0) { int ssl_err = SSL_get_error(ssl_, ssl_write_ret); - QUICHE_LOG(ERROR) << FormatSslError("Error while writing data to TLS", - ssl_err, ssl_write_ret); - return -1; + if (ssl_err == SSL_ERROR_WANT_WRITE) { + if (!buffered) { + tls_write_buffer_ = std::string(data); + } + QUICHE_DVLOG(1) << ENDPOINT << "SSL_write will require another write, " + << "buffered " << tls_write_buffer_.size() << " bytes"; + return true; + } + Abort(SslErrorStatus("Error while writing data to TLS", ssl_err, + ssl_write_ret)); + return false; } - if (ssl_write_ret == static_cast<int>(data.size())) { - QUICHE_DVLOG(1) << ENDPOINT << "Wrote " << data.size() << " bytes to TLS"; + if (ssl_write_ret == static_cast<int>(size_to_write)) { + QUICHE_DVLOG(1) << ENDPOINT << "Wrote " << size_to_write << " bytes to TLS"; + if (buffered) { + tls_write_buffer_.clear(); + } } else { QUICHE_DVLOG(1) << ENDPOINT << "Wrote " << ssl_write_ret << " / " - << data.size() << "bytes to TLS"; + << size_to_write << " bytes to TLS and buffered the rest"; + if (buffered) { + tls_write_buffer_.erase(0, ssl_write_ret); + } else { + tls_write_buffer_ = std::string(data.substr(ssl_write_ret)); + } } - return ssl_write_ret; + return true; } int64_t MasqueH2Connection::OnReadyToSend(absl::string_view serialized) { QUICHE_DVLOG(1) << ENDPOINT << "Writing " << serialized.size() << " bytes of h2 data to TLS"; - int write_res = WriteDataToTls(serialized); - if (write_res < 0) { + if (!WriteDataToTls(serialized)) { return kSendError; } - return write_res; + return serialized.size(); } MasqueH2Connection::DataFrameHeaderInfo @@ -218,21 +244,19 @@ bool MasqueH2Connection::SendDataFrame(Http2StreamId stream_id, absl::string_view frame_header, size_t payload_bytes) { - if (WriteDataToTls(frame_header) < 0) { + if (!WriteDataToTls(frame_header)) { return false; } MasqueH2Stream* stream = GetOrCreateH2Stream(stream_id); size_t length_to_write = std::min(payload_bytes, stream->body_to_send.size()); - int length_written = - WriteDataToTls(stream->body_to_send.substr(0, length_to_write)); - if (length_written < 0) { + if (!WriteDataToTls(stream->body_to_send.substr(0, length_to_write))) { return false; } - if (length_written == static_cast<int>(stream->body_to_send.size())) { + if (length_to_write == stream->body_to_send.size()) { stream->body_to_send.clear(); } else { // Remove the written bytes from the start of `body_to_send`. - stream->body_to_send = stream->body_to_send.substr(length_written); + stream->body_to_send = stream->body_to_send.substr(length_to_write); } return true; } @@ -268,6 +292,13 @@ } bool MasqueH2Connection::AttemptToSend() { + if (!tls_write_buffer_.empty()) { + QUICHE_DVLOG(1) << ENDPOINT << "Attempting to write " + << tls_write_buffer_.size() << " buffered bytes to TLS"; + if (!WriteDataToTls("")) { + return false; + } + } if (!h2_adapter_) { return false; }
diff --git a/quiche/quic/masque/masque_h2_connection.h b/quiche/quic/masque/masque_h2_connection.h index 9139857..bd16403 100644 --- a/quiche/quic/masque/masque_h2_connection.h +++ b/quiche/quic/masque/masque_h2_connection.h
@@ -83,7 +83,7 @@ std::vector<http2::adapter::Header> ConvertHeaders( const quiche::HttpHeaderBlock& headers); - int WriteDataToTls(absl::string_view data); + bool WriteDataToTls(absl::string_view data); // From http2::adapter::Http2VisitorInterface. int64_t OnReadyToSend(absl::string_view serialized) override; @@ -142,6 +142,7 @@ absl::flat_hash_map<Http2StreamId, std::unique_ptr<MasqueH2Stream>> h2_streams_; Visitor* visitor_; + std::string tls_write_buffer_; }; // Formats an SSL error that was provided by BoringSSL.
diff --git a/quiche/quic/masque/masque_tcp_client_bin.cc b/quiche/quic/masque/masque_tcp_client_bin.cc index d9ad722..8ecf710 100644 --- a/quiche/quic/masque/masque_tcp_client_bin.cc +++ b/quiche/quic/masque/masque_tcp_client_bin.cc
@@ -98,6 +98,7 @@ SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION); SSL_CTX_set_max_proto_version(ctx.get(), TLS1_3_VERSION); + SSL_CTX_set_mode(ctx.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); return ctx; } @@ -437,19 +438,47 @@ QUICHE_DVLOG(2) << "Writing " << data.size() << " app bytes to TLS:" << std::endl << quiche::QuicheTextUtils::HexDump(data); - int ssl_write_ret = SSL_write(ssl_.get(), data.data(), data.size()); + const bool buffered = !tls_write_buffer_.empty(); + const char* buffer_to_write; + size_t size_to_write; + if (buffered) { + absl::StrAppend(&tls_write_buffer_, data); + buffer_to_write = tls_write_buffer_.data(); + size_to_write = tls_write_buffer_.size(); + } else { + buffer_to_write = data.data(); + size_to_write = data.size(); + } + const int ssl_write_ret = + SSL_write(ssl_.get(), buffer_to_write, size_to_write); if (ssl_write_ret <= 0) { int ssl_err = SSL_get_error(ssl_.get(), ssl_write_ret); + if (ssl_err == SSL_ERROR_WANT_WRITE) { + QUICHE_DVLOG(1) << "SSL_write will require another write, " + << "buffered " << tls_write_buffer_.size() << " bytes"; + if (!buffered) { + tls_write_buffer_ = std::string(data); + } + return data.size(); + } QUICHE_LOG(ERROR) << FormatSslError("Error while writing request to TLS", ssl_err, ssl_write_ret); done_ = true; return -1; } - if (ssl_write_ret == static_cast<int>(data.size())) { - QUICHE_DVLOG(1) << "Wrote " << data.size() << " bytes to TLS"; + if (ssl_write_ret == static_cast<int>(size_to_write)) { + QUICHE_DVLOG(1) << "Wrote " << size_to_write << " bytes to TLS"; + if (buffered) { + tls_write_buffer_.clear(); + } } else { - QUICHE_DVLOG(1) << "Wrote " << ssl_write_ret << " / " << data.size() - << "bytes to TLS"; + QUICHE_DVLOG(1) << "Wrote " << ssl_write_ret << " / " << size_to_write + << " bytes to TLS and buffered the rest"; + if (buffered) { + tls_write_buffer_.erase(0, ssl_write_ret); + } else { + tls_write_buffer_ = std::string(data.substr(ssl_write_ret)); + } } SendToTransport(); return ssl_write_ret; @@ -514,6 +543,7 @@ bool done_ = false; int32_t stream_id_ = -1; std::unique_ptr<MasqueH2Connection> h2_connection_; + std::string tls_write_buffer_; }; int RunMasqueTcpClient(int argc, char* argv[]) {
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc index 94f894c..9307e17 100644 --- a/quiche/quic/masque/masque_tcp_server_bin.cc +++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -463,6 +463,7 @@ SSL_CTX_set_min_proto_version(ctx_.get(), TLS1_2_VERSION); SSL_CTX_set_max_proto_version(ctx_.get(), TLS1_3_VERSION); + SSL_CTX_set_mode(ctx_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); return true; }