Divert all QUIC key derivation to use heapless functions. This allows us to eliminate the versions that use the heap. Protected by FLAGS_gfe2_reloadable_flag_heapless_key_derivation. PiperOrigin-RevId: 761109514
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h index 0c70217..2af1ea6 100755 --- a/quiche/common/quiche_feature_flags_list.h +++ b/quiche/common/quiche_feature_flags_list.h
@@ -36,6 +36,7 @@ QUICHE_FLAG(bool, quiche_reloadable_flag_quic_enable_version_rfcv2, false, false, "When true, support RFC9369.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_fin_before_completed_http_headers, false, true, "If true, close the connection with error if FIN is received before finish receiving the whole HTTP headers.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_fix_timeouts, true, true, "If true, postpone setting handshake timeout to infinite to handshake complete.") +QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_key_derivation, false, false, "If true, QUIC key derivation uses heapless crypto utils.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_obfuscator, false, true, "If true, generates QUIC initial obfuscators with no heap allocations.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_static_parser, false, false, "If true, stops parsing immediately on unknown version, to avoid a potential malloc when parsing the connection ID") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_ignore_gquic_probing, true, true, "If true, QUIC server will not respond to gQUIC probing packet(PING + PADDING) but treat it as a regular packet.")
diff --git a/quiche/quic/core/crypto/crypto_utils.cc b/quiche/quic/core/crypto/crypto_utils.cc index 51ee57b..8aecaef 100644 --- a/quiche/quic/core/crypto/crypto_utils.cc +++ b/quiche/quic/core/crypto/crypto_utils.cc
@@ -50,6 +50,9 @@ namespace { +inline constexpr size_t kMaxKeySize = 32; +inline constexpr size_t kMaxIVSize = 12; + // Implements the HKDF-Expand-Label function as defined in section 7.1 of RFC // 8446. The HKDF-Expand-Label function takes 4 explicit arguments (Secret, // Label, Context, and Length), as well as implicit PRF which is the hash @@ -97,6 +100,10 @@ std::vector<uint8_t> HkdfExpandLabel(const EVP_MD* prf, absl::Span<const uint8_t> secret, const std::string& label, size_t out_len) { + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + // This value should be zero; the flag should eliminate all paths to here. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 1, 10); + } bssl::ScopedCBB quic_hkdf_label; CBB inner_label; const char label_prefix[] = "tls13 "; @@ -158,6 +165,10 @@ // TODO(martinduke): Delete this. std::string getLabelForVersion(const ParsedQuicVersion& version, const absl::string_view& predicate) { + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + // This value should be zero; the flag should eliminate all paths to here. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 2, 10); + } static_assert(SupportedVersions().size() == 4u, "Supported versions out of sync with HKDF labels"); if (version == ParsedQuicVersion::RFCv2()) { @@ -174,7 +185,7 @@ QuicCrypter* crypter) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_obfuscator, 4, 7); SetKeyAndIVHeapless(prf, pp_secret, version, crypter); - uint8_t header_protection_key[16]; + uint8_t header_protection_key[kMaxKeySize]; QUIC_BUG_IF(quic_bug_hp_length_mismatch, crypter->GetKeySize() > sizeof(header_protection_key)) << "HP length does not match crypter"; @@ -189,6 +200,10 @@ void CryptoUtils::InitializeCrypterSecrets( const EVP_MD* prf, const std::vector<uint8_t>& pp_secret, const ParsedQuicVersion& version, QuicCrypter* crypter) { + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + // This value should be zero; the flag should eliminate all paths to here. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 3, 10); + } SetKeyAndIV(prf, pp_secret, version, crypter); std::vector<uint8_t> header_protection_key = GenerateHeaderProtectionKey( prf, pp_secret, version, crypter->GetKeySize()); @@ -203,7 +218,7 @@ const ParsedQuicVersion& version, QuicCrypter* crypter) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_obfuscator, 5, 7); - uint8_t key[16]; + uint8_t key[kMaxKeySize]; QUIC_BUG_IF(quic_bug_key_length_mismatch, crypter->GetKeySize() > sizeof(key)) << "Key length does not match crypter"; @@ -214,7 +229,7 @@ absl::Span<char>(version_label_raw, kMaxVersionLabelLength)); HkdfExpandLabel(prf, pp_secret, version_label, absl::Span<uint8_t>(key, crypter->GetKeySize())); - uint8_t iv[12]; + uint8_t iv[kMaxIVSize]; QUIC_BUG_IF(quic_bug_iv_length_mismatch, crypter->GetIVSize() > sizeof(iv)) << "IV length does not match crypter"; constexpr absl::string_view kIvPredicate = "iv"; @@ -233,6 +248,10 @@ absl::Span<const uint8_t> pp_secret, const ParsedQuicVersion& version, QuicCrypter* crypter) { + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + // This value should be zero; the flag should eliminate all paths to here. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 4, 10); + } std::vector<uint8_t> key = HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "key"), crypter->GetKeySize()); @@ -260,6 +279,10 @@ std::vector<uint8_t> CryptoUtils::GenerateHeaderProtectionKey( const EVP_MD* prf, absl::Span<const uint8_t> pp_secret, const ParsedQuicVersion& version, size_t out_len) { + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + // This value should be zero; the flag should eliminate all paths to here. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 5, 10); + } return HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "hp"), out_len); } @@ -280,6 +303,10 @@ std::vector<uint8_t> CryptoUtils::GenerateNextKeyPhaseSecret( const EVP_MD* prf, const ParsedQuicVersion& version, const std::vector<uint8_t>& current_secret) { + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + // This value should be zero; the flag should eliminate all paths to here. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 6, 10); + } return HkdfExpandLabel(prf, current_secret, getLabelForVersion(version, "ku"), current_secret.size()); }
diff --git a/quiche/quic/core/crypto/crypto_utils_test.cc b/quiche/quic/core/crypto/crypto_utils_test.cc index e585d57..39862b7 100644 --- a/quiche/quic/core/crypto/crypto_utils_test.cc +++ b/quiche/quic/core/crypto/crypto_utils_test.cc
@@ -4,19 +4,20 @@ #include "quiche/quic/core/crypto/crypto_utils.h" +#include <cstdint> #include <memory> #include <string> +#include <vector> -#include "absl/base/macros.h" -#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "openssl/err.h" #include "openssl/ssl.h" -#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_versions.h" #include "quiche/quic/platform/api/quic_test.h" -#include "quiche/quic/test_tools/quic_test_utils.h" -#include "quiche/common/test_tools/quiche_test_utils.h" namespace quic { namespace test { @@ -364,6 +365,26 @@ ERR_clear_error(); } +// The heapless version of GenerateNextKeyPhaseSecret is sometimes called so +// that the output is written into the memory that contains the input. This +// test verifies that the result is no different than if the output goes +// elsewhere. +TEST_F(CryptoUtilsTest, NextKeyPhaseOnItself) { + std::vector<uint8_t> start_secret = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector<uint8_t> separate_result; + separate_result.resize(start_secret.size()); + CryptoUtils::GenerateNextKeyPhaseSecret( + EVP_sha256(), ParsedQuicVersion::RFCv1(), + absl::Span<const uint8_t>(start_secret), + absl::Span<uint8_t>(separate_result)); + CryptoUtils::GenerateNextKeyPhaseSecret( + EVP_sha256(), ParsedQuicVersion::RFCv1(), + absl::Span<const uint8_t>(start_secret), + absl::Span<uint8_t>(start_secret)); + EXPECT_EQ(separate_result, start_secret); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quiche/quic/core/tls_handshaker.cc b/quiche/quic/core/tls_handshaker.cc index 4a64c7c..9529230 100644 --- a/quiche/quic/core/tls_handshaker.cc +++ b/quiche/quic/core/tls_handshaker.cc
@@ -4,19 +4,21 @@ #include "quiche/quic/core/tls_handshaker.h" +#include <cstdint> #include <memory> #include <string> #include <utility> #include <vector> -#include "absl/base/macros.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "openssl/crypto.h" #include "openssl/ssl.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" #include "quiche/quic/core/quic_crypto_stream.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" -#include "quiche/quic/platform/api/quic_stack_trace.h" namespace quic { @@ -283,13 +285,24 @@ std::unique_ptr<QuicEncrypter> encrypter = QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); const EVP_MD* prf = Prf(cipher); - CryptoUtils::SetKeyAndIV(prf, write_secret, - handshaker_delegate_->parsed_version(), - encrypter.get()); - std::vector<uint8_t> header_protection_key = - CryptoUtils::GenerateHeaderProtectionKey( - prf, write_secret, handshaker_delegate_->parsed_version(), - encrypter->GetKeySize()); + std::vector<uint8_t> header_protection_key; + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 7, 10); + CryptoUtils::SetKeyAndIVHeapless(prf, write_secret, + handshaker_delegate_->parsed_version(), + encrypter.get()); + header_protection_key.resize(encrypter->GetKeySize()); + CryptoUtils::GenerateHeaderProtectionKey( + prf, write_secret, handshaker_delegate_->parsed_version(), + absl::Span<uint8_t>(header_protection_key)); + } else { + CryptoUtils::SetKeyAndIV(prf, write_secret, + handshaker_delegate_->parsed_version(), + encrypter.get()); + header_protection_key = CryptoUtils::GenerateHeaderProtectionKey( + prf, write_secret, handshaker_delegate_->parsed_version(), + encrypter->GetKeySize()); + } encrypter->SetHeaderProtectionKey( absl::string_view(reinterpret_cast<char*>(header_protection_key.data()), header_protection_key.size())); @@ -315,13 +328,24 @@ std::unique_ptr<QuicDecrypter> decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); const EVP_MD* prf = Prf(cipher); - CryptoUtils::SetKeyAndIV(prf, read_secret, - handshaker_delegate_->parsed_version(), - decrypter.get()); - std::vector<uint8_t> header_protection_key = - CryptoUtils::GenerateHeaderProtectionKey( - prf, read_secret, handshaker_delegate_->parsed_version(), - decrypter->GetKeySize()); + std::vector<uint8_t> header_protection_key; + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 8, 10); + CryptoUtils::SetKeyAndIVHeapless(prf, read_secret, + handshaker_delegate_->parsed_version(), + decrypter.get()); + header_protection_key.resize(decrypter->GetKeySize()); + CryptoUtils::GenerateHeaderProtectionKey( + prf, read_secret, handshaker_delegate_->parsed_version(), + absl::Span<uint8_t>(header_protection_key)); + } else { + CryptoUtils::SetKeyAndIV(prf, read_secret, + handshaker_delegate_->parsed_version(), + decrypter.get()); + header_protection_key = CryptoUtils::GenerateHeaderProtectionKey( + prf, read_secret, handshaker_delegate_->parsed_version(), + decrypter->GetKeySize()); + } decrypter->SetHeaderProtectionKey( absl::string_view(reinterpret_cast<char*>(header_protection_key.data()), header_protection_key.size())); @@ -348,16 +372,30 @@ } const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); const EVP_MD* prf = Prf(cipher); - latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret( - prf, handshaker_delegate_->parsed_version(), latest_read_secret_); - latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret( - prf, handshaker_delegate_->parsed_version(), latest_write_secret_); + std::unique_ptr<QuicDecrypter> decrypter; + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 9, 10); + CryptoUtils::GenerateNextKeyPhaseSecret( + prf, handshaker_delegate_->parsed_version(), latest_read_secret_, + absl::Span<uint8_t>(latest_read_secret_)); + CryptoUtils::GenerateNextKeyPhaseSecret( + prf, handshaker_delegate_->parsed_version(), latest_write_secret_, + absl::Span<uint8_t>(latest_write_secret_)); + decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + CryptoUtils::SetKeyAndIVHeapless(prf, latest_read_secret_, + handshaker_delegate_->parsed_version(), + decrypter.get()); + } else { + latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret( + prf, handshaker_delegate_->parsed_version(), latest_read_secret_); + latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret( + prf, handshaker_delegate_->parsed_version(), latest_write_secret_); + decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + CryptoUtils::SetKeyAndIV(prf, latest_read_secret_, + handshaker_delegate_->parsed_version(), + decrypter.get()); + } - std::unique_ptr<QuicDecrypter> decrypter = - QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); - CryptoUtils::SetKeyAndIV(prf, latest_read_secret_, - handshaker_delegate_->parsed_version(), - decrypter.get()); decrypter->SetHeaderProtectionKey(absl::string_view( reinterpret_cast<char*>(one_rtt_read_header_protection_key_.data()), one_rtt_read_header_protection_key_.size())); @@ -376,9 +414,16 @@ const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); std::unique_ptr<QuicEncrypter> encrypter = QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); - CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_, - handshaker_delegate_->parsed_version(), - encrypter.get()); + if (GetQuicReloadableFlag(quic_heapless_key_derivation)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 10, 10); + CryptoUtils::SetKeyAndIVHeapless(Prf(cipher), latest_write_secret_, + handshaker_delegate_->parsed_version(), + encrypter.get()); + } else { + CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_, + handshaker_delegate_->parsed_version(), + encrypter.get()); + } encrypter->SetHeaderProtectionKey(absl::string_view( reinterpret_cast<char*>(one_rtt_write_header_protection_key_.data()), one_rtt_write_header_protection_key_.size()));