Add a simple QuicTransport server for testing and demo purposes. The server currently has two modes, echo and discard. I've changed the integration test to use the simple server instead of a session with a mock visitor. This CL also adds some of the missing APIs and fixes some of the bugs I've found while doing this. gfe-relnote: n/a (code not used) PiperOrigin-RevId: 278456752 Change-Id: Idf9aa654aa0d66673f300f2f5425f0716d6c3e14
diff --git a/quic/quic_transport/quic_transport_client_session.cc b/quic/quic_transport/quic_transport_client_session.cc index f738371..72ed9c5 100644 --- a/quic/quic_transport/quic_transport_client_session.cc +++ b/quic/quic_transport/quic_transport_client_session.cc
@@ -16,9 +16,11 @@ #include "net/third_party/quiche/src/quic/core/quic_session.h" #include "net/third_party/quiche/src/quic/core/quic_types.h" #include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_bug_tracker.h" #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h" #include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h" #include "net/third_party/quiche/src/quic/platform/api/quic_text_utils.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_protocol.h" #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_stream.h" namespace quic { @@ -65,17 +67,15 @@ QuicStream* QuicTransportClientSession::CreateIncomingStream(QuicStreamId id) { QUIC_DVLOG(1) << "Creating incoming QuicTransport stream " << id; - auto stream = std::make_unique<QuicTransportStream>(id, this, this); - QuicTransportStream* stream_ptr = stream.get(); - ActivateStream(std::move(stream)); - if (stream_ptr->type() == BIDIRECTIONAL) { - incoming_bidirectional_streams_.push_back(stream_ptr); + QuicTransportStream* stream = CreateStream(id); + if (stream->type() == BIDIRECTIONAL) { + incoming_bidirectional_streams_.push_back(stream); visitor_->OnIncomingBidirectionalStreamAvailable(); } else { - incoming_unidirectional_streams_.push_back(stream_ptr); + incoming_unidirectional_streams_.push_back(stream); visitor_->OnIncomingUnidirectionalStreamAvailable(); } - return stream_ptr; + return stream; } void QuicTransportClientSession::OnCryptoHandshakeEvent( @@ -108,6 +108,31 @@ return stream; } +QuicTransportStream* +QuicTransportClientSession::OpenOutgoingBidirectionalStream() { + if (!CanOpenNextOutgoingBidirectionalStream()) { + QUIC_BUG << "Attempted to open a stream in violation of flow control"; + return nullptr; + } + return CreateStream(GetNextOutgoingBidirectionalStreamId()); +} + +QuicTransportStream* +QuicTransportClientSession::OpenOutgoingUnidirectionalStream() { + if (!CanOpenNextOutgoingUnidirectionalStream()) { + QUIC_BUG << "Attempted to open a stream in violation of flow control"; + return nullptr; + } + return CreateStream(GetNextOutgoingUnidirectionalStreamId()); +} + +QuicTransportStream* QuicTransportClientSession::CreateStream(QuicStreamId id) { + auto stream = std::make_unique<QuicTransportStream>(id, this, this); + QuicTransportStream* stream_ptr = stream.get(); + ActivateStream(std::move(stream)); + return stream_ptr; +} + std::string QuicTransportClientSession::SerializeClientIndication() { std::string serialized_origin = origin_.Serialize(); if (serialized_origin.size() > std::numeric_limits<uint16_t>::max()) { @@ -151,8 +176,11 @@ } auto client_indication_owned = std::make_unique<ClientIndication>( - /*stream_id=*/ClientIndicationStream(), this, /*is_static=*/false, - WRITE_UNIDIRECTIONAL); + /*stream_id=*/GetNextOutgoingUnidirectionalStreamId(), this, + /*is_static=*/false, WRITE_UNIDIRECTIONAL); + QUIC_BUG_IF(client_indication_owned->id() != ClientIndicationStream()) + << "Client indication stream is " << client_indication_owned->id() + << " instead of expected " << ClientIndicationStream(); ClientIndication* client_indication = client_indication_owned.get(); ActivateStream(std::move(client_indication_owned));
diff --git a/quic/quic_transport/quic_transport_client_session.h b/quic/quic_transport/quic_transport_client_session.h index 12e5fdf..4df008c 100644 --- a/quic/quic_transport/quic_transport_client_session.h +++ b/quic/quic_transport/quic_transport_client_session.h
@@ -88,6 +88,11 @@ QuicTransportStream* AcceptIncomingBidirectionalStream(); QuicTransportStream* AcceptIncomingUnidirectionalStream(); + using QuicSession::CanOpenNextOutgoingBidirectionalStream; + using QuicSession::CanOpenNextOutgoingUnidirectionalStream; + QuicTransportStream* OpenOutgoingBidirectionalStream(); + QuicTransportStream* OpenOutgoingUnidirectionalStream(); + protected: class QUIC_EXPORT_PRIVATE ClientIndication : public QuicStream { public: @@ -100,6 +105,9 @@ } }; + // Creates and activates a QuicTransportStream for the given ID. + QuicTransportStream* CreateStream(QuicStreamId id); + // Serializes the client indication as described in // https://vasilvv.github.io/webtransport/draft-vvv-webtransport-quic.html#rfc.section.3.2 std::string SerializeClientIndication();
diff --git a/quic/quic_transport/quic_transport_integration_test.cc b/quic/quic_transport/quic_transport_integration_test.cc index 9936830..dfccb14 100644 --- a/quic/quic_transport/quic_transport_integration_test.cc +++ b/quic/quic_transport/quic_transport_integration_test.cc
@@ -6,6 +6,7 @@ // server sessions. #include <memory> +#include <vector> #include "url/gurl.h" #include "url/origin.h" @@ -15,9 +16,11 @@ #include "net/third_party/quiche/src/quic/core/quic_error_codes.h" #include "net/third_party/quiche/src/quic/core/quic_types.h" #include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_flags.h" #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_client_session.h" #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_server_session.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_stream.h" #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h" #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h" #include "net/third_party/quiche/src/quic/test_tools/quic_transport_test_tools.h" @@ -25,6 +28,7 @@ #include "net/third_party/quiche/src/quic/test_tools/simulator/quic_endpoint_base.h" #include "net/third_party/quiche/src/quic/test_tools/simulator/simulator.h" #include "net/third_party/quiche/src/quic/test_tools/simulator/switch.h" +#include "net/third_party/quiche/src/quic/tools/quic_transport_simple_server_session.h" namespace quic { namespace test { @@ -32,8 +36,7 @@ using simulator::QuicEndpointBase; using simulator::Simulator; -using testing::_; -using testing::Return; +using testing::Assign; url::Origin GetTestOrigin() { constexpr char kTestOrigin[] = "https://test-origin.test"; @@ -85,6 +88,7 @@ } QuicTransportClientSession* session() { return &session_; } + MockClientVisitor* visitor() { return &visitor_; } private: QuicCryptoClientConfig crypto_config_; @@ -96,7 +100,9 @@ public: QuicTransportServerEndpoint(Simulator* simulator, const std::string& name, - const std::string& peer_name) + const std::string& peer_name, + QuicTransportSimpleServerSession::Mode mode, + std::vector<url::Origin> accepted_origins) : QuicTransportEndpointBase(simulator, name, peer_name, @@ -108,25 +114,31 @@ compressed_certs_cache_( QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), session_(connection_.get(), + /*owns_connection=*/false, nullptr, DefaultQuicConfig(), GetVersions(), &crypto_config_, &compressed_certs_cache_, - &visitor_) { + mode, + accepted_origins) { session_.Initialize(); } QuicTransportServerSession* session() { return &session_; } - MockServerVisitor* visitor() { return &visitor_; } private: QuicCryptoServerConfig crypto_config_; QuicCompressedCertsCache compressed_certs_cache_; - QuicTransportServerSession session_; - MockServerVisitor visitor_; + QuicTransportSimpleServerSession session_; }; +std::unique_ptr<MockStreamVisitor> VisitorExpectingFin() { + auto visitor = std::make_unique<MockStreamVisitor>(); + EXPECT_CALL(*visitor, OnFinRead()); + return visitor; +} + constexpr QuicBandwidth kClientBandwidth = QuicBandwidth::FromKBitsPerSecond(10000); constexpr QuicTime::Delta kClientPropagationDelay = @@ -142,19 +154,18 @@ (kClientPropagationDelay + kServerPropagationDelay + kTransferTime) * 2; const QuicByteCount kBdp = kRtt * kServerBandwidth; -constexpr QuicTime::Delta kHandshakeTimeout = QuicTime::Delta::FromSeconds(3); +constexpr QuicTime::Delta kDefaultTimeout = QuicTime::Delta::FromSeconds(3); class QuicTransportIntegrationTest : public QuicTest { public: QuicTransportIntegrationTest() : switch_(&simulator_, "Switch", 8, 2 * kBdp) {} - void CreateDefaultEndpoints() { + void CreateDefaultEndpoints(QuicTransportSimpleServerSession::Mode mode) { client_ = std::make_unique<QuicTransportClientEndpoint>( &simulator_, "Client", "Server", GetTestOrigin()); - server_ = std::make_unique<QuicTransportServerEndpoint>(&simulator_, - "Server", "Client"); - ON_CALL(*server_->visitor(), CheckOrigin(_)).WillByDefault(Return(true)); + server_ = std::make_unique<QuicTransportServerEndpoint>( + &simulator_, "Server", "Client", mode, accepted_origins_); } void WireUpEndpoints() { @@ -173,7 +184,7 @@ return IsHandshakeDone(client_->session()) && IsHandshakeDone(server_->session()); }, - kHandshakeTimeout); + kDefaultTimeout); EXPECT_TRUE(result); } @@ -190,10 +201,12 @@ std::unique_ptr<QuicTransportClientEndpoint> client_; std::unique_ptr<QuicTransportServerEndpoint> server_; + + std::vector<url::Origin> accepted_origins_ = {GetTestOrigin()}; }; TEST_F(QuicTransportIntegrationTest, SuccessfulHandshake) { - CreateDefaultEndpoints(); + CreateDefaultEndpoints(QuicTransportSimpleServerSession::DISCARD); WireUpEndpoints(); RunHandshake(); EXPECT_TRUE(client_->session()->IsSessionReady()); @@ -201,15 +214,14 @@ } TEST_F(QuicTransportIntegrationTest, OriginMismatch) { - CreateDefaultEndpoints(); + accepted_origins_ = {url::Origin::Create(GURL{"https://wrong-origin.test"})}; + CreateDefaultEndpoints(QuicTransportSimpleServerSession::DISCARD); WireUpEndpoints(); - EXPECT_CALL(*server_->visitor(), CheckOrigin(_)) - .WillRepeatedly(Return(false)); RunHandshake(); // Wait until the client receives CONNECTION_CLOSE. simulator_.RunUntilOrTimeout( [this]() { return !client_->session()->connection()->connected(); }, - kHandshakeTimeout); + kDefaultTimeout); EXPECT_TRUE(client_->session()->IsSessionReady()); EXPECT_FALSE(server_->session()->IsSessionReady()); EXPECT_FALSE(client_->session()->connection()->connected()); @@ -220,6 +232,100 @@ QUIC_TRANSPORT_INVALID_CLIENT_INDICATION); } +TEST_F(QuicTransportIntegrationTest, SendOutgoingStreams) { + CreateDefaultEndpoints(QuicTransportSimpleServerSession::DISCARD); + WireUpEndpoints(); + RunHandshake(); + + std::vector<QuicTransportStream*> streams; + for (int i = 0; i < 10; i++) { + QuicTransportStream* stream = + client_->session()->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream->Write("test")); + streams.push_back(stream); + } + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { + return server_->session()->GetNumOpenIncomingStreams() == 10; + }, + kDefaultTimeout)); + + for (QuicTransportStream* stream : streams) { + ASSERT_TRUE(stream->SendFin()); + } + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { return server_->session()->GetNumOpenIncomingStreams() == 0; }, + kDefaultTimeout)); +} + +TEST_F(QuicTransportIntegrationTest, EchoBidirectionalStreams) { + CreateDefaultEndpoints(QuicTransportSimpleServerSession::ECHO); + WireUpEndpoints(); + RunHandshake(); + + QuicTransportStream* stream = + client_->session()->OpenOutgoingBidirectionalStream(); + EXPECT_TRUE(stream->Write("Hello!")); + + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [stream]() { return stream->ReadableBytes() == strlen("Hello!"); }, + kDefaultTimeout)); + std::string received; + EXPECT_EQ(stream->Read(&received), strlen("Hello!")); + EXPECT_EQ(received, "Hello!"); + + EXPECT_TRUE(stream->SendFin()); + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { return server_->session()->GetNumOpenIncomingStreams() == 0; }, + kDefaultTimeout)); +} + +TEST_F(QuicTransportIntegrationTest, EchoUnidirectionalStreams) { + CreateDefaultEndpoints(QuicTransportSimpleServerSession::ECHO); + WireUpEndpoints(); + RunHandshake(); + + // Send two streams, but only send FIN on the second one. + QuicTransportStream* stream1 = + client_->session()->OpenOutgoingUnidirectionalStream(); + EXPECT_TRUE(stream1->Write("Stream One")); + QuicTransportStream* stream2 = + client_->session()->OpenOutgoingUnidirectionalStream(); + EXPECT_TRUE(stream2->Write("Stream Two")); + EXPECT_TRUE(stream2->SendFin()); + + // Wait until a stream is received. + bool stream_received = false; + EXPECT_CALL(*client_->visitor(), OnIncomingUnidirectionalStreamAvailable()) + .Times(2) + .WillRepeatedly(Assign(&stream_received, true)); + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [&stream_received]() { return stream_received; }, kDefaultTimeout)); + + // Receive a reply stream and expect it to be the second one. + QuicTransportStream* reply = + client_->session()->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(reply != nullptr); + std::string buffer; + reply->set_visitor(VisitorExpectingFin()); + EXPECT_GT(reply->Read(&buffer), 0u); + EXPECT_EQ(buffer, "Stream Two"); + + // Reset reply-related variables. + stream_received = false; + buffer = ""; + + // Send FIN on the first stream, and expect to receive it back. + EXPECT_TRUE(stream1->SendFin()); + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [&stream_received]() { return stream_received; }, kDefaultTimeout)); + reply = client_->session()->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(reply != nullptr); + reply->set_visitor(VisitorExpectingFin()); + EXPECT_GT(reply->Read(&buffer), 0u); + EXPECT_EQ(buffer, "Stream One"); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/quic_transport/quic_transport_server_session.cc b/quic/quic_transport/quic_transport_server_session.cc index fe4aec1..7f00acd 100644 --- a/quic/quic_transport/quic_transport_server_session.cc +++ b/quic/quic_transport/quic_transport_server_session.cc
@@ -66,6 +66,7 @@ auto stream = std::make_unique<QuicTransportStream>(id, this, this); QuicTransportStream* stream_ptr = stream.get(); ActivateStream(std::move(stream)); + OnIncomingDataStream(stream_ptr); return stream_ptr; }
diff --git a/quic/quic_transport/quic_transport_server_session.h b/quic/quic_transport/quic_transport_server_session.h index c488775..b3fcfa0 100644 --- a/quic/quic_transport/quic_transport_server_session.h +++ b/quic/quic_transport/quic_transport_server_session.h
@@ -12,6 +12,7 @@ #include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h" #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_protocol.h" #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_session_interface.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_stream.h" namespace quic { @@ -97,6 +98,8 @@ // https://vasilvv.github.io/webtransport/draft-vvv-webtransport-quic.html#rfc.section.3.2 void ProcessClientIndication(QuicStringPiece indication); + virtual void OnIncomingDataStream(QuicTransportStream* /*stream*/) {} + std::unique_ptr<QuicCryptoServerStream> crypto_stream_; bool ready_ = false; ServerVisitor* visitor_;
diff --git a/quic/quic_transport/quic_transport_stream.cc b/quic/quic_transport/quic_transport_stream.cc index 9546119..61f4345 100644 --- a/quic/quic_transport/quic_transport_stream.cc +++ b/quic/quic_transport/quic_transport_stream.cc
@@ -32,7 +32,21 @@ iovec iov; iov.iov_base = buffer; iov.iov_len = buffer_size; - return sequencer()->Readv(&iov, 1); + const size_t result = sequencer()->Readv(&iov, 1); + if (sequencer()->IsClosed() && visitor_ != nullptr) { + visitor_->OnFinRead(); + } + return result; +} + +size_t QuicTransportStream::Read(std::string* output) { + const size_t old_size = output->size(); + const size_t bytes_to_read = ReadableBytes(); + output->resize(old_size + bytes_to_read); + size_t bytes_read = Read(&(*output)[old_size], bytes_to_read); + DCHECK_EQ(bytes_to_read, bytes_read); + output->resize(old_size + bytes_read); + return bytes_read; } bool QuicTransportStream::Write(QuicStringPiece data) { @@ -40,6 +54,7 @@ return false; } + // TODO(vasilvv): use WriteMemSlices() WriteOrBufferData(data, /*fin=*/false, nullptr); return true; } @@ -66,12 +81,21 @@ } void QuicTransportStream::OnDataAvailable() { + if (sequencer()->IsClosed()) { + if (visitor_ != nullptr) { + visitor_->OnFinRead(); + } + OnFinRead(); + return; + } + + if (visitor_ == nullptr) { + return; + } if (ReadableBytes() == 0) { return; } - if (visitor_ != nullptr) { - visitor_->OnCanRead(); - } + visitor_->OnCanRead(); } void QuicTransportStream::OnCanWriteNewData() {
diff --git a/quic/quic_transport/quic_transport_stream.h b/quic/quic_transport/quic_transport_stream.h index 44694db..1651a1c 100644 --- a/quic/quic_transport/quic_transport_stream.h +++ b/quic/quic_transport/quic_transport_stream.h
@@ -6,6 +6,7 @@ #define QUICHE_QUIC_QUIC_TRANSPORT_QUIC_TRANSPORT_STREAM_H_ #include <cstddef> +#include <memory> #include "net/third_party/quiche/src/quic/core/quic_session.h" #include "net/third_party/quiche/src/quic/core/quic_stream.h" @@ -25,6 +26,7 @@ public: virtual ~Visitor() {} virtual void OnCanRead() = 0; + virtual void OnFinRead() = 0; virtual void OnCanWrite() = 0; }; @@ -35,6 +37,8 @@ // Reads at most |buffer_size| bytes into |buffer| and returns the number of // bytes actually read. size_t Read(char* buffer, size_t buffer_size); + // Reads all available data and appends it to the end of |output|. + size_t Read(std::string* output); // Writes |data| into the stream. Returns true on success. QUIC_MUST_USE_RESULT bool Write(QuicStringPiece data); // Sends the FIN on the stream. Returns true on success. @@ -49,11 +53,18 @@ void OnDataAvailable() override; void OnCanWriteNewData() override; - void set_visitor(Visitor* visitor) { visitor_ = visitor; } + Visitor* visitor() { return visitor_.get(); } + void set_visitor(std::unique_ptr<Visitor> visitor) { + visitor_ = std::move(visitor); + } protected: + // Hide the methods that allow writing data without checking IsSessionReady(). + using QuicStream::WriteMemSlices; + using QuicStream::WriteOrBufferData; + QuicTransportSessionInterface* session_interface_; - Visitor* visitor_ = nullptr; + std::unique_ptr<Visitor> visitor_ = nullptr; }; } // namespace quic
diff --git a/quic/quic_transport/quic_transport_stream_test.cc b/quic/quic_transport/quic_transport_stream_test.cc index 8c5ca5b..c291b54 100644 --- a/quic/quic_transport/quic_transport_stream_test.cc +++ b/quic/quic_transport/quic_transport_stream_test.cc
@@ -3,6 +3,7 @@ // found in the LICENSE file. #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_stream.h" + #include <memory> #include "net/third_party/quiche/src/quic/core/frames/quic_window_update_frame.h" @@ -12,6 +13,7 @@ #include "net/third_party/quiche/src/quic/quic_transport/quic_transport_session_interface.h" #include "net/third_party/quiche/src/quic/test_tools/quic_config_peer.h" #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h" +#include "net/third_party/quiche/src/quic/test_tools/quic_transport_test_tools.h" namespace quic { namespace test { @@ -28,12 +30,6 @@ MOCK_CONST_METHOD0(IsSessionReady, bool()); }; -class MockVisitor : public QuicTransportStream::Visitor { - public: - MOCK_METHOD0(OnCanRead, void()); - MOCK_METHOD0(OnCanWrite, void()); -}; - class QuicTransportStreamTest : public QuicTest { public: QuicTransportStreamTest() @@ -46,7 +42,10 @@ stream_ = new QuicTransportStream(0, &session_, &interface_); session_.ActivateStream(QuicWrapUnique(stream_)); - stream_->set_visitor(&visitor_); + + auto visitor = std::make_unique<MockStreamVisitor>(); + visitor_ = visitor.get(); + stream_->set_visitor(std::move(visitor)); } void ReceiveStreamData(QuicStringPiece data, QuicStreamOffset offset) { @@ -61,8 +60,8 @@ MockQuicConnection* connection_; // Owned by |session_|. MockQuicSession session_; MockQuicTransportSessionInterface interface_; - MockVisitor visitor_; QuicTransportStream* stream_; // Owned by |session_|. + MockStreamVisitor* visitor_; // Owned by |stream_|. }; TEST_F(QuicTransportStreamTest, NotReady) { @@ -95,10 +94,30 @@ TEST_F(QuicTransportStreamTest, ReceiveData) { EXPECT_CALL(interface_, IsSessionReady()).WillRepeatedly(Return(true)); - EXPECT_CALL(visitor_, OnCanRead()); + EXPECT_CALL(*visitor_, OnCanRead()); ReceiveStreamData("test", 0); } +TEST_F(QuicTransportStreamTest, FinReadWithNoDataPending) { + EXPECT_CALL(interface_, IsSessionReady()).WillRepeatedly(Return(true)); + EXPECT_CALL(*visitor_, OnFinRead()); + QuicStreamFrame frame(0, true, 0, ""); + stream_->OnStreamFrame(frame); +} + +TEST_F(QuicTransportStreamTest, FinReadWithDataPending) { + EXPECT_CALL(interface_, IsSessionReady()).WillRepeatedly(Return(true)); + + EXPECT_CALL(*visitor_, OnCanRead()); + EXPECT_CALL(*visitor_, OnFinRead()).Times(0); + QuicStreamFrame frame(0, true, 0, "test"); + stream_->OnStreamFrame(frame); + + EXPECT_CALL(*visitor_, OnFinRead()).Times(1); + std::string buffer; + ASSERT_EQ(stream_->Read(&buffer), 4u); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/test_tools/quic_transport_test_tools.h b/quic/test_tools/quic_transport_test_tools.h index bf2dc0e..c6a8b46 100644 --- a/quic/test_tools/quic_transport_test_tools.h +++ b/quic/test_tools/quic_transport_test_tools.h
@@ -23,6 +23,13 @@ MOCK_METHOD1(CheckOrigin, bool(url::Origin)); }; +class MockStreamVisitor : public QuicTransportStream::Visitor { + public: + MOCK_METHOD0(OnCanRead, void()); + MOCK_METHOD0(OnFinRead, void()); + MOCK_METHOD0(OnCanWrite, void()); +}; + } // namespace test } // namespace quic
diff --git a/quic/tools/quic_transport_simple_server_dispatcher.cc b/quic/tools/quic_transport_simple_server_dispatcher.cc new file mode 100644 index 0000000..da9899f --- /dev/null +++ b/quic/tools/quic_transport_simple_server_dispatcher.cc
@@ -0,0 +1,52 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/third_party/quiche/src/quic/tools/quic_transport_simple_server_dispatcher.h" + +#include <memory> + +#include "net/third_party/quiche/src/quic/core/quic_connection.h" +#include "net/third_party/quiche/src/quic/core/quic_dispatcher.h" +#include "net/third_party/quiche/src/quic/core/quic_types.h" +#include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/tools/quic_transport_simple_server_session.h" + +namespace quic { + +QuicTransportSimpleServerDispatcher::QuicTransportSimpleServerDispatcher( + const QuicConfig* config, + const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr<QuicConnectionHelperInterface> helper, + std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, + std::unique_ptr<QuicAlarmFactory> alarm_factory, + uint8_t expected_server_connection_id_length, + QuicTransportSimpleServerSession::Mode mode, + std::vector<url::Origin> accepted_origins) + : QuicDispatcher(config, + crypto_config, + version_manager, + std::move(helper), + std::move(session_helper), + std::move(alarm_factory), + expected_server_connection_id_length), + mode_(mode), + accepted_origins_(accepted_origins) {} + +QuicSession* QuicTransportSimpleServerDispatcher::CreateQuicSession( + QuicConnectionId server_connection_id, + const QuicSocketAddress& peer_address, + QuicStringPiece /*alpn*/, + const ParsedQuicVersion& version) { + auto connection = std::make_unique<QuicConnection>( + server_connection_id, peer_address, helper(), alarm_factory(), writer(), + /*owns_writer=*/false, Perspective::IS_SERVER, + ParsedQuicVersionVector{version}); + return new QuicTransportSimpleServerSession( + connection.release(), /*owns_connection=*/true, this, config(), + GetSupportedVersions(), crypto_config(), compressed_certs_cache(), mode_, + accepted_origins_); +} + +} // namespace quic
diff --git a/quic/tools/quic_transport_simple_server_dispatcher.h b/quic/tools/quic_transport_simple_server_dispatcher.h new file mode 100644 index 0000000..ea4eb8b --- /dev/null +++ b/quic/tools/quic_transport_simple_server_dispatcher.h
@@ -0,0 +1,41 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_TRANSPORT_SIMPLE_SERVER_DISPATCHER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_TRANSPORT_SIMPLE_SERVER_DISPATCHER_H_ + +#include "url/origin.h" +#include "net/third_party/quiche/src/quic/core/quic_dispatcher.h" +#include "net/third_party/quiche/src/quic/tools/quic_transport_simple_server_session.h" + +namespace quic { + +// Dispatcher that creates a QuicTransportSimpleServerSession for every incoming +// connection. +class QuicTransportSimpleServerDispatcher : public QuicDispatcher { + public: + QuicTransportSimpleServerDispatcher( + const QuicConfig* config, + const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr<QuicConnectionHelperInterface> helper, + std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, + std::unique_ptr<QuicAlarmFactory> alarm_factory, + uint8_t expected_server_connection_id_length, + QuicTransportSimpleServerSession::Mode mode, + std::vector<url::Origin> accepted_origins); + + protected: + QuicSession* CreateQuicSession(QuicConnectionId server_connection_id, + const QuicSocketAddress& peer_address, + QuicStringPiece alpn, + const ParsedQuicVersion& version) override; + + QuicTransportSimpleServerSession::Mode mode_; + std::vector<url::Origin> accepted_origins_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_TRANSPORT_SIMPLE_SERVER_DISPATCHER_H_
diff --git a/quic/tools/quic_transport_simple_server_session.cc b/quic/tools/quic_transport_simple_server_session.cc new file mode 100644 index 0000000..273a9ae --- /dev/null +++ b/quic/tools/quic_transport_simple_server_session.cc
@@ -0,0 +1,228 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/third_party/quiche/src/quic/tools/quic_transport_simple_server_session.h" + +#include <memory> + +#include "url/gurl.h" +#include "url/origin.h" +#include "net/third_party/quiche/src/quic/core/quic_types.h" +#include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_flags.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_logging.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_text_utils.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_protocol.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_stream.h" + +namespace quic { + +namespace { + +// Discards any incoming data. +class DiscardVisitor : public QuicTransportStream::Visitor { + public: + DiscardVisitor(QuicTransportStream* stream) : stream_(stream) {} + + void OnCanRead() override { + std::string buffer; + size_t bytes_read = stream_->Read(&buffer); + QUIC_DVLOG(2) << "Read " << bytes_read << " bytes from stream " + << stream_->id(); + } + + void OnFinRead() override {} + void OnCanWrite() override {} + + private: + QuicTransportStream* stream_; +}; + +// Echoes any incoming data back on the same stream. +class BidirectionalEchoVisitor : public QuicTransportStream::Visitor { + public: + BidirectionalEchoVisitor(QuicTransportStream* stream) : stream_(stream) {} + + void OnCanRead() override { + stream_->Read(&buffer_); + OnCanWrite(); + } + + void OnFinRead() override { + bool success = stream_->SendFin(); + DCHECK(success); + } + + void OnCanWrite() override { + if (buffer_.empty()) { + return; + } + + bool success = stream_->Write(buffer_); + if (success) { + buffer_ = ""; + } + } + + private: + QuicTransportStream* stream_; + std::string buffer_; +}; + +// Buffers all of the data and calls EchoStreamBack() on the parent session. +class UnidirectionalEchoReadVisitor : public QuicTransportStream::Visitor { + public: + UnidirectionalEchoReadVisitor(QuicTransportSimpleServerSession* session, + QuicTransportStream* stream) + : session_(session), stream_(stream) {} + + void OnCanRead() override { + bool success = stream_->Read(&buffer_); + DCHECK(success); + } + + void OnFinRead() override { + QUIC_DVLOG(1) << "Finished receiving data on stream " << stream_->id() + << ", queueing up the echo"; + session_->EchoStreamBack(buffer_); + } + + void OnCanWrite() override { QUIC_NOTREACHED(); } + + private: + QuicTransportSimpleServerSession* session_; + QuicTransportStream* stream_; + std::string buffer_; +}; + +// Sends supplied data. +class UnidirectionalEchoWriteVisitor : public QuicTransportStream::Visitor { + public: + UnidirectionalEchoWriteVisitor(QuicTransportStream* stream, + const std::string& data) + : stream_(stream), data_(data) {} + + void OnCanRead() override { QUIC_NOTREACHED(); } + void OnFinRead() override { QUIC_NOTREACHED(); } + void OnCanWrite() override { + if (data_.empty()) { + return; + } + if (!stream_->Write(data_)) { + return; + } + data_ = ""; + bool fin_sent = stream_->SendFin(); + DCHECK(fin_sent); + } + + private: + QuicTransportStream* stream_; + std::string data_; +}; + +} // namespace + +QuicTransportSimpleServerSession::QuicTransportSimpleServerSession( + QuicConnection* connection, + bool owns_connection, + Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + Mode mode, + std::vector<url::Origin> accepted_origins) + : QuicTransportServerSession(connection, + owner, + config, + supported_versions, + crypto_config, + compressed_certs_cache, + this), + connection_(connection), + owns_connection_(owns_connection), + mode_(mode), + accepted_origins_(accepted_origins) { + Initialize(); +} + +QuicTransportSimpleServerSession::~QuicTransportSimpleServerSession() { + if (owns_connection_) { + delete connection_; + } +} + +void QuicTransportSimpleServerSession::OnIncomingDataStream( + QuicTransportStream* stream) { + switch (mode_) { + case DISCARD: + stream->set_visitor(std::make_unique<DiscardVisitor>(stream)); + break; + + case ECHO: + switch (stream->type()) { + case BIDIRECTIONAL: + QUIC_DVLOG(1) << "Opening bidirectional echo stream " << stream->id(); + stream->set_visitor( + std::make_unique<BidirectionalEchoVisitor>(stream)); + break; + case READ_UNIDIRECTIONAL: + QUIC_DVLOG(1) + << "Started receiving data on unidirectional echo stream " + << stream->id(); + stream->set_visitor( + std::make_unique<UnidirectionalEchoReadVisitor>(this, stream)); + break; + default: + QUIC_NOTREACHED(); + break; + } + break; + } +} + +void QuicTransportSimpleServerSession::OnCanCreateNewOutgoingStream( + bool unidirectional) { + if (mode_ == ECHO && unidirectional) { + MaybeEchoStreamsBack(); + } +} + +bool QuicTransportSimpleServerSession::CheckOrigin(url::Origin origin) { + if (accepted_origins_.empty()) { + return true; + } + + for (const url::Origin& accepted_origin : accepted_origins_) { + if (origin.IsSameOriginWith(accepted_origin)) { + return true; + } + } + return false; +} + +void QuicTransportSimpleServerSession::MaybeEchoStreamsBack() { + while (!streams_to_echo_back_.empty() && + CanOpenNextOutgoingUnidirectionalStream()) { + // Remove the stream from the queue first, in order to avoid accidentally + // entering an infinite loop in case any of the following code calls + // OnCanCreateNewOutgoingStream(). + std::string data = std::move(streams_to_echo_back_.front()); + streams_to_echo_back_.pop_front(); + + auto stream_owned = std::make_unique<QuicTransportStream>( + GetNextOutgoingUnidirectionalStreamId(), this, this); + QuicTransportStream* stream = stream_owned.get(); + ActivateStream(std::move(stream_owned)); + QUIC_DVLOG(1) << "Opened echo response stream " << stream->id(); + + stream->set_visitor( + std::make_unique<UnidirectionalEchoWriteVisitor>(stream, data)); + stream->visitor()->OnCanWrite(); + } +} + +} // namespace quic
diff --git a/quic/tools/quic_transport_simple_server_session.h b/quic/tools/quic_transport_simple_server_session.h new file mode 100644 index 0000000..11f82f2 --- /dev/null +++ b/quic/tools/quic_transport_simple_server_session.h
@@ -0,0 +1,72 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_TRANSPORT_SIMPLE_SERVER_SESSION_H_ +#define QUICHE_QUIC_TOOLS_QUIC_TRANSPORT_SIMPLE_SERVER_SESSION_H_ + +#include <memory> +#include <vector> + +#include "url/origin.h" +#include "net/third_party/quiche/src/quic/core/quic_types.h" +#include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_containers.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_flags.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_server_session.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_stream.h" + +namespace quic { + +// QuicTransport simple server is a non-production server that can be used for +// testing QuicTransport. It has two modes that can be changed using the +// command line flags, "echo" and "discard". +class QuicTransportSimpleServerSession + : public QuicTransportServerSession, + QuicTransportServerSession::ServerVisitor { + public: + enum Mode { + // In DISCARD mode, any data on incoming streams is discarded and no + // outgoing streams are initiated. + DISCARD, + // In ECHO mode, any data sent on a bidirectional stream is echoed back. + // Any data sent on a unidirectional stream is buffered, and echoed back on + // a server-initiated unidirectional stream that is sent as soon as a FIN is + // received on the incoming stream. + ECHO, + }; + + QuicTransportSimpleServerSession( + QuicConnection* connection, + bool owns_connection, + Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + Mode mode, + std::vector<url::Origin> accepted_origins); + ~QuicTransportSimpleServerSession(); + + void OnIncomingDataStream(QuicTransportStream* stream) override; + void OnCanCreateNewOutgoingStream(bool unidirectional) override; + bool CheckOrigin(url::Origin origin) override; + + void EchoStreamBack(const std::string& data) { + streams_to_echo_back_.push_back(data); + MaybeEchoStreamsBack(); + } + + private: + void MaybeEchoStreamsBack(); + + QuicConnection* connection_; + const bool owns_connection_; + Mode mode_; + std::vector<url::Origin> accepted_origins_; + QuicDeque<std::string> streams_to_echo_back_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_TRANSPORT_SIMPLE_SERVER_SESSION_H_