Allow non-OHTTP requests in OHTTP test client PiperOrigin-RevId: 919856801
diff --git a/quiche/quic/masque/masque_ohttp_client.cc b/quiche/quic/masque/masque_ohttp_client.cc index 0f2b5d4..1efe7ad 100644 --- a/quiche/quic/masque/masque_ohttp_client.cc +++ b/quiche/quic/masque/masque_ohttp_client.cc
@@ -395,6 +395,11 @@ return absl::OkStatus(); } +bool MasqueOhttpClient::Config::skip_ohttp() const { + static constexpr absl::string_view kSkip = "skip"; + return key_fetch_url_ == kSkip && relay_url_ == kSkip; +} + MasqueOhttpClient::MasqueOhttpClient(Config config, quic::QuicEventLoop* event_loop, absl::string_view info_string) @@ -444,7 +449,7 @@ } absl::Status MasqueOhttpClient::Start() { - absl::Status status = StartKeyFetch(config_.key_fetch_url()); + absl::Status status = StartKeyFetch(); if (!status.ok()) { Abort(status); return status; @@ -455,7 +460,7 @@ if (!status_.ok()) { return true; } - if (!ohttp_client_.has_value()) { + if (!config_.skip_ohttp() && !ohttp_client_.has_value()) { // Key fetch request is still pending. return false; } @@ -487,19 +492,23 @@ return url; } -absl::Status MasqueOhttpClient::StartKeyFetch(const std::string& url_string) { +absl::Status MasqueOhttpClient::StartKeyFetch() { + if (config_.skip_ohttp()) { + return HandleKeyData(""); + } std::string decoded_key_data; - if (absl::HexStringToBytes(url_string, &decoded_key_data) && + if (absl::HexStringToBytes(config_.key_fetch_url(), &decoded_key_data) && !decoded_key_data.empty()) { return HandleKeyData(decoded_key_data); } - QuicUrl url(url_string, "https"); - if (url.host().empty() && !absl::StrContains(url_string, "://")) { - url = QuicUrl(absl::StrCat("https://", url_string)); + QuicUrl url(config_.key_fetch_url(), "https"); + if (url.host().empty() && + !absl::StrContains(config_.key_fetch_url(), "://")) { + url = QuicUrl(absl::StrCat("https://", config_.key_fetch_url())); } if (url.host().empty()) { - return absl::InvalidArgumentError( - absl::StrCat("Failed to parse OHTTP key URL \"", url_string, "\"")); + return absl::InvalidArgumentError(absl::StrCat( + "Failed to parse OHTTP key URL \"", config_.key_fetch_url(), "\"")); } Message request; request.headers[":method"] = "GET"; @@ -588,6 +597,12 @@ } absl::Status MasqueOhttpClient::HandleKeyData(const std::string& key_data) { + if (config_.skip_ohttp()) { + for (const auto& per_request_config : config_.per_request_configs()) { + QUICHE_RETURN_IF_ERROR(SendDirectRequest(per_request_config)); + } + return absl::OkStatus(); + } absl::StatusOr<ObliviousHttpKeyConfigs> key_configs = ObliviousHttpKeyConfigs::ParseConcatenatedKeys(key_data); if (!key_configs.ok()) { @@ -757,7 +772,7 @@ absl::StatusOr<RequestId> request_id = connection_pool_.SendRequest( request, /*mtls=*/true, end_stream, stream_response); if (!request_id.ok()) { - return absl::InternalError(absl::StrCat("Failed to send request: ", + return absl::InternalError(absl::StrCat("Failed to send OHTTP request: ", request_id.status().message())); } @@ -781,6 +796,47 @@ return absl::OkStatus(); } +absl::Status MasqueOhttpClient::SendDirectRequest( + const Config::PerRequestConfig& per_request_config) { + QuicUrl url(per_request_config.url(), "https"); + if (url.host().empty() && + !absl::StrContains(per_request_config.url(), "://")) { + url = QuicUrl(absl::StrCat("https://", per_request_config.url())); + } + if (url.host().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse URL ", per_request_config.url())); + } + Message request; + std::string post_data = per_request_config.post_data(); + request.headers[":method"] = per_request_config.method(); + request.headers[":scheme"] = url.scheme(); + request.headers[":authority"] = url.HostPort(); + request.headers[":path"] = url.PathParamsQuery(); + if (config_.handle_gzip_response()) { + request.headers["accept-encoding"] = "gzip"; + } + for (const std::pair<std::string, std::string>& header : + per_request_config.headers()) { + request.headers[header.first] = header.second; + } + request.body = per_request_config.post_data(); + absl::StatusOr<RequestId> request_id = connection_pool_.SendRequest( + request, /*mtls=*/true, /*end_stream=*/true, /*stream_response=*/false); + if (!request_id.ok()) { + return absl::InternalError(absl::StrCat("Failed to send direct request: ", + request_id.status().message())); + } + + QUICHE_LOG(INFO) << ENDPOINT << "Sent direct request for " + << per_request_config.url(); + + pending_ohttp_requests_.insert( + {*request_id, PendingRequest(per_request_config)}); + + return absl::OkStatus(); +} + absl::StatusOr<Message> MasqueOhttpClient::TryExtractEncapsulatedResponse( RequestId request_id, quiche::ObliviousHttpRequest::Context& context, const Message& response) { @@ -835,8 +891,7 @@ } absl::Status MasqueOhttpClient::ProcessOhttpResponse( - RequestId request_id, const absl::StatusOr<Message>& response, - bool end_stream) { + RequestId request_id, absl::StatusOr<Message>& response, bool end_stream) { auto it = pending_ohttp_requests_.find(request_id); if (it == pending_ohttp_requests_.end()) { return absl::InternalError(absl::StrCat( @@ -856,6 +911,10 @@ } return response.status(); } + if (config_.skip_ohttp()) { + return ProcessEncapsulatedResponse(request_id, *response, + it->second.per_request_config); + } int16_t gateway_status_code = MasqueConnectionPool::GetStatusCode(*response); if (it->second.per_request_config.expected_gateway_status_code() .has_value()) { @@ -932,32 +991,37 @@ << request_id << ". Body length is " << encapsulated_response->body.size() << ". Headers:" << encapsulated_response->headers.DebugString(); + + return ProcessEncapsulatedResponse(request_id, *encapsulated_response, + it->second.per_request_config); +} + +absl::Status MasqueOhttpClient::ProcessEncapsulatedResponse( + RequestId request_id, Message& response, + const Config::PerRequestConfig& per_request_config) { if (config_.handle_gzip_response()) { - auto content_encoding_it = - encapsulated_response->headers.find("content-encoding"); - if (content_encoding_it != encapsulated_response->headers.end() && + auto content_encoding_it = response.headers.find("content-encoding"); + if (content_encoding_it != response.headers.end() && absl::EqualsIgnoreCase(content_encoding_it->second, "gzip")) { - size_t compressed_size = encapsulated_response->body.size(); + size_t compressed_size = response.body.size(); QUICHE_ASSIGN_OR_RETURN(std::string decompressed_body, - GzipDecompress(encapsulated_response->body)); + GzipDecompress(response.body)); QUICHE_LOG(INFO) << ENDPOINT << "Successfully decompressed gzip response from size " << compressed_size << " to size " << decompressed_body.size(); - encapsulated_response->body = std::move(decompressed_body); + response.body = std::move(decompressed_body); } } - std::cout << encapsulated_response->body; + std::cout << response.body; int16_t encapsulated_status_code = - MasqueConnectionPool::GetStatusCode(*encapsulated_response); - if (it->second.per_request_config.expected_encapsulated_status_code() - .has_value()) { + MasqueConnectionPool::GetStatusCode(response); + if (per_request_config.expected_encapsulated_status_code().has_value()) { if (encapsulated_status_code != - *it->second.per_request_config.expected_encapsulated_status_code()) { + *per_request_config.expected_encapsulated_status_code()) { return absl::InvalidArgumentError(absl::StrCat( "Unexpected encapsulated status code: ", encapsulated_status_code, - " != ", - *it->second.per_request_config.expected_encapsulated_status_code())); + " != ", *per_request_config.expected_encapsulated_status_code())); } } else if (encapsulated_status_code < 200 || encapsulated_status_code >= 300) { @@ -965,14 +1029,15 @@ "Bad encapsulated status code: ", encapsulated_status_code)); } if (const auto& callback = - it->second.per_request_config.encapsulated_response_body_callback(); + per_request_config.encapsulated_response_body_callback(); callback) { - QUICHE_RETURN_IF_ERROR(callback(encapsulated_response->body)); + QUICHE_RETURN_IF_ERROR(callback(response.body)); } if (response_visitor_) { - response_visitor_->OnResponseDone(request_id, *encapsulated_response); + response_visitor_->OnResponseDone(request_id, response); } + return absl::OkStatus(); }
diff --git a/quiche/quic/masque/masque_ohttp_client.h b/quiche/quic/masque/masque_ohttp_client.h index b497ad6..7b30b03 100644 --- a/quiche/quic/masque/masque_ohttp_client.h +++ b/quiche/quic/masque/masque_ohttp_client.h
@@ -215,6 +215,7 @@ return key_fetch_headers_; } bool handle_gzip_response() const { return handle_gzip_response_; } + bool skip_ohttp() const; private: std::string key_fetch_url_; @@ -263,7 +264,7 @@ private: // Fetch key from the key URL. - absl::Status StartKeyFetch(const std::string& url_string); + absl::Status StartKeyFetch(); // Handles the key response. absl::Status HandleKeyResponse(const absl::StatusOr<Message>& response); @@ -275,6 +276,10 @@ absl::Status SendOhttpRequest( const Config::PerRequestConfig& per_request_config); + // Sends a direct HTTP request (without OHTTP) for the given URL. + absl::Status SendDirectRequest( + const Config::PerRequestConfig& per_request_config); + // Signals the client to abort. void Abort(absl::Status status); @@ -379,8 +384,11 @@ RequestId request_id, quiche::ObliviousHttpRequest::Context& context, const Message& response); absl::Status ProcessOhttpResponse(RequestId request_id, - const absl::StatusOr<Message>& response, + absl::StatusOr<Message>& response, bool end_stream); + absl::Status ProcessEncapsulatedResponse( + RequestId request_id, Message& response, + const Config::PerRequestConfig& per_request_config); absl::Status CheckStatusAndContentType( const Message& response, const std::string& content_type, std::optional<uint16_t> expected_status_code);