Implement QUIC ALPN selection on the server side.
This also fixes the bugs in client side found by a full custom ALPN test, and removes a tautological check from server-side ALPN parser.
gfe-relnote: n/a (protected by disabled quic_tls flag)
PiperOrigin-RevId: 266319592
Change-Id: I9e06b383abe187286f31d3cbce8be99e9370c9f2
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index 98aa242..57de377 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -1942,5 +1942,17 @@
return stream_id_manager_.max_open_incoming_streams();
}
+std::vector<QuicStringPiece>::const_iterator QuicSession::SelectAlpn(
+ const std::vector<QuicStringPiece>& alpns) const {
+ const std::string alpn = AlpnForVersion(connection()->version());
+ return std::find(alpns.cbegin(), alpns.cend(), alpn);
+}
+
+void QuicSession::OnAlpnSelected(QuicStringPiece alpn) {
+ QUIC_DLOG(INFO) << (perspective() == Perspective::IS_SERVER ? "Server: "
+ : "Client: ")
+ << "ALPN selected: " << alpn;
+}
+
#undef ENDPOINT // undef for jumbo builds
} // namespace quic
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index d5ddc69..16c7efc 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -460,12 +460,21 @@
}
// Returns the ALPN values to negotiate on this session.
- virtual std::vector<std::string> GetAlpnsToOffer() {
+ virtual std::vector<std::string> GetAlpnsToOffer() const {
// TODO(vasilvv): this currently sets HTTP/3 by default. Switch all
// non-HTTP applications to appropriate ALPNs.
return std::vector<std::string>({AlpnForVersion(connection()->version())});
}
+ // Provided a list of ALPNs offered by the client, selects an ALPN from the
+ // list, or alpns.end() if none of the ALPNs are acceptable.
+ virtual std::vector<QuicStringPiece>::const_iterator SelectAlpn(
+ const std::vector<QuicStringPiece>& alpns) const;
+
+ // Called when the ALPN of the connection is established for a connection that
+ // uses TLS handshake.
+ virtual void OnAlpnSelected(QuicStringPiece alpn);
+
protected:
using StreamMap = QuicSmallMap<QuicStreamId, std::unique_ptr<QuicStream>, 10>;
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index a410f81..bb45602 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -304,17 +304,17 @@
std::string received_alpn_string(reinterpret_cast<const char*>(alpn_data),
alpn_length);
- std::string sent_alpn_string =
- AlpnForVersion(session()->connection()->version());
- if (received_alpn_string != sent_alpn_string) {
+ std::vector<std::string> offered_alpns = session()->GetAlpnsToOffer();
+ if (std::find(offered_alpns.begin(), offered_alpns.end(),
+ received_alpn_string) == offered_alpns.end()) {
QUIC_LOG(ERROR) << "Client: received mismatched ALPN '"
- << received_alpn_string << "', expected '"
- << sent_alpn_string << "'";
+ << received_alpn_string;
// TODO(b/130164908) this should send no_application_protocol
// instead of QUIC_HANDSHAKE_FAILED.
CloseConnection(QUIC_HANDSHAKE_FAILED, "Client received mismatched ALPN");
return;
}
+ session()->OnAlpnSelected(received_alpn_string);
QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string
<< "'";
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index ef7f31b..a298089 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -23,6 +23,7 @@
namespace {
using ::testing::_;
+using ::testing::ElementsAreArray;
using ::testing::Return;
class FakeProofVerifier : public ProofVerifier {
@@ -305,9 +306,15 @@
EXPECT_FALSE(client_stream_->handshake_confirmed());
EXPECT_FALSE(server_stream_->encryption_established());
EXPECT_FALSE(server_stream_->handshake_confirmed());
+ const std::string default_alpn =
+ AlpnForVersion(client_session_.connection()->version());
ON_CALL(client_session_, GetAlpnsToOffer())
- .WillByDefault(Return(std::vector<std::string>(
- {AlpnForVersion(client_session_.connection()->version())})));
+ .WillByDefault(Return(std::vector<std::string>({default_alpn})));
+ ON_CALL(server_session_, SelectAlpn(_))
+ .WillByDefault(
+ [default_alpn](const std::vector<QuicStringPiece>& alpns) {
+ return std::find(alpns.begin(), alpns.end(), default_alpn);
+ });
}
MockQuicConnectionHelper conn_helper_;
@@ -499,6 +506,59 @@
EXPECT_QUIC_BUG(client_stream_->CryptoConnect(), "Failed to set ALPN");
}
+TEST_F(TlsHandshakerTest, ServerRequiresCustomALPN) {
+ static const std::string kTestAlpn = "An ALPN That Client Did Not Offer";
+ EXPECT_CALL(server_session_, SelectAlpn(_))
+ .WillOnce([](const std::vector<QuicStringPiece>& alpns) {
+ return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn);
+ });
+ EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
+ "Server did not select ALPN", _));
+ EXPECT_CALL(*server_conn_,
+ CloseConnection(QUIC_HANDSHAKE_FAILED,
+ "Server did not receive a known ALPN", _));
+ client_stream_->CryptoConnect();
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ EXPECT_FALSE(client_stream_->handshake_confirmed());
+ EXPECT_FALSE(client_stream_->encryption_established());
+ EXPECT_FALSE(server_stream_->handshake_confirmed());
+ EXPECT_FALSE(server_stream_->encryption_established());
+}
+
+TEST_F(TlsHandshakerTest, CustomALPNNegotiation) {
+ EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0);
+ EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0);
+ EXPECT_CALL(client_session_,
+ OnCryptoHandshakeEvent(QuicSession::ENCRYPTION_ESTABLISHED));
+ EXPECT_CALL(client_session_,
+ OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED));
+ EXPECT_CALL(server_session_,
+ OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED));
+
+ static const std::string kTestAlpn = "A Custom ALPN Value";
+ static const std::vector<std::string> kTestAlpns(
+ {"foo", "bar", kTestAlpn, "something else"});
+ EXPECT_CALL(client_session_, GetAlpnsToOffer())
+ .WillRepeatedly(Return(kTestAlpns));
+ EXPECT_CALL(server_session_, SelectAlpn(_))
+ .WillOnce([](const std::vector<QuicStringPiece>& alpns) {
+ EXPECT_THAT(alpns, ElementsAreArray(kTestAlpns));
+ return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn);
+ });
+ EXPECT_CALL(client_session_, OnAlpnSelected(QuicStringPiece(kTestAlpn)));
+ EXPECT_CALL(server_session_, OnAlpnSelected(QuicStringPiece(kTestAlpn)));
+ client_stream_->CryptoConnect();
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ EXPECT_TRUE(client_stream_->handshake_confirmed());
+ EXPECT_TRUE(client_stream_->encryption_established());
+ EXPECT_TRUE(server_stream_->handshake_confirmed());
+ EXPECT_TRUE(server_stream_->encryption_established());
+ EXPECT_TRUE(client_conn_->IsHandshakeConfirmed());
+ EXPECT_FALSE(server_conn_->IsHandshakeConfirmed());
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index a2ea397..0b4215d 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -352,41 +352,38 @@
return SSL_TLSEXT_ERR_NOACK;
}
- std::string expected_alpn_string =
- AlpnForVersion(session()->connection()->version());
-
CBS all_alpns;
CBS_init(&all_alpns, in, in_len);
+ std::vector<QuicStringPiece> alpns;
while (CBS_len(&all_alpns) > 0) {
CBS alpn;
if (!CBS_get_u8_length_prefixed(&all_alpns, &alpn)) {
QUIC_DLOG(ERROR) << "Failed to parse ALPN length";
return SSL_TLSEXT_ERR_NOACK;
}
+
const size_t alpn_length = CBS_len(&alpn);
- if (alpn_length >
- static_cast<size_t>(std::numeric_limits<uint8_t>::max())) {
- QUIC_BUG << "Parsed impossible ALPN length " << alpn_length;
- return SSL_TLSEXT_ERR_NOACK;
- }
if (alpn_length == 0) {
QUIC_DLOG(ERROR) << "Received invalid zero-length ALPN";
return SSL_TLSEXT_ERR_NOACK;
}
- std::string alpn_string(reinterpret_cast<const char*>(CBS_data(&alpn)),
- alpn_length);
- if (alpn_string == expected_alpn_string) {
- QUIC_DLOG(INFO) << "Server selecting ALPN '" << alpn_string << "'";
- *out_len = static_cast<uint8_t>(alpn_length);
- *out = CBS_data(&alpn);
- valid_alpn_received_ = true;
- return SSL_TLSEXT_ERR_OK;
- }
+
+ alpns.emplace_back(reinterpret_cast<const char*>(CBS_data(&alpn)),
+ alpn_length);
}
- QUIC_DLOG(ERROR) << "No known ALPN provided by client";
- return SSL_TLSEXT_ERR_NOACK;
+ auto selected_alpn = session()->SelectAlpn(alpns);
+ if (selected_alpn == alpns.end()) {
+ QUIC_DLOG(ERROR) << "No known ALPN provided by client";
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+
+ session()->OnAlpnSelected(*selected_alpn);
+ valid_alpn_received_ = true;
+ *out_len = selected_alpn->size();
+ *out = reinterpret_cast<const uint8_t*>(selected_alpn->data());
+ return SSL_TLSEXT_ERR_OK;
}
} // namespace quic