Implement the QuicTransport server session subclass. This currently does not handle incoming streams, as those require special logic to prevent access to application data before the indication is received. gfe-relnote: n/a (not used in production) PiperOrigin-RevId: 274228824 Change-Id: Ie1cd37ecfb739d1242a3cdc40186bca00f8373fd
diff --git a/quic/core/quic_crypto_client_stream_test.cc b/quic/core/quic_crypto_client_stream_test.cc index 1001301..8e1ef25 100644 --- a/quic/core/quic_crypto_client_stream_test.cc +++ b/quic/core/quic_crypto_client_stream_test.cc
@@ -51,6 +51,9 @@ session_ = std::make_unique<TestQuicSpdyClientSession>( connection_, DefaultQuicConfig(), supported_versions_, server_id_, &crypto_config_); + EXPECT_CALL(*session_, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector<std::string>( + {AlpnForVersion(connection_->version())}))); } void CompleteCryptoHandshake() {
diff --git a/quic/core/quic_crypto_server_stream_test.cc b/quic/core/quic_crypto_server_stream_test.cc index 8d71f26..360c268 100644 --- a/quic/core/quic_crypto_server_stream_test.cc +++ b/quic/core/quic_crypto_server_stream_test.cc
@@ -132,7 +132,8 @@ return crypto_test_utils::HandshakeWithFakeClient( helpers_.back().get(), alarm_factories_.back().get(), - server_connection_, server_stream(), server_id_, client_options_); + server_connection_, server_stream(), server_id_, client_options_, + /*alpn=*/""); } // Performs a single round of handshake message-exchange between the
diff --git a/quic/core/quic_error_codes.cc b/quic/core/quic_error_codes.cc index e69cd60..b1e8ed2 100644 --- a/quic/core/quic_error_codes.cc +++ b/quic/core/quic_error_codes.cc
@@ -159,6 +159,7 @@ RETURN_STRING_LITERAL( QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM); RETURN_STRING_LITERAL(QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES); + RETURN_STRING_LITERAL(QUIC_TRANSPORT_INVALID_CLIENT_INDICATION); RETURN_STRING_LITERAL(QUIC_LAST_ERROR); // Intentionally have no default case, so we'll break the build
diff --git a/quic/core/quic_error_codes.h b/quic/core/quic_error_codes.h index 298029f..ce5c721 100644 --- a/quic/core/quic_error_codes.h +++ b/quic/core/quic_error_codes.h
@@ -339,8 +339,11 @@ // There are too many buffered control frames in control frame manager. QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES = 124, + // QuicTransport received invalid client indication. + QUIC_TRANSPORT_INVALID_CLIENT_INDICATION = 125, + // No error. Used as bound while iterating. - QUIC_LAST_ERROR = 125, + QUIC_LAST_ERROR = 126, }; // QuicErrorCodes is encoded as four octets on-the-wire when doing Google QUIC, // or a varint62 when doing IETF QUIC. Ensure that its value does not exceed
diff --git a/quic/core/quic_types.cc b/quic/core/quic_types.cc index 499bc66..db48c5c 100644 --- a/quic/core/quic_types.cc +++ b/quic/core/quic_types.cc
@@ -411,6 +411,8 @@ case QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES: return {true, {static_cast<uint64_t>(QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES)}}; + case QUIC_TRANSPORT_INVALID_CLIENT_INDICATION: + return {false, {0u}}; case QUIC_LAST_ERROR: return {false, {static_cast<uint64_t>(QUIC_LAST_ERROR)}}; }
diff --git a/quic/quic_transport/quic_transport_protocol.h b/quic/quic_transport/quic_transport_protocol.h index bef5f8d..307354f 100644 --- a/quic/quic_transport/quic_transport_protocol.h +++ b/quic/quic_transport/quic_transport_protocol.h
@@ -21,6 +21,11 @@ return 2; } +// The maximum allowed size of the client indication. +QUIC_EXPORT constexpr QuicByteCount ClientIndicationMaxSize() { + return 65536; +} + // The keys of the fields in the client indication. enum class QuicTransportClientIndicationKeys : uint16_t { kOrigin = 0x0000,
diff --git a/quic/quic_transport/quic_transport_server_session.cc b/quic/quic_transport/quic_transport_server_session.cc new file mode 100644 index 0000000..de4db0f --- /dev/null +++ b/quic/quic_transport/quic_transport_server_session.cc
@@ -0,0 +1,164 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_server_session.h" + +#include <memory> + +#include "url/gurl.h" +#include "net/third_party/quiche/src/quic/core/quic_error_codes.h" +#include "net/third_party/quiche/src/quic/core/quic_stream.h" +#include "net/third_party/quiche/src/quic/core/quic_types.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_str_cat.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_client_session.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_protocol.h" + +namespace quic { + +namespace { +class QuicTransportServerCryptoHelper : public QuicCryptoServerStream::Helper { + public: + bool CanAcceptClientHello(const CryptoHandshakeMessage& /*message*/, + const QuicSocketAddress& /*client_address*/, + const QuicSocketAddress& /*peer_address*/, + const QuicSocketAddress& /*self_address*/, + std::string* /*error_details*/) const override { + return true; + } +}; +} // namespace + +QuicTransportServerSession::QuicTransportServerSession( + QuicConnection* connection, + Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + ServerVisitor* visitor) + : QuicSession(connection, + owner, + config, + supported_versions, + /*num_expected_unidirectional_static_streams*/ 0), + visitor_(visitor) { + for (const ParsedQuicVersion& version : supported_versions) { + QUIC_BUG_IF(version.handshake_protocol != PROTOCOL_TLS1_3) + << "QuicTransport requires TLS 1.3 handshake"; + } + + static QuicTransportServerCryptoHelper* helper = + new QuicTransportServerCryptoHelper(); + crypto_stream_ = std::make_unique<QuicCryptoServerStream>( + crypto_config, compressed_certs_cache, this, helper); +} + +QuicStream* QuicTransportServerSession::CreateIncomingStream(QuicStreamId id) { + if (id == ClientIndicationStream()) { + auto indication = std::make_unique<ClientIndication>(this); + ClientIndication* indication_ptr = indication.get(); + ActivateStream(std::move(indication)); + return indication_ptr; + } + + // TODO(vasilvv): implement incoming data streams. + QUIC_BUG << "Not implemented"; + return nullptr; +} + +QuicTransportServerSession::ClientIndication::ClientIndication( + QuicTransportServerSession* session) + : QuicStream(ClientIndicationStream(), + session, + /* is_static= */ false, + StreamType::READ_UNIDIRECTIONAL), + session_(session) {} + +void QuicTransportServerSession::ClientIndication::OnDataAvailable() { + sequencer()->Read(&buffer_); + if (buffer_.size() > ClientIndicationMaxSize()) { + session_->connection()->CloseConnection( + QUIC_TRANSPORT_INVALID_CLIENT_INDICATION, + QuicStrCat("Client indication size exceeds ", ClientIndicationMaxSize(), + " bytes"), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + if (sequencer()->IsClosed()) { + session_->ProcessClientIndication(buffer_); + OnFinRead(); + } +} + +bool QuicTransportServerSession::ClientIndicationParser::Parse() { + bool origin_received = false; + while (!reader_.IsDoneReading()) { + uint16_t key; + if (!reader_.ReadUInt16(&key)) { + ParseError("Expected 16-bit key"); + return false; + } + + QuicStringPiece value; + if (!reader_.ReadStringPiece16(&value)) { + ParseError(QuicStrCat("Failed to read value for key ", key)); + return false; + } + + switch (static_cast<QuicTransportClientIndicationKeys>(key)) { + case QuicTransportClientIndicationKeys::kOrigin: { + GURL origin_url{std::string(value)}; + if (!origin_url.is_valid()) { + Error("Unable to parse the specified origin"); + return false; + } + + url::Origin origin = url::Origin::Create(origin_url); + QUIC_DLOG(INFO) << "QuicTransport server received origin " << origin; + if (!session_->visitor_->CheckOrigin(origin)) { + Error("Origin check failed"); + return false; + } + origin_received = true; + break; + } + + default: + QUIC_DLOG(INFO) << "Unknown client indication key: " << key; + break; + } + } + + if (!origin_received) { + Error("No origin received"); + return false; + } + + return true; +} + +void QuicTransportServerSession::ClientIndicationParser::Error( + const std::string& error_message) { + session_->connection()->CloseConnection( + QUIC_TRANSPORT_INVALID_CLIENT_INDICATION, error_message, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicTransportServerSession::ClientIndicationParser::ParseError( + QuicStringPiece error_message) { + Error(QuicStrCat("Failed to parse the client indication stream: ", + error_message, reader_.DebugString())); +} + +void QuicTransportServerSession::ProcessClientIndication( + QuicStringPiece indication) { + ClientIndicationParser parser(this, indication); + if (!parser.Parse()) { + return; + } + client_indication_processed_ = true; +} + +} // namespace quic
diff --git a/quic/quic_transport/quic_transport_server_session.h b/quic/quic_transport/quic_transport_server_session.h new file mode 100644 index 0000000..ee5cf69 --- /dev/null +++ b/quic/quic_transport/quic_transport_server_session.h
@@ -0,0 +1,103 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QUIC_TRANSPORT_QUIC_TRANSPORT_SERVER_SESSION_H_ +#define QUICHE_QUIC_QUIC_TRANSPORT_QUIC_TRANSPORT_SERVER_SESSION_H_ + +#include "url/origin.h" +#include "net/third_party/quiche/src/quic/core/quic_connection.h" +#include "net/third_party/quiche/src/quic/core/quic_crypto_server_stream.h" +#include "net/third_party/quiche/src/quic/core/quic_session.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_protocol.h" + +namespace quic { + +// A server session for the QuicTransport protocol. +class QUIC_EXPORT QuicTransportServerSession : public QuicSession { + public: + class ServerVisitor { + public: + virtual ~ServerVisitor() {} + + virtual bool CheckOrigin(url::Origin origin) = 0; + }; + + QuicTransportServerSession(QuicConnection* connection, + Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + ServerVisitor* visitor); + + std::vector<QuicStringPiece>::const_iterator SelectAlpn( + const std::vector<QuicStringPiece>& alpns) const override { + return std::find(alpns.cbegin(), alpns.cend(), QuicTransportAlpn()); + } + + bool ShouldKeepConnectionAlive() const override { return true; } + + QuicCryptoStream* GetMutableCryptoStream() override { + return crypto_stream_.get(); + } + const QuicCryptoStream* GetCryptoStream() const override { + return crypto_stream_.get(); + } + + bool IsSessionReady() const { + return IsCryptoHandshakeConfirmed() && client_indication_processed_ && + connection()->connected(); + } + + QuicStream* CreateIncomingStream(QuicStreamId id) override; + QuicStream* CreateIncomingStream(PendingStream* /*pending*/) override { + QUIC_BUG << "QuicTransportServerSession::CreateIncomingStream(" + "PendingStream) not implemented"; + return nullptr; + } + + protected: + class ClientIndication : public QuicStream { + public: + explicit ClientIndication(QuicTransportServerSession* session); + void OnDataAvailable() override; + + private: + QuicTransportServerSession* session_; + std::string buffer_; + }; + + // Utility class for parsing the client indication. + class ClientIndicationParser { + public: + ClientIndicationParser(QuicTransportServerSession* session, + QuicStringPiece indication) + : session_(session), reader_(indication) {} + + // Parses the specified indication. Automatically closes the connection + // with detailed error if parsing fails. Returns true on success, false on + // failure. + bool Parse(); + + private: + void Error(const std::string& error_message); + void ParseError(QuicStringPiece error_message); + + QuicTransportServerSession* session_; + QuicDataReader reader_; + }; + + // Parses and processes the client indication as described in + // https://vasilvv.github.io/webtransport/draft-vvv-webtransport-quic.html#rfc.section.3.2 + void ProcessClientIndication(QuicStringPiece indication); + + std::unique_ptr<QuicCryptoServerStream> crypto_stream_; + bool client_indication_processed_ = false; + ServerVisitor* visitor_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QUIC_TRANSPORT_QUIC_TRANSPORT_SERVER_SESSION_H_
diff --git a/quic/quic_transport/quic_transport_server_session_test.cc b/quic/quic_transport/quic_transport_server_session_test.cc new file mode 100644 index 0000000..818c08d --- /dev/null +++ b/quic/quic_transport/quic_transport_server_session_test.cc
@@ -0,0 +1,238 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_server_session.h" + +#include <cstddef> +#include <memory> +#include <string> + +#include "url/gurl.h" +#include "url/origin.h" +#include "net/third_party/quiche/src/quic/core/crypto/quic_compressed_certs_cache.h" +#include "net/third_party/quiche/src/quic/core/crypto/quic_crypto_server_config.h" +#include "net/third_party/quiche/src/quic/core/frames/quic_stream_frame.h" +#include "net/third_party/quiche/src/quic/core/quic_data_writer.h" +#include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_test.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_text_utils.h" +#include "net/third_party/quiche/src/quic/quic_transport/quic_transport_protocol.h" +#include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h" +#include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +using testing::_; +using testing::AnyNumber; +using testing::DoAll; +using testing::HasSubstr; +using testing::Return; +using testing::SaveArg; + +constexpr char kTestOrigin[] = "https://test-origin.test"; +constexpr char kTestOriginClientIndication[] = + "\0\0\0\x18https://test-origin.test"; +const url::Origin GetTestOrigin() { + return url::Origin::Create(GURL(kTestOrigin)); +} +const std::string GetTestOriginClientIndication() { + return std::string(kTestOriginClientIndication, + sizeof(kTestOriginClientIndication) - 1); +} + +ParsedQuicVersionVector GetVersions() { + return {ParsedQuicVersion{PROTOCOL_TLS1_3, QUIC_VERSION_99}}; +} + +class MockVisitor : public QuicTransportServerSession::ServerVisitor { + public: + MOCK_METHOD1(CheckOrigin, bool(url::Origin)); +}; + +class QuicTransportServerSessionTest : public QuicTest { + public: + QuicTransportServerSessionTest() + : connection_(&helper_, + &alarm_factory_, + Perspective::IS_SERVER, + GetVersions()), + crypto_config_(QuicCryptoServerConfig::TESTING, + QuicRandom::GetInstance(), + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()), + compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) { + SetQuicReloadableFlag(quic_supports_tls_handshake, true); + connection_.AdvanceTime(QuicTime::Delta::FromSeconds(100000)); + crypto_test_utils::SetupCryptoServerConfigForTest( + helper_.GetClock(), helper_.GetRandomGenerator(), &crypto_config_); + session_ = std::make_unique<QuicTransportServerSession>( + &connection_, nullptr, DefaultQuicConfig(), GetVersions(), + &crypto_config_, &compressed_certs_cache_, &visitor_); + session_->Initialize(); + crypto_stream_ = static_cast<QuicCryptoServerStream*>( + session_->GetMutableCryptoStream()); + crypto_stream_->OnSuccessfulVersionNegotiation(GetVersions()[0]); + } + + void Connect() { + crypto_test_utils::FakeClientOptions options; + options.only_tls_versions = true; + crypto_test_utils::HandshakeWithFakeClient( + &helper_, &alarm_factory_, &connection_, crypto_stream_, + QuicServerId("test.example.com", 443), options, QuicTransportAlpn()); + } + + void ReceiveIndication(QuicStringPiece indication) { + QUIC_LOG(INFO) << "Receiving indication: " + << QuicTextUtils::HexDump(indication); + constexpr size_t kChunkSize = 1024; + // Shard the indication, since some of the tests cause it to not fit into a + // single frame. + for (size_t i = 0; i < indication.size(); i += kChunkSize) { + QuicStreamFrame frame(ClientIndicationStream(), /*fin=*/false, i, + indication.substr(i, i + kChunkSize)); + session_->OnStreamFrame(frame); + } + session_->OnStreamFrame(QuicStreamFrame(ClientIndicationStream(), + /*fin=*/true, indication.size(), + QuicStringPiece())); + } + + protected: + MockAlarmFactory alarm_factory_; + MockQuicConnectionHelper helper_; + + PacketSavingConnection connection_; + QuicCryptoServerConfig crypto_config_; + std::unique_ptr<QuicTransportServerSession> session_; + QuicCompressedCertsCache compressed_certs_cache_; + testing::StrictMock<MockVisitor> visitor_; + QuicCryptoServerStream* crypto_stream_; +}; + +TEST_F(QuicTransportServerSessionTest, SuccessfulHandshake) { + Connect(); + + url::Origin origin; + EXPECT_CALL(visitor_, CheckOrigin(_)) + .WillOnce(DoAll(SaveArg<0>(&origin), Return(true))); + ReceiveIndication(GetTestOriginClientIndication()); + EXPECT_TRUE(session_->IsSessionReady()); + EXPECT_EQ(origin, GetTestOrigin()); +} + +TEST_F(QuicTransportServerSessionTest, PiecewiseClientIndication) { + Connect(); + size_t i = 0; + for (; i < sizeof(kTestOriginClientIndication) - 2; i++) { + QuicStreamFrame frame(ClientIndicationStream(), false, i, + QuicStringPiece(&kTestOriginClientIndication[i], 1)); + session_->OnStreamFrame(frame); + } + + EXPECT_CALL(visitor_, CheckOrigin(_)).WillOnce(Return(true)); + QuicStreamFrame last_frame( + ClientIndicationStream(), true, i, + QuicStringPiece(&kTestOriginClientIndication[i], 1)); + session_->OnStreamFrame(last_frame); + EXPECT_TRUE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, OriginRejected) { + Connect(); + EXPECT_CALL(connection_, + CloseConnection(_, HasSubstr("Origin check failed"), _)); + EXPECT_CALL(visitor_, CheckOrigin(_)).WillOnce(Return(false)); + ReceiveIndication(GetTestOriginClientIndication()); + EXPECT_FALSE(session_->IsSessionReady()); +} + +std::string MakeUnknownField(QuicStringPiece payload) { + std::string buffer; + buffer.resize(payload.size() + 4); + QuicDataWriter writer(buffer.size(), &buffer[0]); + EXPECT_TRUE(writer.WriteUInt16(0xffff)); + EXPECT_TRUE(writer.WriteUInt16(payload.size())); + EXPECT_TRUE(writer.WriteStringPiece(payload)); + EXPECT_EQ(writer.remaining(), 0u); + return buffer; +} + +TEST_F(QuicTransportServerSessionTest, SkipUnusedFields) { + Connect(); + EXPECT_CALL(visitor_, CheckOrigin(_)).WillOnce(Return(true)); + ReceiveIndication(GetTestOriginClientIndication() + + MakeUnknownField("foobar")); + EXPECT_TRUE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, SkipLongUnusedFields) { + const size_t bytes = + ClientIndicationMaxSize() - GetTestOriginClientIndication().size() - 4; + Connect(); + EXPECT_CALL(visitor_, CheckOrigin(_)).WillOnce(Return(true)); + ReceiveIndication(GetTestOriginClientIndication() + + MakeUnknownField(std::string(bytes, 'a'))); + EXPECT_TRUE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, ClientIndicationTooLong) { + Connect(); + EXPECT_CALL( + connection_, + CloseConnection(_, HasSubstr("Client indication size exceeds"), _)) + .Times(AnyNumber()); + ReceiveIndication(GetTestOriginClientIndication() + + MakeUnknownField(std::string(65534, 'a'))); + EXPECT_FALSE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, NoOrigin) { + Connect(); + EXPECT_CALL(connection_, CloseConnection(_, HasSubstr("No origin"), _)); + ReceiveIndication(MakeUnknownField("foobar")); + EXPECT_FALSE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, EmptyClientIndication) { + Connect(); + EXPECT_CALL(connection_, CloseConnection(_, HasSubstr("No origin"), _)); + ReceiveIndication(""); + EXPECT_FALSE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, MalformedIndicationHeader) { + Connect(); + EXPECT_CALL(connection_, + CloseConnection(_, HasSubstr("Expected 16-bit key"), _)); + ReceiveIndication("\xff"); + EXPECT_FALSE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, FieldTooShort) { + Connect(); + EXPECT_CALL( + connection_, + CloseConnection(_, HasSubstr("Failed to read value for key 257"), _)); + ReceiveIndication("\x01\x01\x01\x01"); + EXPECT_FALSE(session_->IsSessionReady()); +} + +TEST_F(QuicTransportServerSessionTest, InvalidOrigin) { + const std::string kEmptyOriginIndication(4, '\0'); + Connect(); + EXPECT_CALL( + connection_, + CloseConnection(_, HasSubstr("Unable to parse the specified origin"), _)); + ReceiveIndication(kEmptyOriginIndication); + EXPECT_FALSE(session_->IsSessionReady()); +} + +} // namespace +} // namespace test +} // namespace quic
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc index 45872d2..98d65e8 100644 --- a/quic/test_tools/crypto_test_utils.cc +++ b/quic/test_tools/crypto_test_utils.cc
@@ -26,6 +26,7 @@ #include "net/third_party/quiche/src/quic/core/quic_crypto_stream.h" #include "net/third_party/quiche/src/quic/core/quic_server_id.h" #include "net/third_party/quiche/src/quic/core/quic_utils.h" +#include "net/third_party/quiche/src/quic/core/quic_versions.h" #include "net/third_party/quiche/src/quic/platform/api/quic_bug_tracker.h" #include "net/third_party/quiche/src/quic/platform/api/quic_clock.h" #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h" @@ -257,7 +258,8 @@ PacketSavingConnection* server_conn, QuicCryptoServerStream* server, const QuicServerId& server_id, - const FakeClientOptions& options) { + const FakeClientOptions& options, + std::string alpn) { ParsedQuicVersionVector supported_versions = AllSupportedVersions(); if (options.only_tls_versions) { supported_versions.clear(); @@ -282,6 +284,14 @@ EXPECT_CALL(client_session, OnProofVerifyDetailsAvailable(testing::_)) .Times(testing::AnyNumber()); EXPECT_CALL(*client_conn, OnCanWrite()).Times(testing::AnyNumber()); + if (!alpn.empty()) { + EXPECT_CALL(client_session, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector<std::string>({alpn}))); + } else { + EXPECT_CALL(client_session, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector<std::string>( + {AlpnForVersion(client_conn->version())}))); + } client_session.GetMutableCryptoStream()->CryptoConnect(); CHECK_EQ(1u, client_conn->encrypted_packets_.size());
diff --git a/quic/test_tools/crypto_test_utils.h b/quic/test_tools/crypto_test_utils.h index 4cee641..6f87e90 100644 --- a/quic/test_tools/crypto_test_utils.h +++ b/quic/test_tools/crypto_test_utils.h
@@ -78,7 +78,8 @@ PacketSavingConnection* server_conn, QuicCryptoServerStream* server, const QuicServerId& server_id, - const FakeClientOptions& options); + const FakeClientOptions& options, + std::string alpn); // SetupCryptoServerConfigForTest configures |crypto_config| // with sensible defaults for testing.
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index 2fc1b4a..fbadf13 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -864,6 +864,7 @@ MOCK_METHOD1(ShouldCreateIncomingStream, bool(QuicStreamId id)); MOCK_METHOD0(ShouldCreateOutgoingBidirectionalStream, bool()); MOCK_METHOD0(ShouldCreateOutgoingUnidirectionalStream, bool()); + MOCK_CONST_METHOD0(GetAlpnsToOffer, std::vector<std::string>()); // Override to not send max header list size. void OnCryptoHandshakeEvent(CryptoHandshakeEvent event) override;