Add OHTTP Relay support to masque_tcp_server

PiperOrigin-RevId: 844805998
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc
index 8ea142b..3c04f62 100644
--- a/quiche/quic/masque/masque_tcp_server_bin.cc
+++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -20,6 +20,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/escaping.h"
@@ -35,7 +36,9 @@
 #include "quiche/quic/core/io/socket.h"
 #include "quiche/quic/core/quic_default_clock.h"
 #include "quiche/quic/core/quic_time.h"
+#include "quiche/quic/masque/masque_connection_pool.h"
 #include "quiche/quic/masque/masque_h2_connection.h"
+#include "quiche/quic/tools/quic_url.h"
 #include "quiche/binary_http/binary_http_message.h"
 #include "quiche/common/http/http_header_block.h"
 #include "quiche/common/platform/api/quiche_command_line_flags.h"
@@ -63,6 +66,29 @@
     std::string, ohttp_key, "",
     "Hex-encoded bytes of the OHTTP HPKE private key.");
 
+DEFINE_QUICHE_COMMAND_LINE_FLAG(
+    std::string, relay_path, "",
+    "Path of the relay server to accept requests on.");
+
+DEFINE_QUICHE_COMMAND_LINE_FLAG(
+    std::string, relay_gateway_url, "",
+    "URL of the gateway that this relay will forward requests to.");
+
+DEFINE_QUICHE_COMMAND_LINE_FLAG(
+    bool, disable_certificate_verification, false,
+    "If true, don't verify the server certificate.");
+
+DEFINE_QUICHE_COMMAND_LINE_FLAG(int, address_family, 0,
+                                "IP address family to use. Must be 0, 4 or 6. "
+                                "Defaults to 0 which means any.");
+
+DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, client_cert_file, "",
+                                "Path to the client certificate chain.");
+
+DEFINE_QUICHE_COMMAND_LINE_FLAG(
+    std::string, client_cert_key_file, "",
+    "Path to the pkcs8 client certificate private key.");
+
 using quiche::BinaryHttpRequest;
 using quiche::BinaryHttpResponse;
 using quiche::ObliviousHttpGateway;
@@ -317,11 +343,21 @@
 };
 
 class MasqueTcpServer : public QuicSocketEventListener,
-                        public MasqueH2Connection::Visitor {
+                        public MasqueH2Connection::Visitor,
+                        public MasqueConnectionPool::Visitor {
  public:
-  explicit MasqueTcpServer(MasqueOhttpGateway* masque_ohttp_gateway)
+  using RequestId = MasqueConnectionPool::RequestId;
+  using Message = MasqueConnectionPool::Message;
+
+  explicit MasqueTcpServer(MasqueOhttpGateway* masque_ohttp_gateway,
+                           SSL_CTX* client_ssl_ctx,
+                           bool disable_certificate_verification,
+                           int address_family_for_lookup)
       : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())),
-        masque_ohttp_gateway_(masque_ohttp_gateway) {}
+        masque_ohttp_gateway_(masque_ohttp_gateway),
+        connection_pool_(event_loop_.get(), client_ssl_ctx,
+                         disable_certificate_verification,
+                         address_family_for_lookup, this) {}
 
   MasqueTcpServer(const MasqueTcpServer&) = delete;
   MasqueTcpServer(MasqueTcpServer&&) = delete;
@@ -450,12 +486,38 @@
         connections_.end());
   }
 
-  bool HandleOhttpRequest(MasqueH2Connection* connection, int32_t stream_id,
-                          const std::string& encapsulated_request) {
+  bool HandleOhttpGatewayRequest(MasqueH2Connection* connection,
+                                 int32_t stream_id,
+                                 const std::string& encapsulated_request) {
     return masque_ohttp_gateway_->HandleRequest(connection, stream_id,
                                                 encapsulated_request);
   }
 
+  absl::Status HandleOhttpRelayRequest(
+      MasqueH2Connection* connection, int32_t stream_id,
+      const std::string& encapsulated_request) {
+    Message request;
+    request.headers[":method"] = "POST";
+    request.headers[":scheme"] = relay_gateway_url_.scheme();
+    request.headers[":authority"] = relay_gateway_url_.HostPort();
+    request.headers[":path"] = relay_gateway_url_.path();
+    request.headers["content-type"] = "message/ohttp-req";
+    request.body = encapsulated_request;
+    absl::StatusOr<RequestId> request_id =
+        connection_pool_.SendRequest(request);
+    if (!request_id.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to send relayed request: "
+                        << request_id.status();
+      return request_id.status();
+    }
+    QUICHE_LOG(INFO) << "Sent relayed request";
+    PendingRequest pending_request;
+    pending_request.connection = connection;
+    pending_request.stream_id = stream_id;
+    pending_requests_.insert({*request_id, std::move(pending_request)});
+    return absl::OkStatus();
+  }
+
   void OnRequest(MasqueH2Connection* connection, int32_t stream_id,
                  const quiche::HttpHeaderBlock& headers,
                  const std::string& body) override {
@@ -475,14 +537,24 @@
       response_headers[":status"] = "200";
       response_headers["content-type"] = "application/ohttp-keys";
       response_body = masque_ohttp_gateway_->concatenated_keys();
-    } else if (method_pair->second == "POST" &&
+    } else if (!relay_path_.empty() && path_pair->second == relay_path_ &&
+               method_pair->second == "POST" &&
                content_type_pair != headers.end() &&
                content_type_pair->second == "message/ohttp-req") {
-      if (HandleOhttpRequest(connection, stream_id, body)) {
+      if (HandleOhttpRelayRequest(connection, stream_id, body).ok()) {
         return;
       } else {
         response_headers[":status"] = "500";
-        response_body = "Failed to handle OHTTP request";
+        response_body = "Failed to handle OHTTP relay request";
+      }
+    } else if (method_pair->second == "POST" &&
+               content_type_pair != headers.end() &&
+               content_type_pair->second == "message/ohttp-req") {
+      if (HandleOhttpGatewayRequest(connection, stream_id, body)) {
+        return;
+      } else {
+        response_headers[":status"] = "500";
+        response_body = "Failed to handle OHTTP gateway request";
       }
     } else if (method_pair->second == "GET" && path_pair->second == "/") {
       response_headers[":status"] = "200";
@@ -500,7 +572,51 @@
     QUICHE_LOG(FATAL) << "Server cannot receive responses";
   }
 
+  // From MasqueConnectionPool::Visitor.
+  void OnPoolResponse(MasqueConnectionPool* /*pool*/, RequestId request_id,
+                      absl::StatusOr<Message>&& response) override {
+    auto it = pending_requests_.find(request_id);
+    if (it == pending_requests_.end()) {
+      QUICHE_LOG(ERROR) << "Received unexpected response for unknown request "
+                        << request_id;
+      return;
+    }
+    PendingRequest pending_request = std::move(it->second);
+    pending_requests_.erase(it);
+    quiche::HttpHeaderBlock response_headers;
+    std::string response_body;
+    if (response.ok()) {
+      QUICHE_LOG(INFO) << "Forwarding relayed response to stream ID "
+                       << pending_request.stream_id;
+      response_headers = std::move(response->headers);
+      response_body = std::move(response->body);
+    } else {
+      QUICHE_LOG(ERROR) << "Received relayed error response: "
+                        << response.status();
+      response_headers[":status"] = "500";
+      response_body = "Relayed request failed";
+    }
+    pending_request.connection->SendResponse(pending_request.stream_id,
+                                             response_headers, response_body);
+    pending_request.connection->AttemptToSend();
+  }
+
+  bool SetupRelay(const std::string& relay_path,
+                  const std::string& relay_gateway_url) {
+    if (relay_path.empty() != relay_gateway_url.empty()) {
+      return false;
+    }
+    relay_path_ = relay_path;
+    relay_gateway_url_ = QuicUrl(relay_gateway_url, "https");
+    return true;
+  }
+
  private:
+  struct PendingRequest {
+    MasqueH2Connection* connection = nullptr;  // Not owned.
+    int32_t stream_id = -1;
+  };
+
   void AcceptConnection() {
     absl::StatusOr<socket_api::AcceptResult> accept_result =
         socket_api::Accept(server_socket_, /*blocking=*/false);
@@ -531,6 +647,10 @@
   MasqueOhttpGateway* masque_ohttp_gateway_;  // Unowned.
   SocketFd server_socket_ = kInvalidSocketFd;
   std::vector<std::unique_ptr<MasqueH2SocketConnection>> connections_;
+  std::string relay_path_;
+  QuicUrl relay_gateway_url_;
+  MasqueConnectionPool connection_pool_;
+  absl::flat_hash_map<RequestId, PendingRequest> pending_requests_;
 };
 
 int RunMasqueTcpServer(int argc, char* argv[]) {
@@ -565,7 +685,45 @@
     return 1;
   }
 
-  MasqueTcpServer server(&masque_ohttp_gateway);
+  const bool disable_certificate_verification =
+      quiche::GetQuicheCommandLineFlag(FLAGS_disable_certificate_verification);
+  const std::string client_cert_file =
+      quiche::GetQuicheCommandLineFlag(FLAGS_client_cert_file);
+  const std::string client_cert_key_file =
+      quiche::GetQuicheCommandLineFlag(FLAGS_client_cert_key_file);
+  absl::StatusOr<bssl::UniquePtr<SSL_CTX>> client_ssl_ctx =
+      MasqueConnectionPool::CreateSslCtx(client_cert_file,
+                                         client_cert_key_file);
+  if (!client_ssl_ctx.ok()) {
+    QUICHE_LOG(ERROR) << "Failed to create client SSL context: "
+                      << client_ssl_ctx.status();
+    return 1;
+  }
+  const int address_family =
+      quiche::GetQuicheCommandLineFlag(FLAGS_address_family);
+  int address_family_for_lookup;
+  if (address_family == 0) {
+    address_family_for_lookup = AF_UNSPEC;
+  } else if (address_family == 4) {
+    address_family_for_lookup = AF_INET;
+  } else if (address_family == 6) {
+    address_family_for_lookup = AF_INET6;
+  } else {
+    QUICHE_LOG(ERROR) << "Invalid address_family " << address_family;
+    return 1;
+  }
+
+  MasqueTcpServer server(&masque_ohttp_gateway, client_ssl_ctx->get(),
+                         disable_certificate_verification,
+                         address_family_for_lookup);
+
+  if (!server.SetupRelay(
+          quiche::GetQuicheCommandLineFlag(FLAGS_relay_path),
+          quiche::GetQuicheCommandLineFlag(FLAGS_relay_gateway_url))) {
+    QUICHE_LOG(ERROR) << "Both --relay_path and --relay_gateway_url must be "
+                         "set, or neither can be set";
+    return 1;
+  }
   if (!server.SetupSslCtx(certificate_file, key_file, client_root_ca_file)) {
     QUICHE_LOG(ERROR) << "Failed to setup SSL context";
     return 1;