Allow passing in headers to OHTTP client, similar to curl

PiperOrigin-RevId: 866593109
diff --git a/quiche/quic/masque/masque_ohttp_client.cc b/quiche/quic/masque/masque_ohttp_client.cc
index f07ea53..254c4f7 100644
--- a/quiche/quic/masque/masque_ohttp_client.cc
+++ b/quiche/quic/masque/masque_ohttp_client.cc
@@ -15,6 +15,7 @@
 #include "absl/cleanup/cleanup.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/ascii.h"
 #include "absl/strings/escaping.h"
 #include "absl/strings/match.h"
 #include "absl/strings/numbers.h"
@@ -60,7 +61,28 @@
 
 static constexpr uint64_t kFixedSizeResponseFramingIndicator = 0x01;
 
-absl::StatusOr<std::string> FormatPrivateToken(
+}  // namespace
+
+absl::Status MasqueOhttpClient::Config::PerRequestConfig::AddHeaders(
+    const std::vector<std::string>& headers) {
+  for (const std::string& header : headers) {
+    std::vector<absl::string_view> header_split =
+        absl::StrSplit(header, absl::MaxSplits(':', 1));
+    if (header_split.size() != 2) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("Failed to parse header \"", header, "\""));
+    }
+    std::string header_name = std::string(header_split[0]);
+    absl::StripAsciiWhitespace(&header_name);
+    absl::AsciiStrToLower(&header_name);
+    std::string header_value = std::string(header_split[1]);
+    absl::StripAsciiWhitespace(&header_value);
+    headers_.push_back({std::move(header_name), std::move(header_value)});
+  }
+  return absl::OkStatus();
+}
+
+absl::Status MasqueOhttpClient::Config::PerRequestConfig::AddPrivateToken(
     const std::string& private_token) {
   // Private tokens require padded base64url and we allow any encoding for
   // convenience, so we need to unescape and re-escape.
@@ -73,11 +95,11 @@
   }
   formatted_token = absl::Base64Escape(formatted_token);
   absl::StrReplaceAll({{"+", "-"}, {"/", "_"}}, &formatted_token);
-  return absl::StrCat("PrivateToken token=\"", formatted_token, "\"");
+  headers_.push_back({"authorization", absl::StrCat("PrivateToken token=\"",
+                                                    formatted_token, "\"")});
+  return absl::OkStatus();
 }
 
-}  // namespace
-
 absl::Status MasqueOhttpClient::Config::ConfigureKeyFetchClientCert(
     const std::string& client_cert_file,
     const std::string& client_cert_key_file) {
@@ -337,12 +359,6 @@
   control_data.path = url.PathParamsQuery();
   std::string encrypted_data;
   PendingRequest pending_request(per_request_config);
-  std::string formatted_token;
-  if (!per_request_config.private_token().empty()) {
-    QUICHE_ASSIGN_OR_RETURN(
-        formatted_token,
-        FormatPrivateToken(per_request_config.private_token()));
-  }
   if (!ohttp_client_.has_value()) {
     QUICHE_LOG(FATAL) << "Cannot send OHTTP request without OHTTP client";
     return absl::InternalError(
@@ -358,8 +374,9 @@
     QUICHE_ASSIGN_OR_RETURN(encoded_data,
                             encoder.EncodeControlData(control_data));
     std::vector<quiche::BinaryHttpMessage::FieldView> headers;
-    if (!formatted_token.empty()) {
-      headers.push_back({"authorization", formatted_token});
+    for (const std::pair<std::string, std::string>& header :
+         per_request_config.headers()) {
+      headers.push_back({header.first, header.second});
     }
     QUICHE_ASSIGN_OR_RETURN(std::string encoded_headers,
                             encoder.EncodeHeaders(absl::MakeSpan(headers)));
@@ -392,10 +409,11 @@
     encoded_data += encoded_trailers;
   } else {
     BinaryHttpRequest binary_request(control_data);
-    binary_request.set_body(post_data);
-    if (!formatted_token.empty()) {
-      binary_request.AddHeaderField({"authorization", formatted_token});
+    for (const std::pair<std::string, std::string>& header :
+         per_request_config.headers()) {
+      binary_request.AddHeaderField({header.first, header.second});
     }
+    binary_request.set_body(post_data);
     QUICHE_ASSIGN_OR_RETURN(encoded_data, binary_request.Serialize());
   }
   if (pending_request.per_request_config.use_chunked_ohttp()) {
diff --git a/quiche/quic/masque/masque_ohttp_client.h b/quiche/quic/masque/masque_ohttp_client.h
index 5d0754b..c7908d5 100644
--- a/quiche/quic/masque/masque_ohttp_client.h
+++ b/quiche/quic/masque/masque_ohttp_client.h
@@ -46,6 +46,8 @@
       PerRequestConfig& operator=(PerRequestConfig&& other) = default;
 
       void SetPostData(const std::string& post_data) { post_data_ = post_data; }
+      absl::Status AddHeaders(const std::vector<std::string>& headers);
+      absl::Status AddPrivateToken(const std::string& private_token);
       void SetUseChunkedOhttp(bool use_chunked_ohttp) {
         use_chunked_ohttp_ = use_chunked_ohttp;
       }
@@ -62,9 +64,6 @@
       void SetExpectedEncapsulatedStatusCode(uint16_t status_code) {
         expected_encapsulated_status_code_ = status_code;
       }
-      void SetPrivateToken(const std::string& private_token) {
-        private_token_ = private_token;
-      }
       void SetExpectedEncapsulatedResponseBody(
           const std::string& expected_encapsulated_response_body) {
         expected_encapsulated_response_body_ =
@@ -73,7 +72,9 @@
 
       std::string url() const { return url_; }
       std::string post_data() const { return post_data_; }
-      std::string private_token() const { return private_token_; }
+      const std::vector<std::pair<std::string, std::string>>& headers() const {
+        return headers_;
+      }
       bool use_chunked_ohttp() const { return use_chunked_ohttp_; }
       std::optional<bool> use_indeterminate_length() const {
         return use_indeterminate_length_;
@@ -94,7 +95,7 @@
      private:
       std::string url_;
       std::string post_data_;
-      std::string private_token_;
+      std::vector<std::pair<std::string, std::string>> headers_;
       bool use_chunked_ohttp_ = false;
       std::optional<bool> use_indeterminate_length_;
       std::optional<std::string> expected_gateway_error_;
diff --git a/quiche/quic/masque/masque_ohttp_client_bin.cc b/quiche/quic/masque/masque_ohttp_client_bin.cc
index 8e7979b..a6440a2 100644
--- a/quiche/quic/masque/masque_ohttp_client_bin.cc
+++ b/quiche/quic/masque/masque_ohttp_client_bin.cc
@@ -58,6 +58,11 @@
     "file.");
 
 DEFINE_QUICHE_COMMAND_LINE_FLAG(
+    std::vector<std::string>, header, {},
+    "Adds a header field to the encapsulated binary request. Separate the "
+    "header name and value with a colon. Can be specified multiple times.");
+
+DEFINE_QUICHE_COMMAND_LINE_FLAG(
     std::string, private_token, "",
     "When set, the client will attach a base64-encoded private token to the "
     "encapsulated request. Accepts any base64 encoding.");
@@ -125,6 +130,8 @@
     }
     post_data = *post_data_from_file;
   }
+  std::vector<std::string> headers =
+      quiche::GetQuicheCommandLineFlag(FLAGS_header);
   std::string private_token =
       quiche::GetQuicheCommandLineFlag(FLAGS_private_token);
 
@@ -144,9 +151,12 @@
   for (size_t i = 2; i < urls.size(); ++i) {
     MasqueOhttpClient::Config::PerRequestConfig per_request_config(urls[i]);
     per_request_config.SetPostData(post_data);
+    QUICHE_RETURN_IF_ERROR(per_request_config.AddHeaders(headers));
+    if (!private_token.empty()) {
+      QUICHE_RETURN_IF_ERROR(per_request_config.AddPrivateToken(private_token));
+    }
     per_request_config.SetUseChunkedOhttp(use_chunked_ohttp);
     per_request_config.SetUseIndeterminateLength(indeterminate_length);
-    per_request_config.SetPrivateToken(private_token);
     if (expect_gateway_error.has_value()) {
       per_request_config.SetExpectedGatewayError(*expect_gateway_error);
     }