Factor out QuicTransportStream logic so that it can be reused in WebTransport over HTTP/3.
PiperOrigin-RevId: 363545740
Change-Id: I97776a85eac70c5058f6efefbae635df4b5457d3
diff --git a/quic/core/quic_stream.h b/quic/core/quic_stream.h
index d89e460..822ec18 100644
--- a/quic/core/quic_stream.h
+++ b/quic/core/quic_stream.h
@@ -376,6 +376,9 @@
bool fin_buffered() const { return fin_buffered_; }
+ // True if buffered data in send buffer is below buffered_data_threshold_.
+ bool CanWriteNewData() const;
+
protected:
// Called when data of [offset, offset + data_length] is buffered in send
// buffer.
@@ -391,9 +394,6 @@
// a RST_STREAM has been sent.
virtual void OnClose();
- // True if buffered data in send buffer is below buffered_data_threshold_.
- bool CanWriteNewData() const;
-
// True if buffered data in send buffer is still below
// buffered_data_threshold_ even after writing |length| bytes.
bool CanWriteNewDataAfterData(QuicByteCount length) const;
diff --git a/quic/core/web_transport_interface.h b/quic/core/web_transport_interface.h
index a458e65..75c66c5 100644
--- a/quic/core/web_transport_interface.h
+++ b/quic/core/web_transport_interface.h
@@ -9,6 +9,7 @@
#define QUICHE_QUIC_CORE_WEB_TRANSPORT_INTERFACE_H_
#include <cstddef>
+#include <memory>
#include "absl/base/attributes.h"
#include "absl/strings/string_view.h"
@@ -52,6 +53,9 @@
virtual bool CanWrite() const = 0;
// Indicates the number of bytes that can be read from the stream.
virtual size_t ReadableBytes() const = 0;
+
+ virtual void SetVisitor(
+ std::unique_ptr<WebTransportStreamVisitor> visitor) = 0;
};
// Visitor that gets notified about events related to a WebTransport session.
diff --git a/quic/core/web_transport_stream_adapter.cc b/quic/core/web_transport_stream_adapter.cc
new file mode 100644
index 0000000..7a5fee9
--- /dev/null
+++ b/quic/core/web_transport_stream_adapter.cc
@@ -0,0 +1,125 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quic/core/web_transport_stream_adapter.h"
+
+namespace quic {
+
+WebTransportStreamAdapter::WebTransportStreamAdapter(
+ QuicSession* session,
+ QuicStream* stream,
+ QuicStreamSequencer* sequencer)
+ : session_(session), stream_(stream), sequencer_(sequencer) {}
+
+size_t WebTransportStreamAdapter::Read(char* buffer, size_t buffer_size) {
+ iovec iov;
+ iov.iov_base = buffer;
+ iov.iov_len = buffer_size;
+ const size_t result = sequencer_->Readv(&iov, 1);
+ if (sequencer_->IsClosed()) {
+ MaybeNotifyFinRead();
+ }
+ return result;
+}
+
+size_t WebTransportStreamAdapter::Read(std::string* output) {
+ const size_t old_size = output->size();
+ const size_t bytes_to_read = ReadableBytes();
+ output->resize(old_size + bytes_to_read);
+ size_t bytes_read = Read(&(*output)[old_size], bytes_to_read);
+ QUICHE_DCHECK_EQ(bytes_to_read, bytes_read);
+ output->resize(old_size + bytes_read);
+ return bytes_read;
+}
+
+bool WebTransportStreamAdapter::Write(absl::string_view data) {
+ if (!CanWrite()) {
+ return false;
+ }
+
+ QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+ session_->connection()->helper()->GetStreamSendBufferAllocator(),
+ data.size());
+ memcpy(buffer.get(), data.data(), data.size());
+ QuicMemSlice memslice(std::move(buffer), data.size());
+ QuicConsumedData consumed =
+ stream_->WriteMemSlices(QuicMemSliceSpan(&memslice), /*fin=*/false);
+
+ if (consumed.bytes_consumed == data.size()) {
+ return true;
+ }
+ if (consumed.bytes_consumed == 0) {
+ return false;
+ }
+ // WebTransportStream::Write() is an all-or-nothing write API. To achieve
+ // that property, it relies on WriteMemSlices() being an all-or-nothing API.
+ // If WriteMemSlices() fails to provide that guarantee, we have no way to
+ // communicate a partial write to the caller, and thus it's safer to just
+ // close the connection.
+ QUIC_BUG(WebTransportStreamAdapter partial write)
+ << "WriteMemSlices() unexpectedly partially consumed the input "
+ "data, provided: "
+ << data.size() << ", written: " << consumed.bytes_consumed;
+ stream_->OnUnrecoverableError(
+ QUIC_INTERNAL_ERROR,
+ "WriteMemSlices() unexpectedly partially consumed the input data");
+ return false;
+}
+
+bool WebTransportStreamAdapter::SendFin() {
+ if (!CanWrite()) {
+ return false;
+ }
+
+ QuicMemSlice empty;
+ QuicConsumedData consumed =
+ stream_->WriteMemSlices(QuicMemSliceSpan(&empty), /*fin=*/true);
+ QUICHE_DCHECK_EQ(consumed.bytes_consumed, 0u);
+ return consumed.fin_consumed;
+}
+
+bool WebTransportStreamAdapter::CanWrite() const {
+ return stream_->CanWriteNewData() && !stream_->write_side_closed();
+}
+
+size_t WebTransportStreamAdapter::ReadableBytes() const {
+ return sequencer_->ReadableBytes();
+}
+
+void WebTransportStreamAdapter::OnDataAvailable() {
+ if (sequencer_->IsClosed()) {
+ MaybeNotifyFinRead();
+ return;
+ }
+
+ if (visitor_ == nullptr) {
+ return;
+ }
+ if (ReadableBytes() == 0) {
+ return;
+ }
+ visitor_->OnCanRead();
+}
+
+void WebTransportStreamAdapter::OnCanWriteNewData() {
+ // Ensure the origin check has been completed, as the stream can be notified
+ // about being writable before that.
+ if (!CanWrite()) {
+ return;
+ }
+ if (visitor_ != nullptr) {
+ visitor_->OnCanWrite();
+ }
+}
+
+void WebTransportStreamAdapter::MaybeNotifyFinRead() {
+ if (visitor_ == nullptr || fin_read_notified_) {
+ return;
+ }
+ fin_read_notified_ = true;
+ visitor_->OnFinRead();
+ stream_->OnFinRead();
+}
+
+} // namespace quic
diff --git a/quic/core/web_transport_stream_adapter.h b/quic/core/web_transport_stream_adapter.h
new file mode 100644
index 0000000..76a445c
--- /dev/null
+++ b/quic/core/web_transport_stream_adapter.h
@@ -0,0 +1,53 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_CORE_WEB_TRANSPORT_STREAM_ADAPTER_H_
+#define QUICHE_QUIC_CORE_WEB_TRANSPORT_STREAM_ADAPTER_H_
+
+#include "quic/core/quic_session.h"
+#include "quic/core/quic_stream.h"
+#include "quic/core/quic_stream_sequencer.h"
+#include "quic/core/web_transport_interface.h"
+
+namespace quic {
+
+// Converts WebTransportStream API calls into QuicStream API calls. The users
+// of this class can either subclass it, or wrap around it.
+class QUIC_EXPORT_PRIVATE WebTransportStreamAdapter
+ : public WebTransportStream {
+ public:
+ WebTransportStreamAdapter(QuicSession* session,
+ QuicStream* stream,
+ QuicStreamSequencer* sequencer);
+
+ // WebTransportStream implementation.
+ size_t Read(char* buffer, size_t buffer_size) override;
+ size_t Read(std::string* output) override;
+ ABSL_MUST_USE_RESULT bool Write(absl::string_view data) override;
+ ABSL_MUST_USE_RESULT bool SendFin() override;
+ bool CanWrite() const override;
+ size_t ReadableBytes() const override;
+ void SetVisitor(std::unique_ptr<WebTransportStreamVisitor> visitor) override {
+ visitor_ = std::move(visitor);
+ }
+
+ WebTransportStreamVisitor* visitor() { return visitor_.get(); }
+
+ // Calls that need to be passed from the corresponding QuicStream methods.
+ void OnDataAvailable();
+ void OnCanWriteNewData();
+
+ private:
+ void MaybeNotifyFinRead();
+
+ QuicSession* session_; // Unowned.
+ QuicStream* stream_; // Unowned.
+ QuicStreamSequencer* sequencer_; // Unowned.
+ std::unique_ptr<WebTransportStreamVisitor> visitor_;
+ bool fin_read_notified_ = false;
+};
+
+} // namespace quic
+
+#endif // QUICHE_QUIC_CORE_WEB_TRANSPORT_STREAM_ADAPTER_H_
diff --git a/quic/quic_transport/quic_transport_integration_test.cc b/quic/quic_transport/quic_transport_integration_test.cc
index 4e237c7..99bf41d 100644
--- a/quic/quic_transport/quic_transport_integration_test.cc
+++ b/quic/quic_transport/quic_transport_integration_test.cc
@@ -325,7 +325,7 @@
client_->session()->AcceptIncomingUnidirectionalStream();
ASSERT_TRUE(reply != nullptr);
std::string buffer;
- reply->set_visitor(VisitorExpectingFin());
+ reply->SetVisitor(VisitorExpectingFin());
EXPECT_GT(reply->Read(&buffer), 0u);
EXPECT_EQ(buffer, "Stream Two");
@@ -339,7 +339,7 @@
[&stream_received]() { return stream_received; }, kDefaultTimeout));
reply = client_->session()->AcceptIncomingUnidirectionalStream();
ASSERT_TRUE(reply != nullptr);
- reply->set_visitor(VisitorExpectingFin());
+ reply->SetVisitor(VisitorExpectingFin());
EXPECT_GT(reply->Read(&buffer), 0u);
EXPECT_EQ(buffer, "Stream One");
}
diff --git a/quic/quic_transport/quic_transport_stream.cc b/quic/quic_transport/quic_transport_stream.cc
index 5e66616..738a535 100644
--- a/quic/quic_transport/quic_transport_stream.cc
+++ b/quic/quic_transport/quic_transport_stream.cc
@@ -25,6 +25,7 @@
session->connection()->perspective(),
session->IsIncomingStream(id),
session->version())),
+ adapter_(session, this, sequencer()),
session_interface_(session_interface) {}
size_t QuicTransportStream::Read(char* buffer, size_t buffer_size) {
@@ -32,24 +33,15 @@
return 0;
}
- iovec iov;
- iov.iov_base = buffer;
- iov.iov_len = buffer_size;
- const size_t result = sequencer()->Readv(&iov, 1);
- if (sequencer()->IsClosed()) {
- MaybeNotifyFinRead();
- }
- return result;
+ return adapter_.Read(buffer, buffer_size);
}
size_t QuicTransportStream::Read(std::string* output) {
- const size_t old_size = output->size();
- const size_t bytes_to_read = ReadableBytes();
- output->resize(old_size + bytes_to_read);
- size_t bytes_read = Read(&(*output)[old_size], bytes_to_read);
- QUICHE_DCHECK_EQ(bytes_to_read, bytes_read);
- output->resize(old_size + bytes_read);
- return bytes_read;
+ if (!session_interface_->IsSessionReady()) {
+ return 0;
+ }
+
+ return adapter_.Read(output);
}
bool QuicTransportStream::Write(absl::string_view data) {
@@ -57,33 +49,7 @@
return false;
}
- QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
- session()->connection()->helper()->GetStreamSendBufferAllocator(),
- data.size());
- memcpy(buffer.get(), data.data(), data.size());
- QuicMemSlice memslice(std::move(buffer), data.size());
- QuicConsumedData consumed =
- WriteMemSlices(QuicMemSliceSpan(&memslice), /*fin=*/false);
-
- if (consumed.bytes_consumed == data.size()) {
- return true;
- }
- if (consumed.bytes_consumed == 0) {
- return false;
- }
- // QuicTransportStream::Write() is an all-or-nothing write API. To achieve
- // that property, it relies on WriteMemSlices() being an all-or-nothing API.
- // If WriteMemSlices() fails to provide that guarantee, we have no way to
- // communicate a partial write to the caller, and thus it's safer to just
- // close the connection.
- QUIC_BUG(quic_bug_10893_1)
- << "WriteMemSlices() unexpectedly partially consumed the input "
- "data, provided: "
- << data.size() << ", written: " << consumed.bytes_consumed;
- OnUnrecoverableError(
- QUIC_INTERNAL_ERROR,
- "WriteMemSlices() unexpectedly partially consumed the input data");
- return false;
+ return adapter_.Write(data);
}
bool QuicTransportStream::SendFin() {
@@ -91,16 +57,11 @@
return false;
}
- QuicMemSlice empty;
- QuicConsumedData consumed =
- WriteMemSlices(QuicMemSliceSpan(&empty), /*fin=*/true);
- QUICHE_DCHECK_EQ(consumed.bytes_consumed, 0u);
- return consumed.fin_consumed;
+ return adapter_.SendFin();
}
bool QuicTransportStream::CanWrite() const {
- return session_interface_->IsSessionReady() && CanWriteNewData() &&
- !write_side_closed();
+ return session_interface_->IsSessionReady() && adapter_.CanWrite();
}
size_t QuicTransportStream::ReadableBytes() const {
@@ -108,22 +69,11 @@
return 0;
}
- return sequencer()->ReadableBytes();
+ return adapter_.ReadableBytes();
}
void QuicTransportStream::OnDataAvailable() {
- if (sequencer()->IsClosed()) {
- MaybeNotifyFinRead();
- return;
- }
-
- if (visitor_ == nullptr) {
- return;
- }
- if (ReadableBytes() == 0) {
- return;
- }
- visitor_->OnCanRead();
+ adapter_.OnDataAvailable();
}
void QuicTransportStream::OnCanWriteNewData() {
@@ -132,18 +82,7 @@
if (!CanWrite()) {
return;
}
- if (visitor_ != nullptr) {
- visitor_->OnCanWrite();
- }
-}
-
-void QuicTransportStream::MaybeNotifyFinRead() {
- if (visitor_ == nullptr || fin_read_notified_) {
- return;
- }
- fin_read_notified_ = true;
- visitor_->OnFinRead();
- OnFinRead();
+ adapter_.OnCanWriteNewData();
}
} // namespace quic
diff --git a/quic/quic_transport/quic_transport_stream.h b/quic/quic_transport/quic_transport_stream.h
index 776b07f..d9c9a73 100644
--- a/quic/quic_transport/quic_transport_stream.h
+++ b/quic/quic_transport/quic_transport_stream.h
@@ -14,6 +14,7 @@
#include "quic/core/quic_stream.h"
#include "quic/core/quic_types.h"
#include "quic/core/web_transport_interface.h"
+#include "quic/core/web_transport_stream_adapter.h"
#include "quic/quic_transport/quic_transport_session_interface.h"
namespace quic {
@@ -47,9 +48,9 @@
void OnDataAvailable() override;
void OnCanWriteNewData() override;
- WebTransportStreamVisitor* visitor() { return visitor_.get(); }
- void set_visitor(std::unique_ptr<WebTransportStreamVisitor> visitor) {
- visitor_ = std::move(visitor);
+ WebTransportStreamVisitor* visitor() { return adapter_.visitor(); }
+ void SetVisitor(std::unique_ptr<WebTransportStreamVisitor> visitor) override {
+ adapter_.SetVisitor(std::move(visitor));
}
protected:
@@ -59,9 +60,8 @@
void MaybeNotifyFinRead();
+ WebTransportStreamAdapter adapter_;
QuicTransportSessionInterface* session_interface_;
- std::unique_ptr<WebTransportStreamVisitor> visitor_ = nullptr;
- bool fin_read_notified_ = false;
};
} // namespace quic
diff --git a/quic/quic_transport/quic_transport_stream_test.cc b/quic/quic_transport/quic_transport_stream_test.cc
index 75446eb..364780c 100644
--- a/quic/quic_transport/quic_transport_stream_test.cc
+++ b/quic/quic_transport/quic_transport_stream_test.cc
@@ -51,7 +51,7 @@
auto visitor = std::make_unique<MockStreamVisitor>();
visitor_ = visitor.get();
- stream_->set_visitor(std::move(visitor));
+ stream_->SetVisitor(std::move(visitor));
}
void ReceiveStreamData(absl::string_view data, QuicStreamOffset offset) {
diff --git a/quic/tools/quic_transport_simple_server_session.cc b/quic/tools/quic_transport_simple_server_session.cc
index 9acf502..de9fb5c 100644
--- a/quic/tools/quic_transport_simple_server_session.cc
+++ b/quic/tools/quic_transport_simple_server_session.cc
@@ -154,21 +154,21 @@
QuicTransportStream* stream) {
switch (mode_) {
case DISCARD:
- stream->set_visitor(std::make_unique<DiscardVisitor>(stream));
+ stream->SetVisitor(std::make_unique<DiscardVisitor>(stream));
break;
case ECHO:
switch (stream->type()) {
case BIDIRECTIONAL:
QUIC_DVLOG(1) << "Opening bidirectional echo stream " << stream->id();
- stream->set_visitor(
+ stream->SetVisitor(
std::make_unique<BidirectionalEchoVisitor>(stream));
break;
case READ_UNIDIRECTIONAL:
QUIC_DVLOG(1)
<< "Started receiving data on unidirectional echo stream "
<< stream->id();
- stream->set_visitor(
+ stream->SetVisitor(
std::make_unique<UnidirectionalEchoReadVisitor>(this, stream));
break;
default:
@@ -178,7 +178,7 @@
break;
case OUTGOING_BIDIRECTIONAL:
- stream->set_visitor(std::make_unique<DiscardVisitor>(stream));
+ stream->SetVisitor(std::make_unique<DiscardVisitor>(stream));
++pending_outgoing_bidirectional_streams_;
MaybeCreateOutgoingBidirectionalStream();
break;
@@ -252,7 +252,7 @@
ActivateStream(std::move(stream_owned));
QUIC_DVLOG(1) << "Opened echo response stream " << stream->id();
- stream->set_visitor(
+ stream->SetVisitor(
std::make_unique<UnidirectionalEchoWriteVisitor>(stream, data));
stream->visitor()->OnCanWrite();
}
@@ -267,7 +267,7 @@
QuicTransportStream* stream = stream_owned.get();
ActivateStream(std::move(stream_owned));
QUIC_DVLOG(1) << "Opened outgoing bidirectional stream " << stream->id();
- stream->set_visitor(std::make_unique<BidirectionalEchoVisitor>(stream));
+ stream->SetVisitor(std::make_unique<BidirectionalEchoVisitor>(stream));
if (!stream->Write("hello")) {
QUIC_DVLOG(1) << "Write failed.";
}