Redo write support in InMemoryStream. Add a buffer-based version (used in this CL) and a gmock-based version (will be used in a follow-up CL). PiperOrigin-RevId: 872023477
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 2819ab5..e77fdcd 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -44,6 +44,7 @@ #include "quiche/quic/platform/api/quic_test.h" #include "quiche/quic/test_tools/quic_test_utils.h" #include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_data_reader.h" #include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_stream.h" #include "quiche/common/quiche_weak_ptr.h" @@ -128,6 +129,15 @@ return publisher; } +std::optional<MoqtMessageType> PeekControlMessageType(absl::string_view data) { + quiche::QuicheDataReader reader(data); + uint64_t varint; + if (!reader.ReadVarInt62(&varint)) { + return std::nullopt; + } + return static_cast<MoqtMessageType>(varint); +} + } // namespace class MoqtSessionTest : public quic::test::QuicTest { @@ -326,7 +336,7 @@ std::make_unique<quic::test::TestAlarmFactory>(), session_callbacks_.AsSessionCallbacks()); // Load a CLIENT_SETUP message into an in-memory stream. - webtransport::test::InMemoryStream in_memory_stream(0); + webtransport::test::InMemoryStreamWithWriteBuffer in_memory_stream(0); MoqtFramer framer(session_parameters.using_webtrans); MoqtClientSetup setup; session_parameters.ToSetupParameters(setup.parameters); @@ -339,8 +349,8 @@ .WillOnce(Return(nullptr)); EXPECT_CALL(session_callbacks_.session_established_callback, Call()); server_session.OnIncomingBidirectionalStreamAvailable(); - EXPECT_EQ(static_cast<uint8_t>(in_memory_stream.last_data_sent()[0]), - static_cast<uint8_t>(MoqtMessageType::kServerSetup)); + EXPECT_EQ(PeekControlMessageType(in_memory_stream.write_buffer()), + MoqtMessageType::kServerSetup); EXPECT_NE(MoqtSessionPeer::GetControlStream(&server_session), nullptr); } @@ -2825,7 +2835,8 @@ MessageParameters parameters; parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, "foo"); - auto bidi_stream = std::make_unique<webtransport::test::InMemoryStream>(4); + auto bidi_stream = + std::make_unique<webtransport::test::InMemoryStreamWithWriteBuffer>(4); MoqtFramer framer(true); MoqtSubscribeNamespace subscribe_namespace = { /*request_id=*/1, prefix, SubscribeNamespaceOption::kBoth, parameters}; @@ -2847,8 +2858,9 @@ return task_ptr; }); session_.OnIncomingBidirectionalStreamAvailable(); - EXPECT_EQ(static_cast<uint8_t>(bidi_stream->last_data_sent().data()[0]), - static_cast<uint8_t>(MoqtMessageType::kRequestOk)); + EXPECT_EQ(PeekControlMessageType(bidi_stream->write_buffer()), + MoqtMessageType::kRequestOk); + bidi_stream->write_buffer().clear(); // Deliver a NAMESPACE ASSERT_TRUE(task.IsValid()); @@ -2862,8 +2874,8 @@ task.GetIfAvailable()->InvokeCallback(); char expected_data[] = {0x08, 0x00, 0x05, 0x01, 0x03, 'b', 'a', 'r'}; absl::string_view expected_data_view(expected_data, sizeof(expected_data)); - EXPECT_EQ(expected_data_view, bidi_stream->last_data_sent().substr( - 0, expected_data_view.length())); + EXPECT_EQ(expected_data_view, + bidi_stream->write_buffer().substr(0, expected_data_view.length())); // Unsubscribe bidi_stream.reset(); @@ -2875,7 +2887,7 @@ MessageParameters parameters; parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, "foo"); - webtransport::test::InMemoryStream bidi_stream(4); + webtransport::test::InMemoryStreamWithWriteBuffer bidi_stream(4); MoqtFramer framer(true); MoqtSubscribeNamespace subscribe_namespace = { /*request_id=*/1, prefix, SubscribeNamespaceOption::kBoth, parameters}; @@ -2895,8 +2907,8 @@ return nullptr; }); session_.OnIncomingBidirectionalStreamAvailable(); - EXPECT_EQ(static_cast<uint8_t>(bidi_stream.last_data_sent().data()[0]), - static_cast<uint8_t>(MoqtMessageType::kRequestError)); + EXPECT_EQ(PeekControlMessageType(bidi_stream.write_buffer()), + MoqtMessageType::kRequestError); EXPECT_TRUE(bidi_stream.fin_sent()); } @@ -2905,7 +2917,8 @@ MessageParameters parameters; parameters.authorization_tokens.emplace_back(AuthTokenType::kOutOfBand, "foo"); - webtransport::test::InMemoryStream bidi_stream1(4), bidi_stream2(8); + webtransport::test::InMemoryStreamWithWriteBuffer bidi_stream1(4), + bidi_stream2(8); MoqtFramer framer(true); MoqtSubscribeNamespace subscribe_namespace = { /*request_id=*/1, foo, SubscribeNamespaceOption::kBoth, parameters}; @@ -2925,8 +2938,8 @@ return task_ptr; }); session_.OnIncomingBidirectionalStreamAvailable(); - EXPECT_EQ(static_cast<uint8_t>(bidi_stream1.last_data_sent().data()[0]), - static_cast<uint8_t>(MoqtMessageType::kRequestOk)); + EXPECT_EQ(PeekControlMessageType(bidi_stream1.write_buffer()), + MoqtMessageType::kRequestOk); subscribe_namespace.request_id += 2; subscribe_namespace.track_namespace_prefix = foobar; @@ -2937,8 +2950,8 @@ .WillOnce(Return(&bidi_stream2)) .WillOnce(Return(nullptr)); session_.OnIncomingBidirectionalStreamAvailable(); - EXPECT_EQ(static_cast<uint8_t>(bidi_stream2.last_data_sent().data()[0]), - static_cast<uint8_t>(MoqtMessageType::kRequestError)); + EXPECT_EQ(PeekControlMessageType(bidi_stream2.write_buffer()), + MoqtMessageType::kRequestError); EXPECT_TRUE(bidi_stream2.fin_sent()); }
diff --git a/quiche/web_transport/test_tools/in_memory_stream.cc b/quiche/web_transport/test_tools/in_memory_stream.cc index 8a144bb..167b90e 100644 --- a/quiche/web_transport/test_tools/in_memory_stream.cc +++ b/quiche/web_transport/test_tools/in_memory_stream.cc
@@ -13,9 +13,11 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_stream.h" #include "quiche/common/vectorized_io_utils.h" +#include "quiche/web_transport/web_transport.h" namespace webtransport::test { @@ -58,11 +60,21 @@ absl::Status InMemoryStream::Writev(absl::Span<quiche::QuicheMemSlice> data, const quiche::StreamWriteOptions& options) { - if (!CanWrite()) { - return absl::PermissionDeniedError("Stream is not writable."); + absl::Status status = GetWriteStatusWithExtraChecks(); + if (!status.ok()) { + return status; } - last_data_sent_ = std::string(data[0].AsStringView()); - fin_sent_ |= options.send_fin(); + + std::string merged_data; + for (const quiche::QuicheMemSlice& slice : data) { + merged_data.append(slice.AsStringView()); + } + OnWrite(merged_data); + + if (options.send_fin()) { + OnFin(); + fin_sent_ = true; + } return absl::OkStatus(); } @@ -81,4 +93,23 @@ fin_received_ = false; } +absl::Status InMemoryStream::GetWriteStatus() const { + return absl::UnimplementedError( + "Writing not implemented; use InMemoryStreamWithMockWrite"); +} + +absl::Status InMemoryStream::GetWriteStatusWithExtraChecks() const { + if (fin_sent_) { + return absl::FailedPreconditionError( + "Can't write on a stream with FIN sent."); + } + return GetWriteStatus(); +} + +InMemoryStreamWithMockWrite::InMemoryStreamWithMockWrite(StreamId id) + : InMemoryStream(id) { + ON_CALL(*this, GetWriteStatus) + .WillByDefault(testing::Return(absl::OkStatus())); +} + } // namespace webtransport::test
diff --git a/quiche/web_transport/test_tools/in_memory_stream.h b/quiche/web_transport/test_tools/in_memory_stream.h index d4fbd35..a35e8ea 100644 --- a/quiche/web_transport/test_tools/in_memory_stream.h +++ b/quiche/web_transport/test_tools/in_memory_stream.h
@@ -14,7 +14,9 @@ #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" #include "quiche/common/quiche_mem_slice.h" #include "quiche/common/quiche_stream.h" #include "quiche/web_transport/web_transport.h" @@ -22,7 +24,9 @@ namespace webtransport::test { // InMemoryStream models an incoming readable WebTransport stream where all of -// the data is read from an in-memory buffer. +// the data is read from an in-memory buffer. Writes are unsupported by +// default, but a subclass can handle those by overriding +// OnWrite/OnFin/GetWriteStatus. class QUICHE_NO_EXPORT InMemoryStream : public Stream { public: explicit InMemoryStream(StreamId id) : id_(id) {} @@ -37,7 +41,9 @@ // quiche::WriteStream implementation. absl::Status Writev(absl::Span<quiche::QuicheMemSlice> data, const quiche::StreamWriteOptions& options) override; - bool CanWrite() const override { return !fin_sent_; } + bool CanWrite() const override { + return GetWriteStatusWithExtraChecks().ok(); + } void AbruptlyTerminate(absl::Status) override { Terminate(); } @@ -73,12 +79,16 @@ peek_one_byte_at_a_time_ = peek_one_byte_at_a_time; } - // Returns what was last written to the stream. - absl::string_view last_data_sent() const { return last_data_sent_; } bool fin_sent() const { return fin_sent_; } + protected: + virtual void OnWrite(absl::string_view data) {} + virtual void OnFin() {} + virtual absl::Status GetWriteStatus() const; + private: void Terminate(); + absl::Status GetWriteStatusWithExtraChecks() const; StreamId id_; std::unique_ptr<StreamVisitor> visitor_; @@ -87,10 +97,34 @@ bool fin_received_ = false; bool abruptly_terminated_ = false; bool peek_one_byte_at_a_time_ = false; - std::string last_data_sent_; bool fin_sent_ = false; }; +// An InMemoryStream where all of the write-side interactions are exposed as +// mock methods. +class QUICHE_NO_EXPORT InMemoryStreamWithMockWrite : public InMemoryStream { + public: + explicit InMemoryStreamWithMockWrite(StreamId id); + + MOCK_METHOD(void, OnWrite, (absl::string_view data), (override)); + MOCK_METHOD(void, OnFin, (), (override)); + MOCK_METHOD(absl::Status, GetWriteStatus, (), (const, override)); +}; + +// An InMemoryStream where all writes are stored into a buffer. +class QUICHE_NO_EXPORT InMemoryStreamWithWriteBuffer : public InMemoryStream { + public: + using InMemoryStream::InMemoryStream; + + void OnWrite(absl::string_view data) { write_buffer_.append(data); } + absl::Status GetWriteStatus() const { return absl::OkStatus(); } + + std::string& write_buffer() { return write_buffer_; } + + private: + std::string write_buffer_; +}; + } // namespace webtransport::test #endif // QUICHE_WEB_TRANSPORT_TEST_TOOLS_IN_MEMORY_STREAM_H_
diff --git a/quiche/web_transport/test_tools/in_memory_stream_test.cc b/quiche/web_transport/test_tools/in_memory_stream_test.cc index 9c70b7a..2c53df7 100644 --- a/quiche/web_transport/test_tools/in_memory_stream_test.cc +++ b/quiche/web_transport/test_tools/in_memory_stream_test.cc
@@ -7,6 +7,7 @@ #include <array> #include <string> +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -19,6 +20,7 @@ namespace webtransport::test { namespace { +using ::quiche::test::StatusIs; using ::testing::ElementsAre; TEST(InMemoryStreamTest, ReadSpan) { @@ -82,22 +84,47 @@ EXPECT_TRUE(fin_reached); } -TEST(InMemoryStreamTest, Write) { - InMemoryStream stream(0); +TEST(InMemoryStreamTest, InMemoryStreamWithMockWrite) { + InMemoryStreamWithMockWrite stream(0); EXPECT_TRUE(stream.CanWrite()); + std::array write_vector = {quiche::QuicheMemSlice::Copy("test")}; quiche::StreamWriteOptions options; + EXPECT_CALL(stream, OnWrite("test")); QUICHE_EXPECT_OK(stream.Writev(absl::MakeSpan(write_vector), options)); - EXPECT_EQ(stream.last_data_sent(), "test"); EXPECT_FALSE(stream.fin_sent()); + // Send FIN. options.set_send_fin(true); write_vector = {quiche::QuicheMemSlice::Copy("test2")}; + { + testing::InSequence sequence; + EXPECT_CALL(stream, OnWrite("test2")); + EXPECT_CALL(stream, OnFin()); + } QUICHE_EXPECT_OK(stream.Writev(absl::MakeSpan(write_vector), options)); - EXPECT_EQ(stream.last_data_sent(), "test2"); EXPECT_TRUE(stream.fin_sent()); EXPECT_FALSE(stream.CanWrite()); - EXPECT_FALSE(stream.Writev(absl::MakeSpan(write_vector), options).ok()); + EXPECT_THAT(stream.Writev(absl::MakeSpan(write_vector), options), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST(InMemoryStreamTest, InMemoryStreamWithWriteBuffer) { + InMemoryStreamWithWriteBuffer stream(0); + EXPECT_TRUE(stream.CanWrite()); + + std::array write_vector = {quiche::QuicheMemSlice::Copy("foo")}; + quiche::StreamWriteOptions options; + QUICHE_EXPECT_OK(stream.Writev(absl::MakeSpan(write_vector), options)); + EXPECT_FALSE(stream.fin_sent()); + + // Send FIN. + options.set_send_fin(true); + write_vector = {quiche::QuicheMemSlice::Copy("bar")}; + QUICHE_EXPECT_OK(stream.Writev(absl::MakeSpan(write_vector), options)); + EXPECT_EQ(stream.write_buffer(), "foobar"); + EXPECT_TRUE(stream.fin_sent()); + EXPECT_FALSE(stream.CanWrite()); } } // namespace