Add OHTTP Relay support to masque_tcp_server PiperOrigin-RevId: 844805998
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc index 8ea142b..3c04f62 100644 --- a/quiche/quic/masque/masque_tcp_server_bin.cc +++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -20,6 +20,7 @@ #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" @@ -35,7 +36,9 @@ #include "quiche/quic/core/io/socket.h" #include "quiche/quic/core/quic_default_clock.h" #include "quiche/quic/core/quic_time.h" +#include "quiche/quic/masque/masque_connection_pool.h" #include "quiche/quic/masque/masque_h2_connection.h" +#include "quiche/quic/tools/quic_url.h" #include "quiche/binary_http/binary_http_message.h" #include "quiche/common/http/http_header_block.h" #include "quiche/common/platform/api/quiche_command_line_flags.h" @@ -63,6 +66,29 @@ std::string, ohttp_key, "", "Hex-encoded bytes of the OHTTP HPKE private key."); +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, relay_path, "", + "Path of the relay server to accept requests on."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, relay_gateway_url, "", + "URL of the gateway that this relay will forward requests to."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, disable_certificate_verification, false, + "If true, don't verify the server certificate."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int, address_family, 0, + "IP address family to use. Must be 0, 4 or 6. " + "Defaults to 0 which means any."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, client_cert_file, "", + "Path to the client certificate chain."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, client_cert_key_file, "", + "Path to the pkcs8 client certificate private key."); + using quiche::BinaryHttpRequest; using quiche::BinaryHttpResponse; using quiche::ObliviousHttpGateway; @@ -317,11 +343,21 @@ }; class MasqueTcpServer : public QuicSocketEventListener, - public MasqueH2Connection::Visitor { + public MasqueH2Connection::Visitor, + public MasqueConnectionPool::Visitor { public: - explicit MasqueTcpServer(MasqueOhttpGateway* masque_ohttp_gateway) + using RequestId = MasqueConnectionPool::RequestId; + using Message = MasqueConnectionPool::Message; + + explicit MasqueTcpServer(MasqueOhttpGateway* masque_ohttp_gateway, + SSL_CTX* client_ssl_ctx, + bool disable_certificate_verification, + int address_family_for_lookup) : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())), - masque_ohttp_gateway_(masque_ohttp_gateway) {} + masque_ohttp_gateway_(masque_ohttp_gateway), + connection_pool_(event_loop_.get(), client_ssl_ctx, + disable_certificate_verification, + address_family_for_lookup, this) {} MasqueTcpServer(const MasqueTcpServer&) = delete; MasqueTcpServer(MasqueTcpServer&&) = delete; @@ -450,12 +486,38 @@ connections_.end()); } - bool HandleOhttpRequest(MasqueH2Connection* connection, int32_t stream_id, - const std::string& encapsulated_request) { + bool HandleOhttpGatewayRequest(MasqueH2Connection* connection, + int32_t stream_id, + const std::string& encapsulated_request) { return masque_ohttp_gateway_->HandleRequest(connection, stream_id, encapsulated_request); } + absl::Status HandleOhttpRelayRequest( + MasqueH2Connection* connection, int32_t stream_id, + const std::string& encapsulated_request) { + Message request; + request.headers[":method"] = "POST"; + request.headers[":scheme"] = relay_gateway_url_.scheme(); + request.headers[":authority"] = relay_gateway_url_.HostPort(); + request.headers[":path"] = relay_gateway_url_.path(); + request.headers["content-type"] = "message/ohttp-req"; + request.body = encapsulated_request; + absl::StatusOr<RequestId> request_id = + connection_pool_.SendRequest(request); + if (!request_id.ok()) { + QUICHE_LOG(ERROR) << "Failed to send relayed request: " + << request_id.status(); + return request_id.status(); + } + QUICHE_LOG(INFO) << "Sent relayed request"; + PendingRequest pending_request; + pending_request.connection = connection; + pending_request.stream_id = stream_id; + pending_requests_.insert({*request_id, std::move(pending_request)}); + return absl::OkStatus(); + } + void OnRequest(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, const std::string& body) override { @@ -475,14 +537,24 @@ response_headers[":status"] = "200"; response_headers["content-type"] = "application/ohttp-keys"; response_body = masque_ohttp_gateway_->concatenated_keys(); - } else if (method_pair->second == "POST" && + } else if (!relay_path_.empty() && path_pair->second == relay_path_ && + method_pair->second == "POST" && content_type_pair != headers.end() && content_type_pair->second == "message/ohttp-req") { - if (HandleOhttpRequest(connection, stream_id, body)) { + if (HandleOhttpRelayRequest(connection, stream_id, body).ok()) { return; } else { response_headers[":status"] = "500"; - response_body = "Failed to handle OHTTP request"; + response_body = "Failed to handle OHTTP relay request"; + } + } else if (method_pair->second == "POST" && + content_type_pair != headers.end() && + content_type_pair->second == "message/ohttp-req") { + if (HandleOhttpGatewayRequest(connection, stream_id, body)) { + return; + } else { + response_headers[":status"] = "500"; + response_body = "Failed to handle OHTTP gateway request"; } } else if (method_pair->second == "GET" && path_pair->second == "/") { response_headers[":status"] = "200"; @@ -500,7 +572,51 @@ QUICHE_LOG(FATAL) << "Server cannot receive responses"; } + // From MasqueConnectionPool::Visitor. + void OnPoolResponse(MasqueConnectionPool* /*pool*/, RequestId request_id, + absl::StatusOr<Message>&& response) override { + auto it = pending_requests_.find(request_id); + if (it == pending_requests_.end()) { + QUICHE_LOG(ERROR) << "Received unexpected response for unknown request " + << request_id; + return; + } + PendingRequest pending_request = std::move(it->second); + pending_requests_.erase(it); + quiche::HttpHeaderBlock response_headers; + std::string response_body; + if (response.ok()) { + QUICHE_LOG(INFO) << "Forwarding relayed response to stream ID " + << pending_request.stream_id; + response_headers = std::move(response->headers); + response_body = std::move(response->body); + } else { + QUICHE_LOG(ERROR) << "Received relayed error response: " + << response.status(); + response_headers[":status"] = "500"; + response_body = "Relayed request failed"; + } + pending_request.connection->SendResponse(pending_request.stream_id, + response_headers, response_body); + pending_request.connection->AttemptToSend(); + } + + bool SetupRelay(const std::string& relay_path, + const std::string& relay_gateway_url) { + if (relay_path.empty() != relay_gateway_url.empty()) { + return false; + } + relay_path_ = relay_path; + relay_gateway_url_ = QuicUrl(relay_gateway_url, "https"); + return true; + } + private: + struct PendingRequest { + MasqueH2Connection* connection = nullptr; // Not owned. + int32_t stream_id = -1; + }; + void AcceptConnection() { absl::StatusOr<socket_api::AcceptResult> accept_result = socket_api::Accept(server_socket_, /*blocking=*/false); @@ -531,6 +647,10 @@ MasqueOhttpGateway* masque_ohttp_gateway_; // Unowned. SocketFd server_socket_ = kInvalidSocketFd; std::vector<std::unique_ptr<MasqueH2SocketConnection>> connections_; + std::string relay_path_; + QuicUrl relay_gateway_url_; + MasqueConnectionPool connection_pool_; + absl::flat_hash_map<RequestId, PendingRequest> pending_requests_; }; int RunMasqueTcpServer(int argc, char* argv[]) { @@ -565,7 +685,45 @@ return 1; } - MasqueTcpServer server(&masque_ohttp_gateway); + const bool disable_certificate_verification = + quiche::GetQuicheCommandLineFlag(FLAGS_disable_certificate_verification); + const std::string client_cert_file = + quiche::GetQuicheCommandLineFlag(FLAGS_client_cert_file); + const std::string client_cert_key_file = + quiche::GetQuicheCommandLineFlag(FLAGS_client_cert_key_file); + absl::StatusOr<bssl::UniquePtr<SSL_CTX>> client_ssl_ctx = + MasqueConnectionPool::CreateSslCtx(client_cert_file, + client_cert_key_file); + if (!client_ssl_ctx.ok()) { + QUICHE_LOG(ERROR) << "Failed to create client SSL context: " + << client_ssl_ctx.status(); + return 1; + } + const int address_family = + quiche::GetQuicheCommandLineFlag(FLAGS_address_family); + int address_family_for_lookup; + if (address_family == 0) { + address_family_for_lookup = AF_UNSPEC; + } else if (address_family == 4) { + address_family_for_lookup = AF_INET; + } else if (address_family == 6) { + address_family_for_lookup = AF_INET6; + } else { + QUICHE_LOG(ERROR) << "Invalid address_family " << address_family; + return 1; + } + + MasqueTcpServer server(&masque_ohttp_gateway, client_ssl_ctx->get(), + disable_certificate_verification, + address_family_for_lookup); + + if (!server.SetupRelay( + quiche::GetQuicheCommandLineFlag(FLAGS_relay_path), + quiche::GetQuicheCommandLineFlag(FLAGS_relay_gateway_url))) { + QUICHE_LOG(ERROR) << "Both --relay_path and --relay_gateway_url must be " + "set, or neither can be set"; + return 1; + } if (!server.SetupSslCtx(certificate_file, key_file, client_root_ca_file)) { QUICHE_LOG(ERROR) << "Failed to setup SSL context"; return 1;