Allow passing in headers to OHTTP client, similar to curl PiperOrigin-RevId: 866593109
diff --git a/quiche/quic/masque/masque_ohttp_client.cc b/quiche/quic/masque/masque_ohttp_client.cc index f07ea53..254c4f7 100644 --- a/quiche/quic/masque/masque_ohttp_client.cc +++ b/quiche/quic/masque/masque_ohttp_client.cc
@@ -15,6 +15,7 @@ #include "absl/cleanup/cleanup.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" @@ -60,7 +61,28 @@ static constexpr uint64_t kFixedSizeResponseFramingIndicator = 0x01; -absl::StatusOr<std::string> FormatPrivateToken( +} // namespace + +absl::Status MasqueOhttpClient::Config::PerRequestConfig::AddHeaders( + const std::vector<std::string>& headers) { + for (const std::string& header : headers) { + std::vector<absl::string_view> header_split = + absl::StrSplit(header, absl::MaxSplits(':', 1)); + if (header_split.size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse header \"", header, "\"")); + } + std::string header_name = std::string(header_split[0]); + absl::StripAsciiWhitespace(&header_name); + absl::AsciiStrToLower(&header_name); + std::string header_value = std::string(header_split[1]); + absl::StripAsciiWhitespace(&header_value); + headers_.push_back({std::move(header_name), std::move(header_value)}); + } + return absl::OkStatus(); +} + +absl::Status MasqueOhttpClient::Config::PerRequestConfig::AddPrivateToken( const std::string& private_token) { // Private tokens require padded base64url and we allow any encoding for // convenience, so we need to unescape and re-escape. @@ -73,11 +95,11 @@ } formatted_token = absl::Base64Escape(formatted_token); absl::StrReplaceAll({{"+", "-"}, {"/", "_"}}, &formatted_token); - return absl::StrCat("PrivateToken token=\"", formatted_token, "\""); + headers_.push_back({"authorization", absl::StrCat("PrivateToken token=\"", + formatted_token, "\"")}); + return absl::OkStatus(); } -} // namespace - absl::Status MasqueOhttpClient::Config::ConfigureKeyFetchClientCert( const std::string& client_cert_file, const std::string& client_cert_key_file) { @@ -337,12 +359,6 @@ control_data.path = url.PathParamsQuery(); std::string encrypted_data; PendingRequest pending_request(per_request_config); - std::string formatted_token; - if (!per_request_config.private_token().empty()) { - QUICHE_ASSIGN_OR_RETURN( - formatted_token, - FormatPrivateToken(per_request_config.private_token())); - } if (!ohttp_client_.has_value()) { QUICHE_LOG(FATAL) << "Cannot send OHTTP request without OHTTP client"; return absl::InternalError( @@ -358,8 +374,9 @@ QUICHE_ASSIGN_OR_RETURN(encoded_data, encoder.EncodeControlData(control_data)); std::vector<quiche::BinaryHttpMessage::FieldView> headers; - if (!formatted_token.empty()) { - headers.push_back({"authorization", formatted_token}); + for (const std::pair<std::string, std::string>& header : + per_request_config.headers()) { + headers.push_back({header.first, header.second}); } QUICHE_ASSIGN_OR_RETURN(std::string encoded_headers, encoder.EncodeHeaders(absl::MakeSpan(headers))); @@ -392,10 +409,11 @@ encoded_data += encoded_trailers; } else { BinaryHttpRequest binary_request(control_data); - binary_request.set_body(post_data); - if (!formatted_token.empty()) { - binary_request.AddHeaderField({"authorization", formatted_token}); + for (const std::pair<std::string, std::string>& header : + per_request_config.headers()) { + binary_request.AddHeaderField({header.first, header.second}); } + binary_request.set_body(post_data); QUICHE_ASSIGN_OR_RETURN(encoded_data, binary_request.Serialize()); } if (pending_request.per_request_config.use_chunked_ohttp()) {
diff --git a/quiche/quic/masque/masque_ohttp_client.h b/quiche/quic/masque/masque_ohttp_client.h index 5d0754b..c7908d5 100644 --- a/quiche/quic/masque/masque_ohttp_client.h +++ b/quiche/quic/masque/masque_ohttp_client.h
@@ -46,6 +46,8 @@ PerRequestConfig& operator=(PerRequestConfig&& other) = default; void SetPostData(const std::string& post_data) { post_data_ = post_data; } + absl::Status AddHeaders(const std::vector<std::string>& headers); + absl::Status AddPrivateToken(const std::string& private_token); void SetUseChunkedOhttp(bool use_chunked_ohttp) { use_chunked_ohttp_ = use_chunked_ohttp; } @@ -62,9 +64,6 @@ void SetExpectedEncapsulatedStatusCode(uint16_t status_code) { expected_encapsulated_status_code_ = status_code; } - void SetPrivateToken(const std::string& private_token) { - private_token_ = private_token; - } void SetExpectedEncapsulatedResponseBody( const std::string& expected_encapsulated_response_body) { expected_encapsulated_response_body_ = @@ -73,7 +72,9 @@ std::string url() const { return url_; } std::string post_data() const { return post_data_; } - std::string private_token() const { return private_token_; } + const std::vector<std::pair<std::string, std::string>>& headers() const { + return headers_; + } bool use_chunked_ohttp() const { return use_chunked_ohttp_; } std::optional<bool> use_indeterminate_length() const { return use_indeterminate_length_; @@ -94,7 +95,7 @@ private: std::string url_; std::string post_data_; - std::string private_token_; + std::vector<std::pair<std::string, std::string>> headers_; bool use_chunked_ohttp_ = false; std::optional<bool> use_indeterminate_length_; std::optional<std::string> expected_gateway_error_;
diff --git a/quiche/quic/masque/masque_ohttp_client_bin.cc b/quiche/quic/masque/masque_ohttp_client_bin.cc index 8e7979b..a6440a2 100644 --- a/quiche/quic/masque/masque_ohttp_client_bin.cc +++ b/quiche/quic/masque/masque_ohttp_client_bin.cc
@@ -58,6 +58,11 @@ "file."); DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::vector<std::string>, header, {}, + "Adds a header field to the encapsulated binary request. Separate the " + "header name and value with a colon. Can be specified multiple times."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( std::string, private_token, "", "When set, the client will attach a base64-encoded private token to the " "encapsulated request. Accepts any base64 encoding."); @@ -125,6 +130,8 @@ } post_data = *post_data_from_file; } + std::vector<std::string> headers = + quiche::GetQuicheCommandLineFlag(FLAGS_header); std::string private_token = quiche::GetQuicheCommandLineFlag(FLAGS_private_token); @@ -144,9 +151,12 @@ for (size_t i = 2; i < urls.size(); ++i) { MasqueOhttpClient::Config::PerRequestConfig per_request_config(urls[i]); per_request_config.SetPostData(post_data); + QUICHE_RETURN_IF_ERROR(per_request_config.AddHeaders(headers)); + if (!private_token.empty()) { + QUICHE_RETURN_IF_ERROR(per_request_config.AddPrivateToken(private_token)); + } per_request_config.SetUseChunkedOhttp(use_chunked_ohttp); per_request_config.SetUseIndeterminateLength(indeterminate_length); - per_request_config.SetPrivateToken(private_token); if (expect_gateway_error.has_value()) { per_request_config.SetExpectedGatewayError(*expect_gateway_error); }