Make QUIC enforce ALPN when using TLS handshake

gfe-relnote: enforce ALPN when using TLS, protected by disabled quic_tls flag
PiperOrigin-RevId: 261159061
Change-Id: I9ccdd221e92beae2b83677e692e3c6d084351731
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 735fd3b..2e015c8 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -15,6 +15,8 @@
 
 namespace quic {
 
+std::string* quic_alpn_override_on_client_for_tests = nullptr;
+
 TlsClientHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl(
     TlsClientHandshaker* parent)
     : parent_(parent) {}
@@ -77,29 +79,36 @@
   }
 
   std::string alpn_string = AlpnForVersion(session()->connection()->version());
+  if (quic_alpn_override_on_client_for_tests != nullptr) {
+    alpn_string = *quic_alpn_override_on_client_for_tests;
+  }
   if (alpn_string.length() > std::numeric_limits<uint8_t>::max()) {
     QUIC_BUG << "ALPN too long: '" << alpn_string << "'";
-    CloseConnection(QUIC_HANDSHAKE_FAILED, "ALPN too long");
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Client configured ALPN is too long");
     return false;
   }
   const uint8_t alpn_length = alpn_string.length();
-  // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed strings
-  // so we copy alpn_string to a new buffer that has the length in alpn[0].
-  uint8_t alpn[std::numeric_limits<uint8_t>::max() + 1];
-  alpn[0] = alpn_length;
-  memcpy(reinterpret_cast<char*>(alpn + 1), alpn_string.data(), alpn_length);
-  if (SSL_set_alpn_protos(ssl(), alpn,
-                          static_cast<unsigned>(alpn_length) + 1) != 0) {
-    QUIC_BUG << "Failed to set ALPN: '" << alpn_string << "'";
-    CloseConnection(QUIC_HANDSHAKE_FAILED, "Failed to set ALPN");
-    return false;
+  if (alpn_length > 0) {
+    // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed
+    // strings so we copy alpn_string to a new buffer that has the length
+    // in alpn[0].
+    uint8_t alpn[std::numeric_limits<uint8_t>::max() + 1];
+    alpn[0] = alpn_length;
+    memcpy(reinterpret_cast<char*>(alpn + 1), alpn_string.data(), alpn_length);
+    if (SSL_set_alpn_protos(ssl(), alpn,
+                            static_cast<unsigned>(alpn_length) + 1) != 0) {
+      QUIC_BUG << "Failed to set ALPN: '" << alpn_string << "'";
+      CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN");
+      return false;
+    }
   }
   QUIC_DLOG(INFO) << "Client using ALPN: '" << alpn_string << "'";
 
   // Set the Transport Parameters to send in the ClientHello
   if (!SetTransportParameters()) {
     CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Failed to set Transport Parameters");
+                    "Client failed to set Transport Parameters");
     return false;
   }
 
@@ -206,7 +215,8 @@
     return;
   }
   if (state_ == STATE_IDLE) {
-    CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failed");
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Client observed TLS handshake idle failure");
     return;
   }
   if (state_ == STATE_HANDSHAKE_COMPLETE) {
@@ -236,7 +246,8 @@
     // TODO(nharper): Surface error details from the error queue when ssl_error
     // is SSL_ERROR_SSL.
     QUIC_LOG(WARNING) << "SSL_do_handshake failed; closing connection";
-    CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failed");
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Client observed TLS handshake failure");
   }
 }
 
@@ -261,25 +272,31 @@
   const uint8_t* alpn_data = nullptr;
   unsigned alpn_length = 0;
   SSL_get0_alpn_selected(ssl(), &alpn_data, &alpn_length);
-  // TODO(b/130164908) Act on ALPN.
-  if (alpn_length != 0) {
-    std::string received_alpn_string(reinterpret_cast<const char*>(alpn_data),
-                                     alpn_length);
-    std::string sent_alpn_string =
-        AlpnForVersion(session()->connection()->version());
-    if (received_alpn_string != sent_alpn_string) {
-      QUIC_LOG(ERROR) << "Client: received mismatched ALPN '"
-                      << received_alpn_string << "', expected '"
-                      << sent_alpn_string << "'";
-      CloseConnection(QUIC_HANDSHAKE_FAILED, "Mismatched ALPN");
-      return;
-    }
-    QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string
-                    << "'";
-  } else {
-    QUIC_DLOG(INFO) << "Client: server did not select ALPN";
+
+  if (alpn_length == 0) {
+    QUIC_DLOG(ERROR) << "Client: server did not select ALPN";
+    // TODO(b/130164908) this should send no_application_protocol
+    // instead of QUIC_HANDSHAKE_FAILED.
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "Server did not select ALPN");
+    return;
   }
 
+  std::string received_alpn_string(reinterpret_cast<const char*>(alpn_data),
+                                   alpn_length);
+  std::string sent_alpn_string =
+      AlpnForVersion(session()->connection()->version());
+  if (received_alpn_string != sent_alpn_string) {
+    QUIC_LOG(ERROR) << "Client: received mismatched ALPN '"
+                    << received_alpn_string << "', expected '"
+                    << sent_alpn_string << "'";
+    // TODO(b/130164908) this should send no_application_protocol
+    // instead of QUIC_HANDSHAKE_FAILED.
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "Client received mismatched ALPN");
+    return;
+  }
+  QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string
+                  << "'";
+
   session()->connection()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session()->NeuterUnencryptedData();
   encryption_established_ = true;
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index c22fbe9..d8d8499 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -121,6 +121,10 @@
   TlsClientConnection tls_connection_;
 };
 
+// Allows tests to override the ALPN used by clients.
+// DO NOT use outside of tests.
+extern std::string* quic_alpn_override_on_client_for_tests;
+
 }  // namespace quic
 
 #endif  // QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index f15a266..67dd47f 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -430,6 +430,42 @@
   EXPECT_FALSE(server_stream_->handshake_confirmed());
 }
 
+TEST_F(TlsHandshakerTest, ClientNotSendingALPN) {
+  static std::string kTestClientNoAlpn = "";
+  quic_alpn_override_on_client_for_tests = &kTestClientNoAlpn;
+  EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
+                                             "Server did not select ALPN", _));
+  EXPECT_CALL(*server_conn_,
+              CloseConnection(QUIC_HANDSHAKE_FAILED,
+                              "Server did not receive a known ALPN", _));
+  client_stream_->CryptoConnect();
+  ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+  EXPECT_FALSE(client_stream_->handshake_confirmed());
+  EXPECT_FALSE(client_stream_->encryption_established());
+  EXPECT_FALSE(server_stream_->handshake_confirmed());
+  EXPECT_FALSE(server_stream_->encryption_established());
+  quic_alpn_override_on_client_for_tests = nullptr;
+}
+
+TEST_F(TlsHandshakerTest, ClientSendingBadALPN) {
+  static std::string kTestBadClientAlpn = "bad-client-alpn";
+  quic_alpn_override_on_client_for_tests = &kTestBadClientAlpn;
+  EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
+                                             "Server did not select ALPN", _));
+  EXPECT_CALL(*server_conn_,
+              CloseConnection(QUIC_HANDSHAKE_FAILED,
+                              "Server did not receive a known ALPN", _));
+  client_stream_->CryptoConnect();
+  ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+  EXPECT_FALSE(client_stream_->handshake_confirmed());
+  EXPECT_FALSE(client_stream_->encryption_established());
+  EXPECT_FALSE(server_stream_->handshake_confirmed());
+  EXPECT_FALSE(server_stream_->encryption_established());
+  quic_alpn_override_on_client_for_tests = nullptr;
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 0e7d9e5..0f2c0d2 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -61,7 +61,7 @@
 
   if (!SetTransportParameters()) {
     CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Failed to set Transport Parameters");
+                    "Server failed to set Transport Parameters");
   }
 }
 
@@ -170,7 +170,8 @@
     QUIC_LOG(WARNING) << "SSL_do_handshake failed; SSL_get_error returns "
                       << ssl_error << ", state_ = " << state_;
     ERR_print_errors_fp(stderr);
-    CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failed");
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Server observed TLS handshake failure");
   }
 }
 
@@ -239,6 +240,16 @@
 }
 
 void TlsServerHandshaker::FinishHandshake() {
+  if (!valid_alpn_received_) {
+    QUIC_DLOG(ERROR)
+        << "Server: handshake finished without receiving a known ALPN";
+    // TODO(b/130164908) this should send no_application_protocol
+    // instead of QUIC_HANDSHAKE_FAILED.
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Server did not receive a known ALPN");
+    return;
+  }
+
   QUIC_LOG(INFO) << "Server: handshake finished";
   state_ = STATE_HANDSHAKE_COMPLETE;
 
@@ -328,26 +339,48 @@
                                     const uint8_t* in,
                                     unsigned in_len) {
   // |in| contains a sequence of 1-byte-length-prefixed values.
-  // We currently simply return the first provided ALPN value.
-  // TODO(b/130164908) Act on ALPN.
+  *out_len = 0;
+  *out = nullptr;
   if (in_len == 0) {
-    *out_len = 0;
-    *out = nullptr;
-    QUIC_DLOG(INFO) << "No ALPN provided";
-    return SSL_TLSEXT_ERR_OK;
+    QUIC_DLOG(ERROR) << "No ALPN provided by client";
+    return SSL_TLSEXT_ERR_NOACK;
   }
-  const uint8_t first_alpn_length = in[0];
-  if (static_cast<unsigned>(first_alpn_length) > in_len - 1) {
-    QUIC_LOG(ERROR) << "Failed to parse ALPN";
-    return SSL_TLSEXT_ERR_ALERT_FATAL;
+
+  std::string expected_alpn_string =
+      AlpnForVersion(session()->connection()->version());
+
+  CBS all_alpns;
+  CBS_init(&all_alpns, in, in_len);
+
+  while (CBS_len(&all_alpns) > 0) {
+    CBS alpn;
+    if (!CBS_get_u8_length_prefixed(&all_alpns, &alpn)) {
+      QUIC_DLOG(ERROR) << "Failed to parse ALPN length";
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+    const size_t alpn_length = CBS_len(&alpn);
+    if (alpn_length >
+        static_cast<size_t>(std::numeric_limits<uint8_t>::max())) {
+      QUIC_BUG << "Parsed impossible ALPN length " << alpn_length;
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+    if (alpn_length == 0) {
+      QUIC_DLOG(ERROR) << "Received invalid zero-length ALPN";
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+    std::string alpn_string(reinterpret_cast<const char*>(CBS_data(&alpn)),
+                            alpn_length);
+    if (alpn_string == expected_alpn_string) {
+      QUIC_DLOG(INFO) << "Server selecting ALPN '" << alpn_string << "'";
+      *out_len = static_cast<uint8_t>(alpn_length);
+      *out = CBS_data(&alpn);
+      valid_alpn_received_ = true;
+      return SSL_TLSEXT_ERR_OK;
+    }
   }
-  *out_len = first_alpn_length;
-  *out = in + 1;
-  QUIC_DLOG(INFO) << "Server selecting ALPN '"
-                  << QuicStringPiece(reinterpret_cast<const char*>(*out),
-                                     *out_len)
-                  << "'";
-  return SSL_TLSEXT_ERR_OK;
+
+  QUIC_DLOG(ERROR) << "No known ALPN provided by client";
+  return SSL_TLSEXT_ERR_NOACK;
 }
 
 }  // namespace quic
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 71fe6dd..5ce699f 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -123,6 +123,7 @@
 
   bool encryption_established_ = false;
   bool handshake_confirmed_ = false;
+  bool valid_alpn_received_ = false;
   QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters>
       crypto_negotiated_params_;
   TlsServerConnection tls_connection_;