Create QUICHE toy CONNECT proxy server

PiperOrigin-RevId: 466743774
diff --git a/build/source_list.bzl b/build/source_list.bzl
index 313dea4..90dd052 100644
--- a/build/source_list.bzl
+++ b/build/source_list.bzl
@@ -693,6 +693,7 @@
     "common/platform/api/quiche_file_utils.h",
     "common/platform/api/quiche_system_event_loop.h",
     "quic/platform/api/quic_default_proof_providers.h",
+    "quic/tools/connect_server_backend.h",
     "quic/tools/connect_tunnel.h",
     "quic/tools/fake_proof_verifier.h",
     "quic/tools/quic_backend_response.h",
@@ -717,6 +718,7 @@
 ]
 quiche_tool_support_srcs = [
     "common/platform/api/quiche_file_utils.cc",
+    "quic/tools/connect_server_backend.cc",
     "quic/tools/connect_tunnel.cc",
     "quic/tools/quic_backend_response.cc",
     "quic/tools/quic_client_base.cc",
diff --git a/build/source_list.gni b/build/source_list.gni
index 0c08a73..842512d 100644
--- a/build/source_list.gni
+++ b/build/source_list.gni
@@ -693,6 +693,7 @@
     "src/quiche/common/platform/api/quiche_file_utils.h",
     "src/quiche/common/platform/api/quiche_system_event_loop.h",
     "src/quiche/quic/platform/api/quic_default_proof_providers.h",
+    "src/quiche/quic/tools/connect_server_backend.h",
     "src/quiche/quic/tools/connect_tunnel.h",
     "src/quiche/quic/tools/fake_proof_verifier.h",
     "src/quiche/quic/tools/quic_backend_response.h",
@@ -717,6 +718,7 @@
 ]
 quiche_tool_support_srcs = [
     "src/quiche/common/platform/api/quiche_file_utils.cc",
+    "src/quiche/quic/tools/connect_server_backend.cc",
     "src/quiche/quic/tools/connect_tunnel.cc",
     "src/quiche/quic/tools/quic_backend_response.cc",
     "src/quiche/quic/tools/quic_client_base.cc",
diff --git a/build/source_list.json b/build/source_list.json
index a7fdf5b..6d14856 100644
--- a/build/source_list.json
+++ b/build/source_list.json
@@ -692,6 +692,7 @@
     "quiche/common/platform/api/quiche_file_utils.h",
     "quiche/common/platform/api/quiche_system_event_loop.h",
     "quiche/quic/platform/api/quic_default_proof_providers.h",
+    "quiche/quic/tools/connect_server_backend.h",
     "quiche/quic/tools/connect_tunnel.h",
     "quiche/quic/tools/fake_proof_verifier.h",
     "quiche/quic/tools/quic_backend_response.h",
@@ -716,6 +717,7 @@
   ],
   "quiche_tool_support_srcs": [
     "quiche/common/platform/api/quiche_file_utils.cc",
+    "quiche/quic/tools/connect_server_backend.cc",
     "quiche/quic/tools/connect_tunnel.cc",
     "quiche/quic/tools/quic_backend_response.cc",
     "quiche/quic/tools/quic_client_base.cc",
diff --git a/quiche/quic/tools/connect_server_backend.cc b/quiche/quic/tools/connect_server_backend.cc
new file mode 100644
index 0000000..baee9c1
--- /dev/null
+++ b/quiche/quic/tools/connect_server_backend.cc
@@ -0,0 +1,133 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quiche/quic/tools/connect_server_backend.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "quiche/quic/core/io/socket_factory.h"
+#include "quiche/quic/tools/connect_tunnel.h"
+#include "quiche/quic/tools/quic_simple_server_backend.h"
+#include "quiche/common/platform/api/quiche_bug_tracker.h"
+#include "quiche/common/platform/api/quiche_logging.h"
+#include "quiche/spdy/core/http2_header_block.h"
+
+namespace quic {
+
+namespace {
+
+void SendErrorResponse(QuicSimpleServerBackend::RequestHandler* request_handler,
+                       absl::string_view error_code) {
+  spdy::Http2HeaderBlock headers;
+  headers[":status"] = error_code;
+  QuicBackendResponse response;
+  response.set_headers(std::move(headers));
+  request_handler->OnResponseBackendComplete(&response);
+}
+
+}  // namespace
+
+ConnectServerBackend::ConnectServerBackend(
+    std::unique_ptr<QuicSimpleServerBackend> non_connect_backend,
+    absl::flat_hash_set<ConnectTunnel::HostAndPort> acceptable_destinations)
+    : non_connect_backend_(std::move(non_connect_backend)),
+      acceptable_destinations_(std::move(acceptable_destinations)) {
+  QUICHE_DCHECK(non_connect_backend_);
+}
+
+ConnectServerBackend::~ConnectServerBackend() {
+  // Expect all streams to be closed before destroying backend.
+  QUICHE_DCHECK(tunnels_.empty());
+}
+
+bool ConnectServerBackend::InitializeBackend(const std::string&) {
+  return true;
+}
+
+bool ConnectServerBackend::IsBackendInitialized() const { return true; }
+
+void ConnectServerBackend::SetSocketFactory(SocketFactory* socket_factory) {
+  QUICHE_DCHECK_NE(socket_factory_, socket_factory);
+  QUICHE_DCHECK(tunnels_.empty());
+  socket_factory_ = socket_factory;
+}
+
+void ConnectServerBackend::FetchResponseFromBackend(
+    const spdy::Http2HeaderBlock& request_headers,
+    const std::string& request_body, RequestHandler* request_handler) {
+  // Not a CONNECT request, so send to `non_connect_backend_`.
+  non_connect_backend_->FetchResponseFromBackend(request_headers, request_body,
+                                                 request_handler);
+}
+
+void ConnectServerBackend::HandleConnectHeaders(
+    const spdy::Http2HeaderBlock& request_headers,
+    RequestHandler* request_handler) {
+  QUICHE_DCHECK(request_headers.contains(":method") &&
+                request_headers.find(":method")->second == "CONNECT");
+
+  if (!socket_factory_) {
+    QUICHE_BUG(connect_server_backend_no_socket_factory)
+        << "Must set socket factory before ConnectServerBackend receives "
+           "requests.";
+    SendErrorResponse(request_handler, "500");
+    return;
+  }
+
+  if (request_headers.contains(":protocol")) {
+    // Anything other than normal CONNECT not supported.
+    // TODO(ericorth): Add CONNECT-UDP support.
+    non_connect_backend_->HandleConnectHeaders(request_headers,
+                                               request_handler);
+    return;
+  }
+
+  auto [tunnel_it, inserted] = tunnels_.emplace(
+      request_handler->stream_id(),
+      std::make_unique<ConnectTunnel>(request_handler, socket_factory_,
+                                      acceptable_destinations_));
+  QUICHE_DCHECK(inserted);
+
+  tunnel_it->second->OpenTunnel(request_headers);
+}
+
+void ConnectServerBackend::HandleConnectData(absl::string_view data,
+                                             bool data_complete,
+                                             RequestHandler* request_handler) {
+  auto tunnel_it = tunnels_.find(request_handler->stream_id());
+  if (tunnel_it == tunnels_.end()) {
+    // If tunnel not found, perhaps it's something being handled for
+    // non-CONNECT. Possible because this method could be called for anything
+    // with a ":method":"CONNECT" header, but this class does not handle such
+    // requests if they have a ":protocol" header.
+    non_connect_backend_->HandleConnectData(data, data_complete,
+                                            request_handler);
+    return;
+  }
+
+  if (!data.empty()) {
+    tunnel_it->second->SendDataToDestination(data);
+  }
+  if (data_complete) {
+    tunnel_it->second->OnClientStreamClose();
+    tunnels_.erase(tunnel_it);
+  }
+}
+
+void ConnectServerBackend::CloseBackendResponseStream(
+    QuicSimpleServerBackend::RequestHandler* request_handler) {
+  auto tunnel_it = tunnels_.find(request_handler->stream_id());
+  if (tunnel_it != tunnels_.end()) {
+    tunnel_it->second->OnClientStreamClose();
+    tunnels_.erase(tunnel_it);
+  }
+
+  non_connect_backend_->CloseBackendResponseStream(request_handler);
+}
+
+}  // namespace quic
diff --git a/quiche/quic/tools/connect_server_backend.h b/quiche/quic/tools/connect_server_backend.h
new file mode 100644
index 0000000..a1cd843
--- /dev/null
+++ b/quiche/quic/tools/connect_server_backend.h
@@ -0,0 +1,60 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_CONNECT_PROXY_CONNECT_SERVER_BACKEND_H_
+#define QUICHE_QUIC_CONNECT_PROXY_CONNECT_SERVER_BACKEND_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "quiche/quic/core/io/socket_factory.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/tools/connect_tunnel.h"
+#include "quiche/quic/tools/quic_simple_server_backend.h"
+
+namespace quic {
+
+// QUIC server backend that handles CONNECT requests. Non-CONNECT requests are
+// delegated to a separate backend.
+class ConnectServerBackend : public QuicSimpleServerBackend {
+ public:
+  ConnectServerBackend(
+      std::unique_ptr<QuicSimpleServerBackend> non_connect_backend,
+      absl::flat_hash_set<ConnectTunnel::HostAndPort> acceptable_destinations);
+
+  ConnectServerBackend(const ConnectServerBackend&) = delete;
+  ConnectServerBackend& operator=(const ConnectServerBackend&) = delete;
+
+  ~ConnectServerBackend() override;
+
+  // QuicSimpleServerBackend:
+  bool InitializeBackend(const std::string& backend_url) override;
+  bool IsBackendInitialized() const override;
+  void SetSocketFactory(SocketFactory* socket_factory) override;
+  void FetchResponseFromBackend(const spdy::Http2HeaderBlock& request_headers,
+                                const std::string& request_body,
+                                RequestHandler* request_handler) override;
+  void HandleConnectHeaders(const spdy::Http2HeaderBlock& request_headers,
+                            RequestHandler* request_handler) override;
+  void HandleConnectData(absl::string_view data, bool data_complete,
+                         RequestHandler* request_handler) override;
+  void CloseBackendResponseStream(
+      QuicSimpleServerBackend::RequestHandler* request_handler) override;
+
+ private:
+  std::unique_ptr<QuicSimpleServerBackend> non_connect_backend_;
+  const absl::flat_hash_set<ConnectTunnel::HostAndPort>
+      acceptable_destinations_;
+
+  SocketFactory* socket_factory_;  // unowned
+  absl::flat_hash_map<QuicStreamId, std::unique_ptr<ConnectTunnel>> tunnels_;
+};
+
+}  // namespace quic
+
+#endif  // QUICHE_QUIC_CONNECT_PROXY_CONNECT_SERVER_BACKEND_H_
diff --git a/quiche/quic/tools/connect_tunnel.cc b/quiche/quic/tools/connect_tunnel.cc
index bd0b66e..ef2a5eb 100644
--- a/quiche/quic/tools/connect_tunnel.cc
+++ b/quiche/quic/tools/connect_tunnel.cc
@@ -81,8 +81,8 @@
   }
   QUICHE_DCHECK_LE(parsed_port_number, std::numeric_limits<uint16_t>::max());
 
-  return std::make_pair(std::move(hostname),
-                        static_cast<uint16_t>(parsed_port_number));
+  return ConnectTunnel::HostAndPort(std::move(hostname),
+                                    static_cast<uint16_t>(parsed_port_number));
 }
 
 absl::optional<ConnectTunnel::HostAndPort> ValidateHeadersAndGetAuthority(
@@ -114,22 +114,28 @@
   return ValidateAndParseAuthorityString(authority_it->second);
 }
 
-bool ValidateAuthority(
-    const ConnectTunnel::HostAndPort& authority,
-    const absl::flat_hash_set<std::pair<std::string, uint16_t>>&
-        acceptable_destinations) {
+bool ValidateAuthority(const ConnectTunnel::HostAndPort& authority,
+                       const absl::flat_hash_set<ConnectTunnel::HostAndPort>&
+                           acceptable_destinations) {
   if (acceptable_destinations.contains(authority)) {
     return true;
   }
 
   QUICHE_DVLOG(1) << "CONNECT request authority: "
-                  << absl::StrCat(authority.first, ":", authority.second)
+                  << absl::StrCat(authority.host, ":", authority.port)
                   << " is not an acceptable allow-listed destiation ";
   return false;
 }
 
 }  // namespace
 
+ConnectTunnel::HostAndPort::HostAndPort(std::string host, uint16_t port)
+    : host(std::move(host)), port(port) {}
+
+bool ConnectTunnel::HostAndPort::operator==(const HostAndPort& other) const {
+  return host == other.host && port == other.port;
+}
+
 ConnectTunnel::ConnectTunnel(
     QuicSimpleServerBackend::RequestHandler* client_stream_request_handler,
     SocketFactory* socket_factory,
@@ -168,9 +174,8 @@
     return;
   }
 
-  QuicSocketAddress address =
-      tools::LookupAddress(AF_UNSPEC, authority.value().first,
-                           absl::StrCat(authority.value().second));
+  QuicSocketAddress address = tools::LookupAddress(
+      AF_UNSPEC, authority->host, absl::StrCat(authority->port));
   if (!address.IsInitialized()) {
     TerminateClientStream("host resolution error");
     return;
@@ -193,7 +198,7 @@
 
   QUICHE_DVLOG(1) << "CONNECT tunnel opened from stream "
                   << client_stream_request_handler_->stream_id() << " to "
-                  << authority.value().first << ":" << authority.value().second;
+                  << authority->host << ":" << authority->port;
 
   SendConnectResponse();
   BeginAsyncReadFromDestination();
diff --git a/quiche/quic/tools/connect_tunnel.h b/quiche/quic/tools/connect_tunnel.h
index f638324..d18d63f 100644
--- a/quiche/quic/tools/connect_tunnel.h
+++ b/quiche/quic/tools/connect_tunnel.h
@@ -26,7 +26,19 @@
 // Manages a single connection tunneled over a CONNECT proxy.
 class ConnectTunnel : public StreamClientSocket::AsyncVisitor {
  public:
-  using HostAndPort = std::pair<std::string, uint16_t>;
+  struct HostAndPort {
+    HostAndPort(std::string host, uint16_t port);
+
+    bool operator==(const HostAndPort& other) const;
+
+    template <typename H>
+    friend H AbslHashValue(H h, const HostAndPort& host_and_port) {
+      return H::combine(std::move(h), host_and_port.host, host_and_port.port);
+    }
+
+    std::string host;
+    uint16_t port;
+  };
 
   // `client_stream_request_handler` and `socket_factory` must both outlive the
   // created ConnectTunnel.
diff --git a/quiche/quic/tools/quic_server.cc b/quiche/quic/tools/quic_server.cc
index 53985e7..0c3233a 100644
--- a/quiche/quic/tools/quic_server.cc
+++ b/quiche/quic/tools/quic_server.cc
@@ -16,6 +16,7 @@
 
 #include "quiche/quic/core/crypto/crypto_handshake.h"
 #include "quiche/quic/core/crypto/quic_random.h"
+#include "quiche/quic/core/io/event_loop_socket_factory.h"
 #include "quiche/quic/core/io/quic_default_event_loop.h"
 #include "quiche/quic/core/io/quic_event_loop.h"
 #include "quiche/quic/core/quic_clock.h"
@@ -32,6 +33,7 @@
 #include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h"
 #include "quiche/quic/tools/quic_simple_dispatcher.h"
 #include "quiche/quic/tools/quic_simple_server_backend.h"
+#include "quiche/common/simple_buffer_allocator.h"
 
 namespace quic {
 
@@ -103,11 +105,20 @@
 QuicServer::~QuicServer() {
   close(fd_);
   fd_ = -1;
+
+  // Should be fine without because nothing should send requests to the backend
+  // after `this` is destroyed, but for extra pointer safety, clear the socket
+  // factory from the backend before the socket factory is destroyed.
+  quic_simple_server_backend_->SetSocketFactory(nullptr);
 }
 
 bool QuicServer::CreateUDPSocketAndListen(const QuicSocketAddress& address) {
   event_loop_ = CreateEventLoop();
 
+  socket_factory_ = std::make_unique<EventLoopSocketFactory>(
+      event_loop_.get(), quiche::SimpleBufferAllocator::Get());
+  quic_simple_server_backend_->SetSocketFactory(socket_factory_.get());
+
   QuicUdpSocketApi socket_api;
   fd_ = socket_api.Create(address.host().AddressFamilyToInt(),
                           /*receive_buffer_size =*/kDefaultSocketReceiveBuffer,
diff --git a/quiche/quic/tools/quic_server.h b/quiche/quic/tools/quic_server.h
index 7cb870f..d080cb1 100644
--- a/quiche/quic/tools/quic_server.h
+++ b/quiche/quic/tools/quic_server.h
@@ -16,6 +16,7 @@
 #include "absl/strings/string_view.h"
 #include "quiche/quic/core/crypto/quic_crypto_server_config.h"
 #include "quiche/quic/core/io/quic_event_loop.h"
+#include "quiche/quic/core/io/socket_factory.h"
 #include "quiche/quic/core/quic_config.h"
 #include "quiche/quic/core/quic_packet_writer.h"
 #include "quiche/quic/core/quic_udp_socket.h"
@@ -114,6 +115,9 @@
 
   // Schedules alarms and notifies the server of the I/O events.
   std::unique_ptr<QuicEventLoop> event_loop_;
+  // Used by some backends to create additional sockets, e.g. for upstream
+  // destination connections for proxying.
+  std::unique_ptr<SocketFactory> socket_factory_;
   // Accepts data from the framer and demuxes clients to sessions.
   std::unique_ptr<QuicDispatcher> dispatcher_;
 
diff --git a/quiche/quic/tools/quic_simple_server_backend.h b/quiche/quic/tools/quic_simple_server_backend.h
index bbb0d93..26eaa65 100644
--- a/quiche/quic/tools/quic_simple_server_backend.h
+++ b/quiche/quic/tools/quic_simple_server_backend.h
@@ -9,10 +9,10 @@
 #include <memory>
 
 #include "absl/strings/string_view.h"
+#include "quiche/quic/core/io/socket_factory.h"
 #include "quiche/quic/core/quic_error_codes.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/core/web_transport_interface.h"
-#include "quiche/quic/platform/api/quic_logging.h"
 #include "quiche/quic/tools/quic_backend_response.h"
 #include "quiche/spdy/core/http2_header_block.h"
 
@@ -59,6 +59,10 @@
   // Returns true if the backend has been successfully initialized
   // and could be used to fetch HTTP requests
   virtual bool IsBackendInitialized() const = 0;
+  // Passes the socket factory in use by the QuicServer. Must live as long as
+  // incoming requests/data are still sent to the backend, or until cleared by
+  // calling with null. Must not be called while backend is handling requests.
+  virtual void SetSocketFactory(SocketFactory* /*socket_factory*/) {}
   // Triggers a HTTP request to be sent to the backend server or cache
   // If response is immediately available, the function synchronously calls
   // the `request_handler` with the HTTP response.
diff --git a/quiche/quic/tools/quic_toy_server.cc b/quiche/quic/tools/quic_toy_server.cc
index 6b3c76f..500e5a1 100644
--- a/quiche/quic/tools/quic_toy_server.cc
+++ b/quiche/quic/tools/quic_toy_server.cc
@@ -4,14 +4,20 @@
 
 #include "quiche/quic/tools/quic_toy_server.h"
 
+#include <limits>
 #include <utility>
 #include <vector>
 
+#include "absl/strings/str_split.h"
+#include "url/third_party/mozilla/url_parse.h"
 #include "quiche/quic/core/quic_versions.h"
 #include "quiche/quic/platform/api/quic_default_proof_providers.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
+#include "quiche/quic/tools/connect_server_backend.h"
+#include "quiche/quic/tools/connect_tunnel.h"
 #include "quiche/quic/tools/quic_memory_cache_backend.h"
 #include "quiche/common/platform/api/quiche_command_line_flags.h"
+#include "quiche/common/platform/api/quiche_logging.h"
 
 DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, port, 6121,
                                 "The port the quic server will listen on.");
@@ -39,8 +45,54 @@
 DEFINE_QUICHE_COMMAND_LINE_FLAG(bool, enable_webtransport, false,
                                 "If true, WebTransport support is enabled.");
 
+DEFINE_QUICHE_COMMAND_LINE_FLAG(
+    std::string, connect_proxy_destinations, "",
+    "Specifies a comma-separated list of destinations (\"hostname:port\") to "
+    "which the quic server will allow tunneling via CONNECT.");
+
 namespace quic {
 
+namespace {
+
+ConnectTunnel::HostAndPort ParseProxyDestination(
+    absl::string_view destination) {
+  url::Component username_component;
+  url::Component password_component;
+  url::Component host_component;
+  url::Component port_component;
+
+  url::ParseAuthority(destination.data(), url::Component(0, destination.size()),
+                      &username_component, &password_component, &host_component,
+                      &port_component);
+
+  // Only support "host:port"
+  QUICHE_CHECK(!username_component.is_valid() &&
+               !password_component.is_valid());
+  QUICHE_CHECK(host_component.is_nonempty() && port_component.is_nonempty());
+
+  QUICHE_CHECK_LT(static_cast<size_t>(host_component.end()),
+                  destination.size());
+  if (host_component.len > 2 && destination[host_component.begin] == '[' &&
+      destination[host_component.end() - 1] == ']') {
+    // Strip "[]" off IPv6 literals.
+    host_component.begin += 1;
+    host_component.len -= 2;
+  }
+  std::string hostname(destination.data() + host_component.begin,
+                       host_component.len);
+
+  int parsed_port_number = url::ParsePort(destination.data(), port_component);
+
+  // Require specified and valid port.
+  QUICHE_CHECK_GT(parsed_port_number, 0);
+  QUICHE_CHECK_LE(parsed_port_number, std::numeric_limits<uint16_t>::max());
+
+  return ConnectTunnel::HostAndPort(std::move(hostname),
+                                    static_cast<uint16_t>(parsed_port_number));
+}
+
+}  // namespace
+
 std::unique_ptr<quic::QuicSimpleServerBackend>
 QuicToyServer::MemoryCacheBackendFactory::CreateBackend() {
   auto memory_cache_backend = std::make_unique<QuicMemoryCacheBackend>();
@@ -55,6 +107,21 @@
   if (quiche::GetQuicheCommandLineFlag(FLAGS_enable_webtransport)) {
     memory_cache_backend->EnableWebTransport();
   }
+
+  if (!quiche::GetQuicheCommandLineFlag(FLAGS_connect_proxy_destinations)
+           .empty()) {
+    absl::flat_hash_set<ConnectTunnel::HostAndPort> connect_proxy_destinations;
+    for (absl::string_view destination : absl::StrSplit(
+             quiche::GetQuicheCommandLineFlag(FLAGS_connect_proxy_destinations),
+             ',', absl::SkipEmpty())) {
+      connect_proxy_destinations.insert(ParseProxyDestination(destination));
+    }
+    QUICHE_CHECK(!connect_proxy_destinations.empty());
+
+    return std::make_unique<ConnectServerBackend>(
+        std::move(memory_cache_backend), std::move(connect_proxy_destinations));
+  }
+
   return memory_cache_backend;
 }