Divert all QUIC key derivation to use heapless functions.
This allows us to eliminate the versions that use the heap.
Protected by FLAGS_gfe2_reloadable_flag_heapless_key_derivation.
PiperOrigin-RevId: 761109514
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h
index 0c70217..2af1ea6 100755
--- a/quiche/common/quiche_feature_flags_list.h
+++ b/quiche/common/quiche_feature_flags_list.h
@@ -36,6 +36,7 @@
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_enable_version_rfcv2, false, false, "When true, support RFC9369.")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_fin_before_completed_http_headers, false, true, "If true, close the connection with error if FIN is received before finish receiving the whole HTTP headers.")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_fix_timeouts, true, true, "If true, postpone setting handshake timeout to infinite to handshake complete.")
+QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_key_derivation, false, false, "If true, QUIC key derivation uses heapless crypto utils.")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_obfuscator, false, true, "If true, generates QUIC initial obfuscators with no heap allocations.")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_heapless_static_parser, false, false, "If true, stops parsing immediately on unknown version, to avoid a potential malloc when parsing the connection ID")
QUICHE_FLAG(bool, quiche_reloadable_flag_quic_ignore_gquic_probing, true, true, "If true, QUIC server will not respond to gQUIC probing packet(PING + PADDING) but treat it as a regular packet.")
diff --git a/quiche/quic/core/crypto/crypto_utils.cc b/quiche/quic/core/crypto/crypto_utils.cc
index 51ee57b..8aecaef 100644
--- a/quiche/quic/core/crypto/crypto_utils.cc
+++ b/quiche/quic/core/crypto/crypto_utils.cc
@@ -50,6 +50,9 @@
namespace {
+inline constexpr size_t kMaxKeySize = 32;
+inline constexpr size_t kMaxIVSize = 12;
+
// Implements the HKDF-Expand-Label function as defined in section 7.1 of RFC
// 8446. The HKDF-Expand-Label function takes 4 explicit arguments (Secret,
// Label, Context, and Length), as well as implicit PRF which is the hash
@@ -97,6 +100,10 @@
std::vector<uint8_t> HkdfExpandLabel(const EVP_MD* prf,
absl::Span<const uint8_t> secret,
const std::string& label, size_t out_len) {
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ // This value should be zero; the flag should eliminate all paths to here.
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 1, 10);
+ }
bssl::ScopedCBB quic_hkdf_label;
CBB inner_label;
const char label_prefix[] = "tls13 ";
@@ -158,6 +165,10 @@
// TODO(martinduke): Delete this.
std::string getLabelForVersion(const ParsedQuicVersion& version,
const absl::string_view& predicate) {
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ // This value should be zero; the flag should eliminate all paths to here.
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 2, 10);
+ }
static_assert(SupportedVersions().size() == 4u,
"Supported versions out of sync with HKDF labels");
if (version == ParsedQuicVersion::RFCv2()) {
@@ -174,7 +185,7 @@
QuicCrypter* crypter) {
QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_obfuscator, 4, 7);
SetKeyAndIVHeapless(prf, pp_secret, version, crypter);
- uint8_t header_protection_key[16];
+ uint8_t header_protection_key[kMaxKeySize];
QUIC_BUG_IF(quic_bug_hp_length_mismatch,
crypter->GetKeySize() > sizeof(header_protection_key))
<< "HP length does not match crypter";
@@ -189,6 +200,10 @@
void CryptoUtils::InitializeCrypterSecrets(
const EVP_MD* prf, const std::vector<uint8_t>& pp_secret,
const ParsedQuicVersion& version, QuicCrypter* crypter) {
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ // This value should be zero; the flag should eliminate all paths to here.
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 3, 10);
+ }
SetKeyAndIV(prf, pp_secret, version, crypter);
std::vector<uint8_t> header_protection_key = GenerateHeaderProtectionKey(
prf, pp_secret, version, crypter->GetKeySize());
@@ -203,7 +218,7 @@
const ParsedQuicVersion& version,
QuicCrypter* crypter) {
QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_obfuscator, 5, 7);
- uint8_t key[16];
+ uint8_t key[kMaxKeySize];
QUIC_BUG_IF(quic_bug_key_length_mismatch, crypter->GetKeySize() > sizeof(key))
<< "Key length does not match crypter";
@@ -214,7 +229,7 @@
absl::Span<char>(version_label_raw, kMaxVersionLabelLength));
HkdfExpandLabel(prf, pp_secret, version_label,
absl::Span<uint8_t>(key, crypter->GetKeySize()));
- uint8_t iv[12];
+ uint8_t iv[kMaxIVSize];
QUIC_BUG_IF(quic_bug_iv_length_mismatch, crypter->GetIVSize() > sizeof(iv))
<< "IV length does not match crypter";
constexpr absl::string_view kIvPredicate = "iv";
@@ -233,6 +248,10 @@
absl::Span<const uint8_t> pp_secret,
const ParsedQuicVersion& version,
QuicCrypter* crypter) {
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ // This value should be zero; the flag should eliminate all paths to here.
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 4, 10);
+ }
std::vector<uint8_t> key =
HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "key"),
crypter->GetKeySize());
@@ -260,6 +279,10 @@
std::vector<uint8_t> CryptoUtils::GenerateHeaderProtectionKey(
const EVP_MD* prf, absl::Span<const uint8_t> pp_secret,
const ParsedQuicVersion& version, size_t out_len) {
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ // This value should be zero; the flag should eliminate all paths to here.
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 5, 10);
+ }
return HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "hp"),
out_len);
}
@@ -280,6 +303,10 @@
std::vector<uint8_t> CryptoUtils::GenerateNextKeyPhaseSecret(
const EVP_MD* prf, const ParsedQuicVersion& version,
const std::vector<uint8_t>& current_secret) {
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ // This value should be zero; the flag should eliminate all paths to here.
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 6, 10);
+ }
return HkdfExpandLabel(prf, current_secret, getLabelForVersion(version, "ku"),
current_secret.size());
}
diff --git a/quiche/quic/core/crypto/crypto_utils_test.cc b/quiche/quic/core/crypto/crypto_utils_test.cc
index e585d57..39862b7 100644
--- a/quiche/quic/core/crypto/crypto_utils_test.cc
+++ b/quiche/quic/core/crypto/crypto_utils_test.cc
@@ -4,19 +4,20 @@
#include "quiche/quic/core/crypto/crypto_utils.h"
+#include <cstdint>
#include <memory>
#include <string>
+#include <vector>
-#include "absl/base/macros.h"
-#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "openssl/err.h"
#include "openssl/ssl.h"
-#include "quiche/quic/core/quic_utils.h"
+#include "quiche/quic/core/crypto/crypto_protocol.h"
+#include "quiche/quic/core/crypto/quic_encrypter.h"
+#include "quiche/quic/core/quic_versions.h"
#include "quiche/quic/platform/api/quic_test.h"
-#include "quiche/quic/test_tools/quic_test_utils.h"
-#include "quiche/common/test_tools/quiche_test_utils.h"
namespace quic {
namespace test {
@@ -364,6 +365,26 @@
ERR_clear_error();
}
+// The heapless version of GenerateNextKeyPhaseSecret is sometimes called so
+// that the output is written into the memory that contains the input. This
+// test verifies that the result is no different than if the output goes
+// elsewhere.
+TEST_F(CryptoUtilsTest, NextKeyPhaseOnItself) {
+ std::vector<uint8_t> start_secret = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16};
+ std::vector<uint8_t> separate_result;
+ separate_result.resize(start_secret.size());
+ CryptoUtils::GenerateNextKeyPhaseSecret(
+ EVP_sha256(), ParsedQuicVersion::RFCv1(),
+ absl::Span<const uint8_t>(start_secret),
+ absl::Span<uint8_t>(separate_result));
+ CryptoUtils::GenerateNextKeyPhaseSecret(
+ EVP_sha256(), ParsedQuicVersion::RFCv1(),
+ absl::Span<const uint8_t>(start_secret),
+ absl::Span<uint8_t>(start_secret));
+ EXPECT_EQ(separate_result, start_secret);
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quiche/quic/core/tls_handshaker.cc b/quiche/quic/core/tls_handshaker.cc
index 4a64c7c..9529230 100644
--- a/quiche/quic/core/tls_handshaker.cc
+++ b/quiche/quic/core/tls_handshaker.cc
@@ -4,19 +4,21 @@
#include "quiche/quic/core/tls_handshaker.h"
+#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
-#include "absl/base/macros.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "openssl/crypto.h"
#include "openssl/ssl.h"
+#include "quiche/quic/core/crypto/crypto_utils.h"
+#include "quiche/quic/core/crypto/quic_decrypter.h"
#include "quiche/quic/core/quic_crypto_stream.h"
#include "quiche/quic/platform/api/quic_bug_tracker.h"
-#include "quiche/quic/platform/api/quic_stack_trace.h"
namespace quic {
@@ -283,13 +285,24 @@
std::unique_ptr<QuicEncrypter> encrypter =
QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
const EVP_MD* prf = Prf(cipher);
- CryptoUtils::SetKeyAndIV(prf, write_secret,
- handshaker_delegate_->parsed_version(),
- encrypter.get());
- std::vector<uint8_t> header_protection_key =
- CryptoUtils::GenerateHeaderProtectionKey(
- prf, write_secret, handshaker_delegate_->parsed_version(),
- encrypter->GetKeySize());
+ std::vector<uint8_t> header_protection_key;
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 7, 10);
+ CryptoUtils::SetKeyAndIVHeapless(prf, write_secret,
+ handshaker_delegate_->parsed_version(),
+ encrypter.get());
+ header_protection_key.resize(encrypter->GetKeySize());
+ CryptoUtils::GenerateHeaderProtectionKey(
+ prf, write_secret, handshaker_delegate_->parsed_version(),
+ absl::Span<uint8_t>(header_protection_key));
+ } else {
+ CryptoUtils::SetKeyAndIV(prf, write_secret,
+ handshaker_delegate_->parsed_version(),
+ encrypter.get());
+ header_protection_key = CryptoUtils::GenerateHeaderProtectionKey(
+ prf, write_secret, handshaker_delegate_->parsed_version(),
+ encrypter->GetKeySize());
+ }
encrypter->SetHeaderProtectionKey(
absl::string_view(reinterpret_cast<char*>(header_protection_key.data()),
header_protection_key.size()));
@@ -315,13 +328,24 @@
std::unique_ptr<QuicDecrypter> decrypter =
QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
const EVP_MD* prf = Prf(cipher);
- CryptoUtils::SetKeyAndIV(prf, read_secret,
- handshaker_delegate_->parsed_version(),
- decrypter.get());
- std::vector<uint8_t> header_protection_key =
- CryptoUtils::GenerateHeaderProtectionKey(
- prf, read_secret, handshaker_delegate_->parsed_version(),
- decrypter->GetKeySize());
+ std::vector<uint8_t> header_protection_key;
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 8, 10);
+ CryptoUtils::SetKeyAndIVHeapless(prf, read_secret,
+ handshaker_delegate_->parsed_version(),
+ decrypter.get());
+ header_protection_key.resize(decrypter->GetKeySize());
+ CryptoUtils::GenerateHeaderProtectionKey(
+ prf, read_secret, handshaker_delegate_->parsed_version(),
+ absl::Span<uint8_t>(header_protection_key));
+ } else {
+ CryptoUtils::SetKeyAndIV(prf, read_secret,
+ handshaker_delegate_->parsed_version(),
+ decrypter.get());
+ header_protection_key = CryptoUtils::GenerateHeaderProtectionKey(
+ prf, read_secret, handshaker_delegate_->parsed_version(),
+ decrypter->GetKeySize());
+ }
decrypter->SetHeaderProtectionKey(
absl::string_view(reinterpret_cast<char*>(header_protection_key.data()),
header_protection_key.size()));
@@ -348,16 +372,30 @@
}
const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
const EVP_MD* prf = Prf(cipher);
- latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
- prf, handshaker_delegate_->parsed_version(), latest_read_secret_);
- latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
- prf, handshaker_delegate_->parsed_version(), latest_write_secret_);
+ std::unique_ptr<QuicDecrypter> decrypter;
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 9, 10);
+ CryptoUtils::GenerateNextKeyPhaseSecret(
+ prf, handshaker_delegate_->parsed_version(), latest_read_secret_,
+ absl::Span<uint8_t>(latest_read_secret_));
+ CryptoUtils::GenerateNextKeyPhaseSecret(
+ prf, handshaker_delegate_->parsed_version(), latest_write_secret_,
+ absl::Span<uint8_t>(latest_write_secret_));
+ decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
+ CryptoUtils::SetKeyAndIVHeapless(prf, latest_read_secret_,
+ handshaker_delegate_->parsed_version(),
+ decrypter.get());
+ } else {
+ latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
+ prf, handshaker_delegate_->parsed_version(), latest_read_secret_);
+ latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret(
+ prf, handshaker_delegate_->parsed_version(), latest_write_secret_);
+ decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
+ CryptoUtils::SetKeyAndIV(prf, latest_read_secret_,
+ handshaker_delegate_->parsed_version(),
+ decrypter.get());
+ }
- std::unique_ptr<QuicDecrypter> decrypter =
- QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
- CryptoUtils::SetKeyAndIV(prf, latest_read_secret_,
- handshaker_delegate_->parsed_version(),
- decrypter.get());
decrypter->SetHeaderProtectionKey(absl::string_view(
reinterpret_cast<char*>(one_rtt_read_header_protection_key_.data()),
one_rtt_read_header_protection_key_.size()));
@@ -376,9 +414,16 @@
const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
std::unique_ptr<QuicEncrypter> encrypter =
QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher));
- CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_,
- handshaker_delegate_->parsed_version(),
- encrypter.get());
+ if (GetQuicReloadableFlag(quic_heapless_key_derivation)) {
+ QUIC_RELOADABLE_FLAG_COUNT_N(quic_heapless_key_derivation, 10, 10);
+ CryptoUtils::SetKeyAndIVHeapless(Prf(cipher), latest_write_secret_,
+ handshaker_delegate_->parsed_version(),
+ encrypter.get());
+ } else {
+ CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_,
+ handshaker_delegate_->parsed_version(),
+ encrypter.get());
+ }
encrypter->SetHeaderProtectionKey(absl::string_view(
reinterpret_cast<char*>(one_rtt_write_header_protection_key_.data()),
one_rtt_write_header_protection_key_.size()));