Consolidates stream close behavior in a smaller number of places.

This change also adds a map counting queued frames per stream, so the library does not accidentally close a stream before sending all of the queued frames.

PiperOrigin-RevId: 410097334
diff --git a/http2/adapter/oghttp2_adapter_test.cc b/http2/adapter/oghttp2_adapter_test.cc
index 09f9d0c..78e3f94 100644
--- a/http2/adapter/oghttp2_adapter_test.cc
+++ b/http2/adapter/oghttp2_adapter_test.cc
@@ -1273,7 +1273,6 @@
   EXPECT_CALL(http2_visitor_, OnFrameSent(PRIORITY, 3, _, 0x0, 0));
   EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0));
   EXPECT_CALL(http2_visitor_, OnFrameSent(RST_STREAM, 3, _, 0x0, 0x8));
-  EXPECT_CALL(http2_visitor_, OnCloseStream(3, Http2ErrorCode::NO_ERROR));
   EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(PING, 0, _, 0x0));
   EXPECT_CALL(http2_visitor_, OnFrameSent(PING, 0, _, 0x0, 0));
   EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(GOAWAY, 0, _, 0x0));
@@ -1525,9 +1524,6 @@
   const int64_t result = adapter->ProcessBytes(frames);
   EXPECT_EQ(frames.size(), static_cast<size_t>(result));
 
-  // BUG: OnCloseStream() should be invoked after OnFrameSent() for the response
-  // headers.
-  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
   int submit_result =
       adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr);
   EXPECT_EQ(submit_result, 0);
@@ -1539,6 +1535,7 @@
   EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0));
   EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5));
   EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0));
+  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
 
   int send_result = adapter->Send();
   EXPECT_EQ(0, send_result);
@@ -1665,7 +1662,6 @@
   EXPECT_FALSE(adapter->want_write());
 
   // The body source has been exhausted by the call to Send() above.
-  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
   int trailer_result =
       adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}}));
   ASSERT_EQ(trailer_result, 0);
@@ -1673,6 +1669,7 @@
 
   EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5));
   EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0));
+  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
 
   send_result = adapter->Send();
   EXPECT_EQ(0, send_result);
@@ -1777,9 +1774,6 @@
 
   EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5));
   EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0));
-  EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0));
-  EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 0));
-  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
 
   send_result = adapter->Send();
   EXPECT_EQ(0, send_result);
diff --git a/http2/adapter/oghttp2_session.cc b/http2/adapter/oghttp2_session.cc
index d9bfd1d..5ab95a6 100644
--- a/http2/adapter/oghttp2_session.cc
+++ b/http2/adapter/oghttp2_session.cc
@@ -122,7 +122,12 @@
 
 class RunOnExit {
  public:
+  RunOnExit() = default;
   explicit RunOnExit(std::function<void()> f) : f_(std::move(f)) {}
+
+  RunOnExit(RunOnExit&& other) = default;
+  RunOnExit& operator=(RunOnExit&& other) = default;
+
   ~RunOnExit() {
     if (f_) {
       f_();
@@ -393,6 +398,12 @@
       streams_reset_.insert(frame->stream_id());
     }
   }
+  if (frame->stream_id() != 0) {
+    auto result = queued_frames_.insert({frame->stream_id(), 1});
+    if (!result.second) {
+      ++(result.first->second);
+    }
+  }
   frames_.push_back(std::move(frame));
 }
 
@@ -470,14 +481,8 @@
       // Write blocked.
       return SendResult::SEND_BLOCKED;
     } else {
-      visitor_.OnFrameSent(c.frame_type(), c.stream_id(), frame_payload_length,
-                           c.flags(), c.error_code());
-      if (static_cast<FrameType>(c.frame_type()) == FrameType::RST_STREAM) {
-        // If this endpoint is resetting the stream, the stream should be
-        // closed. This endpoint is already aware of the outbound RST_STREAM and
-        // its error code, so close with NO_ERROR.
-        CloseStream(c.stream_id(), Http2ErrorCode::NO_ERROR);
-      }
+      AfterFrameSent(c.frame_type(), c.stream_id(), frame_payload_length,
+                     c.flags(), c.error_code());
 
       frames_.pop_front();
       if (static_cast<size_t>(result) < frame.size()) {
@@ -490,6 +495,24 @@
   return SendResult::SEND_OK;
 }
 
+void OgHttp2Session::AfterFrameSent(uint8_t frame_type, uint32_t stream_id,
+                                    size_t payload_length, uint8_t flags,
+                                    uint32_t error_code) {
+  visitor_.OnFrameSent(frame_type, stream_id, payload_length, flags,
+                       error_code);
+  if (stream_id == 0) {
+    return;
+  }
+  auto iter = queued_frames_.find(stream_id);
+  if (frame_type != 0) {
+    --iter->second;
+  }
+  if (iter->second == 0) {
+    // TODO(birenroy): Consider passing through `error_code` here.
+    CloseStreamIfReady(frame_type, stream_id);
+  }
+}
+
 OgHttp2Session::SendResult OgHttp2Session::WriteForStream(
     Http2StreamId stream_id) {
   auto it = stream_map_.find(stream_id);
@@ -512,7 +535,6 @@
         QUICHE_LOG(ERROR) << "Sent fin; can't send trailers.";
       } else {
         SendTrailers(stream_id, std::move(*block_ptr));
-        MaybeCloseWithRstStream(stream_id, state);
       }
     }
     return SendResult::SEND_OK;
@@ -550,31 +572,29 @@
         connection_can_write = SendResult::SEND_BLOCKED;
         break;
       }
-      visitor_.OnFrameSent(/* DATA */ 0, stream_id, length, fin ? 0x1 : 0x0, 0);
       connection_send_window_ -= length;
       state.send_window -= length;
       available_window = std::min({connection_send_window_, state.send_window,
                                    static_cast<int32_t>(max_frame_payload_)});
+      if (fin) {
+        state.half_closed_local = true;
+      }
+      AfterFrameSent(/* DATA */ 0, stream_id, length, fin ? 0x1 : 0x0, 0);
+      if (!stream_map_.contains(stream_id)) {
+        // Note: the stream may have been closed if `fin` is true.
+        break;
+      }
     }
     if (end_data) {
-      bool sent_trailers = false;
       if (state.trailers != nullptr) {
         auto block_ptr = std::move(state.trailers);
         if (fin) {
           QUICHE_LOG(ERROR) << "Sent fin; can't send trailers.";
         } else {
           SendTrailers(stream_id, std::move(*block_ptr));
-          sent_trailers = true;
         }
       }
       state.outbound_body = nullptr;
-      if (fin || sent_trailers) {
-        state.half_closed_local = true;
-        if (MaybeCloseWithRstStream(stream_id, state)) {
-          // No more work on the stream; it has been closed.
-          break;
-        }
-      }
     }
   }
   // If the stream still exists and has data to send, it should be marked as
@@ -648,11 +668,7 @@
     return -501;  // NGHTTP2_ERR_INVALID_ARGUMENT
   }
   const bool end_stream = data_source == nullptr;
-  if (end_stream) {
-    if (iter->second.half_closed_remote) {
-      CloseStream(stream_id, Http2ErrorCode::NO_ERROR);
-    }
-  } else {
+  if (!end_stream) {
     // Add data source to stream state
     iter->second.outbound_body = std::move(data_source);
     write_scheduler_.MarkStreamReady(stream_id, false);
@@ -682,7 +698,6 @@
   if (state.outbound_body == nullptr) {
     // Enqueue trailers immediately.
     SendTrailers(stream_id, ToHeaderBlock(trailers));
-    MaybeCloseWithRstStream(stream_id, state);
   } else {
     QUICHE_LOG_IF(ERROR, state.outbound_body->send_fin())
         << "DataFrameSource will send fin, preventing trailers!";
@@ -771,6 +786,8 @@
       options_.perspective == Perspective::kClient) {
     // From the client's perspective, the stream can be closed if it's already
     // half_closed_local.
+    // TODO(birenroy): consider whether there are outbound frames queued for the
+    // stream.
     CloseStream(stream_id, Http2ErrorCode::NO_ERROR);
   }
 }
@@ -828,6 +845,8 @@
     return;
   }
   visitor_.OnRstStream(stream_id, TranslateErrorCode(error_code));
+  // TODO(birenroy): Consider whether there are outbound frames queued for the
+  // stream.
   CloseStream(stream_id, TranslateErrorCode(error_code));
 }
 
@@ -1070,23 +1089,6 @@
   EnqueueFrame(std::move(frame));
 }
 
-bool OgHttp2Session::MaybeCloseWithRstStream(Http2StreamId stream_id,
-                                             StreamState& state) {
-  if (options_.perspective == Perspective::kServer) {
-    if (state.half_closed_remote) {
-      CloseStream(stream_id, Http2ErrorCode::NO_ERROR);
-      return true;
-    } else {
-      // Since the peer has not yet ended the stream, this endpoint should
-      // send a RST_STREAM NO_ERROR. See RFC 7540 Section 8.1.
-      EnqueueFrame(absl::make_unique<spdy::SpdyRstStreamIR>(
-          stream_id, spdy::SpdyErrorCode::ERROR_CODE_NO_ERROR));
-      // Sending the RST_STREAM also invokes OnCloseStream.
-    }
-  }
-  return false;
-}
-
 void OgHttp2Session::MarkDataBuffered(Http2StreamId stream_id, size_t bytes) {
   connection_window_manager_.MarkDataBuffered(bytes);
   auto it = stream_map_.find(stream_id);
@@ -1181,5 +1183,18 @@
   }
 }
 
+void OgHttp2Session::CloseStreamIfReady(uint8_t frame_type,
+                                        uint32_t stream_id) {
+  auto iter = stream_map_.find(stream_id);
+  if (iter == stream_map_.end()) {
+    return;
+  }
+  const StreamState& state = iter->second;
+  if (static_cast<FrameType>(frame_type) == FrameType::RST_STREAM ||
+      (state.half_closed_local && state.half_closed_remote)) {
+    CloseStream(stream_id, Http2ErrorCode::NO_ERROR);
+  }
+}
+
 }  // namespace adapter
 }  // namespace http2
diff --git a/http2/adapter/oghttp2_session.h b/http2/adapter/oghttp2_session.h
index 88cdb7d..a3a0232 100644
--- a/http2/adapter/oghttp2_session.h
+++ b/http2/adapter/oghttp2_session.h
@@ -261,6 +261,10 @@
   // Serializes and sends queued frames.
   SendResult SendQueuedFrames();
 
+  void AfterFrameSent(uint8_t frame_type, uint32_t stream_id,
+                      size_t payload_length, uint8_t flags,
+                      uint32_t error_code);
+
   // Writes DATA frames for stream `stream_id`.
   SendResult WriteForStream(Http2StreamId stream_id);
 
@@ -271,10 +275,6 @@
 
   void SendTrailers(Http2StreamId stream_id, spdy::SpdyHeaderBlock trailers);
 
-  // Encapsulates the RST_STREAM NO_ERROR behavior described in RFC 7540
-  // Section 8.1. Returns true if the stream is closed.
-  bool MaybeCloseWithRstStream(Http2StreamId stream_id, StreamState& state);
-
   // Performs flow control accounting for data sent by the peer.
   void MarkDataBuffered(Http2StreamId stream_id, size_t bytes);
 
@@ -302,6 +302,8 @@
   void LatchErrorAndNotify(Http2ErrorCode error_code,
                            Http2VisitorInterface::ConnectionError error);
 
+  void CloseStreamIfReady(uint8_t frame_type, uint32_t stream_id);
+
   // Receives events when inbound frames are parsed.
   Http2VisitorInterface& visitor_;
 
@@ -347,6 +349,7 @@
   WindowManager connection_window_manager_;
 
   absl::flat_hash_set<Http2StreamId> streams_reset_;
+  absl::flat_hash_map<Http2StreamId, int> queued_frames_;
 
   MetadataSequence connection_metadata_;
 
diff --git a/http2/adapter/oghttp2_session_test.cc b/http2/adapter/oghttp2_session_test.cc
index b4b464b..1f604ee 100644
--- a/http2/adapter/oghttp2_session_test.cc
+++ b/http2/adapter/oghttp2_session_test.cc
@@ -886,8 +886,6 @@
   EXPECT_FALSE(session.want_write());
 
   // The body source has been exhausted by the call to Send() above.
-  // TODO(birenroy): Fix this strange ordering.
-  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
   int trailer_result = session.SubmitTrailer(
       1, ToHeaders({{"final-status", "a-ok"},
                     {"x-comment", "trailers sure are cool"}}));
@@ -896,6 +894,7 @@
 
   EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5));
   EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0));
+  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
 
   send_result = session.Send();
   EXPECT_EQ(0, send_result);
@@ -977,11 +976,9 @@
   EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0));
   EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0));
 
-  // TODO(birenroy): Fix this strange ordering.
-  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
-
   EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5));
   EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0));
+  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::NO_ERROR));
 
   send_result = session.Send();
   EXPECT_EQ(0, send_result);