Support session-specific ALPNs for clients, and specifying multiple ALPNs. gfe-relnote: n/a (no functional change) PiperOrigin-RevId: 266192245 Change-Id: Id0921d36b2c32c7df92a4a528ec19d9a28b826b0
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h index a402057..d5ddc69 100644 --- a/quic/core/quic_session.h +++ b/quic/core/quic_session.h
@@ -459,6 +459,13 @@ max_stream + num_expected_unidirectional_static_streams_); } + // Returns the ALPN values to negotiate on this session. + virtual std::vector<std::string> GetAlpnsToOffer() { + // TODO(vasilvv): this currently sets HTTP/3 by default. Switch all + // non-HTTP applications to appropriate ALPNs. + return std::vector<std::string>({AlpnForVersion(connection()->version())}); + } + protected: using StreamMap = QuicSmallMap<QuicStreamId, std::unique_ptr<QuicStream>, 10>;
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc index 8e2d2bb..a410f81 100644 --- a/quic/core/tls_client_handshaker.cc +++ b/quic/core/tls_client_handshaker.cc
@@ -15,8 +15,6 @@ namespace quic { -std::string* quic_alpn_override_on_client_for_tests = nullptr; - TlsClientHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl( TlsClientHandshaker* parent) : parent_(parent) {} @@ -78,32 +76,10 @@ return false; } - std::string alpn_string = AlpnForVersion(session()->connection()->version()); - if (quic_alpn_override_on_client_for_tests != nullptr) { - alpn_string = *quic_alpn_override_on_client_for_tests; - } - if (alpn_string.length() > std::numeric_limits<uint8_t>::max()) { - QUIC_BUG << "ALPN too long: '" << alpn_string << "'"; - CloseConnection(QUIC_HANDSHAKE_FAILED, - "Client configured ALPN is too long"); + if (!SetAlpn()) { + CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN"); return false; } - const uint8_t alpn_length = alpn_string.length(); - if (alpn_length > 0) { - // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed - // strings so we copy alpn_string to a new buffer that has the length - // in alpn[0]. - uint8_t alpn[std::numeric_limits<uint8_t>::max() + 1]; - alpn[0] = alpn_length; - memcpy(reinterpret_cast<char*>(alpn + 1), alpn_string.data(), alpn_length); - if (SSL_set_alpn_protos(ssl(), alpn, - static_cast<unsigned>(alpn_length) + 1) != 0) { - QUIC_BUG << "Failed to set ALPN: '" << alpn_string << "'"; - CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN"); - return false; - } - } - QUIC_DLOG(INFO) << "Client using ALPN: '" << alpn_string << "'"; // Set the Transport Parameters to send in the ClientHello if (!SetTransportParameters()) { @@ -117,6 +93,46 @@ return session()->connection()->connected(); } +static bool IsValidAlpn(const std::string& alpn_string) { + return alpn_string.length() <= std::numeric_limits<uint8_t>::max(); +} + +bool TlsClientHandshaker::SetAlpn() { + std::vector<std::string> alpns = session()->GetAlpnsToOffer(); + if (alpns.empty()) { + if (allow_empty_alpn_for_tests_) { + return true; + } + + QUIC_BUG << "ALPN missing"; + return false; + } + if (!std::all_of(alpns.begin(), alpns.end(), IsValidAlpn)) { + QUIC_BUG << "ALPN too long"; + return false; + } + + // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed + // strings. + uint8_t alpn[1024]; + QuicDataWriter alpn_writer(sizeof(alpn), reinterpret_cast<char*>(alpn)); + bool success = true; + for (const std::string& alpn_string : alpns) { + success = success && alpn_writer.WriteUInt8(alpn_string.size()) && + alpn_writer.WriteStringPiece(alpn_string); + } + success = + success && (SSL_set_alpn_protos(ssl(), alpn, alpn_writer.length()) == 0); + if (!success) { + QUIC_BUG << "Failed to set ALPN: " + << QuicTextUtils::HexDump( + QuicStringPiece(alpn_writer.data(), alpn_writer.length())); + return false; + } + QUIC_DLOG(INFO) << "Client using ALPN: '" << alpns[0] << "'"; + return true; +} + bool TlsClientHandshaker::SetTransportParameters() { TransportParameters params; params.perspective = Perspective::IS_CLIENT;
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h index 47faf81..f3e90ce 100644 --- a/quic/core/tls_client_handshaker.h +++ b/quic/core/tls_client_handshaker.h
@@ -54,6 +54,8 @@ CryptoMessageParser* crypto_message_parser() override; size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + void AllowEmptyAlpnForTests() { allow_empty_alpn_for_tests_ = true; } + protected: const TlsConnection* tls_connection() const override { return &tls_connection_; @@ -95,6 +97,7 @@ STATE_CONNECTION_CLOSED, } state_ = STATE_IDLE; + bool SetAlpn(); bool SetTransportParameters(); bool ProcessTransportParameters(std::string* error_details); void FinishHandshake(); @@ -121,13 +124,11 @@ QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters> crypto_negotiated_params_; + bool allow_empty_alpn_for_tests_ = false; + TlsClientConnection tls_connection_; }; -// Allows tests to override the ALPN used by clients. -// DO NOT use outside of tests. -QUIC_EXPORT_PRIVATE extern std::string* quic_alpn_override_on_client_for_tests; - } // namespace quic #endif // QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc index 00c62d6..ef7f31b 100644 --- a/quic/core/tls_handshaker_test.cc +++ b/quic/core/tls_handshaker_test.cc
@@ -10,6 +10,7 @@ #include "net/third_party/quiche/src/quic/core/tls_client_handshaker.h" #include "net/third_party/quiche/src/quic/core/tls_server_handshaker.h" #include "net/third_party/quiche/src/quic/platform/api/quic_arraysize.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_expect_bug.h" #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h" #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h" @@ -22,6 +23,7 @@ namespace { using ::testing::_; +using ::testing::Return; class FakeProofVerifier : public ProofVerifier { public: @@ -225,6 +227,7 @@ ~TestQuicCryptoClientStream() override = default; TlsHandshaker* handshaker() const override { return handshaker_.get(); } + TlsClientHandshaker* client_handshaker() const { return handshaker_.get(); } bool CryptoConnect() { return handshaker_->CryptoConnect(); } @@ -302,6 +305,9 @@ EXPECT_FALSE(client_stream_->handshake_confirmed()); EXPECT_FALSE(server_stream_->encryption_established()); EXPECT_FALSE(server_stream_->handshake_confirmed()); + ON_CALL(client_session_, GetAlpnsToOffer()) + .WillByDefault(Return(std::vector<std::string>( + {AlpnForVersion(client_session_.connection()->version())}))); } MockQuicConnectionHelper conn_helper_; @@ -442,8 +448,9 @@ } TEST_F(TlsHandshakerTest, ClientNotSendingALPN) { - static std::string kTestClientNoAlpn = ""; - quic_alpn_override_on_client_for_tests = &kTestClientNoAlpn; + client_stream_->client_handshaker()->AllowEmptyAlpnForTests(); + EXPECT_CALL(client_session_, GetAlpnsToOffer()) + .WillOnce(Return(std::vector<std::string>())); EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED, "Server did not select ALPN", _)); EXPECT_CALL(*server_conn_, @@ -456,12 +463,12 @@ EXPECT_FALSE(client_stream_->encryption_established()); EXPECT_FALSE(server_stream_->handshake_confirmed()); EXPECT_FALSE(server_stream_->encryption_established()); - quic_alpn_override_on_client_for_tests = nullptr; } TEST_F(TlsHandshakerTest, ClientSendingBadALPN) { static std::string kTestBadClientAlpn = "bad-client-alpn"; - quic_alpn_override_on_client_for_tests = &kTestBadClientAlpn; + EXPECT_CALL(client_session_, GetAlpnsToOffer()) + .WillOnce(Return(std::vector<std::string>({kTestBadClientAlpn}))); EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED, "Server did not select ALPN", _)); EXPECT_CALL(*server_conn_, @@ -474,7 +481,22 @@ EXPECT_FALSE(client_stream_->encryption_established()); EXPECT_FALSE(server_stream_->handshake_confirmed()); EXPECT_FALSE(server_stream_->encryption_established()); - quic_alpn_override_on_client_for_tests = nullptr; +} + +TEST_F(TlsHandshakerTest, ClientSendingTooManyALPNs) { + std::string long_alpn(250, 'A'); + EXPECT_CALL(client_session_, GetAlpnsToOffer()) + .WillOnce(Return(std::vector<std::string>({ + long_alpn + "1", + long_alpn + "2", + long_alpn + "3", + long_alpn + "4", + long_alpn + "5", + long_alpn + "6", + long_alpn + "7", + long_alpn + "8", + }))); + EXPECT_QUIC_BUG(client_stream_->CryptoConnect(), "Failed to set ALPN"); } } // namespace
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index 40b8036..bbde859 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -633,6 +633,7 @@ MOCK_CONST_METHOD0(ShouldKeepConnectionAlive, bool()); MOCK_METHOD2(SendStopSending, void(uint16_t code, QuicStreamId stream_id)); MOCK_METHOD1(OnCryptoHandshakeEvent, void(QuicSession::CryptoHandshakeEvent)); + MOCK_METHOD0(GetAlpnsToOffer, std::vector<std::string>()); using QuicSession::ActivateStream;