Implement QUIC ALPN selection on the server side. This also fixes the bugs in client side found by a full custom ALPN test, and removes a tautological check from server-side ALPN parser. gfe-relnote: n/a (protected by disabled quic_tls flag) PiperOrigin-RevId: 266319592 Change-Id: I9e06b383abe187286f31d3cbce8be99e9370c9f2
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc index 98aa242..57de377 100644 --- a/quic/core/quic_session.cc +++ b/quic/core/quic_session.cc
@@ -1942,5 +1942,17 @@ return stream_id_manager_.max_open_incoming_streams(); } +std::vector<QuicStringPiece>::const_iterator QuicSession::SelectAlpn( + const std::vector<QuicStringPiece>& alpns) const { + const std::string alpn = AlpnForVersion(connection()->version()); + return std::find(alpns.cbegin(), alpns.cend(), alpn); +} + +void QuicSession::OnAlpnSelected(QuicStringPiece alpn) { + QUIC_DLOG(INFO) << (perspective() == Perspective::IS_SERVER ? "Server: " + : "Client: ") + << "ALPN selected: " << alpn; +} + #undef ENDPOINT // undef for jumbo builds } // namespace quic
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h index d5ddc69..16c7efc 100644 --- a/quic/core/quic_session.h +++ b/quic/core/quic_session.h
@@ -460,12 +460,21 @@ } // Returns the ALPN values to negotiate on this session. - virtual std::vector<std::string> GetAlpnsToOffer() { + virtual std::vector<std::string> GetAlpnsToOffer() const { // TODO(vasilvv): this currently sets HTTP/3 by default. Switch all // non-HTTP applications to appropriate ALPNs. return std::vector<std::string>({AlpnForVersion(connection()->version())}); } + // Provided a list of ALPNs offered by the client, selects an ALPN from the + // list, or alpns.end() if none of the ALPNs are acceptable. + virtual std::vector<QuicStringPiece>::const_iterator SelectAlpn( + const std::vector<QuicStringPiece>& alpns) const; + + // Called when the ALPN of the connection is established for a connection that + // uses TLS handshake. + virtual void OnAlpnSelected(QuicStringPiece alpn); + protected: using StreamMap = QuicSmallMap<QuicStreamId, std::unique_ptr<QuicStream>, 10>;
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc index a410f81..bb45602 100644 --- a/quic/core/tls_client_handshaker.cc +++ b/quic/core/tls_client_handshaker.cc
@@ -304,17 +304,17 @@ std::string received_alpn_string(reinterpret_cast<const char*>(alpn_data), alpn_length); - std::string sent_alpn_string = - AlpnForVersion(session()->connection()->version()); - if (received_alpn_string != sent_alpn_string) { + std::vector<std::string> offered_alpns = session()->GetAlpnsToOffer(); + if (std::find(offered_alpns.begin(), offered_alpns.end(), + received_alpn_string) == offered_alpns.end()) { QUIC_LOG(ERROR) << "Client: received mismatched ALPN '" - << received_alpn_string << "', expected '" - << sent_alpn_string << "'"; + << received_alpn_string; // TODO(b/130164908) this should send no_application_protocol // instead of QUIC_HANDSHAKE_FAILED. CloseConnection(QUIC_HANDSHAKE_FAILED, "Client received mismatched ALPN"); return; } + session()->OnAlpnSelected(received_alpn_string); QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string << "'";
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc index ef7f31b..a298089 100644 --- a/quic/core/tls_handshaker_test.cc +++ b/quic/core/tls_handshaker_test.cc
@@ -23,6 +23,7 @@ namespace { using ::testing::_; +using ::testing::ElementsAreArray; using ::testing::Return; class FakeProofVerifier : public ProofVerifier { @@ -305,9 +306,15 @@ EXPECT_FALSE(client_stream_->handshake_confirmed()); EXPECT_FALSE(server_stream_->encryption_established()); EXPECT_FALSE(server_stream_->handshake_confirmed()); + const std::string default_alpn = + AlpnForVersion(client_session_.connection()->version()); ON_CALL(client_session_, GetAlpnsToOffer()) - .WillByDefault(Return(std::vector<std::string>( - {AlpnForVersion(client_session_.connection()->version())}))); + .WillByDefault(Return(std::vector<std::string>({default_alpn}))); + ON_CALL(server_session_, SelectAlpn(_)) + .WillByDefault( + [default_alpn](const std::vector<QuicStringPiece>& alpns) { + return std::find(alpns.begin(), alpns.end(), default_alpn); + }); } MockQuicConnectionHelper conn_helper_; @@ -499,6 +506,59 @@ EXPECT_QUIC_BUG(client_stream_->CryptoConnect(), "Failed to set ALPN"); } +TEST_F(TlsHandshakerTest, ServerRequiresCustomALPN) { + static const std::string kTestAlpn = "An ALPN That Client Did Not Offer"; + EXPECT_CALL(server_session_, SelectAlpn(_)) + .WillOnce([](const std::vector<QuicStringPiece>& alpns) { + return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn); + }); + EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED, + "Server did not select ALPN", _)); + EXPECT_CALL(*server_conn_, + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Server did not receive a known ALPN", _)); + client_stream_->CryptoConnect(); + ExchangeHandshakeMessages(client_stream_, server_stream_); + + EXPECT_FALSE(client_stream_->handshake_confirmed()); + EXPECT_FALSE(client_stream_->encryption_established()); + EXPECT_FALSE(server_stream_->handshake_confirmed()); + EXPECT_FALSE(server_stream_->encryption_established()); +} + +TEST_F(TlsHandshakerTest, CustomALPNNegotiation) { + EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(client_session_, + OnCryptoHandshakeEvent(QuicSession::ENCRYPTION_ESTABLISHED)); + EXPECT_CALL(client_session_, + OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED)); + EXPECT_CALL(server_session_, + OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED)); + + static const std::string kTestAlpn = "A Custom ALPN Value"; + static const std::vector<std::string> kTestAlpns( + {"foo", "bar", kTestAlpn, "something else"}); + EXPECT_CALL(client_session_, GetAlpnsToOffer()) + .WillRepeatedly(Return(kTestAlpns)); + EXPECT_CALL(server_session_, SelectAlpn(_)) + .WillOnce([](const std::vector<QuicStringPiece>& alpns) { + EXPECT_THAT(alpns, ElementsAreArray(kTestAlpns)); + return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn); + }); + EXPECT_CALL(client_session_, OnAlpnSelected(QuicStringPiece(kTestAlpn))); + EXPECT_CALL(server_session_, OnAlpnSelected(QuicStringPiece(kTestAlpn))); + client_stream_->CryptoConnect(); + ExchangeHandshakeMessages(client_stream_, server_stream_); + + EXPECT_TRUE(client_stream_->handshake_confirmed()); + EXPECT_TRUE(client_stream_->encryption_established()); + EXPECT_TRUE(server_stream_->handshake_confirmed()); + EXPECT_TRUE(server_stream_->encryption_established()); + EXPECT_TRUE(client_conn_->IsHandshakeConfirmed()); + EXPECT_FALSE(server_conn_->IsHandshakeConfirmed()); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index a2ea397..0b4215d 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -352,41 +352,38 @@ return SSL_TLSEXT_ERR_NOACK; } - std::string expected_alpn_string = - AlpnForVersion(session()->connection()->version()); - CBS all_alpns; CBS_init(&all_alpns, in, in_len); + std::vector<QuicStringPiece> alpns; while (CBS_len(&all_alpns) > 0) { CBS alpn; if (!CBS_get_u8_length_prefixed(&all_alpns, &alpn)) { QUIC_DLOG(ERROR) << "Failed to parse ALPN length"; return SSL_TLSEXT_ERR_NOACK; } + const size_t alpn_length = CBS_len(&alpn); - if (alpn_length > - static_cast<size_t>(std::numeric_limits<uint8_t>::max())) { - QUIC_BUG << "Parsed impossible ALPN length " << alpn_length; - return SSL_TLSEXT_ERR_NOACK; - } if (alpn_length == 0) { QUIC_DLOG(ERROR) << "Received invalid zero-length ALPN"; return SSL_TLSEXT_ERR_NOACK; } - std::string alpn_string(reinterpret_cast<const char*>(CBS_data(&alpn)), - alpn_length); - if (alpn_string == expected_alpn_string) { - QUIC_DLOG(INFO) << "Server selecting ALPN '" << alpn_string << "'"; - *out_len = static_cast<uint8_t>(alpn_length); - *out = CBS_data(&alpn); - valid_alpn_received_ = true; - return SSL_TLSEXT_ERR_OK; - } + + alpns.emplace_back(reinterpret_cast<const char*>(CBS_data(&alpn)), + alpn_length); } - QUIC_DLOG(ERROR) << "No known ALPN provided by client"; - return SSL_TLSEXT_ERR_NOACK; + auto selected_alpn = session()->SelectAlpn(alpns); + if (selected_alpn == alpns.end()) { + QUIC_DLOG(ERROR) << "No known ALPN provided by client"; + return SSL_TLSEXT_ERR_NOACK; + } + + session()->OnAlpnSelected(*selected_alpn); + valid_alpn_received_ = true; + *out_len = selected_alpn->size(); + *out = reinterpret_cast<const uint8_t*>(selected_alpn->data()); + return SSL_TLSEXT_ERR_OK; } } // namespace quic
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index bbde859..5750b9b 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -633,7 +633,11 @@ MOCK_CONST_METHOD0(ShouldKeepConnectionAlive, bool()); MOCK_METHOD2(SendStopSending, void(uint16_t code, QuicStreamId stream_id)); MOCK_METHOD1(OnCryptoHandshakeEvent, void(QuicSession::CryptoHandshakeEvent)); - MOCK_METHOD0(GetAlpnsToOffer, std::vector<std::string>()); + MOCK_CONST_METHOD0(GetAlpnsToOffer, std::vector<std::string>()); + MOCK_CONST_METHOD1(SelectAlpn, + std::vector<QuicStringPiece>::const_iterator( + const std::vector<QuicStringPiece>&)); + MOCK_METHOD1(OnAlpnSelected, void(QuicStringPiece)); using QuicSession::ActivateStream;