Support session-specific ALPNs for clients, and specifying multiple ALPNs.
gfe-relnote: n/a (no functional change)
PiperOrigin-RevId: 266192245
Change-Id: Id0921d36b2c32c7df92a4a528ec19d9a28b826b0
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index a402057..d5ddc69 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -459,6 +459,13 @@
max_stream + num_expected_unidirectional_static_streams_);
}
+ // Returns the ALPN values to negotiate on this session.
+ virtual std::vector<std::string> GetAlpnsToOffer() {
+ // TODO(vasilvv): this currently sets HTTP/3 by default. Switch all
+ // non-HTTP applications to appropriate ALPNs.
+ return std::vector<std::string>({AlpnForVersion(connection()->version())});
+ }
+
protected:
using StreamMap = QuicSmallMap<QuicStreamId, std::unique_ptr<QuicStream>, 10>;
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 8e2d2bb..a410f81 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -15,8 +15,6 @@
namespace quic {
-std::string* quic_alpn_override_on_client_for_tests = nullptr;
-
TlsClientHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl(
TlsClientHandshaker* parent)
: parent_(parent) {}
@@ -78,32 +76,10 @@
return false;
}
- std::string alpn_string = AlpnForVersion(session()->connection()->version());
- if (quic_alpn_override_on_client_for_tests != nullptr) {
- alpn_string = *quic_alpn_override_on_client_for_tests;
- }
- if (alpn_string.length() > std::numeric_limits<uint8_t>::max()) {
- QUIC_BUG << "ALPN too long: '" << alpn_string << "'";
- CloseConnection(QUIC_HANDSHAKE_FAILED,
- "Client configured ALPN is too long");
+ if (!SetAlpn()) {
+ CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN");
return false;
}
- const uint8_t alpn_length = alpn_string.length();
- if (alpn_length > 0) {
- // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed
- // strings so we copy alpn_string to a new buffer that has the length
- // in alpn[0].
- uint8_t alpn[std::numeric_limits<uint8_t>::max() + 1];
- alpn[0] = alpn_length;
- memcpy(reinterpret_cast<char*>(alpn + 1), alpn_string.data(), alpn_length);
- if (SSL_set_alpn_protos(ssl(), alpn,
- static_cast<unsigned>(alpn_length) + 1) != 0) {
- QUIC_BUG << "Failed to set ALPN: '" << alpn_string << "'";
- CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN");
- return false;
- }
- }
- QUIC_DLOG(INFO) << "Client using ALPN: '" << alpn_string << "'";
// Set the Transport Parameters to send in the ClientHello
if (!SetTransportParameters()) {
@@ -117,6 +93,46 @@
return session()->connection()->connected();
}
+static bool IsValidAlpn(const std::string& alpn_string) {
+ return alpn_string.length() <= std::numeric_limits<uint8_t>::max();
+}
+
+bool TlsClientHandshaker::SetAlpn() {
+ std::vector<std::string> alpns = session()->GetAlpnsToOffer();
+ if (alpns.empty()) {
+ if (allow_empty_alpn_for_tests_) {
+ return true;
+ }
+
+ QUIC_BUG << "ALPN missing";
+ return false;
+ }
+ if (!std::all_of(alpns.begin(), alpns.end(), IsValidAlpn)) {
+ QUIC_BUG << "ALPN too long";
+ return false;
+ }
+
+ // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed
+ // strings.
+ uint8_t alpn[1024];
+ QuicDataWriter alpn_writer(sizeof(alpn), reinterpret_cast<char*>(alpn));
+ bool success = true;
+ for (const std::string& alpn_string : alpns) {
+ success = success && alpn_writer.WriteUInt8(alpn_string.size()) &&
+ alpn_writer.WriteStringPiece(alpn_string);
+ }
+ success =
+ success && (SSL_set_alpn_protos(ssl(), alpn, alpn_writer.length()) == 0);
+ if (!success) {
+ QUIC_BUG << "Failed to set ALPN: "
+ << QuicTextUtils::HexDump(
+ QuicStringPiece(alpn_writer.data(), alpn_writer.length()));
+ return false;
+ }
+ QUIC_DLOG(INFO) << "Client using ALPN: '" << alpns[0] << "'";
+ return true;
+}
+
bool TlsClientHandshaker::SetTransportParameters() {
TransportParameters params;
params.perspective = Perspective::IS_CLIENT;
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 47faf81..f3e90ce 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -54,6 +54,8 @@
CryptoMessageParser* crypto_message_parser() override;
size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
+ void AllowEmptyAlpnForTests() { allow_empty_alpn_for_tests_ = true; }
+
protected:
const TlsConnection* tls_connection() const override {
return &tls_connection_;
@@ -95,6 +97,7 @@
STATE_CONNECTION_CLOSED,
} state_ = STATE_IDLE;
+ bool SetAlpn();
bool SetTransportParameters();
bool ProcessTransportParameters(std::string* error_details);
void FinishHandshake();
@@ -121,13 +124,11 @@
QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters>
crypto_negotiated_params_;
+ bool allow_empty_alpn_for_tests_ = false;
+
TlsClientConnection tls_connection_;
};
-// Allows tests to override the ALPN used by clients.
-// DO NOT use outside of tests.
-QUIC_EXPORT_PRIVATE extern std::string* quic_alpn_override_on_client_for_tests;
-
} // namespace quic
#endif // QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index 00c62d6..ef7f31b 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -10,6 +10,7 @@
#include "net/third_party/quiche/src/quic/core/tls_client_handshaker.h"
#include "net/third_party/quiche/src/quic/core/tls_server_handshaker.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_arraysize.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_expect_bug.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
#include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h"
@@ -22,6 +23,7 @@
namespace {
using ::testing::_;
+using ::testing::Return;
class FakeProofVerifier : public ProofVerifier {
public:
@@ -225,6 +227,7 @@
~TestQuicCryptoClientStream() override = default;
TlsHandshaker* handshaker() const override { return handshaker_.get(); }
+ TlsClientHandshaker* client_handshaker() const { return handshaker_.get(); }
bool CryptoConnect() { return handshaker_->CryptoConnect(); }
@@ -302,6 +305,9 @@
EXPECT_FALSE(client_stream_->handshake_confirmed());
EXPECT_FALSE(server_stream_->encryption_established());
EXPECT_FALSE(server_stream_->handshake_confirmed());
+ ON_CALL(client_session_, GetAlpnsToOffer())
+ .WillByDefault(Return(std::vector<std::string>(
+ {AlpnForVersion(client_session_.connection()->version())})));
}
MockQuicConnectionHelper conn_helper_;
@@ -442,8 +448,9 @@
}
TEST_F(TlsHandshakerTest, ClientNotSendingALPN) {
- static std::string kTestClientNoAlpn = "";
- quic_alpn_override_on_client_for_tests = &kTestClientNoAlpn;
+ client_stream_->client_handshaker()->AllowEmptyAlpnForTests();
+ EXPECT_CALL(client_session_, GetAlpnsToOffer())
+ .WillOnce(Return(std::vector<std::string>()));
EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
"Server did not select ALPN", _));
EXPECT_CALL(*server_conn_,
@@ -456,12 +463,12 @@
EXPECT_FALSE(client_stream_->encryption_established());
EXPECT_FALSE(server_stream_->handshake_confirmed());
EXPECT_FALSE(server_stream_->encryption_established());
- quic_alpn_override_on_client_for_tests = nullptr;
}
TEST_F(TlsHandshakerTest, ClientSendingBadALPN) {
static std::string kTestBadClientAlpn = "bad-client-alpn";
- quic_alpn_override_on_client_for_tests = &kTestBadClientAlpn;
+ EXPECT_CALL(client_session_, GetAlpnsToOffer())
+ .WillOnce(Return(std::vector<std::string>({kTestBadClientAlpn})));
EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
"Server did not select ALPN", _));
EXPECT_CALL(*server_conn_,
@@ -474,7 +481,22 @@
EXPECT_FALSE(client_stream_->encryption_established());
EXPECT_FALSE(server_stream_->handshake_confirmed());
EXPECT_FALSE(server_stream_->encryption_established());
- quic_alpn_override_on_client_for_tests = nullptr;
+}
+
+TEST_F(TlsHandshakerTest, ClientSendingTooManyALPNs) {
+ std::string long_alpn(250, 'A');
+ EXPECT_CALL(client_session_, GetAlpnsToOffer())
+ .WillOnce(Return(std::vector<std::string>({
+ long_alpn + "1",
+ long_alpn + "2",
+ long_alpn + "3",
+ long_alpn + "4",
+ long_alpn + "5",
+ long_alpn + "6",
+ long_alpn + "7",
+ long_alpn + "8",
+ })));
+ EXPECT_QUIC_BUG(client_stream_->CryptoConnect(), "Failed to set ALPN");
}
} // namespace
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 40b8036..bbde859 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -633,6 +633,7 @@
MOCK_CONST_METHOD0(ShouldKeepConnectionAlive, bool());
MOCK_METHOD2(SendStopSending, void(uint16_t code, QuicStreamId stream_id));
MOCK_METHOD1(OnCryptoHandshakeEvent, void(QuicSession::CryptoHandshakeEvent));
+ MOCK_METHOD0(GetAlpnsToOffer, std::vector<std::string>());
using QuicSession::ActivateStream;