Prevent timeout in OHTTP test client when the stream gets reset PiperOrigin-RevId: 866890021
diff --git a/quiche/quic/masque/masque_connection_pool.cc b/quiche/quic/masque/masque_connection_pool.cc index b6bb8ef..ebf8944 100644 --- a/quiche/quic/masque/masque_connection_pool.cc +++ b/quiche/quic/masque/masque_connection_pool.cc
@@ -197,6 +197,22 @@ } } +void MasqueConnectionPool::OnStreamFailure(MasqueH2Connection* connection, + int32_t stream_id, + absl::Status error) { + for (auto it = pending_requests_.begin(); it != pending_requests_.end();) { + RequestId request_id = it->first; + PendingRequest& pending_request = *it->second; + if (pending_request.connection == connection && + pending_request.stream_id == stream_id) { + pending_requests_.erase(it++); + visitor_->OnPoolResponse(this, request_id, error); + break; + } + ++it; + } +} + absl::StatusOr<MasqueConnectionPool::RequestId> MasqueConnectionPool::SendRequest(const Message& request, bool mtls) { auto authority = request.headers.find(":authority");
diff --git a/quiche/quic/masque/masque_connection_pool.h b/quiche/quic/masque/masque_connection_pool.h index 5ac335d..f1a2931 100644 --- a/quiche/quic/masque/masque_connection_pool.h +++ b/quiche/quic/masque/masque_connection_pool.h
@@ -107,6 +107,8 @@ void OnResponse(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, const std::string& body) override; + void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id, + absl::Status error) override; static absl::StatusOr<bssl::UniquePtr<SSL_CTX>> CreateSslCtx( const std::string& client_cert_file,
diff --git a/quiche/quic/masque/masque_h2_connection.cc b/quiche/quic/masque/masque_h2_connection.cc index ba38888..d7b037d 100644 --- a/quiche/quic/masque/masque_h2_connection.cc +++ b/quiche/quic/masque/masque_h2_connection.cc
@@ -409,12 +409,15 @@ << " body length: " << stream->received_body.size(); QUICHE_DVLOG(2) << ENDPOINT << "Body: " << std::endl << quiche::QuicheTextUtils::HexDump(stream->received_body); - if (is_server_) { - visitor_->OnRequest(this, stream_id, stream->received_headers, - stream->received_body); - } else { - visitor_->OnResponse(this, stream_id, stream->received_headers, - stream->received_body); + if (!stream->callback_fired) { + stream->callback_fired = true; + if (is_server_) { + visitor_->OnRequest(this, stream_id, stream->received_headers, + stream->received_body); + } else { + visitor_->OnResponse(this, stream_id, stream->received_headers, + stream->received_body); + } } return true; } @@ -424,6 +427,17 @@ QUICHE_LOG(INFO) << ENDPOINT << "Stream " << stream_id << " reset with error code " << Http2ErrorCodeToString(error_code); + auto it = h2_streams_.find(stream_id); + if (it != h2_streams_.end()) { + if (!it->second->callback_fired) { + it->second->callback_fired = true; + visitor_->OnStreamFailure( + this, stream_id, + absl::InvalidArgumentError( + absl::StrCat("Stream ", stream_id, " reset with error code ", + Http2ErrorCodeToString(error_code)))); + } + } } bool MasqueH2Connection::OnCloseStream(Http2StreamId stream_id, @@ -431,7 +445,17 @@ QUICHE_LOG(INFO) << ENDPOINT << "Stream " << stream_id << " closed with error code " << Http2ErrorCodeToString(error_code); - h2_streams_.erase(stream_id); + auto it = h2_streams_.find(stream_id); + if (it != h2_streams_.end()) { + if (!it->second->callback_fired) { + visitor_->OnStreamFailure( + this, stream_id, + absl::InternalError( + absl::StrCat("Stream ", stream_id, " closed with error code ", + Http2ErrorCodeToString(error_code)))); + } + h2_streams_.erase(it); + } return true; }
diff --git a/quiche/quic/masque/masque_h2_connection.h b/quiche/quic/masque/masque_h2_connection.h index bd16403..ac4ee61 100644 --- a/quiche/quic/masque/masque_h2_connection.h +++ b/quiche/quic/masque/masque_h2_connection.h
@@ -47,6 +47,8 @@ virtual void OnResponse(MasqueH2Connection* connection, int32_t stream_id, const quiche::HttpHeaderBlock& headers, const std::string& body) = 0; + virtual void OnStreamFailure(MasqueH2Connection* connection, + int32_t stream_id, absl::Status error) = 0; }; // `ssl` and `visitor` must outlive this object. @@ -74,6 +76,7 @@ quiche::HttpHeaderBlock received_headers; std::string received_body; std::string body_to_send; + bool callback_fired = false; }; static constexpr size_t kBioBufferSize = 16384; void Abort(absl::Status error);
diff --git a/quiche/quic/masque/masque_tcp_client_bin.cc b/quiche/quic/masque/masque_tcp_client_bin.cc index 8ecf710..cdcb074 100644 --- a/quiche/quic/masque/masque_tcp_client_bin.cc +++ b/quiche/quic/masque/masque_tcp_client_bin.cc
@@ -385,6 +385,19 @@ done_ = true; } + void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id, + absl::Status error) override { + if (connection != h2_connection_.get()) { + QUICHE_LOG(FATAL) << "Unexpected connection"; + } + if (stream_id != stream_id_) { + QUICHE_LOG(FATAL) << "Unexpected stream id"; + } + QUICHE_LOG(ERROR) << "Stream " << stream_id + << " failed: " << error.message(); + done_ = true; + } + private: void MaybeSendRequest() { if (request_sent_ || done_ || !tls_connected_) {
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc index 9307e17..a7c42a8 100644 --- a/quiche/quic/masque/masque_tcp_server_bin.cc +++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -687,6 +687,11 @@ QUICHE_LOG(FATAL) << "Server cannot receive responses"; } + void OnStreamFailure(MasqueH2Connection* /*connection*/, int32_t stream_id, + absl::Status error) override { + QUICHE_LOG(ERROR) << "Stream " << stream_id << " failed: " << error; + } + // From MasqueConnectionPool::Visitor. void OnPoolResponse(MasqueConnectionPool* /*pool*/, RequestId request_id, absl::StatusOr<Message>&& response) override {