Prevent timeout in OHTTP test client when the stream gets reset

PiperOrigin-RevId: 866890021
diff --git a/quiche/quic/masque/masque_connection_pool.cc b/quiche/quic/masque/masque_connection_pool.cc
index b6bb8ef..ebf8944 100644
--- a/quiche/quic/masque/masque_connection_pool.cc
+++ b/quiche/quic/masque/masque_connection_pool.cc
@@ -197,6 +197,22 @@
   }
 }
 
+void MasqueConnectionPool::OnStreamFailure(MasqueH2Connection* connection,
+                                           int32_t stream_id,
+                                           absl::Status error) {
+  for (auto it = pending_requests_.begin(); it != pending_requests_.end();) {
+    RequestId request_id = it->first;
+    PendingRequest& pending_request = *it->second;
+    if (pending_request.connection == connection &&
+        pending_request.stream_id == stream_id) {
+      pending_requests_.erase(it++);
+      visitor_->OnPoolResponse(this, request_id, error);
+      break;
+    }
+    ++it;
+  }
+}
+
 absl::StatusOr<MasqueConnectionPool::RequestId>
 MasqueConnectionPool::SendRequest(const Message& request, bool mtls) {
   auto authority = request.headers.find(":authority");
diff --git a/quiche/quic/masque/masque_connection_pool.h b/quiche/quic/masque/masque_connection_pool.h
index 5ac335d..f1a2931 100644
--- a/quiche/quic/masque/masque_connection_pool.h
+++ b/quiche/quic/masque/masque_connection_pool.h
@@ -107,6 +107,8 @@
   void OnResponse(MasqueH2Connection* connection, int32_t stream_id,
                   const quiche::HttpHeaderBlock& headers,
                   const std::string& body) override;
+  void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id,
+                       absl::Status error) override;
 
   static absl::StatusOr<bssl::UniquePtr<SSL_CTX>> CreateSslCtx(
       const std::string& client_cert_file,
diff --git a/quiche/quic/masque/masque_h2_connection.cc b/quiche/quic/masque/masque_h2_connection.cc
index ba38888..d7b037d 100644
--- a/quiche/quic/masque/masque_h2_connection.cc
+++ b/quiche/quic/masque/masque_h2_connection.cc
@@ -409,12 +409,15 @@
                    << " body length: " << stream->received_body.size();
   QUICHE_DVLOG(2) << ENDPOINT << "Body: " << std::endl
                   << quiche::QuicheTextUtils::HexDump(stream->received_body);
-  if (is_server_) {
-    visitor_->OnRequest(this, stream_id, stream->received_headers,
-                        stream->received_body);
-  } else {
-    visitor_->OnResponse(this, stream_id, stream->received_headers,
-                         stream->received_body);
+  if (!stream->callback_fired) {
+    stream->callback_fired = true;
+    if (is_server_) {
+      visitor_->OnRequest(this, stream_id, stream->received_headers,
+                          stream->received_body);
+    } else {
+      visitor_->OnResponse(this, stream_id, stream->received_headers,
+                           stream->received_body);
+    }
   }
   return true;
 }
@@ -424,6 +427,17 @@
   QUICHE_LOG(INFO) << ENDPOINT << "Stream " << stream_id
                    << " reset with error code "
                    << Http2ErrorCodeToString(error_code);
+  auto it = h2_streams_.find(stream_id);
+  if (it != h2_streams_.end()) {
+    if (!it->second->callback_fired) {
+      it->second->callback_fired = true;
+      visitor_->OnStreamFailure(
+          this, stream_id,
+          absl::InvalidArgumentError(
+              absl::StrCat("Stream ", stream_id, " reset with error code ",
+                           Http2ErrorCodeToString(error_code))));
+    }
+  }
 }
 
 bool MasqueH2Connection::OnCloseStream(Http2StreamId stream_id,
@@ -431,7 +445,17 @@
   QUICHE_LOG(INFO) << ENDPOINT << "Stream " << stream_id
                    << " closed with error code "
                    << Http2ErrorCodeToString(error_code);
-  h2_streams_.erase(stream_id);
+  auto it = h2_streams_.find(stream_id);
+  if (it != h2_streams_.end()) {
+    if (!it->second->callback_fired) {
+      visitor_->OnStreamFailure(
+          this, stream_id,
+          absl::InternalError(
+              absl::StrCat("Stream ", stream_id, " closed with error code ",
+                           Http2ErrorCodeToString(error_code))));
+    }
+    h2_streams_.erase(it);
+  }
   return true;
 }
 
diff --git a/quiche/quic/masque/masque_h2_connection.h b/quiche/quic/masque/masque_h2_connection.h
index bd16403..ac4ee61 100644
--- a/quiche/quic/masque/masque_h2_connection.h
+++ b/quiche/quic/masque/masque_h2_connection.h
@@ -47,6 +47,8 @@
     virtual void OnResponse(MasqueH2Connection* connection, int32_t stream_id,
                             const quiche::HttpHeaderBlock& headers,
                             const std::string& body) = 0;
+    virtual void OnStreamFailure(MasqueH2Connection* connection,
+                                 int32_t stream_id, absl::Status error) = 0;
   };
 
   // `ssl` and `visitor` must outlive this object.
@@ -74,6 +76,7 @@
     quiche::HttpHeaderBlock received_headers;
     std::string received_body;
     std::string body_to_send;
+    bool callback_fired = false;
   };
   static constexpr size_t kBioBufferSize = 16384;
   void Abort(absl::Status error);
diff --git a/quiche/quic/masque/masque_tcp_client_bin.cc b/quiche/quic/masque/masque_tcp_client_bin.cc
index 8ecf710..cdcb074 100644
--- a/quiche/quic/masque/masque_tcp_client_bin.cc
+++ b/quiche/quic/masque/masque_tcp_client_bin.cc
@@ -385,6 +385,19 @@
     done_ = true;
   }
 
+  void OnStreamFailure(MasqueH2Connection* connection, int32_t stream_id,
+                       absl::Status error) override {
+    if (connection != h2_connection_.get()) {
+      QUICHE_LOG(FATAL) << "Unexpected connection";
+    }
+    if (stream_id != stream_id_) {
+      QUICHE_LOG(FATAL) << "Unexpected stream id";
+    }
+    QUICHE_LOG(ERROR) << "Stream " << stream_id
+                      << " failed: " << error.message();
+    done_ = true;
+  }
+
  private:
   void MaybeSendRequest() {
     if (request_sent_ || done_ || !tls_connected_) {
diff --git a/quiche/quic/masque/masque_tcp_server_bin.cc b/quiche/quic/masque/masque_tcp_server_bin.cc
index 9307e17..a7c42a8 100644
--- a/quiche/quic/masque/masque_tcp_server_bin.cc
+++ b/quiche/quic/masque/masque_tcp_server_bin.cc
@@ -687,6 +687,11 @@
     QUICHE_LOG(FATAL) << "Server cannot receive responses";
   }
 
+  void OnStreamFailure(MasqueH2Connection* /*connection*/, int32_t stream_id,
+                       absl::Status error) override {
+    QUICHE_LOG(ERROR) << "Stream " << stream_id << " failed: " << error;
+  }
+
   // From MasqueConnectionPool::Visitor.
   void OnPoolResponse(MasqueConnectionPool* /*pool*/, RequestId request_id,
                       absl::StatusOr<Message>&& response) override {