Divert all QUIC key derivation to use heapless functions.

This allows us to eliminate the versions that use the heap.

Protected by FLAGS_gfe2_reloadable_flag_heapless_key_derivation.

PiperOrigin-RevId: 761109514
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h
index 0c70217..2af1ea6 100755
--- a/quiche/common/quiche_feature_flags_list.h
+++ b/quiche/common/quiche_feature_flags_list.h
@@ -36,6 +36,7 @@
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_enable_version_rfcv2, false, false, "When true, support RFC9369.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_fin_before_completed_http_headers, false, true, "If true, close the connection with error if FIN is received before finish receiving the whole HTTP headers.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_fix_timeouts, true, true, "If true, postpone setting handshake timeout to infinite to handshake complete.")
+QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_key_derivation, false, false, "If true, QUIC key derivation uses heapless crypto utils.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_obfuscator, false, true, "If true, generates QUIC initial obfuscators with no heap allocations.")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_static_parser, false, false, "If true, stops parsing immediately on unknown version, to avoid a potential malloc when parsing the connection ID")
 QUICHE_FLAG(bool, quiche_reloadable_flag_quic_ignore_gquic_probing, true, true, "If true, QUIC server will not respond to gQUIC probing packet(PING + PADDING) but treat it as a regular packet.")
diff --git a/quiche/quic/core/crypto/crypto_utils.cc b/quiche/quic/core/crypto/crypto_utils.cc
index 51ee57b..8aecaef 100644
--- a/quiche/quic/core/crypto/crypto_utils.cc
+++ b/quiche/quic/core/crypto/crypto_utils.cc
@@ -50,6 +50,9 @@
 
 namespace {
 
+inline constexpr size_t kMaxKeySize = 32;
+inline constexpr size_t kMaxIVSize = 12;
+
 // Implements the HKDF-Expand-Label function as defined in section 7.1 of RFC
 // 8446. The HKDF-Expand-Label function takes 4 explicit arguments (Secret,
 // Label, Context, and Length), as well as implicit PRF which is the hash
@@ -97,6 +100,10 @@
 std::vector<uint8_t> HkdfExpandLabel(const EVP_MD* prf,
                                      absl::Span<const uint8_t> secret,
                                      const std::string& label, size_t out_len) {
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    // This value should be zero; the flag should eliminate all paths to here.
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 1, 10);
+  }
   bssl::ScopedCBB quic_hkdf_label;
   CBB inner_label;
   const char label_prefix[] = "tls13 ";
@@ -158,6 +165,10 @@
 // TODO(martinduke): Delete this.
 std::string getLabelForVersion(const ParsedQuicVersion& version,
                                const absl::string_view& predicate) {
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    // This value should be zero; the flag should eliminate all paths to here.
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 2, 10);
+  }
   static_assert(SupportedVersions().size() == 4u,
                 "Supported versions out of sync with HKDF labels");
   if (version == ParsedQuicVersion::RFCv2()) {
@@ -174,7 +185,7 @@
                                            QuicCrypter* crypter) {
   QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_obfuscator, 4, 7);
   SetKeyAndIVHeapless(prf, pp_secret, version, crypter);
-  uint8_t header_protection_key[16];
+  uint8_t header_protection_key[kMaxKeySize];
   QUIC_BUG_IF(quic_bug_hp_length_mismatch,
               crypter->GetKeySize() > sizeof(header_protection_key))
       << "HP length does not match crypter";
@@ -189,6 +200,10 @@
 void CryptoUtils::InitializeCrypterSecrets(
     const EVP_MD* prf, const std::vector<uint8_t>& pp_secret,
     const ParsedQuicVersion& version, QuicCrypter* crypter) {
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    // This value should be zero; the flag should eliminate all paths to here.
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 3, 10);
+  }
   SetKeyAndIV(prf, pp_secret, version, crypter);
   std::vector<uint8_t> header_protection_key = GenerateHeaderProtectionKey(
       prf, pp_secret, version, crypter->GetKeySize());
@@ -203,7 +218,7 @@
                                       const ParsedQuicVersion& version,
                                       QuicCrypter* crypter) {
   QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_obfuscator, 5, 7);
-  uint8_t key[16];
+  uint8_t key[kMaxKeySize];
   QUIC_BUG_IF(quic_bug_key_length_mismatch, crypter->GetKeySize() > sizeof(key))
       << "Key length does not match crypter";
 
@@ -214,7 +229,7 @@
       absl::Span<char>(version_label_raw, kMaxVersionLabelLength));
   HkdfExpandLabel(prf, pp_secret, version_label,
                   absl::Span<uint8_t>(key, crypter->GetKeySize()));
-  uint8_t iv[12];
+  uint8_t iv[kMaxIVSize];
   QUIC_BUG_IF(quic_bug_iv_length_mismatch, crypter->GetIVSize() > sizeof(iv))
       << "IV length does not match crypter";
   constexpr absl::string_view kIvPredicate = "iv";
@@ -233,6 +248,10 @@
                               absl::Span<const uint8_t> pp_secret,
                               const ParsedQuicVersion& version,
                               QuicCrypter* crypter) {
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    // This value should be zero; the flag should eliminate all paths to here.
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 4, 10);
+  }
   std::vector<uint8_t> key =
       HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "key"),
                       crypter->GetKeySize());
@@ -260,6 +279,10 @@
 std::vector<uint8_t> CryptoUtils::GenerateHeaderProtectionKey(
     const EVP_MD* prf, absl::Span<const uint8_t> pp_secret,
     const ParsedQuicVersion& version, size_t out_len) {
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    // This value should be zero; the flag should eliminate all paths to here.
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 5, 10);
+  }
   return HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "hp"),
                          out_len);
 }
@@ -280,6 +303,10 @@
 std::vector<uint8_t> CryptoUtils::GenerateNextKeyPhaseSecret(
     const EVP_MD* prf, const ParsedQuicVersion& version,
     const std::vector<uint8_t>& current_secret) {
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    // This value should be zero; the flag should eliminate all paths to here.
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 6, 10);
+  }
   return HkdfExpandLabel(prf, current_secret, getLabelForVersion(version, "ku"),
                          current_secret.size());
 }
diff --git a/quiche/quic/core/crypto/crypto_utils_test.cc b/quiche/quic/core/crypto/crypto_utils_test.cc
index e585d57..39862b7 100644
--- a/quiche/quic/core/crypto/crypto_utils_test.cc
+++ b/quiche/quic/core/crypto/crypto_utils_test.cc
@@ -4,19 +4,20 @@
 
 #include "quiche/quic/core/crypto/crypto_utils.h"
 
+#include <cstdint>
 #include <memory>
 #include <string>
+#include <vector>
 
-#include "absl/base/macros.h"
-#include "absl/strings/escaping.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "openssl/err.h"
 #include "openssl/ssl.h"
-#include "quiche/quic/core/quic_utils.h"
+#include "quiche/quic/core/crypto/crypto_protocol.h"
+#include "quiche/quic/core/crypto/quic_encrypter.h"
+#include "quiche/quic/core/quic_versions.h"
 #include "quiche/quic/platform/api/quic_test.h"
-#include "quiche/quic/test_tools/quic_test_utils.h"
-#include "quiche/common/test_tools/quiche_test_utils.h"
 
 namespace quic {
 namespace test {
@@ -364,6 +365,26 @@
   ERR_clear_error();
 }
 
+// The heapless version of GenerateNextKeyPhaseSecret is sometimes called so
+// that the output is written into the memory that contains the input. This
+// test verifies that the result is no different than if the output goes
+// elsewhere.
+TEST_F(CryptoUtilsTest, NextKeyPhaseOnItself) {
+  std::vector<uint8_t> start_secret = {1, 2,  3,  4,  5,  6,  7,  8,
+                                       9, 10, 11, 12, 13, 14, 15, 16};
+  std::vector<uint8_t> separate_result;
+  separate_result.resize(start_secret.size());
+  CryptoUtils::GenerateNextKeyPhaseSecret(
+      EVP_sha256(), ParsedQuicVersion::RFCv1(),
+      absl::Span<const uint8_t>(start_secret),
+      absl::Span<uint8_t>(separate_result));
+  CryptoUtils::GenerateNextKeyPhaseSecret(
+      EVP_sha256(), ParsedQuicVersion::RFCv1(),
+      absl::Span<const uint8_t>(start_secret),
+      absl::Span<uint8_t>(start_secret));
+  EXPECT_EQ(separate_result, start_secret);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quiche/quic/core/tls_handshaker.cc b/quiche/quic/core/tls_handshaker.cc
index 4a64c7c..9529230 100644
--- a/quiche/quic/core/tls_handshaker.cc
+++ b/quiche/quic/core/tls_handshaker.cc
@@ -4,19 +4,21 @@
 
 #include "quiche/quic/core/tls_handshaker.h"
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <utility>
 #include <vector>
 
-#include "absl/base/macros.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "openssl/crypto.h"
 #include "openssl/ssl.h"
+#include "quiche/quic/core/crypto/crypto_utils.h"
+#include "quiche/quic/core/crypto/quic_decrypter.h"
 #include "quiche/quic/core/quic_crypto_stream.h"
 #include "quiche/quic/platform/api/quic_bug_tracker.h"
-#include "quiche/quic/platform/api/quic_stack_trace.h"
 
 namespace quic {
 
@@ -283,13 +285,24 @@
   std::unique_ptr<QuicEncrypter> encrypter =
       QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
   const EVP_MD* prf = Prf(cipher);
-  CryptoUtils::SetKeyAndIV(prf, write_secret,
-                           handshaker_delegate_->parsed_version(),
-                           encrypter.get());
-  std::vector<uint8_t> header_protection_key =
-      CryptoUtils::GenerateHeaderProtectionKey(
-          prf, write_secret, handshaker_delegate_->parsed_version(),
-          encrypter->GetKeySize());
+  std::vector<uint8_t> header_protection_key;
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 7, 10);
+    CryptoUtils::SetKeyAndIVHeapless(prf, write_secret,
+                                     handshaker_delegate_->parsed_version(),
+                                     encrypter.get());
+    header_protection_key.resize(encrypter->GetKeySize());
+    CryptoUtils::GenerateHeaderProtectionKey(
+        prf, write_secret, handshaker_delegate_->parsed_version(),
+        absl::Span<uint8_t>(header_protection_key));
+  } else {
+    CryptoUtils::SetKeyAndIV(prf, write_secret,
+                             handshaker_delegate_->parsed_version(),
+                             encrypter.get());
+    header_protection_key = CryptoUtils::GenerateHeaderProtectionKey(
+        prf, write_secret, handshaker_delegate_->parsed_version(),
+        encrypter->GetKeySize());
+  }
   encrypter->SetHeaderProtectionKey(
       absl::string_view(reinterpret_cast<char*>(header_protection_key.data()),
                         header_protection_key.size()));
@@ -315,13 +328,24 @@
   std::unique_ptr<QuicDecrypter> decrypter =
       QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
   const EVP_MD* prf = Prf(cipher);
-  CryptoUtils::SetKeyAndIV(prf, read_secret,
-                           handshaker_delegate_->parsed_version(),
-                           decrypter.get());
-  std::vector<uint8_t> header_protection_key =
-      CryptoUtils::GenerateHeaderProtectionKey(
-          prf, read_secret, handshaker_delegate_->parsed_version(),
-          decrypter->GetKeySize());
+  std::vector<uint8_t> header_protection_key;
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 8, 10);
+    CryptoUtils::SetKeyAndIVHeapless(prf, read_secret,
+                                     handshaker_delegate_->parsed_version(),
+                                     decrypter.get());
+    header_protection_key.resize(decrypter->GetKeySize());
+    CryptoUtils::GenerateHeaderProtectionKey(
+        prf, read_secret, handshaker_delegate_->parsed_version(),
+        absl::Span<uint8_t>(header_protection_key));
+  } else {
+    CryptoUtils::SetKeyAndIV(prf, read_secret,
+                             handshaker_delegate_->parsed_version(),
+                             decrypter.get());
+    header_protection_key = CryptoUtils::GenerateHeaderProtectionKey(
+        prf, read_secret, handshaker_delegate_->parsed_version(),
+        decrypter->GetKeySize());
+  }
   decrypter->SetHeaderProtectionKey(
       absl::string_view(reinterpret_cast<char*>(header_protection_key.data()),
                         header_protection_key.size()));
@@ -348,16 +372,30 @@
   }
   const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
   const EVP_MD* prf = Prf(cipher);
-  latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
-      prf, handshaker_delegate_->parsed_version(), latest_read_secret_);
-  latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
-      prf, handshaker_delegate_->parsed_version(), latest_write_secret_);
+  std::unique_ptr<QuicDecrypter> decrypter;
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 9, 10);
+    CryptoUtils::GenerateNextKeyPhaseSecret(
+        prf, handshaker_delegate_->parsed_version(), latest_read_secret_,
+        absl::Span<uint8_t>(latest_read_secret_));
+    CryptoUtils::GenerateNextKeyPhaseSecret(
+        prf, handshaker_delegate_->parsed_version(), latest_write_secret_,
+        absl::Span<uint8_t>(latest_write_secret_));
+    decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
+    CryptoUtils::SetKeyAndIVHeapless(prf, latest_read_secret_,
+                                     handshaker_delegate_->parsed_version(),
+                                     decrypter.get());
+  } else {
+    latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
+        prf, handshaker_delegate_->parsed_version(), latest_read_secret_);
+    latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
+        prf, handshaker_delegate_->parsed_version(), latest_write_secret_);
+    decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
+    CryptoUtils::SetKeyAndIV(prf, latest_read_secret_,
+                             handshaker_delegate_->parsed_version(),
+                             decrypter.get());
+  }
 
-  std::unique_ptr<QuicDecrypter> decrypter =
-      QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
-  CryptoUtils::SetKeyAndIV(prf, latest_read_secret_,
-                           handshaker_delegate_->parsed_version(),
-                           decrypter.get());
   decrypter->SetHeaderProtectionKey(absl::string_view(
       reinterpret_cast<char*>(one_rtt_read_header_protection_key_.data()),
       one_rtt_read_header_protection_key_.size()));
@@ -376,9 +414,16 @@
   const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
   std::unique_ptr<QuicEncrypter> encrypter =
       QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
-  CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_,
-                           handshaker_delegate_->parsed_version(),
-                           encrypter.get());
+  if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 10, 10);
+    CryptoUtils::SetKeyAndIVHeapless(Prf(cipher), latest_write_secret_,
+                                     handshaker_delegate_->parsed_version(),
+                                     encrypter.get());
+  } else {
+    CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_,
+                             handshaker_delegate_->parsed_version(),
+                             encrypter.get());
+  }
   encrypter->SetHeaderProtectionKey(absl::string_view(
       reinterpret_cast<char*>(one_rtt_write_header_protection_key_.data()),
       one_rtt_write_header_protection_key_.size()));