Add support for sending OHTTP request bodies in multiple chunks. This change modifies MasqueConnectionPool and MasqueH2Connection to allow sending request bodies in two parts after the initial headers. MasqueOhttpClient is updated to utilize this for chunked OHTTP requests, splitting the encrypted OHTTP request into two parts sent via separate calls. PiperOrigin-RevId: 902529170
diff --git a/quiche/quic/masque/masque_connection_pool.cc b/quiche/quic/masque/masque_connection_pool.cc index b44d8da..24e86d1 100644 --- a/quiche/quic/masque/masque_connection_pool.cc +++ b/quiche/quic/masque/masque_connection_pool.cc
@@ -174,18 +174,25 @@ void MasqueConnectionPool::OnResponse(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, - const std::string& body) { + const std::string& body, + bool end_stream) { bool found = false; for (auto it = pending_requests_.begin(); it != pending_requests_.end();) { RequestId request_id = it->first; PendingRequest& pending_request = *it->second; if (pending_request.connection == connection && pending_request.stream_id == stream_id) { - pending_requests_.erase(it++); Message response; response.headers = headers.Clone(); response.body = body; - visitor_->OnPoolResponse(this, request_id, std::move(response)); + if (end_stream) { + pending_request.response_done = true; + } + if (end_stream && pending_request.end_stream_pending) { + pending_requests_.erase(it); + } + visitor_->OnPoolResponse(this, request_id, std::move(response), + end_stream); found = true; break; } @@ -197,6 +204,28 @@ } } +void MasqueConnectionPool::OnDataForStream(MasqueH2Connection* connection, + int32_t stream_id, + absl::string_view data, + bool end_stream) { + for (auto it = pending_requests_.begin(); it != pending_requests_.end(); + ++it) { + RequestId request_id = it->first; + PendingRequest& pending_request = *it->second; + if (pending_request.connection == connection && + pending_request.stream_id == stream_id) { + if (end_stream) { + pending_request.response_done = true; + } + if (end_stream && pending_request.end_stream_pending) { + pending_requests_.erase(it); + } + visitor_->OnPoolData(this, request_id, data, end_stream); + break; + } + } +} + void MasqueConnectionPool::OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id, absl::Status error) { @@ -206,7 +235,7 @@ if (pending_request.connection == connection && pending_request.stream_id == stream_id) { pending_requests_.erase(it++); - visitor_->OnPoolResponse(this, request_id, error); + visitor_->OnPoolResponse(this, request_id, error, /*end_stream=*/true); break; } ++it; @@ -214,7 +243,8 @@ } absl::StatusOr<MasqueConnectionPool::RequestId> -MasqueConnectionPool::SendRequest(const Message& request, bool mtls) { +MasqueConnectionPool::SendRequest(const Message& request, bool mtls, + bool end_stream, bool stream_response) { auto authority = request.headers.find(":authority"); if (authority == request.headers.end()) { return absl::InvalidArgumentError("Request missing :authority header"); @@ -228,8 +258,8 @@ << connection->connection()->info() << " to " << authority->second; pending_request->connection = connection->connection(); - pending_request->stream_id = - connection->connection()->SendRequest(request.headers, request.body); + pending_request->stream_id = connection->connection()->SendRequest( + request.headers, request.body, end_stream, stream_response); if (pending_request->stream_id < 0) { return absl::InternalError( absl::StrCat("Failed to send request to ", authority->second)); @@ -241,10 +271,42 @@ RequestId request_id = ++next_request_id_; pending_request->request.headers = request.headers.Clone(); pending_request->request.body = request.body; + pending_request->stream_response = stream_response; + pending_request->end_stream_pending = end_stream; pending_requests_.insert({request_id, std::move(pending_request)}); return request_id; } +absl::Status MasqueConnectionPool::SendBodyChunk(RequestId request_id, + const std::string& body, + bool end_stream) { + auto it = pending_requests_.find(request_id); + if (it == pending_requests_.end()) { + return absl::InternalError( + absl::StrCat("SendBodyChunk called for unknown request ", request_id)); + } + PendingRequest& pending_request = *it->second; + if (pending_request.connection == nullptr || pending_request.stream_id < 0) { + // Connection not ready yet, append to pending body chunks. + pending_request.pending_data += body; + pending_request.end_stream_pending = end_stream; + return absl::OkStatus(); + } + if (pending_request.end_stream_pending) { + return absl::FailedPreconditionError(absl::StrCat( + "SendBodyChunk called when end_stream already pending ", request_id)); + } + pending_request.end_stream_pending = end_stream; + pending_request.connection->SendBodyChunk(pending_request.stream_id, body, + end_stream); + pending_request.connection->AttemptToSend(); + + if (end_stream && pending_request.response_done) { + pending_requests_.erase(it); + } + return absl::OkStatus(); +} + absl::StatusOr<MasqueConnectionPool::ConnectionState*> MasqueConnectionPool::GetOrCreateConnectionState(const std::string& authority, bool mtls) { @@ -288,19 +350,32 @@ ++it; continue; } - QUICHE_LOG(INFO) << "Sending pending request ID " << request_id - << " on connection " << connection->info(); - int32_t stream_id = connection->SendRequest(pending_request.request.headers, - pending_request.request.body); + bool is_request_complete = pending_request.pending_data.empty() && + pending_request.end_stream_pending; + QUICHE_LOG(INFO) << "Sending pending request ID " << request_id; + int32_t stream_id = connection->SendRequest( + pending_request.request.headers, pending_request.request.body, + is_request_complete, pending_request.stream_response); if (stream_id < 0) { QUICHE_LOG(ERROR) << "Failed to send request ID " << request_id << " on connection " << connection->info(); visitor_->OnPoolResponse(this, request_id, - absl::InternalError("Failed to send request")); + absl::InternalError("Failed to send request"), + /*end_stream=*/true); pending_requests_.erase(it++); continue; } - connection->AttemptToSend(); + + if (!pending_request.pending_data.empty()) { + connection->AttemptToSend(); + } + + if (!is_request_complete) { + connection->SendBodyChunk(stream_id, pending_request.pending_data, + pending_request.end_stream_pending); + } + pending_request.pending_data.clear(); + pending_request.stream_id = stream_id; ++it; } @@ -315,7 +390,7 @@ ++it; continue; } - visitor_->OnPoolResponse(this, request_id, error); + visitor_->OnPoolResponse(this, request_id, error, /*end_stream=*/true); pending_requests_.erase(it++); } }
diff --git a/quiche/quic/masque/masque_connection_pool.h b/quiche/quic/masque/masque_connection_pool.h index f1a2931..0c89733 100644 --- a/quiche/quic/masque/masque_connection_pool.h +++ b/quiche/quic/masque/masque_connection_pool.h
@@ -79,14 +79,23 @@ virtual ~Visitor() = default; virtual void OnPoolResponse(MasqueConnectionPool* pool, RequestId request_id, - absl::StatusOr<Message>&& response) = 0; + absl::StatusOr<Message>&& response, + bool end_stream) = 0; + virtual void OnPoolData(MasqueConnectionPool* pool, RequestId request_id, + absl::string_view data, bool end_stream) = 0; }; // If the request fails immediately, the error will be returned. Otherwise, a // request ID will be returned and the result (the response or an error) will // be delivered later with that same request ID via Visitor::OnResponse. absl::StatusOr<RequestId> SendRequest(const Message& request, - bool mtls = false); + bool mtls = false, + bool end_stream = true, + bool stream_response = false); + + // Sends a body chunk for an existing request. + absl::Status SendBodyChunk(RequestId request_id, const std::string& body, + bool end_stream); // `event_loop`, `ssl_ctx`, and `visitor` must outlive this object. explicit MasqueConnectionPool(QuicEventLoop* event_loop, SSL_CTX* ssl_ctx, @@ -106,9 +115,11 @@ const std::string& body) override; void OnResponse(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, - const std::string& body) override; + const std::string& body, bool end_stream) override; void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id, absl::Status error) override; + void OnDataForStream(MasqueH2Connection* connection, int32_t stream_id, + absl::string_view data, bool end_stream) override; static absl::StatusOr<bssl::UniquePtr<SSL_CTX>> CreateSslCtx( const std::string& client_cert_file, @@ -147,8 +158,12 @@ }; struct PendingRequest { Message request; + bool end_stream_pending = true; + std::string pending_data; MasqueH2Connection* connection = nullptr; // Not owned. int32_t stream_id = -1; + bool response_done = false; + bool stream_response = false; }; absl::StatusOr<MasqueConnectionPool::ConnectionState*>
diff --git a/quiche/quic/masque/masque_h2_connection.cc b/quiche/quic/masque/masque_h2_connection.cc index 85f8eae..7ead699 100644 --- a/quiche/quic/masque/masque_h2_connection.cc +++ b/quiche/quic/masque/masque_h2_connection.cc
@@ -245,7 +245,7 @@ info.end_data = false; } else { info.payload_length = stream->body_to_send.size(); - info.end_data = true; + info.end_data = stream->end_stream_pending; } info.end_stream = info.end_data; return info; @@ -324,6 +324,10 @@ MasqueH2Stream* stream = GetOrCreateH2Stream(stream_id); QUICHE_LOG(INFO) << ENDPOINT << "OnEndHeadersForStream " << stream_id << " headers: " << stream->received_headers.DebugString(); + if (stream->response_is_streamed) { + visitor_->OnResponse(this, stream_id, stream->received_headers, + /*body=*/"", /*end_stream=*/false); + } return true; } @@ -351,7 +355,8 @@ } int32_t MasqueH2Connection::SendRequest(const quiche::HttpHeaderBlock& headers, - const std::string& body) { + const std::string& body, + bool end_stream, bool stream_response) { if (is_server_) { QUICHE_LOG(FATAL) << ENDPOINT << "Server cannot send requests"; } @@ -361,9 +366,9 @@ return -1; } std::vector<Header> h2_headers = ConvertHeaders(headers); - int32_t stream_id = - h2_adapter_->SubmitRequest(h2_headers, /*end_stream=*/body.empty(), - /*user_data=*/nullptr); + int32_t stream_id = h2_adapter_->SubmitRequest( + h2_headers, /*end_stream=*/end_stream && body.empty(), + /*user_data=*/nullptr); if (stream_id < 0) { Abort(absl::InvalidArgumentError( absl::StrCat("Failed to submit request with body of length ", @@ -377,10 +382,42 @@ QUICHE_DVLOG(2) << ENDPOINT << "Body to be sent:" << std::endl << quiche::QuicheTextUtils::HexDump(body); } - GetOrCreateH2Stream(stream_id)->body_to_send = body; + MasqueH2Stream* stream = GetOrCreateH2Stream(stream_id); + stream->response_is_streamed = stream_response; + stream->body_to_send = body; + stream->end_stream_pending = end_stream; + if (!body.empty() || !end_stream) { + h2_adapter_->ResumeStream(stream_id); + } return stream_id; } +void MasqueH2Connection::SendBodyChunk(int32_t stream_id, + const std::string& body, + bool end_stream) { + QUICHE_LOG(INFO) << ENDPOINT << "SendBodyChunk stream_id: " << stream_id + << " body length: " << body.size() + << " end_stream: " << end_stream; + QUICHE_LOG(INFO) << ENDPOINT << "SendBodyChunk Connection window: " + << h2_adapter_->GetSendWindowSize() << ", Stream window: " + << h2_adapter_->GetStreamSendWindowSize(stream_id); + MasqueH2Stream* stream = GetOrCreateH2Stream(stream_id); + if (stream->end_stream_pending) { + QUICHE_LOG(DFATAL) << ENDPOINT + << "SendBodyChunk called when end_stream already " + "pending for stream " + << stream_id; + return; + } + stream->body_to_send.append(body); + stream->end_stream_pending = end_stream; + h2_adapter_->ResumeStream(stream_id); + QUICHE_LOG(INFO) << ENDPOINT << "Sending body on stream ID " << stream_id + << " with body of length " << body.size() + << ", end_stream: " << end_stream; + AttemptToSend(); +} + std::vector<Header> MasqueH2Connection::ConvertHeaders( const quiche::HttpHeaderBlock& headers) { std::vector<Header> h2_headers; @@ -407,10 +444,16 @@ bool MasqueH2Connection::OnDataForStream(Http2StreamId stream_id, absl::string_view data) { - GetOrCreateH2Stream(stream_id)->received_body.append(data); - QUICHE_DVLOG(1) << ENDPOINT << "OnDataForStream " << stream_id - << " new data length: " << data.size() << " total length: " - << GetOrCreateH2Stream(stream_id)->received_body.size(); + MasqueH2Stream* stream = GetOrCreateH2Stream(stream_id); + if (!stream->response_is_streamed) { + stream->received_body.append(data); + } + QUICHE_LOG(INFO) << ENDPOINT << "OnDataForStream " << stream_id + << " new data length: " << data.size() + << " total length: " << stream->received_body.size(); + if (stream->response_is_streamed) { + visitor_->OnDataForStream(this, stream_id, data, /*end_stream=*/false); + } return true; } @@ -426,8 +469,13 @@ visitor_->OnRequest(this, stream_id, stream->received_headers, stream->received_body); } else { - visitor_->OnResponse(this, stream_id, stream->received_headers, - stream->received_body); + if (stream->response_is_streamed) { + visitor_->OnDataForStream(this, stream_id, /*data=*/"", + /*end_stream=*/true); + } else { + visitor_->OnResponse(this, stream_id, stream->received_headers, + stream->received_body, /*end_stream=*/true); + } } } return true;
diff --git a/quiche/quic/masque/masque_h2_connection.h b/quiche/quic/masque/masque_h2_connection.h index d71dae5..4058e9c 100644 --- a/quiche/quic/masque/masque_h2_connection.h +++ b/quiche/quic/masque/masque_h2_connection.h
@@ -46,9 +46,12 @@ const std::string& body) = 0; virtual void OnResponse(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, - const std::string& body) = 0; + const std::string& body, bool end_stream) = 0; virtual void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id, absl::Status error) = 0; + virtual void OnDataForStream(MasqueH2Connection* connection, + int32_t stream_id, absl::string_view data, + bool end_stream) = 0; }; // `ssl` and `visitor` must outlive this object. @@ -67,7 +70,11 @@ // Call when there is more data to be written to SSL. bool AttemptToSend(); int32_t SendRequest(const quiche::HttpHeaderBlock& headers, - const std::string& body); + const std::string& body, bool end_stream = true, + bool stream_response = false); + // Enqueues a body chunk or fin for the given stream. + void SendBodyChunk(int32_t stream_id, const std::string& body, + bool end_stream); void SendResponse(int32_t stream_id, const quiche::HttpHeaderBlock& headers, const std::string& body); @@ -79,7 +86,9 @@ quiche::HttpHeaderBlock received_headers; std::string received_body; std::string body_to_send; + bool end_stream_pending = true; bool callback_fired = false; + bool response_is_streamed = false; }; static constexpr size_t kBioBufferSize = 16384; void Abort(absl::Status error);
diff --git a/quiche/quic/masque/masque_ohttp_client.cc b/quiche/quic/masque/masque_ohttp_client.cc index e409be7..ccaa86a 100644 --- a/quiche/quic/masque/masque_ohttp_client.cc +++ b/quiche/quic/masque/masque_ohttp_client.cc
@@ -4,6 +4,7 @@ #include "quiche/quic/masque/masque_ohttp_client.h" +#include <functional> #include <iostream> #include <memory> #include <optional> @@ -99,6 +100,102 @@ return chunks; } +class PingPongResponseVisitor : public MasqueOhttpClient::ResponseVisitor { + public: + explicit PingPongResponseVisitor(std::vector<std::string> chunks) + : chunks_(std::move(chunks)) {} + + void OnRequestStarted(quic::MasqueConnectionPool::RequestId request_id, + MasqueOhttpClient* client) override { + request_id_ = request_id; + client_ = client; + bool is_final = (chunks_.size() <= 1); + status_ = client_->SendBodyChunk(request_id_, chunks_[0], is_final); + current_chunk_idx_ = 1; + if (is_final) { + done_ = true; + } + } + + void OnResponseChunk(quic::MasqueConnectionPool::RequestId request_id, + absl::string_view) override { + if (request_id != request_id_ || current_chunk_idx_ >= chunks_.size()) { + return; + } + bool is_final = (current_chunk_idx_ == chunks_.size() - 1); + status_ = client_->SendBodyChunk(request_id_, chunks_[current_chunk_idx_], + is_final); + current_chunk_idx_++; + if (is_final) { + done_ = true; + } + } + + void OnResponseDone(quic::MasqueConnectionPool::RequestId request_id, + const MasqueOhttpClient::Message&) override { + if (request_id == request_id_) done_ = true; + } + + void OnError(quic::MasqueConnectionPool::RequestId request_id, + absl::Status status) override { + if (request_id == request_id_) { + status_ = status; + done_ = true; + } + } + + bool done() const { return done_; } + absl::Status status() const { return status_; } + + private: + MasqueOhttpClient* client_ = nullptr; + std::vector<std::string> chunks_; + size_t current_chunk_idx_ = 0; + quic::MasqueConnectionPool::RequestId request_id_ = 0; + bool done_ = false; + absl::Status status_ = absl::OkStatus(); +}; + +absl::StatusOr<std::unique_ptr<PingPongResponseVisitor>> +CreateVisitorIfPingPong(const MasqueOhttpClient::Config& config) { + const MasqueOhttpClient::Config::PerRequestConfig* ping_pong_config = nullptr; + + for (const auto& req_config : config.per_request_configs()) { + if (req_config.ping_pong_mode()) { + ping_pong_config = &req_config; + break; + } + } + + if (ping_pong_config == nullptr) { + return nullptr; + } + + if (config.per_request_configs().size() > 1) { + return absl::InvalidArgumentError( + "PingPong mode is exclusive and supports only one request at a time"); + } + + if (ping_pong_config->num_ohttp_chunks() <= 0) { + return absl::InvalidArgumentError( + "num_ohttp_chunks must be set and greater than 0 when " + "ping_pong_mode is enabled"); + } + + std::string post_data = ping_pong_config->post_data(); + QUICHE_ASSIGN_OR_RETURN( + std::vector<absl::string_view> chunks, + SplitIntoChunks(post_data, ping_pong_config->num_ohttp_chunks())); + + std::vector<std::string> string_chunks; + string_chunks.reserve(chunks.size()); + for (absl::string_view chunk : chunks) { + string_chunks.push_back(std::string(chunk)); + } + + return std::make_unique<PingPongResponseVisitor>(std::move(string_chunks)); +} + } // namespace std::string MasqueOhttpClient::Config::PerRequestConfig::method() const { @@ -200,12 +297,24 @@ if (config.ohttp_ssl_ctx() == nullptr) { QUICHE_RETURN_IF_ERROR(config.ConfigureOhttpMtls("", "")); } + QUICHE_ASSIGN_OR_RETURN( + std::unique_ptr<PingPongResponseVisitor> ping_pong_visitor, + CreateVisitorIfPingPong(config)); + quiche::QuicheSystemEventLoop system_event_loop("masque_ohttp_client"); std::unique_ptr<QuicEventLoop> event_loop = GetDefaultEventLoop()->Create(QuicDefaultClock::Get()); MasqueOhttpClient ohttp_client(std::move(config), event_loop.get()); + + if (ping_pong_visitor) { + ohttp_client.set_response_visitor(ping_pong_visitor.get()); + } + QUICHE_RETURN_IF_ERROR(ohttp_client.Start()); while (!ohttp_client.IsDone()) { + if (ping_pong_visitor && ping_pong_visitor->done()) { + break; + } ohttp_client.connection_pool_.event_loop()->RunEventLoopOnce( quic::QuicTime::Delta::FromMilliseconds(50)); } @@ -428,35 +537,38 @@ num_bhttp_chunks = (per_request_config.num_ohttp_chunks() > 0 ? 1 : 0); } if (num_bhttp_chunks > 0) { - BinaryHttpRequest::IndeterminateLengthEncoder encoder; - - QUICHE_ASSIGN_OR_RETURN(encoded_data, - encoder.EncodeControlData(control_data)); + pending_request.encoder.emplace(); + QUICHE_ASSIGN_OR_RETURN( + encoded_data, pending_request.encoder->EncodeControlData(control_data)); std::vector<quiche::BinaryHttpMessage::FieldView> headers; for (const std::pair<std::string, std::string>& header : per_request_config.headers()) { headers.push_back({header.first, header.second}); } - QUICHE_ASSIGN_OR_RETURN(std::string encoded_headers, - encoder.EncodeHeaders(absl::MakeSpan(headers))); - encoded_data += encoded_headers; - if (!post_data.empty()) { - QUICHE_ASSIGN_OR_RETURN(std::vector<absl::string_view> body_chunks, - SplitIntoChunks(post_data, num_bhttp_chunks)); - QUICHE_ASSIGN_OR_RETURN( - std::string encoded_body, - encoder.EncodeBodyChunks(absl::MakeSpan(body_chunks), - /*body_chunks_done=*/false)); - encoded_data += encoded_body; - } QUICHE_ASSIGN_OR_RETURN( - std::string encoded_final_chunk, - encoder.EncodeBodyChunks({}, /*body_chunks_done=*/true)); - encoded_data += encoded_final_chunk; - std::vector<quiche::BinaryHttpMessage::FieldView> trailers; - QUICHE_ASSIGN_OR_RETURN(std::string encoded_trailers, - encoder.EncodeTrailers(absl::MakeSpan(trailers))); - encoded_data += encoded_trailers; + std::string encoded_headers, + pending_request.encoder->EncodeHeaders(absl::MakeSpan(headers))); + encoded_data += encoded_headers; + if (!per_request_config.ping_pong_mode()) { + if (!post_data.empty()) { + QUICHE_ASSIGN_OR_RETURN(std::vector<absl::string_view> body_chunks, + SplitIntoChunks(post_data, num_bhttp_chunks)); + QUICHE_ASSIGN_OR_RETURN(std::string encoded_body, + pending_request.encoder->EncodeBodyChunks( + absl::MakeSpan(body_chunks), + /*body_chunks_done=*/false)); + encoded_data += encoded_body; + } + QUICHE_ASSIGN_OR_RETURN(std::string encoded_final_chunk, + pending_request.encoder->EncodeBodyChunks( + {}, /*body_chunks_done=*/true)); + encoded_data += encoded_final_chunk; + std::vector<quiche::BinaryHttpMessage::FieldView> trailers; + QUICHE_ASSIGN_OR_RETURN( + std::string encoded_trailers, + pending_request.encoder->EncodeTrailers(absl::MakeSpan(trailers))); + encoded_data += encoded_trailers; + } } else { BinaryHttpRequest binary_request(control_data); for (const std::pair<std::string, std::string>& header : @@ -466,7 +578,7 @@ binary_request.set_body(post_data); QUICHE_ASSIGN_OR_RETURN(encoded_data, binary_request.Serialize()); } - int num_ohttp_chunks = pending_request.per_request_config.num_ohttp_chunks(); + int num_ohttp_chunks = per_request_config.num_ohttp_chunks(); if (num_ohttp_chunks > 0) { pending_request.chunk_handler = std::make_unique<ChunkHandler>(); QUICHE_ASSIGN_OR_RETURN( @@ -478,7 +590,8 @@ SplitIntoChunks(encoded_data, num_ohttp_chunks)); for (size_t i = 0; i < ohttp_chunks.size(); i++) { - bool is_final_chunk = (i == ohttp_chunks.size() - 1); + bool is_final_chunk = (i == ohttp_chunks.size() - 1 && + !per_request_config.ping_pong_mode()); QUICHE_ASSIGN_OR_RETURN( std::string ohttp_chunk, chunked_client.EncryptRequestChunk(ohttp_chunks[i], is_final_chunk)); @@ -507,15 +620,32 @@ QUICHE_VLOG(1) << "Sending encrypted request: " << absl::BytesToHexString(encrypted_data); request.body = encrypted_data; - absl::StatusOr<RequestId> request_id = - connection_pool_.SendRequest(request, /*mtls=*/true); + bool end_stream = + num_ohttp_chunks <= 0 || !per_request_config.ping_pong_mode(); + bool stream_response = num_ohttp_chunks > 0; + absl::StatusOr<RequestId> request_id = connection_pool_.SendRequest( + request, /*mtls=*/true, end_stream, stream_response); if (!request_id.ok()) { return absl::InternalError(absl::StrCat("Failed to send request: ", request_id.status().message())); } + QUICHE_LOG(INFO) << "Sent OHTTP request for " << per_request_config.url(); + if (pending_request.chunk_handler) { + pending_request.chunk_handler->SetResponseChunkCallback( + [this, request_id = *request_id](absl::string_view chunk) { + if (response_visitor_) { + response_visitor_->OnResponseChunk(request_id, chunk); + } + }); + } pending_ohttp_requests_.insert({*request_id, std::move(pending_request)}); + + if (response_visitor_) { + response_visitor_->OnRequestStarted(*request_id, this); + } + return absl::OkStatus(); } @@ -553,6 +683,12 @@ return encapsulated_response; } ChunkHandler chunk_handler; + chunk_handler.SetResponseChunkCallback( + [this, request_id](absl::string_view chunk) { + if (response_visitor_) { + response_visitor_->OnResponseChunk(request_id, chunk); + } + }); QUICHE_RETURN_IF_ERROR( chunk_handler.OnDecryptedChunk(ohttp_response.GetPlaintextData())); QUICHE_RETURN_IF_ERROR(chunk_handler.OnChunksDone()); @@ -567,14 +703,18 @@ } absl::Status MasqueOhttpClient::ProcessOhttpResponse( - RequestId request_id, const absl::StatusOr<Message>& response) { + RequestId request_id, const absl::StatusOr<Message>& response, + bool end_stream) { auto it = pending_ohttp_requests_.find(request_id); if (it == pending_ohttp_requests_.end()) { return absl::InternalError(absl::StrCat( "Received unexpected response for unknown request ", request_id)); } - auto cleanup = - absl::MakeCleanup([this, it]() { pending_ohttp_requests_.erase(it); }); + auto cleanup = absl::MakeCleanup([this, it, end_stream]() { + if (end_stream) { + pending_ohttp_requests_.erase(it); + } + }); if (!response.ok()) { if (it->second.per_request_config.expected_gateway_error().has_value() && absl::StrContains( @@ -620,13 +760,28 @@ // If we expect a failure status code, skip decapsulation. return absl::OkStatus(); } + + if (!end_stream) { + // We delivered headers. For chunked OHTTP, we just wait for data chunks. + if (it->second.per_request_config.num_ohttp_chunks() <= 0) { + QUICHE_LOG(ERROR) << "Received partial response for non-chunked OHTTP"; + return absl::InternalError( + "Received partial response for non-chunked OHTTP"); + } + return absl::OkStatus(); + } + std::optional<Message> encapsulated_response; QUICHE_VLOG(2) << "Received encrypted response body: " << absl::BytesToHexString(response->body); if (it->second.per_request_config.num_ohttp_chunks() > 0) { - QUICHE_ASSIGN_OR_RETURN( - encapsulated_response, - it->second.chunk_handler->DecryptFullResponse(response->body)); + absl::Status decrypt_status = it->second.chunk_handler->DecryptChunk( + /*encrypted_chunk=*/"", /*end_stream=*/true); + if (!decrypt_status.ok()) { + return decrypt_status; + } + encapsulated_response = + std::move(*it->second.chunk_handler).ExtractResponse(); } else { if (!it->second.context.has_value()) { QUICHE_LOG(FATAL) << "Received OHTTP response without OHTTP context"; @@ -663,12 +818,17 @@ callback) { QUICHE_RETURN_IF_ERROR(callback(encapsulated_response->body)); } + + if (response_visitor_) { + response_visitor_->OnResponseDone(request_id, *encapsulated_response); + } return absl::OkStatus(); } void MasqueOhttpClient::OnPoolResponse(MasqueConnectionPool* /*pool*/, RequestId request_id, - absl::StatusOr<Message>&& response) { + absl::StatusOr<Message>&& response, + bool end_stream) { if (key_fetch_request_id_.has_value() && *key_fetch_request_id_ == request_id) { absl::Status status = HandleKeyResponse(response); @@ -676,29 +836,138 @@ Abort(status); } } else { - absl::Status status = ProcessOhttpResponse(request_id, response); + absl::Status status = + ProcessOhttpResponse(request_id, response, end_stream); if (!status.ok()) { Abort(status); + if (response_visitor_) { + response_visitor_->OnError(request_id, status); + } } } } +void MasqueOhttpClient::OnPoolData(MasqueConnectionPool* /*pool*/, + RequestId request_id, absl::string_view data, + bool end_stream) { + if (key_fetch_request_id_.has_value() && + *key_fetch_request_id_ == request_id) { + Abort(absl::InternalError(absl::StrCat( + "Received data for non-streamed key fetch request ", request_id))); + return; + } + auto it = pending_ohttp_requests_.find(request_id); + if (it == pending_ohttp_requests_.end()) { + Abort(absl::InternalError(absl::StrCat( + "Received data for unknown or non-OHTTP request: ", request_id))); + return; + } + PendingRequest& pending_request = it->second; + + if (!pending_request.chunk_handler) { + Abort(absl::InternalError( + absl::StrCat("Received data for non-streamed request ", request_id))); + return; + } + + auto cleanup = absl::MakeCleanup([this, it, end_stream]() { + if (end_stream) { + pending_ohttp_requests_.erase(it); + } + }); + + absl::Status status = + pending_request.chunk_handler->DecryptChunk(data, end_stream); + if (!status.ok()) { + Abort(status); + if (response_visitor_) { + response_visitor_->OnError(request_id, status); + } + return; + } + if (end_stream && response_visitor_) { + response_visitor_->OnResponseDone( + request_id, + std::move(*pending_request.chunk_handler).ExtractResponse()); + } +} + +absl::Status MasqueOhttpClient::SendBodyChunk(RequestId request_id, + absl::string_view chunk, + bool is_final) { + if (chunk.empty() && !is_final) { + return absl::OkStatus(); + } + auto it = pending_ohttp_requests_.find(request_id); + if (it == pending_ohttp_requests_.end()) { + return absl::NotFoundError( + absl::StrCat("Request ", request_id, " not found")); + } + PendingRequest& pending_request = it->second; + if (!pending_request.encoder.has_value()) { + if (!chunk.empty()) { + return absl::FailedPreconditionError( + "Cannot send non-empty body chunks for known-length requests"); + } + std::string encoded_data; + auto& chunked_client = pending_request.chunk_handler->chunked_client(); + if (!chunked_client.has_value()) { + return absl::FailedPreconditionError("Chunked client not initialized"); + } + QUICHE_ASSIGN_OR_RETURN( + std::string encrypted_chunk, + chunked_client->EncryptRequestChunk(encoded_data, is_final)); + return connection_pool_.SendBodyChunk(request_id, encrypted_chunk, + is_final); + } + + std::string encoded_data; + std::vector<absl::string_view> body_chunks; + if (!chunk.empty()) { + body_chunks.push_back(chunk); + } + + QUICHE_ASSIGN_OR_RETURN( + encoded_data, + pending_request.encoder->EncodeBodyChunks(absl::MakeSpan(body_chunks), + /*body_chunks_done=*/is_final)); + + if (is_final) { + std::vector<quiche::BinaryHttpMessage::FieldView> trailers; + QUICHE_ASSIGN_OR_RETURN( + std::string encoded_trailers, + pending_request.encoder->EncodeTrailers(absl::MakeSpan(trailers))); + encoded_data += encoded_trailers; + } + + auto& chunked_client = pending_request.chunk_handler->chunked_client(); + if (!chunked_client.has_value()) { + return absl::FailedPreconditionError("Chunked client not initialized"); + } + + QUICHE_ASSIGN_OR_RETURN( + std::string encrypted_chunk, + chunked_client->EncryptRequestChunk(encoded_data, is_final)); + + return connection_pool_.SendBodyChunk(request_id, encrypted_chunk, is_final); +} + MasqueOhttpClient::ChunkHandler::ChunkHandler() : decoder_(this) {} -absl::StatusOr<Message> MasqueOhttpClient::ChunkHandler::DecryptFullResponse( - absl::string_view encrypted_response) { +absl::Status MasqueOhttpClient::ChunkHandler::DecryptChunk( + absl::string_view encrypted_chunk, bool end_stream) { if (!chunked_client_.has_value()) { - QUICHE_LOG(FATAL) << "DecryptFullResponse called without a chunked client"; - return absl::InternalError( - "DecryptFullResponse called without a chunked client"); + QUICHE_LOG(FATAL) << "DecryptChunk called without a chunked client"; + return absl::InternalError("DecryptChunk called without a chunked client"); } - QUICHE_RETURN_IF_ERROR(chunked_client_->DecryptResponse(encrypted_response, - /*end_stream=*/true)); - return std::move(response_); + return chunked_client_->DecryptResponse(encrypted_chunk, end_stream); } absl::Status MasqueOhttpClient::ChunkHandler::OnDecryptedChunk( absl::string_view decrypted_chunk) { + decrypted_chunk_count_++; + QUICHE_LOG(INFO) << "Received decrypted chunk #" << decrypted_chunk_count_ + << " of size " << decrypted_chunk.size(); absl::StrAppend(&buffered_binary_response_, decrypted_chunk); if (!is_chunked_response_.has_value()) { quiche::QuicheDataReader reader(buffered_binary_response_); @@ -766,11 +1035,21 @@ return absl::OkStatus(); } absl::Status MasqueOhttpClient::ChunkHandler::OnFinalResponseHeadersDone() { + QUICHE_LOG(INFO) << "Received incremental OHTTP response headers: " + << response_.headers.DebugString(); return absl::OkStatus(); } absl::Status MasqueOhttpClient::ChunkHandler::OnBodyChunk( absl::string_view body_chunk) { + body_chunk_count_++; + QUICHE_LOG(INFO) << "Received body chunk #" << body_chunk_count_ + << " of size " << body_chunk.size(); response_.body += body_chunk; + + if (response_chunk_callback_) { + response_chunk_callback_(body_chunk); + } + return absl::OkStatus(); } absl::Status MasqueOhttpClient::ChunkHandler::OnBodyChunksDone() {
diff --git a/quiche/quic/masque/masque_ohttp_client.h b/quiche/quic/masque/masque_ohttp_client.h index 0d7ce97..69669a8 100644 --- a/quiche/quic/masque/masque_ohttp_client.h +++ b/quiche/quic/masque/masque_ohttp_client.h
@@ -55,6 +55,9 @@ absl::Status AddPrivateToken(const std::string& private_token); void SetNumBhttpChunks(int num_chunks) { num_bhttp_chunks_ = num_chunks; } void SetNumOhttpChunks(int num_chunks) { num_ohttp_chunks_ = num_chunks; } + void SetPingPongMode(bool ping_pong_mode) { + ping_pong_mode_ = ping_pong_mode; + } void SetExpectedGatewayError(const std::string& expected_gateway_error) { expected_gateway_error_ = expected_gateway_error; } @@ -81,6 +84,7 @@ } int num_bhttp_chunks() const { return num_bhttp_chunks_; } int num_ohttp_chunks() const { return num_ohttp_chunks_; } + bool ping_pong_mode() const { return ping_pong_mode_; } std::optional<std::string> expected_gateway_error() const { return expected_gateway_error_; } @@ -103,6 +107,7 @@ std::vector<std::pair<std::string, std::string>> outer_headers_; int num_ohttp_chunks_ = 0; int num_bhttp_chunks_ = -1; + bool ping_pong_mode_ = false; std::optional<std::string> expected_gateway_error_; std::optional<uint16_t> expected_gateway_status_code_; std::optional<uint16_t> expected_encapsulated_status_code_; @@ -172,14 +177,36 @@ std::vector<PerRequestConfig> per_request_configs_; }; + class ResponseVisitor { + public: + virtual ~ResponseVisitor() = default; + virtual void OnRequestStarted(RequestId request_id, + MasqueOhttpClient* client) = 0; + virtual void OnResponseChunk(RequestId request_id, + absl::string_view chunk) = 0; + virtual void OnResponseDone(RequestId request_id, + const Message& response) = 0; + virtual void OnError(RequestId request_id, absl::Status status) = 0; + }; + + void set_response_visitor(ResponseVisitor* visitor) { + response_visitor_ = visitor; + } + // Starts by fetching the HPKE keys and then runs the client until all // requests are complete or aborted. static absl::Status Run(Config config); + // Sends a body chunk for a chunked OHTTP request. + absl::Status SendBodyChunk(RequestId request_id, absl::string_view chunk, + bool is_final); + // From quic::MasqueConnectionPool::Visitor. void OnPoolResponse(quic::MasqueConnectionPool* /*pool*/, - RequestId request_id, - absl::StatusOr<Message>&& response) override; + RequestId request_id, absl::StatusOr<Message>&& response, + bool end_stream) override; + void OnPoolData(quic::MasqueConnectionPool* /*pool*/, RequestId request_id, + absl::string_view data, bool end_stream) override; private: // Fetch key from the key URL. @@ -203,17 +230,26 @@ public quiche::BinaryHttpResponse::IndeterminateLengthDecoder:: MessageSectionHandler { public: + using ResponseChunkCallback = std::function<void(absl::string_view)>; + explicit ChunkHandler(); + void SetResponseChunkCallback(ResponseChunkCallback callback) { + response_chunk_callback_ = std::move(callback); + } // Neither copyable nor movable to ensure pointer stability as required for // quiche::ObliviousHttpChunkHandler. ChunkHandler(const ChunkHandler& other) = delete; ChunkHandler& operator=(const ChunkHandler& other) = delete; + + std::optional<quiche::ChunkedObliviousHttpClient>& chunked_client() { + return chunked_client_; + } ChunkHandler(ChunkHandler&& other) = delete; ChunkHandler& operator=(ChunkHandler&& other) = delete; - // Decrypts the full chunked response and returns the encapsulated response. - absl::StatusOr<Message> DecryptFullResponse( - absl::string_view encrypted_response); + // Decrypts a response chunk. + absl::Status DecryptChunk(absl::string_view encrypted_chunk, + bool end_stream); void SetChunkedClient(quiche::ChunkedObliviousHttpClient chunked_client) { chunked_client_.emplace(std::move(chunked_client)); @@ -244,11 +280,14 @@ absl::Status OnTrailersDone() override; private: + ResponseChunkCallback response_chunk_callback_; std::optional<quiche::ChunkedObliviousHttpClient> chunked_client_; quiche::BinaryHttpResponse::IndeterminateLengthDecoder decoder_; Message response_; std::string buffered_binary_response_; std::optional<bool> is_chunked_response_; + size_t decrypted_chunk_count_ = 0; + size_t body_chunk_count_ = 0; }; struct PendingRequest { @@ -262,6 +301,8 @@ // std::unique_ptr to ensure pointer stability since this object is used as // a callback target. std::unique_ptr<ChunkHandler> chunk_handler; + std::optional<quiche::BinaryHttpRequest::IndeterminateLengthEncoder> + encoder; }; explicit MasqueOhttpClient(Config config, quic::QuicEventLoop* event_loop); @@ -276,7 +317,8 @@ RequestId request_id, quiche::ObliviousHttpRequest::Context& context, const Message& response); absl::Status ProcessOhttpResponse(RequestId request_id, - const absl::StatusOr<Message>& response); + const absl::StatusOr<Message>& response, + bool end_stream); absl::Status CheckStatusAndContentType( const Message& response, const std::string& content_type, std::optional<uint16_t> expected_status_code); @@ -288,6 +330,7 @@ std::optional<quiche::ObliviousHttpClient> ohttp_client_; quic::QuicUrl relay_url_; absl::flat_hash_map<RequestId, PendingRequest> pending_ohttp_requests_; + ResponseVisitor* response_visitor_ = nullptr; }; } // namespace quic
diff --git a/quiche/quic/masque/masque_ohttp_client_bin.cc b/quiche/quic/masque/masque_ohttp_client_bin.cc index 071cc51..b82a1b8 100644 --- a/quiche/quic/masque/masque_ohttp_client_bin.cc +++ b/quiche/quic/masque/masque_ohttp_client_bin.cc
@@ -67,6 +67,11 @@ "or if set to 0, the client will use standard non-chunked OHTTP."); DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, ping_pong_mode, false, + "If true, enables ping-pong mode for chunked OHTTP requests. Limitations: " + "num_ohttp_chunks must be > 0 and there can only be one request."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( std::vector<std::string>, header, {}, "Adds a header field to the encapsulated binary request. Separate the " "header name and value with a colon. Can be specified multiple times."); @@ -131,6 +136,8 @@ quiche::GetQuicheCommandLineFlag(FLAGS_expect_gateway_error); const std::optional<int16_t> expect_gateway_response_code = quiche::GetQuicheCommandLineFlag(FLAGS_expect_gateway_response_code); + const bool ping_pong_mode = + quiche::GetQuicheCommandLineFlag(FLAGS_ping_pong_mode); MasqueConnectionPool::DnsConfig dns_config; QUICHE_RETURN_IF_ERROR(dns_config.SetAddressFamily( @@ -229,6 +236,7 @@ } per_request_config.SetNumOhttpChunks(num_ohttp_chunks); per_request_config.SetNumBhttpChunks(num_bhttp_chunks); + per_request_config.SetPingPongMode(ping_pong_mode); if (expect_gateway_error.has_value()) { per_request_config.SetExpectedGatewayError(*expect_gateway_error); }
diff --git a/quiche/quic/masque/masque_tcp_client_bin.cc b/quiche/quic/masque/masque_tcp_client_bin.cc index cdcb074..97a1ee4 100644 --- a/quiche/quic/masque/masque_tcp_client_bin.cc +++ b/quiche/quic/masque/masque_tcp_client_bin.cc
@@ -373,7 +373,7 @@ void OnResponse(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, - const std::string& body) override { + const std::string& body, bool end_stream) override { if (connection != h2_connection_.get()) { QUICHE_LOG(FATAL) << "Unexpected connection"; } @@ -381,8 +381,11 @@ QUICHE_LOG(FATAL) << "Unexpected stream id"; } QUICHE_LOG(INFO) << "Received h2 response headers: " - << headers.DebugString() << " body: " << body; - done_ = true; + << headers.DebugString() << " body: " << body + << " end_stream: " << end_stream; + if (end_stream) { + done_ = true; + } } void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id, @@ -398,6 +401,21 @@ done_ = true; } + void OnDataForStream(MasqueH2Connection* connection, int32_t stream_id, + absl::string_view data, bool end_stream) override { + if (connection != h2_connection_.get()) { + QUICHE_LOG(FATAL) << "Unexpected connection"; + } + if (stream_id != stream_id_) { + QUICHE_LOG(FATAL) << "Unexpected stream id"; + } + QUICHE_LOG(INFO) << "Received h2 response data: " << data + << " end_stream: " << end_stream; + if (end_stream) { + done_ = true; + } + } + private: void MaybeSendRequest() { if (request_sent_ || done_ || !tls_connected_) {
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc index a7c42a8..396d404 100644 --- a/quiche/quic/masque/masque_tcp_server_bin.cc +++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -683,7 +683,7 @@ void OnResponse(MasqueH2Connection* /*connection*/, int32_t /*stream_id*/, const quiche::HttpHeaderBlock& /*headers*/, - const std::string& /*body*/) override { + const std::string& /*body*/, bool /*end_stream*/) override { QUICHE_LOG(FATAL) << "Server cannot receive responses"; } @@ -692,9 +692,20 @@ QUICHE_LOG(ERROR) << "Stream " << stream_id << " failed: " << error; } + void OnDataForStream(MasqueH2Connection* /*connection*/, + int32_t /*stream_id*/, absl::string_view /*data*/, + bool /*end_stream*/) override { + QUICHE_LOG(FATAL) << "MasqueTcpServer does not request streamed responses"; + } + // From MasqueConnectionPool::Visitor. void OnPoolResponse(MasqueConnectionPool* /*pool*/, RequestId request_id, - absl::StatusOr<Message>&& response) override { + absl::StatusOr<Message>&& response, + bool end_stream) override { + if (!end_stream) { + QUICHE_LOG(FATAL) + << "MasqueTcpServer does not request streamed responses"; + } auto it = pending_requests_.find(request_id); if (it == pending_requests_.end()) { QUICHE_LOG(ERROR) << "Received unexpected response for unknown request " @@ -739,6 +750,11 @@ pending_request.connection->AttemptToSend(); } + void OnPoolData(MasqueConnectionPool* /*pool*/, RequestId /*request_id*/, + absl::string_view /*data*/, bool /*end_stream*/) override { + QUICHE_LOG(FATAL) << "Server received unexpected pool data"; + } + bool SetupGateway(const std::string& gateway_path, MasqueOhttpGateway* gateway) { if (gateway_path.empty() != (gateway == nullptr)) {