Support session-specific ALPNs for clients, and specifying multiple ALPNs.

gfe-relnote: n/a (no functional change)
PiperOrigin-RevId: 266192245
Change-Id: Id0921d36b2c32c7df92a4a528ec19d9a28b826b0
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index a402057..d5ddc69 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -459,6 +459,13 @@
         max_stream + num_expected_unidirectional_static_streams_);
   }
 
+  // Returns the ALPN values to negotiate on this session.
+  virtual std::vector<std::string> GetAlpnsToOffer() {
+    // TODO(vasilvv): this currently sets HTTP/3 by default.  Switch all
+    // non-HTTP applications to appropriate ALPNs.
+    return std::vector<std::string>({AlpnForVersion(connection()->version())});
+  }
+
  protected:
   using StreamMap = QuicSmallMap<QuicStreamId, std::unique_ptr<QuicStream>, 10>;
 
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 8e2d2bb..a410f81 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -15,8 +15,6 @@
 
 namespace quic {
 
-std::string* quic_alpn_override_on_client_for_tests = nullptr;
-
 TlsClientHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl(
     TlsClientHandshaker* parent)
     : parent_(parent) {}
@@ -78,32 +76,10 @@
     return false;
   }
 
-  std::string alpn_string = AlpnForVersion(session()->connection()->version());
-  if (quic_alpn_override_on_client_for_tests != nullptr) {
-    alpn_string = *quic_alpn_override_on_client_for_tests;
-  }
-  if (alpn_string.length() > std::numeric_limits<uint8_t>::max()) {
-    QUIC_BUG << "ALPN too long: '" << alpn_string << "'";
-    CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Client configured ALPN is too long");
+  if (!SetAlpn()) {
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN");
     return false;
   }
-  const uint8_t alpn_length = alpn_string.length();
-  if (alpn_length > 0) {
-    // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed
-    // strings so we copy alpn_string to a new buffer that has the length
-    // in alpn[0].
-    uint8_t alpn[std::numeric_limits<uint8_t>::max() + 1];
-    alpn[0] = alpn_length;
-    memcpy(reinterpret_cast<char*>(alpn + 1), alpn_string.data(), alpn_length);
-    if (SSL_set_alpn_protos(ssl(), alpn,
-                            static_cast<unsigned>(alpn_length) + 1) != 0) {
-      QUIC_BUG << "Failed to set ALPN: '" << alpn_string << "'";
-      CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN");
-      return false;
-    }
-  }
-  QUIC_DLOG(INFO) << "Client using ALPN: '" << alpn_string << "'";
 
   // Set the Transport Parameters to send in the ClientHello
   if (!SetTransportParameters()) {
@@ -117,6 +93,46 @@
   return session()->connection()->connected();
 }
 
+static bool IsValidAlpn(const std::string& alpn_string) {
+  return alpn_string.length() <= std::numeric_limits<uint8_t>::max();
+}
+
+bool TlsClientHandshaker::SetAlpn() {
+  std::vector<std::string> alpns = session()->GetAlpnsToOffer();
+  if (alpns.empty()) {
+    if (allow_empty_alpn_for_tests_) {
+      return true;
+    }
+
+    QUIC_BUG << "ALPN missing";
+    return false;
+  }
+  if (!std::all_of(alpns.begin(), alpns.end(), IsValidAlpn)) {
+    QUIC_BUG << "ALPN too long";
+    return false;
+  }
+
+  // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed
+  // strings.
+  uint8_t alpn[1024];
+  QuicDataWriter alpn_writer(sizeof(alpn), reinterpret_cast<char*>(alpn));
+  bool success = true;
+  for (const std::string& alpn_string : alpns) {
+    success = success && alpn_writer.WriteUInt8(alpn_string.size()) &&
+              alpn_writer.WriteStringPiece(alpn_string);
+  }
+  success =
+      success && (SSL_set_alpn_protos(ssl(), alpn, alpn_writer.length()) == 0);
+  if (!success) {
+    QUIC_BUG << "Failed to set ALPN: "
+             << QuicTextUtils::HexDump(
+                    QuicStringPiece(alpn_writer.data(), alpn_writer.length()));
+    return false;
+  }
+  QUIC_DLOG(INFO) << "Client using ALPN: '" << alpns[0] << "'";
+  return true;
+}
+
 bool TlsClientHandshaker::SetTransportParameters() {
   TransportParameters params;
   params.perspective = Perspective::IS_CLIENT;
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index 47faf81..f3e90ce 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -54,6 +54,8 @@
   CryptoMessageParser* crypto_message_parser() override;
   size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
 
+  void AllowEmptyAlpnForTests() { allow_empty_alpn_for_tests_ = true; }
+
  protected:
   const TlsConnection* tls_connection() const override {
     return &tls_connection_;
@@ -95,6 +97,7 @@
     STATE_CONNECTION_CLOSED,
   } state_ = STATE_IDLE;
 
+  bool SetAlpn();
   bool SetTransportParameters();
   bool ProcessTransportParameters(std::string* error_details);
   void FinishHandshake();
@@ -121,13 +124,11 @@
   QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters>
       crypto_negotiated_params_;
 
+  bool allow_empty_alpn_for_tests_ = false;
+
   TlsClientConnection tls_connection_;
 };
 
-// Allows tests to override the ALPN used by clients.
-// DO NOT use outside of tests.
-QUIC_EXPORT_PRIVATE extern std::string* quic_alpn_override_on_client_for_tests;
-
 }  // namespace quic
 
 #endif  // QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index 00c62d6..ef7f31b 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -10,6 +10,7 @@
 #include "net/third_party/quiche/src/quic/core/tls_client_handshaker.h"
 #include "net/third_party/quiche/src/quic/core/tls_server_handshaker.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_arraysize.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_expect_bug.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
 #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h"
@@ -22,6 +23,7 @@
 namespace {
 
 using ::testing::_;
+using ::testing::Return;
 
 class FakeProofVerifier : public ProofVerifier {
  public:
@@ -225,6 +227,7 @@
   ~TestQuicCryptoClientStream() override = default;
 
   TlsHandshaker* handshaker() const override { return handshaker_.get(); }
+  TlsClientHandshaker* client_handshaker() const { return handshaker_.get(); }
 
   bool CryptoConnect() { return handshaker_->CryptoConnect(); }
 
@@ -302,6 +305,9 @@
     EXPECT_FALSE(client_stream_->handshake_confirmed());
     EXPECT_FALSE(server_stream_->encryption_established());
     EXPECT_FALSE(server_stream_->handshake_confirmed());
+    ON_CALL(client_session_, GetAlpnsToOffer())
+        .WillByDefault(Return(std::vector<std::string>(
+            {AlpnForVersion(client_session_.connection()->version())})));
   }
 
   MockQuicConnectionHelper conn_helper_;
@@ -442,8 +448,9 @@
 }
 
 TEST_F(TlsHandshakerTest, ClientNotSendingALPN) {
-  static std::string kTestClientNoAlpn = "";
-  quic_alpn_override_on_client_for_tests = &kTestClientNoAlpn;
+  client_stream_->client_handshaker()->AllowEmptyAlpnForTests();
+  EXPECT_CALL(client_session_, GetAlpnsToOffer())
+      .WillOnce(Return(std::vector<std::string>()));
   EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
                                              "Server did not select ALPN", _));
   EXPECT_CALL(*server_conn_,
@@ -456,12 +463,12 @@
   EXPECT_FALSE(client_stream_->encryption_established());
   EXPECT_FALSE(server_stream_->handshake_confirmed());
   EXPECT_FALSE(server_stream_->encryption_established());
-  quic_alpn_override_on_client_for_tests = nullptr;
 }
 
 TEST_F(TlsHandshakerTest, ClientSendingBadALPN) {
   static std::string kTestBadClientAlpn = "bad-client-alpn";
-  quic_alpn_override_on_client_for_tests = &kTestBadClientAlpn;
+  EXPECT_CALL(client_session_, GetAlpnsToOffer())
+      .WillOnce(Return(std::vector<std::string>({kTestBadClientAlpn})));
   EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
                                              "Server did not select ALPN", _));
   EXPECT_CALL(*server_conn_,
@@ -474,7 +481,22 @@
   EXPECT_FALSE(client_stream_->encryption_established());
   EXPECT_FALSE(server_stream_->handshake_confirmed());
   EXPECT_FALSE(server_stream_->encryption_established());
-  quic_alpn_override_on_client_for_tests = nullptr;
+}
+
+TEST_F(TlsHandshakerTest, ClientSendingTooManyALPNs) {
+  std::string long_alpn(250, 'A');
+  EXPECT_CALL(client_session_, GetAlpnsToOffer())
+      .WillOnce(Return(std::vector<std::string>({
+          long_alpn + "1",
+          long_alpn + "2",
+          long_alpn + "3",
+          long_alpn + "4",
+          long_alpn + "5",
+          long_alpn + "6",
+          long_alpn + "7",
+          long_alpn + "8",
+      })));
+  EXPECT_QUIC_BUG(client_stream_->CryptoConnect(), "Failed to set ALPN");
 }
 
 }  // namespace
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 40b8036..bbde859 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -633,6 +633,7 @@
   MOCK_CONST_METHOD0(ShouldKeepConnectionAlive, bool());
   MOCK_METHOD2(SendStopSending, void(uint16_t code, QuicStreamId stream_id));
   MOCK_METHOD1(OnCryptoHandshakeEvent, void(QuicSession::CryptoHandshakeEvent));
+  MOCK_METHOD0(GetAlpnsToOffer, std::vector<std::string>());
 
   using QuicSession::ActivateStream;