Fix intentional failure OHTTP probes PiperOrigin-RevId: 863415602
diff --git a/quiche/quic/masque/masque_ohttp_client.cc b/quiche/quic/masque/masque_ohttp_client.cc index f8e9ccf..75b0740 100644 --- a/quiche/quic/masque/masque_ohttp_client.cc +++ b/quiche/quic/masque/masque_ohttp_client.cc
@@ -213,7 +213,8 @@ } absl::Status MasqueOhttpClient::CheckStatusAndContentType( - const Message& response, const std::string& content_type) { + const Message& response, const std::string& content_type, + std::optional<uint16_t> expected_status_code) { auto status_it = response.headers.find(":status"); if (status_it == response.headers.end()) { return absl::InvalidArgumentError( @@ -224,10 +225,22 @@ return absl::InvalidArgumentError( absl::StrCat("Failed to parse ", content_type, " status code.")); } - if (status_code < 200 || status_code >= 300) { - return absl::InvalidArgumentError( - absl::StrCat("Unexpected status in ", content_type, - " response: ", status_it->second)); + if (expected_status_code.has_value()) { + if (status_code != *expected_status_code) { + return absl::InvalidArgumentError(absl::StrCat( + "Unexpected status in ", content_type, " response: ", status_code, + " (expected ", *expected_status_code, ")")); + } + if (*expected_status_code < 200 || *expected_status_code >= 300) { + // If we expect a failure status code, skip the content-type check. + return absl::OkStatus(); + } + } else { + if (status_code < 200 || status_code >= 300) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected status in ", content_type, + " response: ", status_it->second)); + } } auto content_type_it = response.headers.find("content-type"); if (content_type_it == response.headers.end()) { @@ -257,8 +270,8 @@ } QUICHE_LOG(INFO) << "Received OHTTP keys response: " << response->headers.DebugString(); - QUICHE_RETURN_IF_ERROR( - CheckStatusAndContentType(*response, "application/ohttp-keys")); + QUICHE_RETURN_IF_ERROR(CheckStatusAndContentType( + *response, "application/ohttp-keys", std::nullopt)); absl::StatusOr<ObliviousHttpKeyConfigs> key_configs = ObliviousHttpKeyConfigs::ParseConcatenatedKeys(response->body); if (!key_configs.ok()) { @@ -482,7 +495,10 @@ std::string content_type = it->second.per_request_config.use_chunked_ohttp() ? "message/ohttp-chunked-res" : "message/ohttp-res"; - absl::Status status = CheckStatusAndContentType(*response, content_type); + std::optional<uint16_t> expected_gateway_status_code = + it->second.per_request_config.expected_gateway_status_code(); + absl::Status status = CheckStatusAndContentType(*response, content_type, + expected_gateway_status_code); if (!status.ok()) { if (!response->body.empty()) { QUICHE_LOG(ERROR) << "Bad " << content_type << " with body:" << std::endl @@ -492,6 +508,12 @@ } return status; } + if (expected_gateway_status_code.has_value() && + (*expected_gateway_status_code < 200 || + *expected_gateway_status_code >= 300)) { + // If we expect a failure status code, skip decapsulation. + return absl::OkStatus(); + } std::optional<Message> encapsulated_response; if (it->second.per_request_config.use_chunked_ohttp()) { QUICHE_ASSIGN_OR_RETURN(
diff --git a/quiche/quic/masque/masque_ohttp_client.h b/quiche/quic/masque/masque_ohttp_client.h index b9a18cc..988efaa 100644 --- a/quiche/quic/masque/masque_ohttp_client.h +++ b/quiche/quic/masque/masque_ohttp_client.h
@@ -242,8 +242,9 @@ const Message& response); absl::Status ProcessOhttpResponse(RequestId request_id, const absl::StatusOr<Message>& response); - absl::Status CheckStatusAndContentType(const Message& response, - const std::string& content_type); + absl::Status CheckStatusAndContentType( + const Message& response, const std::string& content_type, + std::optional<uint16_t> expected_status_code); Config config_; quic::MasqueConnectionPool connection_pool_;
diff --git a/quiche/quic/masque/masque_ohttp_client_bin.cc b/quiche/quic/masque/masque_ohttp_client_bin.cc index 70cf317..5f50cbe 100644 --- a/quiche/quic/masque/masque_ohttp_client_bin.cc +++ b/quiche/quic/masque/masque_ohttp_client_bin.cc
@@ -4,6 +4,7 @@ #include <stdbool.h> +#include <optional> #include <string> #include <utility> #include <vector> @@ -56,6 +57,10 @@ "port. PORT2 can be empty to not override ports. Multiple overrides can be " "specified separated by semi-colons."); +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::optional<int16_t>, expect_gateway_response_code, std::nullopt, + "If set, the client will expect this response code from the gateway."); + namespace quic { namespace { absl::Status RunMasqueOhttpClient(int argc, char* argv[]) { @@ -74,6 +79,8 @@ quiche::GetQuicheCommandLineFlag(FLAGS_client_cert_file); const std::string client_cert_key_file = quiche::GetQuicheCommandLineFlag(FLAGS_client_cert_key_file); + const std::optional<int16_t> expect_gateway_response_code = + quiche::GetQuicheCommandLineFlag(FLAGS_expect_gateway_response_code); MasqueConnectionPool::DnsConfig dns_config; QUICHE_RETURN_IF_ERROR(dns_config.SetAddressFamily( @@ -102,6 +109,10 @@ per_request_config.SetPostData(post_data); per_request_config.SetUseChunkedOhttp(use_chunked_ohttp); per_request_config.SetPrivateToken(private_token); + if (expect_gateway_response_code.has_value()) { + per_request_config.SetExpectedGatewayStatusCode( + *expect_gateway_response_code); + } config.AddPerRequestConfig(per_request_config); } return MasqueOhttpClient::Run(std::move(config));