In WebTransport over HTTP/2, implement basic stream support. This code supports ~all stream logic that does not involve flow control. PiperOrigin-RevId: 599530029
diff --git a/quiche/common/capsule.cc b/quiche/common/capsule.cc index 421f11d..6cfcd4c 100644 --- a/quiche/common/capsule.cc +++ b/quiche/common/capsule.cc
@@ -20,6 +20,7 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_data_reader.h" @@ -366,6 +367,21 @@ return *std::move(buffer); } +QUICHE_EXPORT QuicheBuffer SerializeWebTransportStreamCapsuleHeader( + webtransport::StreamId stream_id, bool fin, uint64_t write_size, + QuicheBufferAllocator* allocator) { + absl::StatusOr<QuicheBuffer> buffer = SerializeIntoBuffer( + allocator, + WireVarInt62(fin ? CapsuleType::WT_STREAM_WITH_FIN + : CapsuleType::WT_STREAM), + WireVarInt62(write_size + QuicheDataWriter::GetVarInt62Len(stream_id)), + WireVarInt62(stream_id)); + if (!buffer.ok()) { + return QuicheBuffer(); + } + return *std::move(buffer); +} + QuicheBuffer SerializeCapsule(const Capsule& capsule, quiche::QuicheBufferAllocator* allocator) { absl::StatusOr<QuicheBuffer> serialized =
diff --git a/quiche/common/capsule.h b/quiche/common/capsule.h index 3cadc4d..aa42530 100644 --- a/quiche/common/capsule.h +++ b/quiche/common/capsule.h
@@ -395,6 +395,11 @@ QUICHE_EXPORT QuicheBuffer SerializeDatagramCapsuleHeader( uint64_t datagram_size, QuicheBufferAllocator* allocator); +// Serializes the header for a WT_STREAM or a WT_STREAM_WITH_FIN capsule. +QUICHE_EXPORT QuicheBuffer SerializeWebTransportStreamCapsuleHeader( + webtransport::StreamId stream_id, bool fin, uint64_t write_size, + QuicheBufferAllocator* allocator); + } // namespace quiche #endif // QUICHE_COMMON_CAPSULE_H_
diff --git a/quiche/common/capsule_test.cc b/quiche/common/capsule_test.cc index 0aed714..fd07608 100644 --- a/quiche/common/capsule_test.cc +++ b/quiche/common/capsule_test.cc
@@ -297,6 +297,19 @@ ValidateParserIsEmpty(); TestSerialization(expected_capsule, capsule_fragment); } +TEST_F(CapsuleTest, WebTransportStreamDataHeader) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d3b" // WT_STREAM without FIN + "04" // capsule length + "17" // stream ID + // three bytes of stream payload implied below + ); + QuicheBufferAllocator* allocator = SimpleBufferAllocator::Get(); + QuicheBuffer capsule_header = + quiche::SerializeWebTransportStreamCapsuleHeader(0x17, /*fin=*/false, 3, + allocator); + EXPECT_EQ(capsule_header.AsStringView(), capsule_fragment); +} TEST_F(CapsuleTest, WebTransportStreamDataWithFin) { std::string capsule_fragment = absl::HexStringToBytes( "990b4d3c" // data with FIN
diff --git a/quiche/common/quiche_stream.h b/quiche/common/quiche_stream.h index 9de876b..ca1ef7b 100644 --- a/quiche/common/quiche_stream.h +++ b/quiche/common/quiche_stream.h
@@ -191,6 +191,15 @@ return stream.Writev(absl::Span<const absl::string_view>(), options); } +inline size_t TotalStringViewSpanSize( + absl::Span<const absl::string_view> span) { + size_t total = 0; + for (absl::string_view view : span) { + total += view.size(); + } + return total; +} + } // namespace quiche #endif // QUICHE_COMMON_QUICHE_STREAM_H_
diff --git a/quiche/web_transport/encapsulated/encapsulated_web_transport.cc b/quiche/web_transport/encapsulated/encapsulated_web_transport.cc index 0aa3763..adbbb53 100644 --- a/quiche/web_transport/encapsulated/encapsulated_web_transport.cc +++ b/quiche/web_transport/encapsulated/encapsulated_web_transport.cc
@@ -4,22 +4,36 @@ #include "quiche/web_transport/encapsulated/encapsulated_web_transport.h" -#include <array> -#include <cstdint> -#include <memory> -#include <string> -#include <utility> +#include <stdbool.h> +#include <algorithm> +#include <array> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <iterator> +#include <memory> +#include <optional> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "quiche/common/capsule.h" #include "quiche/common/http/http_header_block.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/quiche_buffer_allocator.h" #include "quiche/common/quiche_callbacks.h" +#include "quiche/common/quiche_circular_deque.h" #include "quiche/common/quiche_status_utils.h" #include "quiche/common/quiche_stream.h" #include "quiche/web_transport/web_transport.h" @@ -36,13 +50,20 @@ // over TCP. constexpr uint64_t kEncapsulatedMaxDatagramSize = 9000; +constexpr StreamPriority kDefaultPriority = StreamPriority{0, 0}; + } // namespace EncapsulatedSession::EncapsulatedSession( Perspective perspective, FatalErrorCallback fatal_error_callback) : perspective_(perspective), fatal_error_callback_(std::move(fatal_error_callback)), - capsule_parser_(this) {} + capsule_parser_(this), + next_outgoing_bidi_stream_(perspective == Perspective::kClient ? 0 : 1), + next_outgoing_unidi_stream_(perspective == Perspective::kClient ? 2 : 3) { + QUICHE_DCHECK(IsIdOpenedBy(next_outgoing_bidi_stream_, perspective)); + QUICHE_DCHECK(IsIdOpenedBy(next_outgoing_unidi_stream_, perspective)); +} void EncapsulatedSession::InitializeClient( std::unique_ptr<SessionVisitor> visitor, @@ -115,26 +136,66 @@ OnCanWrite(); } -Stream* EncapsulatedSession::AcceptIncomingBidirectionalStream() { - return nullptr; -} -Stream* EncapsulatedSession::AcceptIncomingUnidirectionalStream() { - return nullptr; -} -bool EncapsulatedSession::CanOpenNextOutgoingBidirectionalStream() { - return false; -} -bool EncapsulatedSession::CanOpenNextOutgoingUnidirectionalStream() { - return false; -} -Stream* EncapsulatedSession::OpenOutgoingBidirectionalStream() { - return nullptr; -} -Stream* EncapsulatedSession::OpenOutgoingUnidirectionalStream() { +Stream* EncapsulatedSession::AcceptIncomingStream( + quiche::QuicheCircularDeque<StreamId>& queue) { + while (!queue.empty()) { + StreamId id = queue.front(); + queue.pop_front(); + Stream* stream = GetStreamById(id); + if (stream == nullptr) { + // Stream got reset and garbage collected before the peer ever had a + // chance to look at it. + continue; + } + return stream; + } return nullptr; } -Stream* EncapsulatedSession::GetStreamById(StreamId /*id*/) { return nullptr; } +Stream* EncapsulatedSession::AcceptIncomingBidirectionalStream() { + return AcceptIncomingStream(incoming_bidirectional_streams_); +} +Stream* EncapsulatedSession::AcceptIncomingUnidirectionalStream() { + return AcceptIncomingStream(incoming_unidirectional_streams_); +} +bool EncapsulatedSession::CanOpenNextOutgoingBidirectionalStream() { + // TODO: implement flow control. + return true; +} +bool EncapsulatedSession::CanOpenNextOutgoingUnidirectionalStream() { + // TODO: implement flow control. + return true; +} +Stream* EncapsulatedSession::OpenOutgoingStream(StreamId& counter) { + StreamId stream_id = counter; + counter += 4; + auto [it, inserted] = streams_.emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(this, stream_id)); + QUICHE_DCHECK(inserted); + return &it->second; +} +Stream* EncapsulatedSession::OpenOutgoingBidirectionalStream() { + if (!CanOpenNextOutgoingBidirectionalStream()) { + return nullptr; + } + return OpenOutgoingStream(next_outgoing_bidi_stream_); +} +Stream* EncapsulatedSession::OpenOutgoingUnidirectionalStream() { + if (!CanOpenNextOutgoingUnidirectionalStream()) { + return nullptr; + } + return OpenOutgoingStream(next_outgoing_unidi_stream_); +} + +Stream* EncapsulatedSession::GetStreamById(StreamId id) { + auto it = streams_.find(id); + if (it == streams_.end()) { + return nullptr; + } + return &it->second; +} + DatagramStats EncapsulatedSession::GetDatagramStats() { DatagramStats stats; stats.expired_outgoing = 0; @@ -149,8 +210,7 @@ } void EncapsulatedSession::NotifySessionDraining() { - control_capsule_queue_.push_back(quiche::SerializeCapsule( - quiche::Capsule(quiche::DrainWebTransportSessionCapsule()), allocator_)); + SendControlCapsule(quiche::DrainWebTransportSessionCapsule()); OnCanWrite(); } void EncapsulatedSession::SetOnDraining( @@ -256,7 +316,21 @@ control_capsule_queue_.pop_front(); } - // TODO(b/264263113): send stream data. + while (writer_->CanWrite()) { + absl::StatusOr<StreamId> next_id = scheduler_.PopFront(); + if (!next_id.ok()) { + QUICHE_DCHECK_EQ(next_id.status().code(), absl::StatusCode::kNotFound); + return; + } + auto it = streams_.find(*next_id); + if (it == streams_.end()) { + QUICHE_BUG(WT_H2_NextStreamNotInTheMap); + OnFatalError("Next scheduled stream is not in the map"); + return; + } + QUICHE_DCHECK(it->second.HasPendingWrite()); + it->second.FlushPendingWrite(); + } } void EncapsulatedSession::OnCanRead() { @@ -271,6 +345,9 @@ capsule_parser_.ErrorIfThereIsRemainingBufferedData(); OnSessionClosed(0, ""); } + if (state_ == kSessionOpen) { + GarbageCollectStreams(); + } } bool EncapsulatedSession::OnCapsule(const quiche::Capsule& capsule) { @@ -290,17 +367,114 @@ std::string( capsule.close_web_transport_session_capsule().error_message)); break; + case CapsuleType::WT_STREAM: + case CapsuleType::WT_STREAM_WITH_FIN: + ProcessStreamCapsule(capsule, + capsule.web_transport_stream_data().stream_id); + break; + case CapsuleType::WT_RESET_STREAM: + ProcessStreamCapsule(capsule, + capsule.web_transport_reset_stream().stream_id); + break; + case CapsuleType::WT_STOP_SENDING: + ProcessStreamCapsule(capsule, + capsule.web_transport_stop_sending().stream_id); + break; default: break; } - return true; + return state_ != kSessionClosed; } void EncapsulatedSession::OnCapsuleParseFailure( absl::string_view error_message) { + if (state_ == kSessionClosed) { + return; + } OnFatalError(absl::StrCat("Stream parse error: ", error_message)); } +void EncapsulatedSession::ProcessStreamCapsule(const quiche::Capsule& capsule, + StreamId stream_id) { + bool new_stream_created = false; + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + if (IsOutgoing(stream_id)) { + // Ignore this frame, as it is possible that it refers to an outgoing + // stream that has been closed. + return; + } + // TODO: check flow control here. + it = streams_.emplace_hint(it, std::piecewise_construct, + std::forward_as_tuple(stream_id), + std::forward_as_tuple(this, stream_id)); + new_stream_created = true; + } + InnerStream& stream = it->second; + stream.ProcessCapsule(capsule); + if (new_stream_created) { + if (IsBidirectionalId(stream_id)) { + incoming_bidirectional_streams_.push_back(stream_id); + visitor_->OnIncomingBidirectionalStreamAvailable(); + } else { + incoming_unidirectional_streams_.push_back(stream_id); + visitor_->OnIncomingUnidirectionalStreamAvailable(); + } + } +} + +void EncapsulatedSession::InnerStream::ProcessCapsule( + const quiche::Capsule& capsule) { + switch (capsule.capsule_type()) { + case CapsuleType::WT_STREAM: + case CapsuleType::WT_STREAM_WITH_FIN: { + if (fin_received_) { + session_->OnFatalError( + "Received stream data for a stream that has already received a " + "FIN"); + return; + } + if (read_side_closed_) { + // It is possible that we sent STOP_SENDING but it has not been received + // yet. Ignore. + return; + } + fin_received_ = capsule.capsule_type() == CapsuleType::WT_STREAM_WITH_FIN; + const quiche::WebTransportStreamDataCapsule& data = + capsule.web_transport_stream_data(); + if (!data.data.empty()) { + incoming_reads_.push_back(IncomingRead{data.data, std::string()}); + } + // Fast path: if the visitor consumes all of the incoming reads, we don't + // need to copy data from the capsule parser. + if (visitor_ != nullptr) { + visitor_->OnCanRead(); + } + // Slow path: copy all data that the visitor have not consumed. + for (IncomingRead& read : incoming_reads_) { + QUICHE_DCHECK(!read.data.empty()); + if (read.storage.empty()) { + read.storage = std::string(read.data); + read.data = read.storage; + } + } + return; + } + case CapsuleType::WT_RESET_STREAM: + CloseReadSide(capsule.web_transport_reset_stream().error_code); + return; + case CapsuleType::WT_STOP_SENDING: + CloseWriteSide(capsule.web_transport_stop_sending().error_code); + return; + default: + QUICHE_BUG(WT_H2_ProcessStreamCapsule_Unknown) + << "Unexpected capsule dispatched to InnerStream: " << capsule; + session_->OnFatalError( + "Internal error: Unexpected capsule dispatched to InnerStream"); + return; + } +} + void EncapsulatedSession::OpenSession() { state_ = kSessionOpen; visitor_->OnSessionReady(); @@ -344,6 +518,7 @@ state_ = kSessionClosed; if (fatal_error_callback_) { std::move(fatal_error_callback_)(error_message); + fatal_error_callback_ = nullptr; } } @@ -352,4 +527,260 @@ error, " while trying to write encapsulated WebTransport data")); } +EncapsulatedSession::InnerStream::InnerStream(EncapsulatedSession* session, + StreamId id) + : session_(session), + id_(id), + read_side_closed_(IsUnidirectionalId(id) && + IsIdOpenedBy(id, session->perspective_)), + write_side_closed_(IsUnidirectionalId(id) && + !IsIdOpenedBy(id, session->perspective_)) { + if (!write_side_closed_) { + absl::Status status = session_->scheduler_.Register(id_, kDefaultPriority); + if (!status.ok()) { + QUICHE_BUG(WT_H2_FailedToRegisterNewStream) << status; + session_->OnFatalError( + "Failed to register new stream with the scheduler"); + return; + } + } +} + +quiche::ReadStream::ReadResult EncapsulatedSession::InnerStream::Read( + absl::Span<char> output) { + const size_t total_size = output.size(); + for (const IncomingRead& read : incoming_reads_) { + size_t size_to_read = std::min(read.size(), output.size()); + if (size_to_read == 0) { + break; + } + memcpy(output.data(), read.data.data(), size_to_read); + output = output.subspan(size_to_read); + } + bool fin_consumed = SkipBytes(total_size); + return ReadResult{total_size, fin_consumed}; +} +quiche::ReadStream::ReadResult EncapsulatedSession::InnerStream::Read( + std::string* output) { + const size_t total_size = ReadableBytes(); + const size_t initial_offset = output->size(); + output->resize(initial_offset + total_size); + return Read(absl::Span<char>(&((*output)[initial_offset]), total_size)); +} +size_t EncapsulatedSession::InnerStream::ReadableBytes() const { + size_t total_size = 0; + for (const IncomingRead& read : incoming_reads_) { + total_size += read.size(); + } + return total_size; +} +quiche::ReadStream::PeekResult +EncapsulatedSession::InnerStream::PeekNextReadableRegion() const { + if (incoming_reads_.empty()) { + return PeekResult{absl::string_view(), fin_received_, fin_received_}; + } + return PeekResult{incoming_reads_.front().data, + fin_received_ && incoming_reads_.size() == 1, + fin_received_}; +} + +bool EncapsulatedSession::InnerStream::SkipBytes(size_t bytes) { + size_t remaining = bytes; + while (remaining > 0) { + if (incoming_reads_.empty()) { + QUICHE_BUG(WT_H2_SkipBytes_toomuch) + << "Requested to skip " << remaining + << " bytes that are not present in the read buffer."; + return false; + } + IncomingRead& current = incoming_reads_.front(); + if (remaining < current.size()) { + current.data = current.data.substr(remaining); + return false; + } + remaining -= current.size(); + incoming_reads_.pop_front(); + } + if (incoming_reads_.empty() && fin_received_) { + fin_consumed_ = true; + CloseReadSide(std::nullopt); + return true; + } + return false; +} + +absl::Status EncapsulatedSession::InnerStream::Writev( + const absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + if (write_side_closed_) { + return absl::FailedPreconditionError( + "Trying to write into an already-closed stream"); + } + if (fin_buffered_) { + return absl::FailedPreconditionError("FIN already buffered"); + } + if (!CanWrite()) { + return absl::FailedPreconditionError( + "Trying to write into a stream when CanWrite() = false"); + } + + const absl::StatusOr<bool> should_yield = + session_->scheduler_.ShouldYield(id_); + if (!should_yield.ok()) { + QUICHE_BUG(WT_H2_Writev_NotRegistered) << should_yield.status(); + session_->OnFatalError("Stream not registered with the scheduler"); + return absl::InternalError("Stream not registered with the scheduler"); + } + const bool write_blocked = !session_->writer_->CanWrite() || *should_yield || + !pending_write_.empty(); + if (write_blocked) { + fin_buffered_ = options.send_fin(); + for (absl::string_view chunk : data) { + absl::StrAppend(&pending_write_, chunk); + } + absl::Status status = session_->scheduler_.Schedule(id_); + if (!status.ok()) { + QUICHE_BUG(WT_H2_Writev_CantSchedule) << status; + session_->OnFatalError("Could not schedule a write-blocked stream"); + return absl::InternalError("Could not schedule a write-blocked stream"); + } + return absl::OkStatus(); + } + + size_t bytes_written = WriteInner(data, options.send_fin()); + // TODO: handle partial writes when flow control requires those. + QUICHE_DCHECK(bytes_written == 0 || + bytes_written == quiche::TotalStringViewSpanSize(data)); + if (bytes_written == 0) { + for (absl::string_view chunk : data) { + absl::StrAppend(&pending_write_, chunk); + } + } + + if (options.send_fin()) { + CloseWriteSide(std::nullopt); + } + return absl::OkStatus(); +} + +bool EncapsulatedSession::InnerStream::CanWrite() const { + return session_->state_ != EncapsulatedSession::kSessionClosed && + !write_side_closed_ && + (pending_write_.size() <= session_->max_stream_data_buffered_); +} + +void EncapsulatedSession::InnerStream::FlushPendingWrite() { + QUICHE_DCHECK(!write_side_closed_); + QUICHE_DCHECK(session_->writer_->CanWrite()); + QUICHE_DCHECK(!pending_write_.empty()); + absl::string_view to_write = pending_write_; + size_t bytes_written = + WriteInner(absl::MakeSpan(&to_write, 1), fin_buffered_); + if (bytes_written < to_write.size()) { + pending_write_ = pending_write_.substr(bytes_written); + return; + } + pending_write_.clear(); + if (fin_buffered_) { + CloseWriteSide(std::nullopt); + } + if (!write_side_closed_ && visitor_ != nullptr) { + visitor_->OnCanWrite(); + } +} + +size_t EncapsulatedSession::InnerStream::WriteInner( + absl::Span<const absl::string_view> data, bool fin) { + size_t total_size = quiche::TotalStringViewSpanSize(data); + if (total_size == 0 && !fin) { + session_->OnFatalError("Attempted to make an empty write with fin=false"); + return 0; + } + quiche::QuicheBuffer header = + quiche::SerializeWebTransportStreamCapsuleHeader(id_, fin, total_size, + session_->allocator_); + std::vector<absl::string_view> views_to_write; + views_to_write.reserve(data.size() + 1); + views_to_write.push_back(header.AsStringView()); + absl::c_copy(data, std::back_inserter(views_to_write)); + absl::Status write_status = session_->writer_->Writev( + views_to_write, quiche::kDefaultStreamWriteOptions); + if (!write_status.ok()) { + session_->OnWriteError(write_status); + return 0; + } + return total_size; +} + +void EncapsulatedSession::InnerStream::AbruptlyTerminate(absl::Status error) { + QUICHE_DLOG(INFO) << "Abruptly terminating the stream due to error: " + << error; + ResetDueToInternalError(); +} + +void EncapsulatedSession::InnerStream::ResetWithUserCode( + StreamErrorCode error) { + if (reset_frame_sent_) { + return; + } + reset_frame_sent_ = true; + + session_->SendControlCapsule( + quiche::WebTransportResetStreamCapsule{id_, error}); + CloseWriteSide(std::nullopt); +} + +void EncapsulatedSession::InnerStream::SendStopSending(StreamErrorCode error) { + if (stop_sending_sent_) { + return; + } + stop_sending_sent_ = true; + + session_->SendControlCapsule( + quiche::WebTransportStopSendingCapsule{id_, error}); + CloseReadSide(std::nullopt); +} + +void EncapsulatedSession::InnerStream::CloseReadSide( + std::optional<StreamErrorCode> error) { + if (read_side_closed_) { + return; + } + read_side_closed_ = true; + incoming_reads_.clear(); + if (error.has_value() && visitor_ != nullptr) { + visitor_->OnResetStreamReceived(*error); + } + if (CanBeGarbageCollected()) { + session_->streams_to_garbage_collect_.push_back(id_); + } +} + +void EncapsulatedSession::InnerStream::CloseWriteSide( + std::optional<StreamErrorCode> error) { + if (write_side_closed_) { + return; + } + write_side_closed_ = true; + pending_write_.clear(); + absl::Status status = session_->scheduler_.Unregister(id_); + if (!status.ok()) { + session_->OnFatalError("Failed to unregister closed stream"); + return; + } + if (error.has_value() && visitor_ != nullptr) { + visitor_->OnStopSendingReceived(*error); + } + if (CanBeGarbageCollected()) { + session_->streams_to_garbage_collect_.push_back(id_); + } +} + +void EncapsulatedSession::GarbageCollectStreams() { + for (StreamId id : streams_to_garbage_collect_) { + streams_.erase(id); + } + streams_to_garbage_collect_.clear(); +} + } // namespace webtransport
diff --git a/quiche/web_transport/encapsulated/encapsulated_web_transport.h b/quiche/web_transport/encapsulated/encapsulated_web_transport.h index 85c14c9..3247267 100644 --- a/quiche/web_transport/encapsulated/encapsulated_web_transport.h +++ b/quiche/web_transport/encapsulated/encapsulated_web_transport.h
@@ -5,13 +5,20 @@ #ifndef QUICHE_WEB_TRANSPORT_ENCAPSULATED_ENCAPSULATED_WEB_TRANSPORT_H_ #define QUICHE_WEB_TRANSPORT_ENCAPSULATED_ENCAPSULATED_WEB_TRANSPORT_H_ +#include <cstddef> #include <cstdint> #include <memory> +#include <optional> #include <string> +#include <utility> +#include <vector> +#include "absl/base/attributes.h" +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "quiche/common/capsule.h" #include "quiche/common/http/http_header_block.h" #include "quiche/common/platform/api/quiche_export.h" @@ -21,9 +28,18 @@ #include "quiche/common/quiche_stream.h" #include "quiche/common/simple_buffer_allocator.h" #include "quiche/web_transport/web_transport.h" +#include "quiche/web_transport/web_transport_priority_scheduler.h" namespace webtransport { +constexpr bool IsUnidirectionalId(StreamId id) { return id & 0b10; } +constexpr bool IsBidirectionalId(StreamId id) { + return !IsUnidirectionalId(id); +} +constexpr bool IsIdOpenedBy(StreamId id, Perspective perspective) { + return (id & 0b01) ^ (perspective == Perspective::kClient); +} + using FatalErrorCallback = quiche::SingleUseCallback<void(absl::string_view)>; // Implementation of the WebTransport over HTTP/2 protocol; works over any @@ -97,7 +113,103 @@ State state() const { return state_; } + // Cleans up the state for all of the streams that have been closed. QUIC + // uses timers to safely delete closed streams while minimizing the risk that + // something on stack holds an active pointer to them; WebTransport over + // HTTP/2 does not have any timers in it, making that approach inapplicable + // here. This class does automatically run garbage collection at the end of + // every OnCanRead() call (since it's a top-level entrypoint that is likely to + // come directly from I/O handler), but if the application does not happen to + // read data frequently, manual calls to this function may be requried. + void GarbageCollectStreams(); + private: + // If the amount of data buffered in the socket exceeds the amount specified + // here, CanWrite() will start returning false. + static constexpr size_t kDefaultMaxBufferedStreamData = 16 * 1024; + + class InnerStream : public Stream { + public: + InnerStream(EncapsulatedSession* session, StreamId id); + InnerStream(const InnerStream&) = delete; + InnerStream(InnerStream&&) = delete; + InnerStream& operator=(const InnerStream&) = delete; + InnerStream& operator=(InnerStream&&) = delete; + + // ReadStream implementation. + ABSL_MUST_USE_RESULT ReadResult Read(absl::Span<char> output) override; + ABSL_MUST_USE_RESULT ReadResult Read(std::string* output) override; + size_t ReadableBytes() const override; + PeekResult PeekNextReadableRegion() const override; + bool SkipBytes(size_t bytes) override; + + // WriteStream implementation. + absl::Status Writev(absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) override; + bool CanWrite() const override; + + // TerminableStream implementation. + void AbruptlyTerminate(absl::Status error) override; + + // Stream implementation. + StreamId GetStreamId() const override { return id_; } + StreamVisitor* visitor() override { return visitor_.get(); } + void SetVisitor(std::unique_ptr<StreamVisitor> visitor) override { + visitor_ = std::move(visitor); + } + + void ResetWithUserCode(StreamErrorCode error) override; + void SendStopSending(StreamErrorCode error) override; + + void ResetDueToInternalError() override { ResetWithUserCode(0); } + void MaybeResetDueToStreamObjectGone() override { ResetWithUserCode(0); } + + void CloseReadSide(std::optional<StreamErrorCode> error); + void CloseWriteSide(std::optional<StreamErrorCode> error); + bool CanBeGarbageCollected() const { + return read_side_closed_ && write_side_closed_; + } + + bool HasPendingWrite() const { return !pending_write_.empty(); } + void FlushPendingWrite(); + + void ProcessCapsule(const quiche::Capsule& capsule); + + private: + // Struct for storing data that can potentially either stored inside the + // object or inside some other object on the stack. Here is roughly how this + // works: + // 1. A read is enqueued with `data` pointing to a temporary buffer, and + // `storage` being empty. + // 2. Visitor::OnCanRead() is called, potentially causing the user to + // consume the data from the temporary buffer directly. + // 3. If user does not consume data immediately, it's copied to `storage` + // (and the pointer to `data` is updated) so that it can be read later. + struct IncomingRead { + absl::string_view data; + std::string storage; + + size_t size() const { return data.size(); } + }; + + // Tries to send `data`; may send less if limited by flow control. + [[nodiscard]] size_t WriteInner(absl::Span<const absl::string_view> data, + bool fin); + + EncapsulatedSession* session_; + StreamId id_; + std::unique_ptr<StreamVisitor> visitor_; + quiche::QuicheCircularDeque<IncomingRead> incoming_reads_; + std::string pending_write_; + bool read_side_closed_; + bool write_side_closed_; + bool reset_frame_sent_ = false; + bool stop_sending_sent_ = false; + bool fin_received_ = false; + bool fin_consumed_ = false; + bool fin_buffered_ = false; + }; + struct BufferedClose { SessionErrorCode error_code = 0; std::string error_message; @@ -114,6 +226,18 @@ quiche::SimpleBufferAllocator::Get(); quiche::CapsuleParser capsule_parser_; + size_t max_stream_data_buffered_ = kDefaultMaxBufferedStreamData; + + PriorityScheduler scheduler_; + absl::node_hash_map<StreamId, InnerStream> + streams_; // Streams unregister themselves with scheduler on deletion, + // and thus have to be above it. + quiche::QuicheCircularDeque<StreamId> incoming_bidirectional_streams_; + quiche::QuicheCircularDeque<StreamId> incoming_unidirectional_streams_; + std::vector<StreamId> streams_to_garbage_collect_; + StreamId next_outgoing_bidi_stream_; + StreamId next_outgoing_unidi_stream_; + bool session_close_notified_ = false; bool fin_sent_ = false; @@ -126,6 +250,20 @@ const std::string& error_message); void OnFatalError(absl::string_view error_message); void OnWriteError(absl::Status error); + + bool IsOutgoing(StreamId id) { return IsIdOpenedBy(id, perspective_); } + bool IsIncoming(StreamId id) { return !IsOutgoing(id); } + + template <typename CapsuleType> + void SendControlCapsule(CapsuleType capsule) { + control_capsule_queue_.push_back(quiche::SerializeCapsule( + quiche::Capsule(std::move(capsule)), allocator_)); + OnCanWrite(); + } + + Stream* AcceptIncomingStream(quiche::QuicheCircularDeque<StreamId>& queue); + Stream* OpenOutgoingStream(StreamId& counter); + void ProcessStreamCapsule(const quiche::Capsule& capsule, StreamId stream_id); }; } // namespace webtransport
diff --git a/quiche/web_transport/encapsulated/encapsulated_web_transport_test.cc b/quiche/web_transport/encapsulated/encapsulated_web_transport_test.cc index 15728e4..8c0c79f 100644 --- a/quiche/web_transport/encapsulated/encapsulated_web_transport_test.cc +++ b/quiche/web_transport/encapsulated/encapsulated_web_transport_test.cc
@@ -4,9 +4,11 @@ #include "quiche/web_transport/encapsulated/encapsulated_web_transport.h" +#include <array> #include <memory> #include <string> #include <utility> +#include <vector> #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -18,6 +20,7 @@ #include "quiche/common/quiche_stream.h" #include "quiche/common/simple_buffer_allocator.h" #include "quiche/common/test_tools/mock_streams.h" +#include "quiche/common/test_tools/quiche_test_utils.h" #include "quiche/web_transport/test_tools/mock_web_transport.h" #include "quiche/web_transport/web_transport.h" @@ -26,7 +29,9 @@ using ::quiche::Capsule; using ::quiche::CapsuleType; +using ::quiche::test::StatusIs; using ::testing::_; +using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Return; @@ -78,6 +83,14 @@ session_->OnCanRead(); } + template <typename CapsuleType> + void ProcessIncomingCapsule(const CapsuleType& capsule) { + quiche::QuicheBuffer buffer = quiche::SerializeCapsule( + quiche::Capsule(capsule), quiche::SimpleBufferAllocator::Get()); + read_buffer_.append(buffer.data(), buffer.size()); + session_->OnCanRead(); + } + void DefaultHandshakeForClient(EncapsulatedSession& session) { quiche::HttpHeaderBlock outgoing_headers, incoming_headers; session.InitializeClient(CreateAndStoreVisitor(), outgoing_headers, @@ -96,6 +109,18 @@ testing::MockFunction<void(absl::string_view)> fatal_error_callback_; }; +TEST_F(EncapsulatedWebTransportTest, IsOpenedBy) { + EXPECT_EQ(IsIdOpenedBy(0x00, Perspective::kClient), true); + EXPECT_EQ(IsIdOpenedBy(0x01, Perspective::kClient), false); + EXPECT_EQ(IsIdOpenedBy(0x02, Perspective::kClient), true); + EXPECT_EQ(IsIdOpenedBy(0x03, Perspective::kClient), false); + + EXPECT_EQ(IsIdOpenedBy(0x00, Perspective::kServer), false); + EXPECT_EQ(IsIdOpenedBy(0x01, Perspective::kServer), true); + EXPECT_EQ(IsIdOpenedBy(0x02, Perspective::kServer), false); + EXPECT_EQ(IsIdOpenedBy(0x03, Perspective::kServer), true); +} + TEST_F(EncapsulatedWebTransportTest, SetupClientSession) { std::unique_ptr<EncapsulatedSession> session = CreateTransport(Perspective::kClient); @@ -314,5 +339,448 @@ session->NotifySessionDraining(); } +TEST_F(EncapsulatedWebTransportTest, SimpleRead) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + bool stream_received = false; + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()) + .WillOnce([&] { stream_received = true; }); + std::string data = "test"; + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, data, false}); + // Make sure data gets copied. + data[0] = 'q'; + EXPECT_TRUE(stream_received); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_EQ(stream->GetStreamId(), 1u); + EXPECT_EQ(stream->visitor(), nullptr); + EXPECT_EQ(stream->ReadableBytes(), 4u); + + quiche::ReadStream::PeekResult peek = stream->PeekNextReadableRegion(); + EXPECT_EQ(peek.peeked_data, "test"); + EXPECT_FALSE(peek.fin_next); + EXPECT_FALSE(peek.all_data_received); + + std::string buffer; + quiche::ReadStream::ReadResult read = stream->Read(&buffer); + EXPECT_EQ(read.bytes_read, 4); + EXPECT_FALSE(read.fin); + EXPECT_EQ(buffer, "test"); + EXPECT_EQ(stream->ReadableBytes(), 0u); +} + +class MockStreamVisitorWithDestructor : public MockStreamVisitor { + public: + ~MockStreamVisitorWithDestructor() { OnDelete(); } + + MOCK_METHOD(void, OnDelete, (), ()); +}; + +MockStreamVisitorWithDestructor* SetupVisitor(Stream& stream) { + auto visitor = std::make_unique<MockStreamVisitorWithDestructor>(); + MockStreamVisitorWithDestructor* result = visitor.get(); + stream.SetVisitor(std::move(visitor)); + return result; +} + +TEST_F(EncapsulatedWebTransportTest, ImmediateRead) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + ProcessIncomingCapsule( + quiche::WebTransportStreamDataCapsule{1, "abcd", false}); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_EQ(stream->ReadableBytes(), 4u); + + MockStreamVisitor* visitor = SetupVisitor(*stream); + EXPECT_CALL(*visitor, OnCanRead()).WillOnce([&] { + std::string output; + (void)stream->Read(&output); + EXPECT_EQ(output, "abcdef"); + }); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, "ef", false}); +} + +TEST_F(EncapsulatedWebTransportTest, FinPeek) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + ProcessIncomingCapsule( + quiche::WebTransportStreamDataCapsule{1, "abcd", false}); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_EQ(stream->ReadableBytes(), 4u); + + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, "ef", true}); + + quiche::ReadStream::PeekResult peek = stream->PeekNextReadableRegion(); + EXPECT_EQ(peek.peeked_data, "abcd"); + EXPECT_FALSE(peek.fin_next); + EXPECT_TRUE(peek.all_data_received); + + EXPECT_FALSE(stream->SkipBytes(2)); + peek = stream->PeekNextReadableRegion(); + EXPECT_FALSE(peek.fin_next); + EXPECT_TRUE(peek.all_data_received); + + EXPECT_FALSE(stream->SkipBytes(2)); + peek = stream->PeekNextReadableRegion(); + EXPECT_EQ(peek.peeked_data, "ef"); + EXPECT_TRUE(peek.fin_next); + EXPECT_TRUE(peek.all_data_received); + + EXPECT_TRUE(stream->SkipBytes(2)); +} + +TEST_F(EncapsulatedWebTransportTest, FinRead) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + ProcessIncomingCapsule( + quiche::WebTransportStreamDataCapsule{1, "abcdef", true}); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_EQ(stream->ReadableBytes(), 6u); + + std::array<char, 3> buffer; + quiche::ReadStream::ReadResult read = stream->Read(absl::MakeSpan(buffer)); + EXPECT_THAT(buffer, ElementsAre('a', 'b', 'c')); + EXPECT_EQ(read.bytes_read, 3); + EXPECT_FALSE(read.fin); + + read = stream->Read(absl::MakeSpan(buffer)); + EXPECT_THAT(buffer, ElementsAre('d', 'e', 'f')); + EXPECT_EQ(read.bytes_read, 3); + EXPECT_TRUE(read.fin); +} + +TEST_F(EncapsulatedWebTransportTest, LargeRead) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{ + 1, std::string(64 * 1024, 'a'), true}); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_EQ(stream->ReadableBytes(), 65536u); + + for (int i = 0; i < 64; i++) { + std::array<char, 1024> buffer; + quiche::ReadStream::ReadResult read = stream->Read(absl::MakeSpan(buffer)); + EXPECT_EQ(read.bytes_read, 1024); + EXPECT_EQ(read.fin, i == 63); + } +} + +TEST_F(EncapsulatedWebTransportTest, DoubleFinReceived) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, "abc", true}); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + EXPECT_CALL(fatal_error_callback_, Call(_)) + .WillOnce([](absl::string_view error) { + EXPECT_THAT(error, HasSubstr("has already received a FIN")); + }); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, "def", true}); +} + +TEST_F(EncapsulatedWebTransportTest, CanWriteUnidiBidi) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + EXPECT_CALL(*visitor_, OnIncomingUnidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, "abc", true}); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{3, "abc", true}); + + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->CanWrite()); + + stream = session->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_FALSE(stream->CanWrite()); + + stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->CanWrite()); + + stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->CanWrite()); +} + +TEST_F(EncapsulatedWebTransportTest, ReadOnlyGarbageCollection) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingUnidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{3, "abc", true}); + + Stream* stream = session->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->SkipBytes(3)); + + MockStreamVisitorWithDestructor* visitor = SetupVisitor(*stream); + bool deleted = false; + EXPECT_CALL(*visitor, OnDelete()).WillOnce([&] { deleted = true; }); + session->GarbageCollectStreams(); + EXPECT_TRUE(deleted); +} + +TEST_F(EncapsulatedWebTransportTest, WriteOnlyGarbageCollection) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + MockStreamVisitorWithDestructor* visitor = SetupVisitor(*stream); + bool deleted = false; + EXPECT_CALL(*visitor, OnDelete()).WillOnce([&] { deleted = true; }); + EXPECT_CALL(*this, OnCapsule(_)).WillOnce(Return(true)); + + quiche::StreamWriteOptions options; + options.set_send_fin(true); + EXPECT_THAT(stream->Writev(absl::Span<const absl::string_view>(), options), + StatusIs(absl::StatusCode::kOk)); + session->GarbageCollectStreams(); + EXPECT_TRUE(deleted); +} + +TEST_F(EncapsulatedWebTransportTest, SimpleWrite) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingBidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{1, "", true}); + Stream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STREAM); + EXPECT_EQ(capsule.web_transport_stream_data().stream_id, 1u); + EXPECT_EQ(capsule.web_transport_stream_data().fin, false); + EXPECT_EQ(capsule.web_transport_stream_data().data, "test"); + return true; + }); + absl::Status status = quiche::WriteIntoStream(*stream, "test"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); +} + +TEST_F(EncapsulatedWebTransportTest, WriteWithFin) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STREAM_WITH_FIN); + EXPECT_EQ(capsule.web_transport_stream_data().stream_id, 2u); + EXPECT_EQ(capsule.web_transport_stream_data().fin, true); + EXPECT_EQ(capsule.web_transport_stream_data().data, "test"); + return true; + }); + quiche::StreamWriteOptions options; + options.set_send_fin(true); + EXPECT_TRUE(stream->CanWrite()); + absl::Status status = quiche::WriteIntoStream(*stream, "test", options); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + EXPECT_FALSE(stream->CanWrite()); +} + +TEST_F(EncapsulatedWebTransportTest, FinOnlyWrite) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STREAM_WITH_FIN); + EXPECT_EQ(capsule.web_transport_stream_data().stream_id, 2u); + EXPECT_EQ(capsule.web_transport_stream_data().fin, true); + EXPECT_EQ(capsule.web_transport_stream_data().data, ""); + return true; + }); + quiche::StreamWriteOptions options; + options.set_send_fin(true); + EXPECT_TRUE(stream->CanWrite()); + absl::Status status = + stream->Writev(absl::Span<const absl::string_view>(), options); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + EXPECT_FALSE(stream->CanWrite()); +} + +TEST_F(EncapsulatedWebTransportTest, BufferedWriteThenUnbuffer) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + EXPECT_CALL(writer_, CanWrite()).WillOnce(Return(false)); + absl::Status status = quiche::WriteIntoStream(*stream, "abc"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + + // While the stream cannot be written right now, we should be still able to + // buffer data into it. + EXPECT_TRUE(stream->CanWrite()); + EXPECT_CALL(writer_, CanWrite()).WillRepeatedly(Return(true)); + status = quiche::WriteIntoStream(*stream, "def"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STREAM); + EXPECT_EQ(capsule.web_transport_stream_data().stream_id, 2u); + EXPECT_EQ(capsule.web_transport_stream_data().data, "abcdef"); + return true; + }); + session_->OnCanWrite(); +} + +TEST_F(EncapsulatedWebTransportTest, BufferedWriteThenFlush) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + EXPECT_CALL(writer_, CanWrite()).Times(2).WillRepeatedly(Return(false)); + absl::Status status = quiche::WriteIntoStream(*stream, "abc"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + status = quiche::WriteIntoStream(*stream, "def"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + + EXPECT_CALL(writer_, CanWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STREAM); + EXPECT_EQ(capsule.web_transport_stream_data().stream_id, 2u); + EXPECT_EQ(capsule.web_transport_stream_data().data, "abcdef"); + return true; + }); + session_->OnCanWrite(); +} + +TEST_F(EncapsulatedWebTransportTest, BufferedStreamBlocksAnother) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream1 = session->OpenOutgoingUnidirectionalStream(); + Stream* stream2 = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream1 != nullptr); + ASSERT_TRUE(stream2 != nullptr); + + EXPECT_CALL(*this, OnCapsule(_)).Times(0); + EXPECT_CALL(writer_, CanWrite()).WillOnce(Return(false)); + absl::Status status = quiche::WriteIntoStream(*stream1, "abc"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + // ShouldYield will return false here, causing the write to get buffered. + EXPECT_CALL(writer_, CanWrite()).WillRepeatedly(Return(true)); + status = quiche::WriteIntoStream(*stream2, "abc"); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kOk)); + + std::vector<StreamId> writes; + EXPECT_CALL(*this, OnCapsule(_)).WillRepeatedly([&](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STREAM); + writes.push_back(capsule.web_transport_stream_data().stream_id); + return true; + }); + session_->OnCanWrite(); + EXPECT_THAT(writes, ElementsAre(2, 6)); +} + +TEST_F(EncapsulatedWebTransportTest, SendReset) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + MockStreamVisitorWithDestructor* visitor = SetupVisitor(*stream); + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([&](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_RESET_STREAM); + EXPECT_EQ(capsule.web_transport_reset_stream().stream_id, 2u); + EXPECT_EQ(capsule.web_transport_reset_stream().error_code, 1234u); + return true; + }); + stream->ResetWithUserCode(1234u); + + bool deleted = false; + EXPECT_CALL(*visitor, OnDelete()).WillOnce([&] { deleted = true; }); + session->GarbageCollectStreams(); + EXPECT_TRUE(deleted); +} + +TEST_F(EncapsulatedWebTransportTest, ReceiveReset) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingUnidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{3, "", true}); + Stream* stream = session->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + MockStreamVisitorWithDestructor* visitor = SetupVisitor(*stream); + EXPECT_CALL(*visitor, OnResetStreamReceived(1234u)); + EXPECT_TRUE(session->GetStreamById(3) != nullptr); + ProcessIncomingCapsule(quiche::WebTransportResetStreamCapsule{3u, 1234u}); + // Reading from the underlying transport automatically triggers garbage + // collection. + EXPECT_TRUE(session->GetStreamById(3) == nullptr); +} + +TEST_F(EncapsulatedWebTransportTest, SendStopSending) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + EXPECT_CALL(*visitor_, OnIncomingUnidirectionalStreamAvailable()); + ProcessIncomingCapsule(quiche::WebTransportStreamDataCapsule{3, "", true}); + Stream* stream = session->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + MockStreamVisitorWithDestructor* visitor = SetupVisitor(*stream); + EXPECT_CALL(*this, OnCapsule(_)).WillOnce([&](const Capsule& capsule) { + EXPECT_EQ(capsule.capsule_type(), CapsuleType::WT_STOP_SENDING); + EXPECT_EQ(capsule.web_transport_stop_sending().stream_id, 3u); + EXPECT_EQ(capsule.web_transport_stop_sending().error_code, 1234u); + return true; + }); + stream->SendStopSending(1234u); + + bool deleted = false; + EXPECT_CALL(*visitor, OnDelete()).WillOnce([&] { deleted = true; }); + session->GarbageCollectStreams(); + EXPECT_TRUE(deleted); +} + +TEST_F(EncapsulatedWebTransportTest, ReceiveStopSending) { + std::unique_ptr<EncapsulatedSession> session = + CreateTransport(Perspective::kClient); + DefaultHandshakeForClient(*session); + Stream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + + MockStreamVisitorWithDestructor* visitor = SetupVisitor(*stream); + EXPECT_CALL(*visitor, OnStopSendingReceived(1234u)); + EXPECT_TRUE(session->GetStreamById(2) != nullptr); + ProcessIncomingCapsule(quiche::WebTransportStopSendingCapsule{2u, 1234u}); + // Reading from the underlying transport automatically triggers garbage + // collection. + EXPECT_TRUE(session->GetStreamById(2) == nullptr); +} + } // namespace } // namespace webtransport::test