Create ConnectTunnel for toy server use Handles a single connection through a CONNECT server PiperOrigin-RevId: 466374783
diff --git a/build/source_list.bzl b/build/source_list.bzl index a928fe6..313dea4 100644 --- a/build/source_list.bzl +++ b/build/source_list.bzl
@@ -693,6 +693,7 @@ "common/platform/api/quiche_file_utils.h", "common/platform/api/quiche_system_event_loop.h", "quic/platform/api/quic_default_proof_providers.h", + "quic/tools/connect_tunnel.h", "quic/tools/fake_proof_verifier.h", "quic/tools/quic_backend_response.h", "quic/tools/quic_client_base.h", @@ -716,6 +717,7 @@ ] quiche_tool_support_srcs = [ "common/platform/api/quiche_file_utils.cc", + "quic/tools/connect_tunnel.cc", "quic/tools/quic_backend_response.cc", "quic/tools/quic_client_base.cc", "quic/tools/quic_client_default_network_helper.cc", @@ -1270,6 +1272,7 @@ "quic/test_tools/simple_session_notifier_test.cc", "quic/test_tools/simulator/quic_endpoint_test.cc", "quic/test_tools/simulator/simulator_test.cc", + "quic/tools/connect_tunnel_test.cc", "quic/tools/quic_default_client_test.cc", "quic/tools/quic_memory_cache_backend_test.cc", "quic/tools/quic_tcp_like_trace_converter_test.cc",
diff --git a/build/source_list.gni b/build/source_list.gni index 2fd7580..0c08a73 100644 --- a/build/source_list.gni +++ b/build/source_list.gni
@@ -693,6 +693,7 @@ "src/quiche/common/platform/api/quiche_file_utils.h", "src/quiche/common/platform/api/quiche_system_event_loop.h", "src/quiche/quic/platform/api/quic_default_proof_providers.h", + "src/quiche/quic/tools/connect_tunnel.h", "src/quiche/quic/tools/fake_proof_verifier.h", "src/quiche/quic/tools/quic_backend_response.h", "src/quiche/quic/tools/quic_client_base.h", @@ -716,6 +717,7 @@ ] quiche_tool_support_srcs = [ "src/quiche/common/platform/api/quiche_file_utils.cc", + "src/quiche/quic/tools/connect_tunnel.cc", "src/quiche/quic/tools/quic_backend_response.cc", "src/quiche/quic/tools/quic_client_base.cc", "src/quiche/quic/tools/quic_client_default_network_helper.cc", @@ -1270,6 +1272,7 @@ "src/quiche/quic/test_tools/simple_session_notifier_test.cc", "src/quiche/quic/test_tools/simulator/quic_endpoint_test.cc", "src/quiche/quic/test_tools/simulator/simulator_test.cc", + "src/quiche/quic/tools/connect_tunnel_test.cc", "src/quiche/quic/tools/quic_default_client_test.cc", "src/quiche/quic/tools/quic_memory_cache_backend_test.cc", "src/quiche/quic/tools/quic_tcp_like_trace_converter_test.cc",
diff --git a/build/source_list.json b/build/source_list.json index 02f608b..a7fdf5b 100644 --- a/build/source_list.json +++ b/build/source_list.json
@@ -692,6 +692,7 @@ "quiche/common/platform/api/quiche_file_utils.h", "quiche/common/platform/api/quiche_system_event_loop.h", "quiche/quic/platform/api/quic_default_proof_providers.h", + "quiche/quic/tools/connect_tunnel.h", "quiche/quic/tools/fake_proof_verifier.h", "quiche/quic/tools/quic_backend_response.h", "quiche/quic/tools/quic_client_base.h", @@ -715,6 +716,7 @@ ], "quiche_tool_support_srcs": [ "quiche/common/platform/api/quiche_file_utils.cc", + "quiche/quic/tools/connect_tunnel.cc", "quiche/quic/tools/quic_backend_response.cc", "quiche/quic/tools/quic_client_base.cc", "quiche/quic/tools/quic_client_default_network_helper.cc", @@ -1269,6 +1271,7 @@ "quiche/quic/test_tools/simple_session_notifier_test.cc", "quiche/quic/test_tools/simulator/quic_endpoint_test.cc", "quiche/quic/test_tools/simulator/simulator_test.cc", + "quiche/quic/tools/connect_tunnel_test.cc", "quiche/quic/tools/quic_default_client_test.cc", "quiche/quic/tools/quic_memory_cache_backend_test.cc", "quiche/quic/tools/quic_tcp_like_trace_converter_test.cc",
diff --git a/quiche/quic/core/quic_error_codes.cc b/quiche/quic/core/quic_error_codes.cc index e63141b..9cd540a 100644 --- a/quiche/quic/core/quic_error_codes.cc +++ b/quiche/quic/core/quic_error_codes.cc
@@ -973,6 +973,17 @@ IetfResetStreamErrorCodeToRstStreamErrorCode(code), code); } +// static +QuicResetStreamError QuicResetStreamError::FromIetf(QuicHttp3ErrorCode code) { + return FromIetf(static_cast<uint64_t>(code)); +} + +// static +QuicResetStreamError QuicResetStreamError::FromIetf( + QuicHttpQpackErrorCode code) { + return FromIetf(static_cast<uint64_t>(code)); +} + #undef RETURN_STRING_LITERAL // undef for jumbo builds } // namespace quic
diff --git a/quiche/quic/core/quic_error_codes.h b/quiche/quic/core/quic_error_codes.h index 34fc0c1..bbd1d39 100644 --- a/quiche/quic/core/quic_error_codes.h +++ b/quiche/quic/core/quic_error_codes.h
@@ -627,6 +627,37 @@ static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()), "QuicErrorCode exceeds four octets"); +// Wire values for HTTP/3 errors. +// https://www.rfc-editor.org/rfc/rfc9114.html#http-error-codes +enum class QuicHttp3ErrorCode { + // NO_ERROR is defined as a C preprocessor macro on Windows. + HTTP3_NO_ERROR = 0x100, + GENERAL_PROTOCOL_ERROR = 0x101, + INTERNAL_ERROR = 0x102, + STREAM_CREATION_ERROR = 0x103, + CLOSED_CRITICAL_STREAM = 0x104, + FRAME_UNEXPECTED = 0x105, + FRAME_ERROR = 0x106, + EXCESSIVE_LOAD = 0x107, + ID_ERROR = 0x108, + SETTINGS_ERROR = 0x109, + MISSING_SETTINGS = 0x10A, + REQUEST_REJECTED = 0x10B, + REQUEST_CANCELLED = 0x10C, + REQUEST_INCOMPLETE = 0x10D, + MESSAGE_ERROR = 0x10E, + CONNECT_ERROR = 0x10F, + VERSION_FALLBACK = 0x110, +}; + +// Wire values for QPACK errors. +// https://www.rfc-editor.org/rfc/rfc9204.html#error-code-registration +enum class QuicHttpQpackErrorCode { + DECOMPRESSION_FAILED = 0x200, + ENCODER_STREAM_ERROR = 0x201, + DECODER_STREAM_ERROR = 0x202 +}; + // Represents a reason for resetting a stream in both gQUIC and IETF error code // space. Both error codes have to be present. class QUIC_EXPORT_PRIVATE QuicResetStreamError { @@ -637,6 +668,8 @@ // Constructs a QuicResetStreamError from an IETF error code; the internal // error code is inferred. static QuicResetStreamError FromIetf(uint64_t code); + static QuicResetStreamError FromIetf(QuicHttp3ErrorCode code); + static QuicResetStreamError FromIetf(QuicHttpQpackErrorCode code); // Constructs a QuicResetStreamError with no error. static QuicResetStreamError NoError() { return FromInternal(QUIC_STREAM_NO_ERROR); @@ -716,37 +749,6 @@ QUIC_EXPORT_PRIVATE QuicErrorCodeToIetfMapping QuicErrorCodeToTransportErrorCode(QuicErrorCode error); -// Wire values for HTTP/3 errors. -// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#http-error-codes -enum class QuicHttp3ErrorCode { - // NO_ERROR is defined as a C preprocessor macro on Windows. - HTTP3_NO_ERROR = 0x100, - GENERAL_PROTOCOL_ERROR = 0x101, - INTERNAL_ERROR = 0x102, - STREAM_CREATION_ERROR = 0x103, - CLOSED_CRITICAL_STREAM = 0x104, - FRAME_UNEXPECTED = 0x105, - FRAME_ERROR = 0x106, - EXCESSIVE_LOAD = 0x107, - ID_ERROR = 0x108, - SETTINGS_ERROR = 0x109, - MISSING_SETTINGS = 0x10A, - REQUEST_REJECTED = 0x10B, - REQUEST_CANCELLED = 0x10C, - REQUEST_INCOMPLETE = 0x10D, - MESSAGE_ERROR = 0x10E, - CONNECT_ERROR = 0x10F, - VERSION_FALLBACK = 0x110, -}; - -// Wire values for QPACK errors. -// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#error-code-registration -enum class QuicHttpQpackErrorCode { - DECOMPRESSION_FAILED = 0x200, - ENCODER_STREAM_ERROR = 0x201, - DECODER_STREAM_ERROR = 0x202 -}; - // Convert a QuicRstStreamErrorCode to an application error code to be used in // an IETF QUIC RESET_STREAM frame QUIC_EXPORT_PRIVATE uint64_t RstStreamErrorCodeToIetfResetStreamErrorCode(
diff --git a/quiche/quic/tools/connect_tunnel.cc b/quiche/quic/tools/connect_tunnel.cc new file mode 100644 index 0000000..bd0b66e --- /dev/null +++ b/quiche/quic/tools/connect_tunnel.cc
@@ -0,0 +1,337 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_tunnel.h" + +#include <cstdint> +#include <limits> +#include <string> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "url/third_party/mozilla/url_parse.h" +#include "quiche/quic/core/io/socket_factory.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_client.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace { + +// Arbitrarily chosen. No effort has been made to figure out an optimal size. +constexpr size_t kReadSize = 4 * 1024; + +absl::optional<ConnectTunnel::HostAndPort> ValidateAndParseAuthorityString( + absl::string_view authority_string) { + url::Component username_component; + url::Component password_component; + url::Component host_component; + url::Component port_component; + + url::ParseAuthority(authority_string.data(), + url::Component(0, authority_string.size()), + &username_component, &password_component, &host_component, + &port_component); + + // A valid CONNECT authority must contain host and port and nothing else, per + // https://www.rfc-editor.org/rfc/rfc9110.html#name-connect. + if (username_component.is_valid() || password_component.is_valid() || + !host_component.is_nonempty() || !port_component.is_nonempty()) { + QUICHE_DVLOG(1) << "CONNECT request authority is malformed: " + << authority_string; + return absl::nullopt; + } + + QUICHE_DCHECK_LT(static_cast<size_t>(host_component.end()), + authority_string.length()); + if (authority_string.length() > 2 && + authority_string.data()[host_component.begin] == '[' && + authority_string.data()[host_component.end() - 1] == ']') { + // Strip "[]" off IPv6 literals. + host_component.begin += 1; + host_component.len -= 2; + } + std::string hostname(authority_string.data() + host_component.begin, + host_component.len); + + int parsed_port_number = + url::ParsePort(authority_string.data(), port_component); + // Negative result is either invalid or unspecified, either of which is + // disallowed for a CONNECT authority. Port 0 is technically valid but + // reserved and not really usable in practice, so easiest to just disallow it + // here. + if (parsed_port_number <= 0) { + QUICHE_DVLOG(1) << "CONNECT request authority port is malformed: " + << authority_string; + return absl::nullopt; + } + QUICHE_DCHECK_LE(parsed_port_number, std::numeric_limits<uint16_t>::max()); + + return std::make_pair(std::move(hostname), + static_cast<uint16_t>(parsed_port_number)); +} + +absl::optional<ConnectTunnel::HostAndPort> ValidateHeadersAndGetAuthority( + const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(request_headers.contains(":method")); + QUICHE_DCHECK(request_headers.find(":method")->second == "CONNECT"); + QUICHE_DCHECK(!request_headers.contains(":protocol")); + + auto scheme_it = request_headers.find(":scheme"); + if (scheme_it != request_headers.end()) { + QUICHE_DVLOG(1) << "CONNECT request contains unexpected scheme: " + << scheme_it->second; + return absl::nullopt; + } + + auto path_it = request_headers.find(":path"); + if (path_it != request_headers.end()) { + QUICHE_DVLOG(1) << "CONNECT request contains unexpected path: " + << path_it->second; + return absl::nullopt; + } + + auto authority_it = request_headers.find(":authority"); + if (authority_it == request_headers.end() || authority_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT request missing authority"; + return absl::nullopt; + } + + return ValidateAndParseAuthorityString(authority_it->second); +} + +bool ValidateAuthority( + const ConnectTunnel::HostAndPort& authority, + const absl::flat_hash_set<std::pair<std::string, uint16_t>>& + acceptable_destinations) { + if (acceptable_destinations.contains(authority)) { + return true; + } + + QUICHE_DVLOG(1) << "CONNECT request authority: " + << absl::StrCat(authority.first, ":", authority.second) + << " is not an acceptable allow-listed destiation "; + return false; +} + +} // namespace + +ConnectTunnel::ConnectTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, + absl::flat_hash_set<HostAndPort> acceptable_destinations) + : acceptable_destinations_(std::move(acceptable_destinations)), + socket_factory_(socket_factory), + client_stream_request_handler_(client_stream_request_handler) { + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(socket_factory_); +} + +ConnectTunnel::~ConnectTunnel() { + // Expect client and destination sides of tunnel to both be closed before + // destruction. + QUICHE_DCHECK_EQ(client_stream_request_handler_, nullptr); + QUICHE_DCHECK(!IsConnectedToDestination()); + QUICHE_DCHECK(!receive_started_); +} + +void ConnectTunnel::OpenTunnel(const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(!IsConnectedToDestination()); + + absl::optional<HostAndPort> authority = + ValidateHeadersAndGetAuthority(request_headers); + if (!authority) { + TerminateClientStream( + "invalid request headers", + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::MESSAGE_ERROR)); + return; + } + + if (!ValidateAuthority(authority.value(), acceptable_destinations_)) { + TerminateClientStream( + "disallowed request authority", + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::REQUEST_REJECTED)); + return; + } + + QuicSocketAddress address = + tools::LookupAddress(AF_UNSPEC, authority.value().first, + absl::StrCat(authority.value().second)); + if (!address.IsInitialized()) { + TerminateClientStream("host resolution error"); + return; + } + + destination_socket_ = + socket_factory_->CreateTcpClientSocket(address, + /*receive_buffer_size=*/0, + /*send_buffer_size=*/0, + /*async_visitor=*/this); + QUICHE_DCHECK(destination_socket_); + + absl::Status connect_result = destination_socket_->ConnectBlocking(); + if (!connect_result.ok()) { + TerminateClientStream( + "error connecting TCP socket to destination server: " + + connect_result.ToString()); + return; + } + + QUICHE_DVLOG(1) << "CONNECT tunnel opened from stream " + << client_stream_request_handler_->stream_id() << " to " + << authority.value().first << ":" << authority.value().second; + + SendConnectResponse(); + BeginAsyncReadFromDestination(); +} + +bool ConnectTunnel::IsConnectedToDestination() const { + return !!destination_socket_; +} + +void ConnectTunnel::SendDataToDestination(absl::string_view data) { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(!data.empty()); + + absl::Status send_result = + destination_socket_->SendBlocking(std::string(data)); + if (!send_result.ok()) { + TerminateClientStream("TCP error sending data to destination server: " + + send_result.ToString()); + } +} + +void ConnectTunnel::OnClientStreamClose() { + QUICHE_DCHECK(client_stream_request_handler_); + + QUICHE_DVLOG(1) << "CONNECT stream " + << client_stream_request_handler_->stream_id() << " closed"; + + client_stream_request_handler_ = nullptr; + + if (IsConnectedToDestination()) { + // TODO(ericorth): Consider just calling shutdown() on the socket rather + // than fully disconnecting in order to allow a graceful TCP FIN stream + // shutdown per + // https://www.rfc-editor.org/rfc/rfc9114.html#name-the-connect-method. + // Would require shutdown support in the socket library, and would need to + // deal with the tunnel/socket outliving the client stream. + destination_socket_->Disconnect(); + } + + // Clear socket pointer. + destination_socket_.reset(); +} + +void ConnectTunnel::ConnectComplete(absl::Status /*status*/) { + // Async connect not expected. + QUICHE_NOTREACHED(); +} + +void ConnectTunnel::ReceiveComplete( + absl::StatusOr<quiche::QuicheMemSlice> data) { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(receive_started_); + + receive_started_ = false; + + if (!data.ok()) { + if (client_stream_request_handler_) { + TerminateClientStream("TCP error receiving data from destination server"); + } else { + // This typically just means a receive operation was cancelled on calling + // destination_socket_->Disconnect(). + QUICHE_DVLOG(1) << "TCP error receiving data from destination server " + "after stream already closed."; + } + return; + } else if (data.value().empty()) { + OnDestinationConnectionClosed(); + return; + } + + QUICHE_DCHECK(client_stream_request_handler_); + client_stream_request_handler_->SendStreamData(data.value().AsStringView(), + /*close_stream=*/false); + + BeginAsyncReadFromDestination(); +} + +void ConnectTunnel::SendComplete(absl::Status /*status*/) { + // Async send not expected. + QUICHE_NOTREACHED(); +} + +void ConnectTunnel::BeginAsyncReadFromDestination() { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(!receive_started_); + + receive_started_ = true; + destination_socket_->ReceiveAsync(kReadSize); +} + +void ConnectTunnel::OnDestinationConnectionClosed() { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(client_stream_request_handler_); + + QUICHE_DVLOG(1) << "CONNECT stream " + << client_stream_request_handler_->stream_id() + << " destination connection closed"; + destination_socket_->Disconnect(); + + // Clear socket pointer. + destination_socket_.reset(); + + // Extra check that nothing in the Disconnect could lead to terminating the + // stream. + QUICHE_DCHECK(client_stream_request_handler_); + + client_stream_request_handler_->SendStreamData("", /*close_stream=*/true); +} + +void ConnectTunnel::SendConnectResponse() { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(client_stream_request_handler_); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + + QuicBackendResponse response; + response.set_headers(std::move(response_headers)); + // Need to leave the stream open after sending the CONNECT response. + response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + + client_stream_request_handler_->OnResponseBackendComplete(&response); +} + +void ConnectTunnel::TerminateClientStream(absl::string_view error_description, + QuicResetStreamError error_code) { + QUICHE_DCHECK(client_stream_request_handler_); + + std::string error_description_str = + error_description.empty() ? "" + : absl::StrCat(" due to ", error_description); + QUICHE_DVLOG(1) << "Terminating CONNECT stream " + << client_stream_request_handler_->stream_id() + << " with error code " << error_code.ietf_application_code() + << error_description_str; + + client_stream_request_handler_->TerminateStreamWithError(error_code); +} + +} // namespace quic
diff --git a/quiche/quic/tools/connect_tunnel.h b/quiche/quic/tools/connect_tunnel.h new file mode 100644 index 0000000..f638324 --- /dev/null +++ b/quiche/quic/tools/connect_tunnel.h
@@ -0,0 +1,90 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ +#define QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/socket_factory.h" +#include "quiche/quic/core/io/stream_client_socket.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// Manages a single connection tunneled over a CONNECT proxy. +class ConnectTunnel : public StreamClientSocket::AsyncVisitor { + public: + using HostAndPort = std::pair<std::string, uint16_t>; + + // `client_stream_request_handler` and `socket_factory` must both outlive the + // created ConnectTunnel. + ConnectTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, + absl::flat_hash_set<HostAndPort> acceptable_destinations); + ~ConnectTunnel(); + ConnectTunnel(const ConnectTunnel&) = delete; + ConnectTunnel& operator=(const ConnectTunnel&) = delete; + + // Attempts to open TCP connection to destination server and then sends + // appropriate success/error response to the request stream. `request_headers` + // must represent headers from a CONNECT request, that is ":method"="CONNECT" + // and no ":protocol". + void OpenTunnel(const spdy::Http2HeaderBlock& request_headers); + + // Returns true iff the connection to the destination server is currently open + bool IsConnectedToDestination() const; + + void SendDataToDestination(absl::string_view data); + + // Called when the client stream has been closed. Connection to destination + // server is closed if connected. The RequestHandler will no longer be + // interacted with after completion. + void OnClientStreamClose(); + + // StreamClientSocket::AsyncVisitor: + void ConnectComplete(absl::Status status) override; + void ReceiveComplete(absl::StatusOr<quiche::QuicheMemSlice> data) override; + void SendComplete(absl::Status status) override; + + private: + void BeginAsyncReadFromDestination(); + void OnDataReceivedFromDestination(bool success); + + // For normal (FIN) closure. Errors (RST) should result in directly calling + // TerminateClientStream(). + void OnDestinationConnectionClosed(); + + void SendConnectResponse(); + void TerminateClientStream( + absl::string_view error_description, + QuicResetStreamError error_code = + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::CONNECT_ERROR)); + + const absl::flat_hash_set<HostAndPort> acceptable_destinations_; + SocketFactory* const socket_factory_; + + // Null when client stream closed. + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler_; + + // Null when destination connection disconnected. + std::unique_ptr<StreamClientSocket> destination_socket_; + + bool receive_started_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_
diff --git a/quiche/quic/tools/connect_tunnel_test.cc b/quiche/quic/tools/connect_tunnel_test.cc new file mode 100644 index 0000000..a9d4c1d --- /dev/null +++ b/quiche/quic/tools/connect_tunnel_test.cc
@@ -0,0 +1,342 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_tunnel.h" + +#include <cstdint> +#include <utility> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/socket_factory.h" +#include "quiche/quic/core/io/stream_client_socket.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::AllOf; +using ::testing::AnyOf; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::InvokeWithoutArgs; +using ::testing::IsEmpty; +using ::testing::Matcher; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Property; +using ::testing::Return; +using ::testing::StrictMock; + +class MockRequestHandler : public QuicSimpleServerBackend::RequestHandler { + public: + QuicConnectionId connection_id() const override { + return TestConnectionId(41212); + } + QuicStreamId stream_id() const override { return 100; } + std::string peer_host() const override { return "127.0.0.1"; } + + MOCK_METHOD(void, OnResponseBackendComplete, + (const QuicBackendResponse* response), (override)); + MOCK_METHOD(void, SendStreamData, (absl::string_view data, bool close_stream), + (override)); + MOCK_METHOD(void, TerminateStreamWithError, (QuicResetStreamError error), + (override)); +}; + +class MockSocketFactory : public SocketFactory { + public: + MOCK_METHOD(std::unique_ptr<StreamClientSocket>, CreateTcpClientSocket, + (const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + StreamClientSocket::AsyncVisitor* async_visitor), + (override)); +}; + +class MockSocket : public StreamClientSocket { + public: + MOCK_METHOD(absl::Status, ConnectBlocking, (), (override)); + MOCK_METHOD(void, ConnectAsync, (), (override)); + MOCK_METHOD(void, Disconnect, (), (override)); + MOCK_METHOD(absl::StatusOr<quiche::QuicheMemSlice>, ReceiveBlocking, + (QuicByteCount max_size), (override)); + MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (std::string data), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (quiche::QuicheMemSlice data), + (override)); + MOCK_METHOD(void, SendAsync, (std::string data), (override)); + MOCK_METHOD(void, SendAsync, (quiche::QuicheMemSlice data), (override)); +}; + +class ConnectTunnelTest : public quiche::test::QuicheTest { + public: + void SetUp() override { + auto socket = std::make_unique<StrictMock<MockSocket>>(); + socket_ = socket.get(); + ON_CALL(socket_factory_, + CreateTcpClientSocket( + AnyOf(QuicSocketAddress(TestLoopback4(), kAcceptablePort), + QuicSocketAddress(TestLoopback6(), kAcceptablePort)), + _, _, &tunnel_)) + .WillByDefault(Return(ByMove(std::move(socket)))); + } + + protected: + static constexpr absl::string_view kAcceptableDestination = "localhost"; + static constexpr uint16_t kAcceptablePort = 977; + + StrictMock<MockRequestHandler> request_handler_; + NiceMock<MockSocketFactory> socket_factory_; + StrictMock<MockSocket>* socket_; + + ConnectTunnel tunnel_{&request_handler_, + &socket_factory_, + /*acceptable_destinations=*/ + {{std::string(kAcceptableDestination), kAcceptablePort}, + {TestLoopback4().ToString(), kAcceptablePort}, + {TestLoopback6().ToString(), kAcceptablePort}}}; +}; + +TEST_F(ConnectTunnelTest, OpenTunnel) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + spdy::Http2HeaderBlock expected_response_headers; + expected_response_headers[":status"] = "200"; + QuicBackendResponse expected_response; + expected_response.set_headers(std::move(expected_response_headers)); + expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + EXPECT_CALL(request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + ElementsAre(Pair(":status", "200"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); +} + +TEST_F(ConnectTunnelTest, OpenTunnelToIpv4LiteralDestination) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + spdy::Http2HeaderBlock expected_response_headers; + expected_response_headers[":status"] = "200"; + QuicBackendResponse expected_response; + expected_response.set_headers(std::move(expected_response_headers)); + expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + EXPECT_CALL(request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + ElementsAre(Pair(":status", "200"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(TestLoopback4().ToString(), ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); +} + +TEST_F(ConnectTunnelTest, OpenTunnelToIpv6LiteralDestination) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + spdy::Http2HeaderBlock expected_response_headers; + expected_response_headers[":status"] = "200"; + QuicBackendResponse expected_response; + expected_response.set_headers(std::move(expected_response_headers)); + expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + EXPECT_CALL(request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + ElementsAre(Pair(":status", "200"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat("[", TestLoopback6().ToString(), "]:", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); +} + +TEST_F(ConnectTunnelTest, OpenTunnelWithMalformedRequest) { + EXPECT_CALL(request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast<uint64_t>(QuicHttp3ErrorCode::MESSAGE_ERROR)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + // No ":authority" header. + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, OpenTunnelWithUnacceptableDestination) { + EXPECT_CALL( + request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast<uint64_t>(QuicHttp3ErrorCode::REQUEST_REJECTED)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = "unacceptable.test:100"; + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, ReceiveFromDestination) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Ge(kData.size()))).Times(2); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + EXPECT_CALL(request_handler_, SendStreamData(kData, /*close_stream=*/false)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receiving `kData`. + tunnel_.ReceiveComplete(MemSliceFromString(kData)); + + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, SendToDestination) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, SendBlocking(Matcher<std::string>(Eq(kData)))) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + tunnel_.SendDataToDestination(kData); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, DestinationDisconnect) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + EXPECT_CALL(request_handler_, SendStreamData("", /*close_stream=*/true)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receiving empty data. + tunnel_.ReceiveComplete(quiche::QuicheMemSlice()); + + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); + + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, DestinationTcpConnectionError) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + EXPECT_CALL(request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast<uint64_t>(QuicHttp3ErrorCode::CONNECT_ERROR)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receving error. + tunnel_.ReceiveComplete(absl::UnknownError("error")); + + tunnel_.OnClientStreamClose(); +} + +} // namespace +} // namespace quic::test