Add OHTTP client and gateway support to MASQUE TCP test tools

Right now the gateway generates a fake response to allow testing, but in the future the gateway will forward that request to the right origin.

This CL also fixes a bug in MasqueConnectionPool when handling multiple requests to the same origin.

PiperOrigin-RevId: 735888511
diff --git a/quiche/quic/masque/masque_connection_pool.cc b/quiche/quic/masque/masque_connection_pool.cc
index 2463438..5413c27 100644
--- a/quiche/quic/masque/masque_connection_pool.cc
+++ b/quiche/quic/masque/masque_connection_pool.cc
@@ -99,8 +99,17 @@
     return absl::InternalError(
         absl::StrCat("Failed to create connection to ", authority->second));
   }
-  RequestId request_id = ++next_request_id_;
   auto pending_request = std::make_unique<PendingRequest>();
+  if (connection->connection() != nullptr) {
+    pending_request->connection = connection->connection();
+    pending_request->stream_id =
+        connection->connection()->SendRequest(request.headers, request.body);
+    if (pending_request->stream_id < 0) {
+      return absl::InternalError(
+          absl::StrCat("Failed to send request to ", authority->second));
+    }
+  }
+  RequestId request_id = ++next_request_id_;
   pending_request->request.headers = request.headers.Clone();
   pending_request->request.body = request.body;
   pending_requests_.insert({request_id, std::move(pending_request)});
diff --git a/quiche/quic/masque/masque_connection_pool.h b/quiche/quic/masque/masque_connection_pool.h
index 3b6042e..7b61bea 100644
--- a/quiche/quic/masque/masque_connection_pool.h
+++ b/quiche/quic/masque/masque_connection_pool.h
@@ -78,6 +78,8 @@
     void OnSocketEvent(QuicEventLoop *event_loop, SocketFd fd,
                        QuicSocketEventMask events) override;
 
+    MasqueH2Connection *connection() { return connection_.get(); }
+
    private:
     static enum ssl_verify_result_t VerifyCallback(SSL *ssl,
                                                    uint8_t *out_alert);
diff --git a/quiche/quic/masque/masque_ohttp_client_bin.cc b/quiche/quic/masque/masque_ohttp_client_bin.cc
index 41e8b48..92b7d56 100644
--- a/quiche/quic/masque/masque_ohttp_client_bin.cc
+++ b/quiche/quic/masque/masque_ohttp_client_bin.cc
@@ -4,15 +4,19 @@
 
 #include <stdbool.h>
 
+#include <cstddef>
 #include <memory>
 #include <optional>
 #include <ostream>
 #include <string>
+#include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "openssl/base.h"
 #include "quiche/quic/core/io/quic_default_event_loop.h"
 #include "quiche/quic/core/io/quic_event_loop.h"
@@ -20,11 +24,12 @@
 #include "quiche/quic/core/quic_time.h"
 #include "quiche/quic/masque/masque_connection_pool.h"
 #include "quiche/quic/tools/quic_url.h"
-#include "quiche/common/http/http_header_block.h"
+#include "quiche/binary_http/binary_http_message.h"
 #include "quiche/common/platform/api/quiche_command_line_flags.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/common/platform/api/quiche_system_event_loop.h"
 #include "quiche/oblivious_http/common/oblivious_http_header_key_config.h"
+#include "quiche/oblivious_http/oblivious_http_client.h"
 
 DEFINE_QUICHE_COMMAND_LINE_FLAG(
     bool, disable_certificate_verification, false,
@@ -41,7 +46,13 @@
     std::string, client_cert_key_file, "",
     "Path to the pkcs8 client certificate private key.");
 
+using quiche::BinaryHttpRequest;
+using quiche::BinaryHttpResponse;
+using quiche::ObliviousHttpClient;
+using quiche::ObliviousHttpHeaderKeyConfig;
 using quiche::ObliviousHttpKeyConfigs;
+using quiche::ObliviousHttpRequest;
+using quiche::ObliviousHttpResponse;
 
 namespace quic {
 namespace {
@@ -70,7 +81,16 @@
     }
     return true;
   }
-  bool IsDone() { return done_; }
+  bool IsDone() {
+    if (aborted_) {
+      return true;
+    }
+    if (!ohttp_client_.has_value()) {
+      // Key fetch request is still pending.
+      return false;
+    }
+    return pending_ohttp_requests_.empty();
+  }
 
   // From MasqueConnectionPool::Visitor.
   void OnResponse(MasqueConnectionPool * /*pool*/, RequestId request_id,
@@ -79,6 +99,42 @@
         *key_fetch_request_id_ == request_id) {
       key_fetch_request_id_ = std::nullopt;
       HandleKeyResponse(response);
+    } else {
+      auto it = pending_ohttp_requests_.find(request_id);
+      if (it == pending_ohttp_requests_.end()) {
+        QUICHE_LOG(ERROR) << "Received unexpected response for unknown request "
+                          << request_id;
+        Abort();
+        return;
+      }
+      if (response.ok()) {
+        if (!ohttp_client_.has_value()) {
+          QUICHE_LOG(FATAL) << "Received OHTTP response without OHTTP client";
+          return;
+        }
+        absl::StatusOr<ObliviousHttpResponse> ohttp_response =
+            ohttp_client_->DecryptObliviousHttpResponse(response->body,
+                                                        it->second);
+        if (ohttp_response.ok()) {
+          QUICHE_LOG(INFO) << "Received OHTTP response for " << request_id;
+          absl::StatusOr<BinaryHttpResponse> binary_response =
+              BinaryHttpResponse::Create(ohttp_response->GetPlaintextData());
+          if (binary_response.ok()) {
+            QUICHE_LOG(INFO) << "Successfully decoded OHTTP response: "
+                             << binary_response->body();
+          } else {
+            QUICHE_LOG(ERROR) << "Failed to parse binary response: "
+                              << binary_response.status();
+          }
+        } else {
+          QUICHE_LOG(ERROR) << "Failed to decrypt OHTTP response: "
+                            << ohttp_response.status();
+        }
+      } else {
+        QUICHE_LOG(ERROR) << "OHTTP request " << request_id
+                          << " failed: " << response.status();
+      }
+      pending_ohttp_requests_.erase(it);
     }
   }
 
@@ -128,24 +184,113 @@
     QUICHE_LOG(INFO) << "Successfully got " << key_configs->NumKeys()
                      << " OHTTP keys: " << std::endl
                      << key_configs->DebugString();
-    // TODO(dschinazi): Use the keys to send requests.
-    Abort();
+    if (urls_.size() <= 2) {
+      QUICHE_LOG(INFO) << "No OHTTP URLs to request, exiting.";
+      Abort();
+      return;
+    }
+    relay_url_ = QuicUrl(urls_[1], "https");
+    if (relay_url_.host().empty() && !absl::StrContains(urls_[1], "://")) {
+      relay_url_ = QuicUrl(absl::StrCat("https://", urls_[1]));
+    }
+    QUICHE_LOG(INFO) << "Using relay URL: " << relay_url_.ToString();
+    ObliviousHttpHeaderKeyConfig key_config = key_configs->PreferredConfig();
+    absl::StatusOr<absl::string_view> public_key =
+        key_configs->GetPublicKeyForId(key_config.GetKeyId());
+    if (!public_key.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to get public key for key ID "
+                        << static_cast<int>(key_config.GetKeyId()) << ": "
+                        << public_key.status();
+      Abort();
+      return;
+    }
+    absl::StatusOr<ObliviousHttpClient> ohttp_client =
+        ObliviousHttpClient::Create(*public_key, key_config);
+    if (!ohttp_client.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to create OHTTP client: "
+                        << ohttp_client.status();
+      Abort();
+      return;
+    }
+    ohttp_client_.emplace(std::move(*ohttp_client));
+    for (size_t i = 2; i < urls_.size(); ++i) {
+      SendOhttpRequestForUrl(urls_[i]);
+    }
   }
 
-  void Abort() { done_ = true; }
+  void SendOhttpRequestForUrl(const std::string &url_string) {
+    QuicUrl url(url_string, "https");
+    if (url.host().empty() && !absl::StrContains(url_string, "://")) {
+      url = QuicUrl(absl::StrCat("https://", url_string));
+    }
+    if (url.host().empty()) {
+      QUICHE_LOG(ERROR) << "Failed to parse key URL \"" << url_string << "\"";
+      return;
+    }
+    BinaryHttpRequest::ControlData control_data;
+    control_data.method = "GET";
+    control_data.scheme = url.scheme();
+    control_data.authority = url.HostPort();
+    control_data.path = url.path();
+    BinaryHttpRequest binary_request(control_data);
+    absl::StatusOr<std::string> encoded_request = binary_request.Serialize();
+    if (!encoded_request.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to encode request: "
+                        << encoded_request.status();
+      return;
+    }
+    if (!ohttp_client_.has_value()) {
+      QUICHE_LOG(FATAL) << "Cannot send OHTTP request without OHTTP client";
+      return;
+    }
+    absl::StatusOr<ObliviousHttpRequest> ohttp_request =
+        ohttp_client_->CreateObliviousHttpRequest(*encoded_request);
+    if (!ohttp_request.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to create OHTTP request: "
+                        << ohttp_request.status();
+      return;
+    }
+    Message request;
+    request.headers[":method"] = "POST";
+    request.headers[":scheme"] = relay_url_.scheme();
+    request.headers[":authority"] = relay_url_.HostPort();
+    request.headers[":path"] = relay_url_.path();
+    request.headers["host"] = relay_url_.HostPort();
+    request.headers["content-type"] = "message/ohttp-req";
+    request.body = ohttp_request->EncapsulateAndSerialize();
+    absl::StatusOr<RequestId> request_id =
+        connection_pool_.SendRequest(request);
+    if (!request_id.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to send request: " << request_id.status();
+      return;
+    }
+    QUICHE_LOG(INFO) << "Sent OHTTP request for " << url_string;
+    auto context = std::move(*ohttp_request).ReleaseContext();
+    pending_ohttp_requests_.insert({*request_id, std::move(context)});
+  }
+
+  void Abort() {
+    QUICHE_LOG(INFO) << "Aborting";
+    aborted_ = true;
+  }
 
   std::vector<std::string> urls_;
   MasqueConnectionPool connection_pool_;
   std::optional<RequestId> key_fetch_request_id_;
-  bool done_ = false;
+  bool aborted_ = false;
+  std::optional<ObliviousHttpClient> ohttp_client_;
+  QuicUrl relay_url_;
+  absl::flat_hash_map<RequestId, ObliviousHttpRequest::Context>
+      pending_ohttp_requests_;
 };
 
 int RunMasqueOhttpClient(int argc, char *argv[]) {
-  const char *usage = "Usage: masque_ohttp_client <url>";
+  const char *usage =
+      "Usage: masque_ohttp_client <key-url> <relay-url> <url>...";
   std::vector<std::string> urls =
       quiche::QuicheParseCommandLineFlags(usage, argc, argv);
 
-  quiche::QuicheSystemEventLoop system_event_loop("masque_client");
+  quiche::QuicheSystemEventLoop system_event_loop("masque_ohttp_client");
   const bool disable_certificate_verification =
       quiche::GetQuicheCommandLineFlag(FLAGS_disable_certificate_verification);
 
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc
index d443a96..1401413 100644
--- a/quiche/quic/masque/masque_tcp_server_bin.cc
+++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -14,6 +14,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <ostream>
 #include <string>
 #include <utility>
@@ -22,6 +23,7 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/escaping.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "openssl/base.h"
 #include "openssl/bio.h"
@@ -34,6 +36,7 @@
 #include "quiche/quic/core/quic_default_clock.h"
 #include "quiche/quic/core/quic_time.h"
 #include "quiche/quic/masque/masque_h2_connection.h"
+#include "quiche/binary_http/binary_http_message.h"
 #include "quiche/common/http/http_header_block.h"
 #include "quiche/common/platform/api/quiche_command_line_flags.h"
 #include "quiche/common/platform/api/quiche_logging.h"
@@ -42,6 +45,7 @@
 #include "quiche/common/quiche_ip_address_family.h"
 #include "quiche/common/quiche_socket_address.h"
 #include "quiche/oblivious_http/common/oblivious_http_header_key_config.h"
+#include "quiche/oblivious_http/oblivious_http_gateway.h"
 
 DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, port, 9661,
                                 "The port the MASQUE server will listen on.");
@@ -59,8 +63,13 @@
     std::string, ohttp_key, "",
     "Hex-encoded bytes of the OHTTP HPKE private key.");
 
+using quiche::BinaryHttpRequest;
+using quiche::BinaryHttpResponse;
+using quiche::ObliviousHttpGateway;
 using quiche::ObliviousHttpHeaderKeyConfig;
 using quiche::ObliviousHttpKeyConfigs;
+using quiche::ObliviousHttpRequest;
+using quiche::ObliviousHttpResponse;
 
 namespace quic {
 
@@ -144,6 +153,73 @@
       return false;
     }
     concatenated_keys_ = *concatenated_keys;
+    absl::StatusOr<ObliviousHttpGateway> ohttp_gateway =
+        ObliviousHttpGateway::Create(hpke_private_key_,
+                                     *ohttp_header_key_config);
+    if (!ohttp_gateway.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to create OHTTP gateway: "
+                        << ohttp_gateway.status();
+      return false;
+    }
+    ohttp_gateway_.emplace(std::move(*ohttp_gateway));
+    return true;
+  }
+
+  bool HandleRequest(MasqueH2Connection *connection, int32_t stream_id,
+                     const std::string &encapsulated_request) {
+    if (!ohttp_gateway_.has_value()) {
+      QUICHE_LOG(ERROR) << "Not ready to handle OHTTP request";
+      return false;
+    }
+    absl::StatusOr<ObliviousHttpRequest> decrypted_request =
+        ohttp_gateway_->DecryptObliviousHttpRequest(encapsulated_request);
+    if (!decrypted_request.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to decrypt OHTTP request: "
+                        << decrypted_request.status();
+      return false;
+    }
+    absl::StatusOr<BinaryHttpRequest> binary_request =
+        BinaryHttpRequest::Create(decrypted_request->GetPlaintextData());
+    if (!binary_request.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to parse binary request: "
+                        << binary_request.status();
+      return false;
+    }
+    const BinaryHttpRequest::ControlData &control_data =
+        binary_request->control_data();
+    // TODO(dschinazi): Send the decapsulated request to the authority instead
+    // of replying with a fake local response.
+    std::string response_body = absl::StrCat(
+        "OHTTP Response! Request method: ", control_data.method,
+        " scheme: ", control_data.scheme, " path: ", control_data.path,
+        " authority: ", control_data.authority);
+
+    BinaryHttpResponse binary_response(/*status_code=*/200);
+    binary_response.swap_body(response_body);
+    absl::StatusOr<std::string> encoded_response = binary_response.Serialize();
+    if (!encoded_response.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to encode response: "
+                        << encoded_response.status();
+      return false;
+    }
+
+    auto context = std::move(*decrypted_request).ReleaseContext();
+    absl::StatusOr<ObliviousHttpResponse> ohttp_response =
+        ohttp_gateway_->CreateObliviousHttpResponse(*encoded_response, context);
+    if (!ohttp_response.ok()) {
+      QUICHE_LOG(ERROR) << "Failed to create OHTTP response: "
+                        << ohttp_response.status();
+      return false;
+    }
+    std::string encapsulated_response =
+        ohttp_response->EncapsulateAndSerialize();
+    QUICHE_LOG(INFO) << "Sending OHTTP response";
+
+    quiche::HttpHeaderBlock response_headers;
+    response_headers[":status"] = "200";
+    response_headers["content-type"] = "message/ohttp-res";
+    connection->SendResponse(stream_id, response_headers,
+                             encapsulated_response);
     return true;
   }
 
@@ -155,6 +231,7 @@
   const EVP_HPKE_KEM *kem_ = EVP_hpke_x25519_hkdf_sha256();
   bssl::UniquePtr<EVP_HPKE_KEY> hpke_key_;
   std::string concatenated_keys_;
+  std::optional<ObliviousHttpGateway> ohttp_gateway_;
 };
 
 static int SelectAlpnCallback(SSL * /*ssl*/, const uint8_t **out,
@@ -240,9 +317,9 @@
 class MasqueTcpServer : public QuicSocketEventListener,
                         public MasqueH2Connection::Visitor {
  public:
-  explicit MasqueTcpServer(MasqueOhttpGateway *ohttp_gateway)
+  explicit MasqueTcpServer(MasqueOhttpGateway *masque_ohttp_gateway)
       : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())),
-        ohttp_gateway_(ohttp_gateway) {}
+        masque_ohttp_gateway_(masque_ohttp_gateway) {}
 
   MasqueTcpServer(const MasqueTcpServer &) = delete;
   MasqueTcpServer(MasqueTcpServer &&) = delete;
@@ -371,9 +448,15 @@
         connections_.end());
   }
 
+  bool HandleOhttpRequest(MasqueH2Connection *connection, int32_t stream_id,
+                          const std::string &encapsulated_request) {
+    return masque_ohttp_gateway_->HandleRequest(connection, stream_id,
+                                                encapsulated_request);
+  }
+
   void OnRequest(MasqueH2Connection *connection, int32_t stream_id,
                  const quiche::HttpHeaderBlock &headers,
-                 const std::string & /*body*/) override {
+                 const std::string &body) override {
     quiche::HttpHeaderBlock response_headers;
     std::string response_body;
     auto path_pair = headers.find(":path");
@@ -389,7 +472,16 @@
                content_type_pair->second == "application/ohttp-keys") {
       response_headers[":status"] = "200";
       response_headers["content-type"] = "application/ohttp-keys";
-      response_body = ohttp_gateway_->concatenated_keys();
+      response_body = masque_ohttp_gateway_->concatenated_keys();
+    } else if (method_pair->second == "POST" &&
+               content_type_pair != headers.end() &&
+               content_type_pair->second == "message/ohttp-req") {
+      if (HandleOhttpRequest(connection, stream_id, body)) {
+        return;
+      } else {
+        response_headers[":status"] = "500";
+        response_body = "Failed to handle OHTTP request";
+      }
     } else if (method_pair->second == "GET" && path_pair->second == "/") {
       response_headers[":status"] = "200";
       response_body = "<h1>This is a response body</h1>";
@@ -434,7 +526,7 @@
 
   std::unique_ptr<QuicEventLoop> event_loop_;
   bssl::UniquePtr<SSL_CTX> ctx_;
-  MasqueOhttpGateway *ohttp_gateway_;  // Unowned.
+  MasqueOhttpGateway *masque_ohttp_gateway_;  // Unowned.
   SocketFd server_socket_ = kInvalidSocketFd;
   std::vector<std::unique_ptr<MasqueH2SocketConnection>> connections_;
 };
@@ -464,13 +556,14 @@
 
   quiche::QuicheSystemEventLoop system_event_loop("masque_tcp_server");
 
-  MasqueOhttpGateway ohttp_gateway;
-  if (!ohttp_gateway.Setup(quiche::GetQuicheCommandLineFlag(FLAGS_ohttp_key))) {
+  MasqueOhttpGateway masque_ohttp_gateway;
+  if (!masque_ohttp_gateway.Setup(
+          quiche::GetQuicheCommandLineFlag(FLAGS_ohttp_key))) {
     QUICHE_LOG(ERROR) << "Failed to setup OHTTP";
     return 1;
   }
 
-  MasqueTcpServer server(&ohttp_gateway);
+  MasqueTcpServer server(&masque_ohttp_gateway);
   if (!server.SetupSslCtx(certificate_file, key_file, client_root_ca_file)) {
     QUICHE_LOG(ERROR) << "Failed to setup SSL context";
     return 1;