Create ConnectUdpTunnel Will be used by QUICHE toy server in a subsequent CL to create a CONNECT-UDP-capable backend. PiperOrigin-RevId: 479730001
diff --git a/build/source_list.bzl b/build/source_list.bzl index 62d8265..9ab0ba3 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -704,6 +704,7 @@ "quic/platform/api/quic_default_proof_providers.h", "quic/tools/connect_server_backend.h", "quic/tools/connect_tunnel.h", + "quic/tools/connect_udp_tunnel.h", "quic/tools/fake_proof_verifier.h", "quic/tools/quic_backend_response.h", "quic/tools/quic_client_base.h", @@ -727,6 +728,7 @@ "common/platform/api/quiche_file_utils.cc", "quic/tools/connect_server_backend.cc", "quic/tools/connect_tunnel.cc", + "quic/tools/connect_udp_tunnel.cc", "quic/tools/quic_backend_response.cc", "quic/tools/quic_client_base.cc", "quic/tools/quic_memory_cache_backend.cc", @@ -1245,6 +1247,7 @@ "quic/test_tools/simulator/quic_endpoint_test.cc", "quic/test_tools/simulator/simulator_test.cc", "quic/tools/connect_tunnel_test.cc", + "quic/tools/connect_udp_tunnel_test.cc", "quic/tools/quic_memory_cache_backend_test.cc", "quic/tools/quic_tcp_like_trace_converter_test.cc", "quic/tools/simple_ticket_crypter_test.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index 7dba73c..35107f0 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -704,6 +704,7 @@ "src/quiche/quic/platform/api/quic_default_proof_providers.h", "src/quiche/quic/tools/connect_server_backend.h", "src/quiche/quic/tools/connect_tunnel.h", + "src/quiche/quic/tools/connect_udp_tunnel.h", "src/quiche/quic/tools/fake_proof_verifier.h", "src/quiche/quic/tools/quic_backend_response.h", "src/quiche/quic/tools/quic_client_base.h", @@ -727,6 +728,7 @@ "src/quiche/common/platform/api/quiche_file_utils.cc", "src/quiche/quic/tools/connect_server_backend.cc", "src/quiche/quic/tools/connect_tunnel.cc", + "src/quiche/quic/tools/connect_udp_tunnel.cc", "src/quiche/quic/tools/quic_backend_response.cc", "src/quiche/quic/tools/quic_client_base.cc", "src/quiche/quic/tools/quic_memory_cache_backend.cc", @@ -1245,6 +1247,7 @@ "src/quiche/quic/test_tools/simulator/quic_endpoint_test.cc", "src/quiche/quic/test_tools/simulator/simulator_test.cc", "src/quiche/quic/tools/connect_tunnel_test.cc", + "src/quiche/quic/tools/connect_udp_tunnel_test.cc", "src/quiche/quic/tools/quic_memory_cache_backend_test.cc", "src/quiche/quic/tools/quic_tcp_like_trace_converter_test.cc", "src/quiche/quic/tools/simple_ticket_crypter_test.cc",
diff --git a/build/source_list.json b/build/source_list.json index db01f1e..a4394c8 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -703,6 +703,7 @@ "quiche/quic/platform/api/quic_default_proof_providers.h", "quiche/quic/tools/connect_server_backend.h", "quiche/quic/tools/connect_tunnel.h", + "quiche/quic/tools/connect_udp_tunnel.h", "quiche/quic/tools/fake_proof_verifier.h", "quiche/quic/tools/quic_backend_response.h", "quiche/quic/tools/quic_client_base.h", @@ -726,6 +727,7 @@ "quiche/common/platform/api/quiche_file_utils.cc", "quiche/quic/tools/connect_server_backend.cc", "quiche/quic/tools/connect_tunnel.cc", + "quiche/quic/tools/connect_udp_tunnel.cc", "quiche/quic/tools/quic_backend_response.cc", "quiche/quic/tools/quic_client_base.cc", "quiche/quic/tools/quic_memory_cache_backend.cc", @@ -1244,6 +1246,7 @@ "quiche/quic/test_tools/simulator/quic_endpoint_test.cc", "quiche/quic/test_tools/simulator/simulator_test.cc", "quiche/quic/tools/connect_tunnel_test.cc", + "quiche/quic/tools/connect_udp_tunnel_test.cc", "quiche/quic/tools/quic_memory_cache_backend_test.cc", "quiche/quic/tools/quic_tcp_like_trace_converter_test.cc", "quiche/quic/tools/simple_ticket_crypter_test.cc",
diff --git a/quiche/quic/core/http/quic_spdy_stream.h b/quiche/quic/core/http/quic_spdy_stream.h index b5b6e9e..4718591 100644 --- a/quiche/quic/core/http/quic_spdy_stream.h +++ b/quiche/quic/core/http/quic_spdy_stream.h
@@ -254,8 +254,9 @@ bool OnCapsule(const Capsule& capsule) override; void OnCapsuleParseFailure(const std::string& error_message) override; - // Sends an HTTP/3 datagram. The stream ID is not part of |payload|. - MessageStatus SendHttp3Datagram(absl::string_view payload); + // Sends an HTTP/3 datagram. The stream ID is not part of |payload|. Virtual + // to allow mocking in tests. + virtual MessageStatus SendHttp3Datagram(absl::string_view payload); class QUIC_EXPORT_PRIVATE Http3DatagramVisitor { public:
diff --git a/quiche/quic/tools/connect_udp_tunnel.cc b/quiche/quic/tools/connect_udp_tunnel.cc new file mode 100644 index 0000000..ebeef8b --- /dev/null +++ b/quiche/quic/tools/connect_udp_tunnel.cc
@@ -0,0 +1,424 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_udp_tunnel.h" + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "url/url_canon.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_name_lookup.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/masque/connect_udp_datagram_payload.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_url_utils.h" +#include "quiche/common/structured_headers.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace structured_headers = quiche::structured_headers; + +namespace { + +// Arbitrarily chosen. No effort has been made to figure out an optimal size. +constexpr size_t kReadSize = 4 * 1024; + +// Only support the default path +// ("/.well-known/masque/udp/{target_host}/{target_port}/") +absl::optional<QuicServerId> ValidateAndParseTargetFromPath( + absl::string_view path) { + std::string canonicalized_path_str; + url::StdStringCanonOutput canon_output(&canonicalized_path_str); + url::Component path_component; + url::CanonicalizePath(path.data(), url::Component(0, path.size()), + &canon_output, &path_component); + if (!path_component.is_nonempty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with non-canonicalizable path: " + << path; + return absl::nullopt; + } + canon_output.Complete(); + absl::string_view canonicalized_path = + absl::string_view(canonicalized_path_str) + .substr(path_component.begin, path_component.len); + + std::vector<absl::string_view> path_split = + absl::StrSplit(canonicalized_path, '/'); + if (path_split.size() != 7 || !path_split[0].empty() || + path_split[1] != ".well-known" || path_split[2] != "masque" || + path_split[3] != "udp" || path_split[4].empty() || + path_split[5].empty() || !path_split[6].empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with bad path: " + << canonicalized_path; + return absl::nullopt; + } + + absl::optional<std::string> decoded_host = + quiche::AsciiUrlDecode(path_split[4]); + if (!decoded_host.has_value()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with undecodable host: " + << path_split[4]; + return absl::nullopt; + } + // Empty host checked above after path split. Expect decoding to never result + // in an empty decoded host from non-empty encoded host. + QUICHE_DCHECK(!decoded_host.value().empty()); + + absl::optional<std::string> decoded_port = + quiche::AsciiUrlDecode(path_split[5]); + if (!decoded_port.has_value()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with undecodable port: " + << path_split[5]; + return absl::nullopt; + } + // Empty port checked above after path split. Expect decoding to never result + // in an empty decoded port from non-empty encoded port. + QUICHE_DCHECK(!decoded_port.value().empty()); + + int parsed_port_number = + url::ParsePort(decoded_port.value().data(), + url::Component(0, decoded_port.value().size())); + // Negative result is either invalid or unspecified, either of which is + // disallowed for this parse. Port 0 is technically valid but reserved and not + // really usable in practice, so easiest to just disallow it here. + if (parsed_port_number <= 0) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with bad port: " + << decoded_port.value(); + return absl::nullopt; + } + // Expect url::ParsePort() to validate port is uint16_t and otherwise return + // negative number checked for above. + QUICHE_DCHECK_LE(parsed_port_number, std::numeric_limits<uint16_t>::max()); + + return QuicServerId(decoded_host.value(), + static_cast<uint16_t>(parsed_port_number)); +} + +// Validate header expectations from RFC 9298, section 3.4. +absl::optional<QuicServerId> ValidateHeadersAndGetTarget( + const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(request_headers.contains(":method")); + QUICHE_DCHECK(request_headers.find(":method")->second == "CONNECT"); + QUICHE_DCHECK(request_headers.contains(":protocol")); + QUICHE_DCHECK(request_headers.find(":protocol")->second == "connect-udp"); + + auto authority_it = request_headers.find(":authority"); + if (authority_it == request_headers.end() || authority_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request missing authority"; + return absl::nullopt; + } + // For toy server simplicity, skip validating that the authority matches the + // current server. + + auto scheme_it = request_headers.find(":scheme"); + if (scheme_it == request_headers.end() || scheme_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request missing scheme"; + return absl::nullopt; + } else if (scheme_it->second != "https") { + QUICHE_DVLOG(1) << "CONNECT-UDP request contains unexpected scheme: " + << scheme_it->second; + return absl::nullopt; + } + + auto path_it = request_headers.find(":path"); + if (path_it == request_headers.end() || path_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request missing path"; + return absl::nullopt; + } + absl::optional<QuicServerId> target_server_id = + ValidateAndParseTargetFromPath(path_it->second); + + return target_server_id; +} + +bool ValidateTarget( + const QuicServerId& target, + const absl::flat_hash_set<QuicServerId>& acceptable_targets) { + if (acceptable_targets.contains(target)) { + return true; + } + + QUICHE_DVLOG(1) + << "CONNECT-UDP request target is not an acceptable allow-listed target: " + << target.ToHostPortString(); + return false; +} + +} // namespace + +ConnectUdpTunnel::ConnectUdpTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, uint64_t server_label, + absl::flat_hash_set<QuicServerId> acceptable_targets) + : acceptable_targets_(std::move(acceptable_targets)), + socket_factory_(socket_factory), + server_label_(server_label), + client_stream_request_handler_(client_stream_request_handler) { + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(socket_factory_); +} + +ConnectUdpTunnel::~ConnectUdpTunnel() { + // Expect client and target sides of tunnel to both be closed before + // destruction. + QUICHE_DCHECK(!IsTunnelOpenToTarget()); + QUICHE_DCHECK(!receive_started_); + QUICHE_DCHECK(!datagram_visitor_registered_); +} + +void ConnectUdpTunnel::OpenTunnel( + const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(!IsTunnelOpenToTarget()); + + absl::optional<QuicServerId> target = + ValidateHeadersAndGetTarget(request_headers); + if (!target.has_value()) { + // Malformed request. + TerminateClientStream( + "invalid request headers", + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::MESSAGE_ERROR)); + return; + } + + if (!ValidateTarget(target.value(), acceptable_targets_)) { + SendErrorResponse("403", "destination_ip_prohibited", + "disallowed proxy target"); + return; + } + + // TODO(ericorth): Validate that the IP address doesn't fall into diallowed + // ranges per RFC 9298, Section 7. + QuicSocketAddress address = tools::LookupAddress(AF_UNSPEC, target.value()); + if (!address.IsInitialized()) { + SendErrorResponse("500", "dns_error", "host resolution error"); + return; + } + + target_socket_ = socket_factory_->CreateConnectingUdpClientSocket( + address, + /*receive_buffer_size=*/0, + /*send_buffer_size=*/0, + /*async_visitor=*/this); + QUICHE_DCHECK(target_socket_); + + absl::Status connect_result = target_socket_->ConnectBlocking(); + if (!connect_result.ok()) { + SendErrorResponse( + "502", "destination_ip_unroutable", + absl::StrCat("UDP socket error: ", connect_result.ToString())); + return; + } + + QUICHE_DVLOG(1) << "CONNECT-UDP tunnel opened from stream " + << client_stream_request_handler_->stream_id() << " to " + << target.value().ToHostPortString(); + + client_stream_request_handler_->GetStream()->RegisterHttp3DatagramVisitor( + this); + datagram_visitor_registered_ = true; + + SendConnectResponse(); + BeginAsyncReadFromTarget(); +} + +bool ConnectUdpTunnel::IsTunnelOpenToTarget() const { return !!target_socket_; } + +void ConnectUdpTunnel::OnClientStreamClose() { + QUICHE_CHECK(client_stream_request_handler_); + + QUICHE_DVLOG(1) << "CONNECT-UDP stream " + << client_stream_request_handler_->stream_id() << " closed"; + + if (datagram_visitor_registered_) { + client_stream_request_handler_->GetStream() + ->UnregisterHttp3DatagramVisitor(); + datagram_visitor_registered_ = false; + } + client_stream_request_handler_ = nullptr; + + if (IsTunnelOpenToTarget()) { + target_socket_->Disconnect(); + } + + // Clear socket pointer. + target_socket_.reset(); +} + +void ConnectUdpTunnel::ConnectComplete(absl::Status /*status*/) { + // Async connect not expected. + QUICHE_NOTREACHED(); +} + +void ConnectUdpTunnel::ReceiveComplete( + absl::StatusOr<quiche::QuicheMemSlice> data) { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK(receive_started_); + + receive_started_ = false; + + if (!data.ok()) { + if (client_stream_request_handler_) { + QUICHE_LOG(WARNING) << "Error receiving CONNECT-UDP data from target: " + << data.status(); + } else { + // This typically just means a receive operation was cancelled on calling + // target_socket_->Disconnect(). + QUICHE_DVLOG(1) << "Error receiving CONNECT-UDP data from target after " + "stream already closed."; + } + return; + } + + QUICHE_DCHECK(client_stream_request_handler_); + quiche::ConnectUdpDatagramUdpPacketPayload payload( + data.value().AsStringView()); + client_stream_request_handler_->GetStream()->SendHttp3Datagram( + payload.Serialize()); + + BeginAsyncReadFromTarget(); +} + +void ConnectUdpTunnel::SendComplete(absl::Status /*status*/) { + // Async send not expected. + QUICHE_NOTREACHED(); +} + +void ConnectUdpTunnel::OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK_EQ(stream_id, client_stream_request_handler_->stream_id()); + QUICHE_DCHECK(!payload.empty()); + + std::unique_ptr<quiche::ConnectUdpDatagramPayload> parsed_payload = + quiche::ConnectUdpDatagramPayload::Parse(payload); + if (!parsed_payload) { + QUICHE_DVLOG(1) << "Ignoring HTTP Datagram payload, due to inability to " + "parse as CONNECT-UDP payload."; + return; + } + + switch (parsed_payload->GetType()) { + case quiche::ConnectUdpDatagramPayload::Type::kUdpPacket: + SendUdpPacketToTarget(parsed_payload->GetUdpProxyingPayload()); + break; + case quiche::ConnectUdpDatagramPayload::Type::kUnknown: + QUICHE_DVLOG(1) + << "Ignoring HTTP Datagram payload with unrecognized context ID."; + } +} + +void ConnectUdpTunnel::BeginAsyncReadFromTarget() { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(!receive_started_); + + receive_started_ = true; + target_socket_->ReceiveAsync(kReadSize); +} + +void ConnectUdpTunnel::SendUdpPacketToTarget(absl::string_view packet) { + absl::Status send_result = target_socket_->SendBlocking(std::string(packet)); + if (!send_result.ok()) { + QUICHE_LOG(WARNING) << "Error sending CONNECT-UDP datagram to target: " + << send_result; + } +} + +void ConnectUdpTunnel::SendConnectResponse() { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK(client_stream_request_handler_); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + + absl::optional<std::string> capsule_protocol_value = + structured_headers::SerializeItem(structured_headers::Item(true)); + QUICHE_CHECK(capsule_protocol_value.has_value()); + response_headers["Capsule-Protocol"] = capsule_protocol_value.value(); + + QuicBackendResponse response; + response.set_headers(std::move(response_headers)); + // Need to leave the stream open after sending the CONNECT response. + response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + + client_stream_request_handler_->OnResponseBackendComplete(&response); +} + +void ConnectUdpTunnel::SendErrorResponse(absl::string_view status, + absl::string_view proxy_status_error, + absl::string_view error_details) { + QUICHE_DCHECK(!status.empty()); + QUICHE_DCHECK(!proxy_status_error.empty()); + QUICHE_DCHECK(!error_details.empty()); + QUICHE_DCHECK(client_stream_request_handler_); + +#ifndef NDEBUG + // Expect a valid status code (number, 100 to 599 inclusive) and not a + // Successful code (200 to 299 inclusive). + int status_num = 0; + bool is_num = absl::SimpleAtoi(status, &status_num); + QUICHE_DCHECK(is_num); + QUICHE_DCHECK_GE(status_num, 100); + QUICHE_DCHECK_LT(status_num, 600); + QUICHE_DCHECK(status_num < 200 || status_num >= 300); +#endif // !NDEBUG + + spdy::Http2HeaderBlock headers; + headers[":status"] = status; + + structured_headers::Item proxy_status_item( + absl::StrCat("QuicToyServer", server_label_)); + structured_headers::Item proxy_status_error_item( + std::string{proxy_status_error}); + structured_headers::Item proxy_status_details_item( + std::string{error_details}); + structured_headers::ParameterizedMember proxy_status_member( + std::move(proxy_status_item), + {{"error", std::move(proxy_status_error_item)}, + {"details", std::move(proxy_status_details_item)}}); + absl::optional<std::string> proxy_status_value = + structured_headers::SerializeList({proxy_status_member}); + QUICHE_CHECK(proxy_status_value.has_value()); + headers["Proxy-Status"] = proxy_status_value.value(); + + QuicBackendResponse response; + response.set_headers(std::move(headers)); + + client_stream_request_handler_->OnResponseBackendComplete(&response); +} + +void ConnectUdpTunnel::TerminateClientStream( + absl::string_view error_description, QuicResetStreamError error_code) { + QUICHE_DCHECK(client_stream_request_handler_); + + std::string error_description_str = + error_description.empty() ? "" + : absl::StrCat(" due to ", error_description); + QUICHE_DVLOG(1) << "Terminating CONNECT stream " + << client_stream_request_handler_->stream_id() + << " with error code " << error_code.ietf_application_code() + << error_description_str; + + client_stream_request_handler_->TerminateStreamWithError(error_code); +} + +} // namespace quic
diff --git a/quiche/quic/tools/connect_udp_tunnel.h b/quiche/quic/tools/connect_udp_tunnel.h new file mode 100644 index 0000000..f254b61 --- /dev/null +++ b/quiche/quic/tools/connect_udp_tunnel.h
@@ -0,0 +1,97 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ +#define QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// Manages a single UDP tunnel for a CONNECT-UDP proxy (see RFC 9298). +class ConnectUdpTunnel : public ConnectingClientSocket::AsyncVisitor, + public QuicSpdyStream::Http3DatagramVisitor { + public: + // `client_stream_request_handler` and `socket_factory` must both outlive the + // created ConnectUdpTunnel. `server_label` is an identifier (typically + // randomly generated) to indentify the server or backend in error headers, + // per the requirements of RFC 9209, Section 2. + ConnectUdpTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, uint64_t server_label, + absl::flat_hash_set<QuicServerId> acceptable_targets); + ~ConnectUdpTunnel(); + ConnectUdpTunnel(const ConnectUdpTunnel&) = delete; + ConnectUdpTunnel& operator=(const ConnectUdpTunnel&) = delete; + + // Attempts to open UDP tunnel to target server and then sends appropriate + // success/error response to the request stream. `request_headers` must + // represent headers from a CONNECT-UDP request, that is ":method"="CONNECT" + // and ":protocol"="connect-udp". + void OpenTunnel(const spdy::Http2HeaderBlock& request_headers); + + // Returns true iff the tunnel to the target server is currently open + bool IsTunnelOpenToTarget() const; + + // Called when the client stream has been closed. Tunnel to target + // server is closed if open. The RequestHandler will no longer be + // interacted with after completion. + void OnClientStreamClose(); + + // ConnectingClientSocket::AsyncVisitor: + void ConnectComplete(absl::Status status) override; + void ReceiveComplete(absl::StatusOr<quiche::QuicheMemSlice> data) override; + void SendComplete(absl::Status status) override; + + // QuicSpdyStream::Http3DatagramVisitor: + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + + private: + void BeginAsyncReadFromTarget(); + void OnDataReceivedFromTarget(bool success); + + void SendUdpPacketToTarget(absl::string_view packet); + + void SendConnectResponse(); + void SendErrorResponse(absl::string_view status, + absl::string_view proxy_status_error, + absl::string_view error_details); + void TerminateClientStream(absl::string_view error_description, + QuicResetStreamError error_code); + + const absl::flat_hash_set<QuicServerId> acceptable_targets_; + SocketFactory* const socket_factory_; + const uint64_t server_label_; + + // Null when client stream closed. + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler_; + + // Null when target connection disconnected. + std::unique_ptr<ConnectingClientSocket> target_socket_; + + bool receive_started_ = false; + bool datagram_visitor_registered_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_
diff --git a/quiche/quic/tools/connect_udp_tunnel_test.cc b/quiche/quic/tools/connect_udp_tunnel_test.cc new file mode 100644 index 0000000..23b1c11 --- /dev/null +++ b/quiche/quic/tools/connect_udp_tunnel_test.cc
@@ -0,0 +1,362 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_udp_tunnel.h" + +#include <memory> +#include <string> + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "url/url_canon_stdstring.h" +#include "url/url_util.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/masque/connect_udp_datagram_payload.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::HasSubstr; +using ::testing::InvokeWithoutArgs; +using ::testing::IsEmpty; +using ::testing::Matcher; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Property; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::UnorderedElementsAre; + +constexpr QuicStreamId kStreamId = 100; + +class MockStream : public QuicSpdyStream { + public: + explicit MockStream(QuicSpdySession* spdy_session) + : QuicSpdyStream(kStreamId, spdy_session, BIDIRECTIONAL) {} + + void OnBodyAvailable() override {} + + MOCK_METHOD(MessageStatus, SendHttp3Datagram, (absl::string_view data), + (override)); +}; + +class MockRequestHandler : public QuicSimpleServerBackend::RequestHandler { + public: + QuicConnectionId connection_id() const override { + return TestConnectionId(41212); + } + QuicStreamId stream_id() const override { return kStreamId; } + std::string peer_host() const override { return "127.0.0.1"; } + + MOCK_METHOD(QuicSpdyStream*, GetStream, (), (override)); + MOCK_METHOD(void, OnResponseBackendComplete, + (const QuicBackendResponse* response), (override)); + MOCK_METHOD(void, SendStreamData, (absl::string_view data, bool close_stream), + (override)); + MOCK_METHOD(void, TerminateStreamWithError, (QuicResetStreamError error), + (override)); +}; + +class MockSocketFactory : public SocketFactory { + public: + MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>, CreateTcpClientSocket, + (const QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor), + (override)); + MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>, + CreateConnectingUdpClientSocket, + (const QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor), + (override)); +}; + +class MockSocket : public ConnectingClientSocket { + public: + MOCK_METHOD(absl::Status, ConnectBlocking, (), (override)); + MOCK_METHOD(void, ConnectAsync, (), (override)); + MOCK_METHOD(void, Disconnect, (), (override)); + MOCK_METHOD(absl::StatusOr<QuicSocketAddress>, GetLocalAddress, (), + (override)); + MOCK_METHOD(absl::StatusOr<quiche::QuicheMemSlice>, ReceiveBlocking, + (QuicByteCount max_size), (override)); + MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (std::string data), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (quiche::QuicheMemSlice data), + (override)); + MOCK_METHOD(void, SendAsync, (std::string data), (override)); + MOCK_METHOD(void, SendAsync, (quiche::QuicheMemSlice data), (override)); +}; + +class ConnectUdpTunnelTest : public quiche::test::QuicheTest { + public: + void SetUp() override { + auto socket = std::make_unique<StrictMock<MockSocket>>(); + socket_ = socket.get(); + ON_CALL(socket_factory_, + CreateConnectingUdpClientSocket( + AnyOf(QuicSocketAddress(TestLoopback4(), kAcceptablePort), + QuicSocketAddress(TestLoopback6(), kAcceptablePort)), + _, _, &tunnel_)) + .WillByDefault(Return(ByMove(std::move(socket)))); + + EXPECT_CALL(request_handler_, GetStream()).WillRepeatedly(Return(&stream_)); + } + + protected: + static constexpr absl::string_view kAcceptableTarget = "localhost"; + static constexpr uint16_t kAcceptablePort = 977; + + NiceMock<MockQuicConnectionHelper> connection_helper_; + NiceMock<MockAlarmFactory> alarm_factory_; + NiceMock<MockQuicSpdySession> session_{new NiceMock<MockQuicConnection>( + &connection_helper_, &alarm_factory_, Perspective::IS_SERVER)}; + StrictMock<MockStream> stream_{&session_}; + + StrictMock<MockRequestHandler> request_handler_; + NiceMock<MockSocketFactory> socket_factory_; + StrictMock<MockSocket>* socket_; + + ConnectUdpTunnel tunnel_{ + &request_handler_, + &socket_factory_, + /*server_label=*/123, + /*acceptable_targets=*/ + {{std::string(kAcceptableTarget), kAcceptablePort}, + {TestLoopback4().ToString(), kAcceptablePort}, + {absl::StrCat("[", TestLoopback6().ToString(), "]"), kAcceptablePort}}}; +}; + +TEST_F(ConnectUdpTunnelTest, OpenTunnel) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL( + request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre(Pair(":status", "200"), + Pair("Capsule-Protocol", "?1"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", kAcceptableTarget, "/", kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelToIpv4LiteralTarget) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL( + request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre(Pair(":status", "200"), + Pair("Capsule-Protocol", "?1"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = + absl::StrCat("/.well-known/masque/udp/", TestLoopback4().ToString(), "/", + kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); +} + +std::string PercentEncode(absl::string_view input) { + std::string encoded; + url::StdStringCanonOutput canon_output(&encoded); + url::EncodeURIComponent(input.data(), input.size(), &canon_output); + canon_output.Complete(); + return encoded; +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelToIpv6LiteralTarget) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL( + request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre(Pair(":status", "200"), + Pair("Capsule-Protocol", "?1"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", + PercentEncode(absl::StrCat("[", TestLoopback6().ToString(), "]")), "/", + kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelWithMalformedRequest) { + EXPECT_CALL(request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast<uint64_t>(QuicHttp3ErrorCode::MESSAGE_ERROR)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + // No ":path" header. + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelWithUnacceptableTarget) { + EXPECT_CALL(request_handler_, + OnResponseBackendComplete(AllOf( + Property(&QuicBackendResponse::response_type, + QuicBackendResponse::REGULAR_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre( + Pair(":status", "403"), + Pair("Proxy-Status", + HasSubstr("destination_ip_prohibited")))), + Property(&QuicBackendResponse::trailers, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = "/.well-known/masque/udp/unacceptable.test/100/"; + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectUdpTunnelTest, ReceiveFromTarget) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Ge(kData.size()))).Times(2); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + EXPECT_CALL( + stream_, + SendHttp3Datagram( + quiche::ConnectUdpDatagramUdpPacketPayload(kData).Serialize())) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", kAcceptableTarget, "/", kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receiving `kData`. + tunnel_.ReceiveComplete(MemSliceFromString(kData)); + + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectUdpTunnelTest, SendToTarget) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, SendBlocking(Matcher<std::string>(Eq(kData)))) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", kAcceptableTarget, "/", kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + tunnel_.OnHttp3Datagram( + kStreamId, quiche::ConnectUdpDatagramUdpPacketPayload(kData).Serialize()); + tunnel_.OnClientStreamClose(); +} + +} // namespace +} // namespace quic::test