Factor out QuicTransportStream logic so that it can be reused in WebTransport over HTTP/3.

PiperOrigin-RevId: 363545740
Change-Id: I97776a85eac70c5058f6efefbae635df4b5457d3
diff --git a/quic/core/quic_stream.h b/quic/core/quic_stream.h
index d89e460..822ec18 100644
--- a/quic/core/quic_stream.h
+++ b/quic/core/quic_stream.h
@@ -376,6 +376,9 @@
 
   bool fin_buffered() const { return fin_buffered_; }
 
+  // True if buffered data in send buffer is below buffered_data_threshold_.
+  bool CanWriteNewData() const;
+
  protected:
   // Called when data of [offset, offset + data_length] is buffered in send
   // buffer.
@@ -391,9 +394,6 @@
   // a RST_STREAM has been sent.
   virtual void OnClose();
 
-  // True if buffered data in send buffer is below buffered_data_threshold_.
-  bool CanWriteNewData() const;
-
   // True if buffered data in send buffer is still below
   // buffered_data_threshold_ even after writing |length| bytes.
   bool CanWriteNewDataAfterData(QuicByteCount length) const;
diff --git a/quic/core/web_transport_interface.h b/quic/core/web_transport_interface.h
index a458e65..75c66c5 100644
--- a/quic/core/web_transport_interface.h
+++ b/quic/core/web_transport_interface.h
@@ -9,6 +9,7 @@
 #define QUICHE_QUIC_CORE_WEB_TRANSPORT_INTERFACE_H_
 
 #include <cstddef>
+#include <memory>
 
 #include "absl/base/attributes.h"
 #include "absl/strings/string_view.h"
@@ -52,6 +53,9 @@
   virtual bool CanWrite() const = 0;
   // Indicates the number of bytes that can be read from the stream.
   virtual size_t ReadableBytes() const = 0;
+
+  virtual void SetVisitor(
+      std::unique_ptr<WebTransportStreamVisitor> visitor) = 0;
 };
 
 // Visitor that gets notified about events related to a WebTransport session.
diff --git a/quic/core/web_transport_stream_adapter.cc b/quic/core/web_transport_stream_adapter.cc
new file mode 100644
index 0000000..7a5fee9
--- /dev/null
+++ b/quic/core/web_transport_stream_adapter.cc
@@ -0,0 +1,125 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quic/core/web_transport_stream_adapter.h"
+
+namespace quic {
+
+WebTransportStreamAdapter::WebTransportStreamAdapter(
+    QuicSession* session,
+    QuicStream* stream,
+    QuicStreamSequencer* sequencer)
+    : session_(session), stream_(stream), sequencer_(sequencer) {}
+
+size_t WebTransportStreamAdapter::Read(char* buffer, size_t buffer_size) {
+  iovec iov;
+  iov.iov_base = buffer;
+  iov.iov_len = buffer_size;
+  const size_t result = sequencer_->Readv(&iov, 1);
+  if (sequencer_->IsClosed()) {
+    MaybeNotifyFinRead();
+  }
+  return result;
+}
+
+size_t WebTransportStreamAdapter::Read(std::string* output) {
+  const size_t old_size = output->size();
+  const size_t bytes_to_read = ReadableBytes();
+  output->resize(old_size + bytes_to_read);
+  size_t bytes_read = Read(&(*output)[old_size], bytes_to_read);
+  QUICHE_DCHECK_EQ(bytes_to_read, bytes_read);
+  output->resize(old_size + bytes_read);
+  return bytes_read;
+}
+
+bool WebTransportStreamAdapter::Write(absl::string_view data) {
+  if (!CanWrite()) {
+    return false;
+  }
+
+  QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+      session_->connection()->helper()->GetStreamSendBufferAllocator(),
+      data.size());
+  memcpy(buffer.get(), data.data(), data.size());
+  QuicMemSlice memslice(std::move(buffer), data.size());
+  QuicConsumedData consumed =
+      stream_->WriteMemSlices(QuicMemSliceSpan(&memslice), /*fin=*/false);
+
+  if (consumed.bytes_consumed == data.size()) {
+    return true;
+  }
+  if (consumed.bytes_consumed == 0) {
+    return false;
+  }
+  // WebTransportStream::Write() is an all-or-nothing write API.  To achieve
+  // that property, it relies on WriteMemSlices() being an all-or-nothing API.
+  // If WriteMemSlices() fails to provide that guarantee, we have no way to
+  // communicate a partial write to the caller, and thus it's safer to just
+  // close the connection.
+  QUIC_BUG(WebTransportStreamAdapter partial write)
+      << "WriteMemSlices() unexpectedly partially consumed the input "
+         "data, provided: "
+      << data.size() << ", written: " << consumed.bytes_consumed;
+  stream_->OnUnrecoverableError(
+      QUIC_INTERNAL_ERROR,
+      "WriteMemSlices() unexpectedly partially consumed the input data");
+  return false;
+}
+
+bool WebTransportStreamAdapter::SendFin() {
+  if (!CanWrite()) {
+    return false;
+  }
+
+  QuicMemSlice empty;
+  QuicConsumedData consumed =
+      stream_->WriteMemSlices(QuicMemSliceSpan(&empty), /*fin=*/true);
+  QUICHE_DCHECK_EQ(consumed.bytes_consumed, 0u);
+  return consumed.fin_consumed;
+}
+
+bool WebTransportStreamAdapter::CanWrite() const {
+  return stream_->CanWriteNewData() && !stream_->write_side_closed();
+}
+
+size_t WebTransportStreamAdapter::ReadableBytes() const {
+  return sequencer_->ReadableBytes();
+}
+
+void WebTransportStreamAdapter::OnDataAvailable() {
+  if (sequencer_->IsClosed()) {
+    MaybeNotifyFinRead();
+    return;
+  }
+
+  if (visitor_ == nullptr) {
+    return;
+  }
+  if (ReadableBytes() == 0) {
+    return;
+  }
+  visitor_->OnCanRead();
+}
+
+void WebTransportStreamAdapter::OnCanWriteNewData() {
+  // Ensure the origin check has been completed, as the stream can be notified
+  // about being writable before that.
+  if (!CanWrite()) {
+    return;
+  }
+  if (visitor_ != nullptr) {
+    visitor_->OnCanWrite();
+  }
+}
+
+void WebTransportStreamAdapter::MaybeNotifyFinRead() {
+  if (visitor_ == nullptr || fin_read_notified_) {
+    return;
+  }
+  fin_read_notified_ = true;
+  visitor_->OnFinRead();
+  stream_->OnFinRead();
+}
+
+}  // namespace quic
diff --git a/quic/core/web_transport_stream_adapter.h b/quic/core/web_transport_stream_adapter.h
new file mode 100644
index 0000000..76a445c
--- /dev/null
+++ b/quic/core/web_transport_stream_adapter.h
@@ -0,0 +1,53 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_CORE_WEB_TRANSPORT_STREAM_ADAPTER_H_
+#define QUICHE_QUIC_CORE_WEB_TRANSPORT_STREAM_ADAPTER_H_
+
+#include "quic/core/quic_session.h"
+#include "quic/core/quic_stream.h"
+#include "quic/core/quic_stream_sequencer.h"
+#include "quic/core/web_transport_interface.h"
+
+namespace quic {
+
+// Converts WebTransportStream API calls into QuicStream API calls.  The users
+// of this class can either subclass it, or wrap around it.
+class QUIC_EXPORT_PRIVATE WebTransportStreamAdapter
+    : public WebTransportStream {
+ public:
+  WebTransportStreamAdapter(QuicSession* session,
+                            QuicStream* stream,
+                            QuicStreamSequencer* sequencer);
+
+  // WebTransportStream implementation.
+  size_t Read(char* buffer, size_t buffer_size) override;
+  size_t Read(std::string* output) override;
+  ABSL_MUST_USE_RESULT bool Write(absl::string_view data) override;
+  ABSL_MUST_USE_RESULT bool SendFin() override;
+  bool CanWrite() const override;
+  size_t ReadableBytes() const override;
+  void SetVisitor(std::unique_ptr<WebTransportStreamVisitor> visitor) override {
+    visitor_ = std::move(visitor);
+  }
+
+  WebTransportStreamVisitor* visitor() { return visitor_.get(); }
+
+  // Calls that need to be passed from the corresponding QuicStream methods.
+  void OnDataAvailable();
+  void OnCanWriteNewData();
+
+ private:
+  void MaybeNotifyFinRead();
+
+  QuicSession* session_;            // Unowned.
+  QuicStream* stream_;              // Unowned.
+  QuicStreamSequencer* sequencer_;  // Unowned.
+  std::unique_ptr<WebTransportStreamVisitor> visitor_;
+  bool fin_read_notified_ = false;
+};
+
+}  // namespace quic
+
+#endif  // QUICHE_QUIC_CORE_WEB_TRANSPORT_STREAM_ADAPTER_H_
diff --git a/quic/quic_transport/quic_transport_integration_test.cc b/quic/quic_transport/quic_transport_integration_test.cc
index 4e237c7..99bf41d 100644
--- a/quic/quic_transport/quic_transport_integration_test.cc
+++ b/quic/quic_transport/quic_transport_integration_test.cc
@@ -325,7 +325,7 @@
       client_->session()->AcceptIncomingUnidirectionalStream();
   ASSERT_TRUE(reply != nullptr);
   std::string buffer;
-  reply->set_visitor(VisitorExpectingFin());
+  reply->SetVisitor(VisitorExpectingFin());
   EXPECT_GT(reply->Read(&buffer), 0u);
   EXPECT_EQ(buffer, "Stream Two");
 
@@ -339,7 +339,7 @@
       [&stream_received]() { return stream_received; }, kDefaultTimeout));
   reply = client_->session()->AcceptIncomingUnidirectionalStream();
   ASSERT_TRUE(reply != nullptr);
-  reply->set_visitor(VisitorExpectingFin());
+  reply->SetVisitor(VisitorExpectingFin());
   EXPECT_GT(reply->Read(&buffer), 0u);
   EXPECT_EQ(buffer, "Stream One");
 }
diff --git a/quic/quic_transport/quic_transport_stream.cc b/quic/quic_transport/quic_transport_stream.cc
index 5e66616..738a535 100644
--- a/quic/quic_transport/quic_transport_stream.cc
+++ b/quic/quic_transport/quic_transport_stream.cc
@@ -25,6 +25,7 @@
                                           session->connection()->perspective(),
                                           session->IsIncomingStream(id),
                                           session->version())),
+      adapter_(session, this, sequencer()),
       session_interface_(session_interface) {}
 
 size_t QuicTransportStream::Read(char* buffer, size_t buffer_size) {
@@ -32,24 +33,15 @@
     return 0;
   }
 
-  iovec iov;
-  iov.iov_base = buffer;
-  iov.iov_len = buffer_size;
-  const size_t result = sequencer()->Readv(&iov, 1);
-  if (sequencer()->IsClosed()) {
-    MaybeNotifyFinRead();
-  }
-  return result;
+  return adapter_.Read(buffer, buffer_size);
 }
 
 size_t QuicTransportStream::Read(std::string* output) {
-  const size_t old_size = output->size();
-  const size_t bytes_to_read = ReadableBytes();
-  output->resize(old_size + bytes_to_read);
-  size_t bytes_read = Read(&(*output)[old_size], bytes_to_read);
-  QUICHE_DCHECK_EQ(bytes_to_read, bytes_read);
-  output->resize(old_size + bytes_read);
-  return bytes_read;
+  if (!session_interface_->IsSessionReady()) {
+    return 0;
+  }
+
+  return adapter_.Read(output);
 }
 
 bool QuicTransportStream::Write(absl::string_view data) {
@@ -57,33 +49,7 @@
     return false;
   }
 
-  QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
-      session()->connection()->helper()->GetStreamSendBufferAllocator(),
-      data.size());
-  memcpy(buffer.get(), data.data(), data.size());
-  QuicMemSlice memslice(std::move(buffer), data.size());
-  QuicConsumedData consumed =
-      WriteMemSlices(QuicMemSliceSpan(&memslice), /*fin=*/false);
-
-  if (consumed.bytes_consumed == data.size()) {
-    return true;
-  }
-  if (consumed.bytes_consumed == 0) {
-    return false;
-  }
-  // QuicTransportStream::Write() is an all-or-nothing write API.  To achieve
-  // that property, it relies on WriteMemSlices() being an all-or-nothing API.
-  // If WriteMemSlices() fails to provide that guarantee, we have no way to
-  // communicate a partial write to the caller, and thus it's safer to just
-  // close the connection.
-  QUIC_BUG(quic_bug_10893_1)
-      << "WriteMemSlices() unexpectedly partially consumed the input "
-         "data, provided: "
-      << data.size() << ", written: " << consumed.bytes_consumed;
-  OnUnrecoverableError(
-      QUIC_INTERNAL_ERROR,
-      "WriteMemSlices() unexpectedly partially consumed the input data");
-  return false;
+  return adapter_.Write(data);
 }
 
 bool QuicTransportStream::SendFin() {
@@ -91,16 +57,11 @@
     return false;
   }
 
-  QuicMemSlice empty;
-  QuicConsumedData consumed =
-      WriteMemSlices(QuicMemSliceSpan(&empty), /*fin=*/true);
-  QUICHE_DCHECK_EQ(consumed.bytes_consumed, 0u);
-  return consumed.fin_consumed;
+  return adapter_.SendFin();
 }
 
 bool QuicTransportStream::CanWrite() const {
-  return session_interface_->IsSessionReady() && CanWriteNewData() &&
-         !write_side_closed();
+  return session_interface_->IsSessionReady() && adapter_.CanWrite();
 }
 
 size_t QuicTransportStream::ReadableBytes() const {
@@ -108,22 +69,11 @@
     return 0;
   }
 
-  return sequencer()->ReadableBytes();
+  return adapter_.ReadableBytes();
 }
 
 void QuicTransportStream::OnDataAvailable() {
-  if (sequencer()->IsClosed()) {
-    MaybeNotifyFinRead();
-    return;
-  }
-
-  if (visitor_ == nullptr) {
-    return;
-  }
-  if (ReadableBytes() == 0) {
-    return;
-  }
-  visitor_->OnCanRead();
+  adapter_.OnDataAvailable();
 }
 
 void QuicTransportStream::OnCanWriteNewData() {
@@ -132,18 +82,7 @@
   if (!CanWrite()) {
     return;
   }
-  if (visitor_ != nullptr) {
-    visitor_->OnCanWrite();
-  }
-}
-
-void QuicTransportStream::MaybeNotifyFinRead() {
-  if (visitor_ == nullptr || fin_read_notified_) {
-    return;
-  }
-  fin_read_notified_ = true;
-  visitor_->OnFinRead();
-  OnFinRead();
+  adapter_.OnCanWriteNewData();
 }
 
 }  // namespace quic
diff --git a/quic/quic_transport/quic_transport_stream.h b/quic/quic_transport/quic_transport_stream.h
index 776b07f..d9c9a73 100644
--- a/quic/quic_transport/quic_transport_stream.h
+++ b/quic/quic_transport/quic_transport_stream.h
@@ -14,6 +14,7 @@
 #include "quic/core/quic_stream.h"
 #include "quic/core/quic_types.h"
 #include "quic/core/web_transport_interface.h"
+#include "quic/core/web_transport_stream_adapter.h"
 #include "quic/quic_transport/quic_transport_session_interface.h"
 
 namespace quic {
@@ -47,9 +48,9 @@
   void OnDataAvailable() override;
   void OnCanWriteNewData() override;
 
-  WebTransportStreamVisitor* visitor() { return visitor_.get(); }
-  void set_visitor(std::unique_ptr<WebTransportStreamVisitor> visitor) {
-    visitor_ = std::move(visitor);
+  WebTransportStreamVisitor* visitor() { return adapter_.visitor(); }
+  void SetVisitor(std::unique_ptr<WebTransportStreamVisitor> visitor) override {
+    adapter_.SetVisitor(std::move(visitor));
   }
 
  protected:
@@ -59,9 +60,8 @@
 
   void MaybeNotifyFinRead();
 
+  WebTransportStreamAdapter adapter_;
   QuicTransportSessionInterface* session_interface_;
-  std::unique_ptr<WebTransportStreamVisitor> visitor_ = nullptr;
-  bool fin_read_notified_ = false;
 };
 
 }  // namespace quic
diff --git a/quic/quic_transport/quic_transport_stream_test.cc b/quic/quic_transport/quic_transport_stream_test.cc
index 75446eb..364780c 100644
--- a/quic/quic_transport/quic_transport_stream_test.cc
+++ b/quic/quic_transport/quic_transport_stream_test.cc
@@ -51,7 +51,7 @@
 
     auto visitor = std::make_unique<MockStreamVisitor>();
     visitor_ = visitor.get();
-    stream_->set_visitor(std::move(visitor));
+    stream_->SetVisitor(std::move(visitor));
   }
 
   void ReceiveStreamData(absl::string_view data, QuicStreamOffset offset) {
diff --git a/quic/tools/quic_transport_simple_server_session.cc b/quic/tools/quic_transport_simple_server_session.cc
index 9acf502..de9fb5c 100644
--- a/quic/tools/quic_transport_simple_server_session.cc
+++ b/quic/tools/quic_transport_simple_server_session.cc
@@ -154,21 +154,21 @@
     QuicTransportStream* stream) {
   switch (mode_) {
     case DISCARD:
-      stream->set_visitor(std::make_unique<DiscardVisitor>(stream));
+      stream->SetVisitor(std::make_unique<DiscardVisitor>(stream));
       break;
 
     case ECHO:
       switch (stream->type()) {
         case BIDIRECTIONAL:
           QUIC_DVLOG(1) << "Opening bidirectional echo stream " << stream->id();
-          stream->set_visitor(
+          stream->SetVisitor(
               std::make_unique<BidirectionalEchoVisitor>(stream));
           break;
         case READ_UNIDIRECTIONAL:
           QUIC_DVLOG(1)
               << "Started receiving data on unidirectional echo stream "
               << stream->id();
-          stream->set_visitor(
+          stream->SetVisitor(
               std::make_unique<UnidirectionalEchoReadVisitor>(this, stream));
           break;
         default:
@@ -178,7 +178,7 @@
       break;
 
     case OUTGOING_BIDIRECTIONAL:
-      stream->set_visitor(std::make_unique<DiscardVisitor>(stream));
+      stream->SetVisitor(std::make_unique<DiscardVisitor>(stream));
       ++pending_outgoing_bidirectional_streams_;
       MaybeCreateOutgoingBidirectionalStream();
       break;
@@ -252,7 +252,7 @@
     ActivateStream(std::move(stream_owned));
     QUIC_DVLOG(1) << "Opened echo response stream " << stream->id();
 
-    stream->set_visitor(
+    stream->SetVisitor(
         std::make_unique<UnidirectionalEchoWriteVisitor>(stream, data));
     stream->visitor()->OnCanWrite();
   }
@@ -267,7 +267,7 @@
     QuicTransportStream* stream = stream_owned.get();
     ActivateStream(std::move(stream_owned));
     QUIC_DVLOG(1) << "Opened outgoing bidirectional stream " << stream->id();
-    stream->set_visitor(std::make_unique<BidirectionalEchoVisitor>(stream));
+    stream->SetVisitor(std::make_unique<BidirectionalEchoVisitor>(stream));
     if (!stream->Write("hello")) {
       QUIC_DVLOG(1) << "Write failed.";
     }