Refactor QuicSession to allow subclasses to do their own thing on pending streams.

gfe-relnote: code refactor and v99 only. Not used in production.
PiperOrigin-RevId: 248015584
Change-Id: Ie0ac855070b304bd90a8e68392cff3cf4c5ac52a
diff --git a/quic/core/http/quic_spdy_session.cc b/quic/core/http/quic_spdy_session.cc
index 163b890..1bcf9cc 100644
--- a/quic/core/http/quic_spdy_session.cc
+++ b/quic/core/http/quic_spdy_session.cc
@@ -558,9 +558,9 @@
   return GetNumOpenDynamicStreams() > 0;
 }
 
-bool QuicSpdySession::ShouldBufferIncomingStream(QuicStreamId id) const {
-  DCHECK_EQ(QUIC_VERSION_99, connection()->transport_version());
-  return !QuicUtils::IsBidirectionalStreamId(id);
+bool QuicSpdySession::UsesPendingStreams() const {
+  DCHECK(VersionHasControlStreams(connection()->transport_version()));
+  return true;
 }
 
 size_t QuicSpdySession::WriteHeadersOnHeadersStreamImpl(
@@ -724,7 +724,7 @@
   return false;
 }
 
-void QuicSpdySession::ProcessPendingStreamType(PendingStream* pending) {
+void QuicSpdySession::ProcessPendingStream(PendingStream* pending) {
   DCHECK(VersionHasControlStreams(connection()->transport_version()));
   struct iovec iov;
   if (!pending->sequencer()->GetReadableRegion(&iov)) {
diff --git a/quic/core/http/quic_spdy_session.h b/quic/core/http/quic_spdy_session.h
index 46f2b8a..65fa346 100644
--- a/quic/core/http/quic_spdy_session.h
+++ b/quic/core/http/quic_spdy_session.h
@@ -195,12 +195,12 @@
   // Returns true if there are open HTTP requests.
   bool ShouldKeepConnectionAlive() const override;
 
-  // Overridden to buffer incoming streams for version 99.
-  bool ShouldBufferIncomingStream(QuicStreamId id) const override;
+  // Overridden to buffer incoming unidirectional streams for version 99.
+  bool UsesPendingStreams() const override;
 
   // Overridden to Process HTTP/3 stream types. No action will be taken if
   // stream type cannot be read.
-  void ProcessPendingStreamType(PendingStream* pending) override;
+  void ProcessPendingStream(PendingStream* pending) override;
 
   size_t WriteHeadersOnHeadersStreamImpl(
       QuicStreamId id,
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 1059a94..347a5f7 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -267,8 +267,8 @@
 
   using QuicSession::closed_streams;
   using QuicSession::zombie_streams;
-  using QuicSpdySession::ProcessPendingStreamType;
-  using QuicSpdySession::ShouldBufferIncomingStream;
+  using QuicSpdySession::ProcessPendingStream;
+  using QuicSpdySession::UsesPendingStreams;
 
  private:
   StrictMock<TestCryptoStream> crypto_stream_;
@@ -417,22 +417,11 @@
                          QuicSpdySessionTestServer,
                          ::testing::ValuesIn(AllSupportedVersions()));
 
-TEST_P(QuicSpdySessionTestServer, ShouldBufferIncomingStreamUnidirectional) {
-  if (!IsVersion99()) {
+TEST_P(QuicSpdySessionTestServer, UsesPendingStreams) {
+  if (!VersionHasControlStreams(transport_version())) {
     return;
   }
-  EXPECT_TRUE(session_.ShouldBufferIncomingStream(
-      QuicUtils::GetFirstUnidirectionalStreamId(
-          connection_->transport_version(), Perspective::IS_CLIENT)));
-}
-
-TEST_P(QuicSpdySessionTestServer, ShouldBufferIncomingStreamBidirectional) {
-  if (!IsVersion99()) {
-    return;
-  }
-  EXPECT_FALSE(session_.ShouldBufferIncomingStream(
-      QuicUtils::GetFirstBidirectionalStreamId(connection_->transport_version(),
-                                               Perspective::IS_CLIENT)));
+  EXPECT_TRUE(session_.UsesPendingStreams());
 }
 
 TEST_P(QuicSpdySessionTestServer, PeerAddress) {
@@ -1609,6 +1598,13 @@
                          QuicSpdySessionTestClient,
                          ::testing::ValuesIn(AllSupportedVersions()));
 
+TEST_P(QuicSpdySessionTestClient, UsesPendingStreams) {
+  if (!VersionHasControlStreams(transport_version())) {
+    return;
+  }
+  EXPECT_TRUE(session_.UsesPendingStreams());
+}
+
 TEST_P(QuicSpdySessionTestClient, AvailableStreamsClient) {
   ASSERT_TRUE(session_.GetOrCreateDynamicStream(
                   GetNthServerInitiatedBidirectionalId(2)) != nullptr);
@@ -1889,7 +1885,7 @@
 
   // A stop sending frame will be sent to indicate unknown type.
   EXPECT_CALL(*connection_, SendControlFrame(_));
-  session_.ProcessPendingStreamType(&pending);
+  session_.ProcessPendingStream(&pending);
 }
 
 TEST_P(QuicSpdySessionTestServer, SimplePendingStreamTypeOutOfOrderDelivery) {
@@ -1905,13 +1901,13 @@
                   'a', 'b', 'c'};
   QuicStreamFrame data1(pending.id(), true, 1, QuicStringPiece(&input[1], 3));
   pending.OnStreamFrame(data1);
-  session_.ProcessPendingStreamType(&pending);
+  session_.ProcessPendingStream(&pending);
 
   QuicStreamFrame data2(pending.id(), false, 0, QuicStringPiece(input, 1));
   pending.OnStreamFrame(data2);
 
   EXPECT_CALL(*connection_, SendControlFrame(_));
-  session_.ProcessPendingStreamType(&pending);
+  session_.ProcessPendingStream(&pending);
 }
 
 TEST_P(QuicSpdySessionTestServer,
@@ -1929,17 +1925,17 @@
 
   QuicStreamFrame data1(pending.id(), true, 2, QuicStringPiece(&input[2], 3));
   pending.OnStreamFrame(data1);
-  session_.ProcessPendingStreamType(&pending);
+  session_.ProcessPendingStream(&pending);
 
   QuicStreamFrame data2(pending.id(), false, 0, QuicStringPiece(input, 1));
   pending.OnStreamFrame(data2);
-  session_.ProcessPendingStreamType(&pending);
+  session_.ProcessPendingStream(&pending);
 
   QuicStreamFrame data3(pending.id(), false, 1, QuicStringPiece(&input[1], 1));
   pending.OnStreamFrame(data3);
 
   EXPECT_CALL(*connection_, SendControlFrame(_));
-  session_.ProcessPendingStreamType(&pending);
+  session_.ProcessPendingStream(&pending);
 }
 
 }  // namespace
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index 8bdf3b5..0545fa6 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -149,6 +149,24 @@
   }
 }
 
+void QuicSession::PendingStreamOnStreamFrame(const QuicStreamFrame& frame) {
+  DCHECK(VersionHasControlStreams(connection()->transport_version()));
+  QuicStreamId stream_id = frame.stream_id;
+
+  PendingStream* pending = GetOrCreatePendingStream(stream_id);
+
+  if (!pending) {
+    if (frame.fin) {
+      QuicStreamOffset final_byte_offset = frame.offset + frame.data_length;
+      OnFinalByteOffsetReceived(stream_id, final_byte_offset);
+    }
+    return;
+  }
+
+  pending->OnStreamFrame(frame);
+  ProcessPendingStream(pending);
+}
+
 void QuicSession::OnStreamFrame(const QuicStreamFrame& frame) {
   // TODO(rch) deal with the error case of stream id 0.
   QuicStreamId stream_id = frame.stream_id;
@@ -167,12 +185,18 @@
     return;
   }
 
-  StreamHandler handler = GetOrCreateStreamImpl(stream_id, frame.offset != 0);
-  if (handler.is_pending) {
-    handler.pending->OnStreamFrame(frame);
+  if (VersionHasControlStreams(connection()->transport_version()) &&
+      UsesPendingStreams() &&
+      QuicUtils::GetStreamType(stream_id, perspective(),
+                               IsIncomingStream(stream_id)) ==
+          READ_UNIDIRECTIONAL &&
+      dynamic_stream_map_.find(stream_id) == dynamic_stream_map_.end()) {
+    PendingStreamOnStreamFrame(frame);
     return;
   }
 
+  StreamHandler handler = GetOrCreateStreamImpl(stream_id);
+
   if (!handler.stream) {
     // The stream no longer exists, but we may still be interested in the
     // final stream byte offset sent by the peer. A frame with a FIN can give
@@ -296,6 +320,21 @@
   return true;
 }
 
+void QuicSession::PendingStreamOnRstStream(const QuicRstStreamFrame& frame) {
+  DCHECK(VersionHasControlStreams(connection()->transport_version()));
+  QuicStreamId stream_id = frame.stream_id;
+
+  PendingStream* pending = GetOrCreatePendingStream(stream_id);
+
+  if (!pending) {
+    HandleRstOnValidNonexistentStream(frame);
+    return;
+  }
+
+  pending->OnRstStreamFrame(frame);
+  ClosePendingStream(stream_id);
+}
+
 void QuicSession::OnRstStream(const QuicRstStreamFrame& frame) {
   QuicStreamId stream_id = frame.stream_id;
   if (stream_id ==
@@ -317,15 +356,18 @@
     visitor_->OnRstStreamReceived(frame);
   }
 
-  // may_buffer is true here to allow subclasses to buffer streams until the
-  // first byte of payload arrives which would allow sessions to delay
-  // creation of the stream until the type is known.
-  StreamHandler handler = GetOrCreateStreamImpl(stream_id, /*may_buffer=*/true);
-  if (handler.is_pending) {
-    handler.pending->OnRstStreamFrame(frame);
-    ClosePendingStream(stream_id);
+  if (VersionHasControlStreams(connection()->transport_version()) &&
+      UsesPendingStreams() &&
+      QuicUtils::GetStreamType(stream_id, perspective(),
+                               IsIncomingStream(stream_id)) ==
+          READ_UNIDIRECTIONAL &&
+      dynamic_stream_map_.find(stream_id) == dynamic_stream_map_.end()) {
+    PendingStreamOnRstStream(frame);
     return;
   }
+
+  StreamHandler handler = GetOrCreateStreamImpl(stream_id);
+
   if (!handler.stream) {
     HandleRstOnValidNonexistentStream(frame);
     return;  // Errors are handled by GetOrCreateStream.
@@ -1200,15 +1242,13 @@
 }
 
 QuicStream* QuicSession::GetOrCreateStream(const QuicStreamId stream_id) {
-  StreamHandler handler =
-      GetOrCreateStreamImpl(stream_id, /*may_buffer=*/false);
+  StreamHandler handler = GetOrCreateStreamImpl(stream_id);
   DCHECK(!handler.is_pending);
   return handler.stream;
 }
 
 QuicSession::StreamHandler QuicSession::GetOrCreateStreamImpl(
-    QuicStreamId stream_id,
-    bool may_buffer) {
+    QuicStreamId stream_id) {
   if (eliminate_static_stream_map_ &&
       stream_id ==
           QuicUtils::GetCryptoStreamId(connection_->transport_version())) {
@@ -1219,7 +1259,7 @@
   if (it != static_stream_map_.end()) {
     return StreamHandler(it->second);
   }
-  return GetOrCreateDynamicStreamImpl(stream_id, may_buffer);
+  return GetOrCreateDynamicStreamImpl(stream_id);
 }
 
 void QuicSession::StreamDraining(QuicStreamId stream_id) {
@@ -1254,17 +1294,32 @@
   return write_blocked_streams()->ShouldYield(stream_id);
 }
 
+PendingStream* QuicSession::GetOrCreatePendingStream(QuicStreamId stream_id) {
+  auto it = pending_stream_map_.find(stream_id);
+  if (it != pending_stream_map_.end()) {
+    return it->second.get();
+  }
+
+  if (IsClosedStream(stream_id) ||
+      !MaybeIncreaseLargestPeerStreamId(stream_id)) {
+    return nullptr;
+  }
+
+  auto pending = QuicMakeUnique<PendingStream>(stream_id, this);
+  PendingStream* unowned_pending = pending.get();
+  pending_stream_map_[stream_id] = std::move(pending);
+  return unowned_pending;
+}
+
 QuicStream* QuicSession::GetOrCreateDynamicStream(
     const QuicStreamId stream_id) {
-  StreamHandler handler =
-      GetOrCreateDynamicStreamImpl(stream_id, /*may_buffer=*/false);
+  StreamHandler handler = GetOrCreateDynamicStreamImpl(stream_id);
   DCHECK(!handler.is_pending);
   return handler.stream;
 }
 
 QuicSession::StreamHandler QuicSession::GetOrCreateDynamicStreamImpl(
-    QuicStreamId stream_id,
-    bool may_buffer) {
+    QuicStreamId stream_id) {
   DCHECK(!QuicContainsKey(static_stream_map_, stream_id))
       << "Attempt to call GetOrCreateDynamicStream for a static stream";
 
@@ -1282,21 +1337,6 @@
     return StreamHandler();
   }
 
-  auto pending_it = pending_stream_map_.find(stream_id);
-  if (pending_it != pending_stream_map_.end()) {
-    DCHECK_EQ(QUIC_VERSION_99, connection_->transport_version());
-    if (may_buffer) {
-      return StreamHandler(pending_it->second.get());
-    }
-    // The stream limit accounting has already been taken care of
-    // when the PendingStream was created, so there is no need to
-    // do so here. Now we can create the actual stream from the
-    // PendingStream.
-    StreamHandler handler(CreateIncomingStream(std::move(*pending_it->second)));
-    pending_stream_map_.erase(pending_it);
-    return handler;
-  }
-
   // TODO(fkastenholz): If we are creating a new stream and we have
   // sent a goaway, we should ignore the stream creation. Need to
   // add code to A) test if goaway was sent ("if (goaway_sent_)") and
@@ -1317,18 +1357,6 @@
     }
   }
 
-  if (connection_->transport_version() == QUIC_VERSION_99 && may_buffer &&
-      ShouldBufferIncomingStream(stream_id)) {
-    ++num_dynamic_incoming_streams_;
-    // Since STREAM frames may arrive out of order, delay creating the
-    // stream object until the first byte arrives. Buffer the frames and
-    // handle flow control accounting in the PendingStream.
-    auto pending = QuicMakeUnique<PendingStream>(stream_id, this);
-    StreamHandler handler(pending.get());
-    pending_stream_map_[stream_id] = std::move(pending);
-    return handler;
-  }
-
   return StreamHandler(CreateIncomingStream(stream_id));
 }
 
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index e0f2458..f16477c 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -483,11 +483,11 @@
   virtual void OnFinalByteOffsetReceived(QuicStreamId id,
                                          QuicStreamOffset final_byte_offset);
 
-  // Returns true if incoming streams should be buffered until the first
-  // byte of the stream arrives.
-  virtual bool ShouldBufferIncomingStream(QuicStreamId id) const {
-    return false;
-  }
+  // Returns true if incoming unidirectional streams should be buffered until
+  // the first byte of the stream arrives.
+  // If a subclass returns true here, it should make sure to implement
+  // ProcessPendingStream().
+  virtual bool UsesPendingStreams() const { return false; }
 
   // Register (|id|, |stream|) with the static stream map. Override previous
   // registrations with the same id.
@@ -573,11 +573,11 @@
     };
   };
 
-  StreamHandler GetOrCreateStreamImpl(QuicStreamId stream_id, bool may_buffer);
+  StreamHandler GetOrCreateStreamImpl(QuicStreamId stream_id);
 
   // Processes the stream type information of |pending| depending on
-  // different kinds of sessions's own rules.
-  virtual void ProcessPendingStreamType(PendingStream* pending) {}
+  // different kinds of sessions' own rules.
+  virtual void ProcessPendingStream(PendingStream* pending) {}
 
   bool eliminate_static_stream_map() const {
     return eliminate_static_stream_map_;
@@ -613,8 +613,9 @@
   // closed.
   QuicStream* GetStream(QuicStreamId id) const;
 
-  StreamHandler GetOrCreateDynamicStreamImpl(QuicStreamId stream_id,
-                                             bool may_buffer);
+  StreamHandler GetOrCreateDynamicStreamImpl(QuicStreamId stream_id);
+
+  PendingStream* GetOrCreatePendingStream(QuicStreamId stream_id);
 
   // Let streams and control frame managers retransmit lost data, returns true
   // if all lost data is retransmitted. Returns false otherwise.
@@ -623,6 +624,14 @@
   // Closes the pending stream |stream_id| before it has been created.
   void ClosePendingStream(QuicStreamId stream_id);
 
+  // Creates or gets pending stream, feeds it with |frame|, and processes the
+  // pending stream.
+  void PendingStreamOnStreamFrame(const QuicStreamFrame& frame);
+
+  // Creates or gets pending strea, feed it with |frame|, and closes the pending
+  // stream.
+  void PendingStreamOnRstStream(const QuicRstStreamFrame& frame);
+
   // Keep track of highest received byte offset of locally closed streams, while
   // waiting for a definitive final highest offset from the peer.
   std::map<QuicStreamId, QuicStreamOffset>
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc
index a441dc4..8f418c0 100644
--- a/quic/core/quic_session_test.cc
+++ b/quic/core/quic_session_test.cc
@@ -138,7 +138,7 @@
                     CurrentSupportedVersions()),
         crypto_stream_(this),
         writev_consumes_all_data_(false),
-        should_buffer_incoming_streams_(false),
+        uses_pending_streams_(false),
         num_incoming_streams_created_(0) {
     Initialize();
     this->connection()->SetEncrypter(
@@ -209,6 +209,17 @@
     return stream;
   }
 
+  // QuicSession doesn't do anything in this method. So it's overridden here to
+  // test that the session handles pending streams correctly in terms of
+  // receiving stream frames.
+  void ProcessPendingStream(PendingStream* pending) override {
+    struct iovec iov;
+    if (pending->sequencer()->GetReadableRegion(&iov)) {
+      // Create TestStream once the first byte is received.
+      CreateIncomingStream(std::move(*pending));
+    }
+  }
+
   bool IsClosedStream(QuicStreamId id) {
     return QuicSession::IsClosedStream(id);
   }
@@ -279,12 +290,10 @@
     return WritevData(stream, stream->id(), bytes, 0, FIN);
   }
 
-  bool ShouldBufferIncomingStream(QuicStreamId id) const override {
-    return should_buffer_incoming_streams_;
-  }
+  bool UsesPendingStreams() const override { return uses_pending_streams_; }
 
-  void set_should_buffer_incoming_streams(bool should_buffer_incoming_streams) {
-    should_buffer_incoming_streams_ = should_buffer_incoming_streams;
+  void set_uses_pending_streams(bool uses_pending_streams) {
+    uses_pending_streams_ = uses_pending_streams;
   }
 
   int num_incoming_streams_created() const {
@@ -299,7 +308,7 @@
   StrictMock<TestCryptoStream> crypto_stream_;
 
   bool writev_consumes_all_data_;
-  bool should_buffer_incoming_streams_;
+  bool uses_pending_streams_;
   QuicFrame save_frame_;
   int num_incoming_streams_created_;
 };
@@ -1575,7 +1584,7 @@
 }
 
 TEST_P(QuicSessionTestServer, NoPendingStreams) {
-  session_.set_should_buffer_incoming_streams(false);
+  session_.set_uses_pending_streams(false);
 
   QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId(
       transport_version(), Perspective::IS_CLIENT);
@@ -1592,7 +1601,7 @@
   if (connection_->transport_version() != QUIC_VERSION_99) {
     return;
   }
-  session_.set_should_buffer_incoming_streams(true);
+  session_.set_uses_pending_streams(true);
 
   QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId(
       transport_version(), Perspective::IS_CLIENT);
@@ -1609,7 +1618,7 @@
   if (connection_->transport_version() != QUIC_VERSION_99) {
     return;
   }
-  session_.set_should_buffer_incoming_streams(true);
+  session_.set_uses_pending_streams(true);
 
   QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId(
       transport_version(), Perspective::IS_CLIENT);