Reduce memcpy in four-pass connection ID encryption/decryption. Not used in production. PiperOrigin-RevId: 597301121
diff --git a/quiche/quic/load_balancer/load_balancer_config.cc b/quiche/quic/load_balancer/load_balancer_config.cc index 6f25c90..00dd5de 100644 --- a/quiche/quic/load_balancer/load_balancer_config.cc +++ b/quiche/quic/load_balancer/load_balancer_config.cc
@@ -8,10 +8,13 @@ #include <cstring> #include <optional> +#include "absl/numeric/int128.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "openssl/aes.h" #include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/load_balancer/load_balancer_server_id.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" namespace quic { @@ -54,66 +57,6 @@ return raw_key; } -// Functions to handle 4-pass encryption/decryption. -// TakePlaintextFrom{Left,Right}() reads the left or right half of 'from' and -// expands it into a full encryption block ('to') in accordance with the -// internet-draft. -void TakePlaintextFromLeft(const uint8_t *from, const uint8_t plaintext_len, - const uint8_t index, uint8_t *to) { - uint8_t half = plaintext_len / 2; - - to[0] = plaintext_len; - to[1] = index; - memcpy(to + 2, from, half); - if (plaintext_len % 2) { - to[2 + half] = from[half] & 0xf0; - half++; - } - memset(to + 2 + half, 0, kLoadBalancerBlockSize - 2 - half); -} - -void TakePlaintextFromRight(const uint8_t *from, const uint8_t plaintext_len, - const uint8_t index, uint8_t *to) { - uint8_t half = plaintext_len / 2; - - to[0] = plaintext_len; - to[1] = index; - memcpy(to + 2, from + half, half + (plaintext_len % 2)); - if (plaintext_len % 2) { - to[2] &= 0x0f; - half++; - } - memset(to + 2 + half, 0, kLoadBalancerBlockSize - 2 - half); -} - -// CiphertextXorWith{Left,Right}() takes the relevant end of the ciphertext in -// 'from' and XORs it with half of the ConnectionId stored at 'to', in -// accordance with the internet-draft. -void CiphertextXorWithLeft(const uint8_t *from, const uint8_t plaintext_len, - uint8_t *to) { - uint8_t half = plaintext_len / 2; - for (int i = 0; i < half; i++) { - to[i] ^= from[i]; - } - if (plaintext_len % 2) { - to[half] ^= (from[half] & 0xf0); - } -} - -void CiphertextXorWithRight(const uint8_t *from, const uint8_t plaintext_len, - uint8_t *to) { - uint8_t half = plaintext_len / 2; - int i = 0; - if (plaintext_len % 2) { - to[half] ^= (from[0] & 0x0f); - i++; - } - while ((half + i) < plaintext_len) { - to[half + i] ^= from[i]; - i++; - } -} - } // namespace std::optional<LoadBalancerConfig> LoadBalancerConfig::Create( @@ -148,47 +91,103 @@ : std::optional<LoadBalancerConfig>(); } -bool LoadBalancerConfig::EncryptionPass(absl::Span<uint8_t> target, - const uint8_t index) const { - uint8_t buf[kLoadBalancerBlockSize]; - if (!key_.has_value() || target.size() < plaintext_len()) { - return false; +LoadBalancerServerId LoadBalancerConfig::Decrypt( + absl::Span<const uint8_t> ciphertext) const { + if (ciphertext.length() < total_len()) { + return LoadBalancerServerId(); } - if (index % 2) { // Odd indices go from left to right - TakePlaintextFromLeft(target.data(), plaintext_len(), index, buf); - } else { - TakePlaintextFromRight(target.data(), plaintext_len(), index, buf); - } - if (!BlockEncrypt(buf, buf)) { - return false; - } - // XOR bits over the correct half. - if (index % 2) { - CiphertextXorWithRight(buf, plaintext_len(), target.data()); - } else { - CiphertextXorWithLeft(buf, plaintext_len(), target.data()); - } - return true; -} - -bool LoadBalancerConfig::BlockEncrypt( - const uint8_t plaintext[kLoadBalancerBlockSize], - uint8_t ciphertext[kLoadBalancerBlockSize]) const { if (!key_.has_value()) { - return false; + return LoadBalancerServerId( + absl::Span<const uint8_t>(ciphertext.data() + 1, server_id_len_)); } - AES_encrypt(plaintext, ciphertext, &*key_); - return true; + if (plaintext_len() == kLoadBalancerBlockSize) { + if (!block_decrypt_key_.has_value()) { + QUIC_BUG(quic_bug_596735037_01) << "Block decrypt key is not set."; + return LoadBalancerServerId(); + } + uint8_t plaintext[kLoadBalancerBlockSize]; + AES_decrypt(ciphertext.subspan(1, kLoadBalancerBlockSize).data(), plaintext, + &*block_decrypt_key_); + return LoadBalancerServerId( + absl::Span<const uint8_t>(plaintext, server_id_len_)); + } + // Do 3 or 4 passes. Only 3 are necessary if the server_id is short enough + // to fit in the first half of the connection ID (the decoder doesn't need + // to extract the nonce). + uint8_t left[kLoadBalancerBlockSize]; + uint8_t right[kLoadBalancerBlockSize]; + uint8_t half_len; // half the length of the plaintext, rounded up + bool is_length_odd = + InitializeFourPass(ciphertext.data(), left, right, &half_len); + uint8_t end_index = (server_id_len_ > nonce_len_) ? 1 : 2; + for (uint8_t index = kNumLoadBalancerCryptoPasses; index >= end_index; + --index) { + // Encrypt left/right and xor the result with right/left, respectively. + EncryptionPass(index, half_len, is_length_odd, left, right); + } + // Consolidate left and right into a server ID with minimum copying. + if (server_id_len_ < half_len || + (server_id_len_ == half_len && !is_length_odd)) { + // There is no half-byte to handle + return LoadBalancerServerId(absl::Span<uint8_t>(&left[2], server_id_len_)); + } + if (is_length_odd) { + right[2] |= left[half_len-- + 1]; // Combine the halves of the odd byte. + } + return LoadBalancerServerId( + absl::Span<uint8_t>(&left[2], half_len), + absl::Span<uint8_t>(&right[2], server_id_len_ - half_len)); } -bool LoadBalancerConfig::BlockDecrypt( - const uint8_t ciphertext[kLoadBalancerBlockSize], - uint8_t plaintext[kLoadBalancerBlockSize]) const { - if (!block_decrypt_key_.has_value()) { - return false; +QuicConnectionId LoadBalancerConfig::Encrypt( + absl::Span<uint8_t> connection_id) const { + if (connection_id.length() < total_len()) { + return QuicConnectionId(); } - AES_decrypt(ciphertext, plaintext, &*block_decrypt_key_); - return true; + if (!key_.has_value()) { // Plaintext connection ID + // Fill the nonce field with a hash of the Connection ID to avoid the nonce + // visibly increasing by one. This would allow observers to correlate + // connection IDs as being sequential and likely from the same connection, + // not just the same server. + absl::uint128 nonce_hash = QuicUtils::FNV1a_128_Hash(absl::string_view( + reinterpret_cast<char*>(connection_id.data()), connection_id.length())); + const uint64_t lo = absl::Uint128Low64(nonce_hash); + if (nonce_len_ <= sizeof(uint64_t)) { + memcpy(connection_id.data() + 1 + server_id_len_, &lo, nonce_len_); + return QuicConnectionId(connection_id); + } + memcpy(connection_id.data() + 1 + server_id_len_, &lo, sizeof(uint64_t)); + const uint64_t hi = absl::Uint128High64(nonce_hash); + memcpy(connection_id.data() + 1 + server_id_len_ + sizeof(uint64_t), &hi, + nonce_len_ - sizeof(uint64_t)); + return QuicConnectionId(connection_id); + } + if (plaintext_len() == kLoadBalancerBlockSize) { + AES_encrypt(connection_id.subspan(1, plaintext_len()).data(), + connection_id.data() + 1, &*key_); + return QuicConnectionId(connection_id); + } + // 4 Pass Encryption + uint8_t left[kLoadBalancerBlockSize]; + uint8_t right[kLoadBalancerBlockSize]; + uint8_t half_len; // half the length of the plaintext, rounded up + bool is_length_odd = + InitializeFourPass(connection_id.data(), left, right, &half_len); + for (uint8_t index = 1; index <= kNumLoadBalancerCryptoPasses; ++index) { + EncryptionPass(index, half_len, is_length_odd, left, right); + } + // Consolidate left and right into a server ID with minimum copying. + if (is_length_odd) { + // Combine the halves of the odd byte. + left[half_len + 1] |= right[2]; + } + memcpy(connection_id.data() + 1, &left[2], half_len); + if (is_length_odd) { + memcpy(connection_id.data() + 1 + half_len, &right[3], half_len - 1); + } else { + memcpy(connection_id.data() + 1 + half_len, &right[2], half_len); + } + return QuicConnectionId(connection_id); } LoadBalancerConfig::LoadBalancerConfig(const uint8_t config_id, @@ -203,4 +202,62 @@ ? BuildKey(key, /* encrypt = */ false) : std::optional<AES_KEY>()) {} +bool LoadBalancerConfig::InitializeFourPass(const uint8_t* input, uint8_t* left, + uint8_t* right, + uint8_t* half_len) const { + *half_len = plaintext_len() / 2; + bool is_length_odd; + if (plaintext_len() % 2 == 1) { + ++(*half_len); + is_length_odd = true; + } else { + is_length_odd = false; + } + bzero(left, kLoadBalancerBlockSize); + bzero(right, kLoadBalancerBlockSize); + // The first byte is the plaintext/ciphertext length, the second byte will be + // the index of the pass. Half the plaintext or ciphertext follows. + left[0] = plaintext_len(); + right[0] = plaintext_len(); + // Leave left_[1], right_[1] as zero. It will be set for each pass. + memcpy(&left[2], input + 1, *half_len); + // If is_length_odd, then both left and right will have part of the middle + // byte. Then that middle byte will be split in half via the bitmask in the + // next step. + memcpy(&right[2], input + (plaintext_len() / 2) + 1, *half_len); + if (is_length_odd) { + left[*half_len + 1] &= 0xf0; + right[2] &= 0x0f; + } + return is_length_odd; +} + +void LoadBalancerConfig::EncryptionPass(uint8_t index, uint8_t half_len, + bool is_length_odd, uint8_t* left, + uint8_t* right) const { + uint8_t ciphertext[kLoadBalancerBlockSize]; + if (index % 2 == 0) { // Go right to left. + right[1] = index; + AES_encrypt(right, ciphertext, &*key_); + for (int i = 0; i < half_len; ++i) { + // Skip over the first two bytes, which have the plaintext_len and the + // index. The CID bits are in [2, half_len - 1]. + left[2 + i] ^= ciphertext[i]; + } + if (is_length_odd) { + left[half_len + 1] &= 0xf0; + } + return; + } + // Go left to right. + left[1] = index; + AES_encrypt(left, ciphertext, &*key_); + for (int i = 0; i < half_len; ++i) { + right[2 + i] ^= ciphertext[i]; + } + if (is_length_odd) { + right[2] &= 0x0f; + } +} + } // namespace quic
diff --git a/quiche/quic/load_balancer/load_balancer_config.h b/quiche/quic/load_balancer/load_balancer_config.h index f3ca75c..a3a5ca9 100644 --- a/quiche/quic/load_balancer/load_balancer_config.h +++ b/quiche/quic/load_balancer/load_balancer_config.h
@@ -11,6 +11,8 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "openssl/aes.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/load_balancer/load_balancer_server_id.h" #include "quiche/quic/platform/api/quic_export.h" namespace quic { @@ -55,21 +57,13 @@ static std::optional<LoadBalancerConfig> CreateUnencrypted( uint8_t config_id, uint8_t server_id_len, uint8_t nonce_len); - // Handles one pass of 4-pass encryption. Encoder and decoder use of this - // function varies substantially, so they are not implemented here. - // Returns false if the config is not encrypted, or if |target| isn't long - // enough. - ABSL_MUST_USE_RESULT bool EncryptionPass(absl::Span<uint8_t> target, - uint8_t index) const; - // Use the key to do a block encryption, which is used both in all cases of - // encrypted configs. Returns false if there's no key. - ABSL_MUST_USE_RESULT bool BlockEncrypt( - const uint8_t plaintext[kLoadBalancerBlockSize], - uint8_t ciphertext[kLoadBalancerBlockSize]) const; - // Returns false if the config does not require block decryption. - ABSL_MUST_USE_RESULT bool BlockDecrypt( - const uint8_t ciphertext[kLoadBalancerBlockSize], - uint8_t plaintext[kLoadBalancerBlockSize]) const; + // Returns an invalid Server ID if ciphertext is too small, or needed keys are + // missing. |ciphertext| contains the full connection ID. + LoadBalancerServerId Decrypt(absl::Span<const uint8_t> ciphertext) const; + // Encrypts |connection_id|, which must be of the form first byte, + // server ID, nonce. Returns empty if plaintext is not long enough. The + // argument is NOT const, and will be overwritten. + QuicConnectionId Encrypt(absl::Span<uint8_t> connection_id) const; uint8_t config_id() const { return config_id_; } uint8_t server_id_len() const { return server_id_len_; } @@ -85,6 +79,15 @@ LoadBalancerConfig(uint8_t config_id, uint8_t server_id_len, uint8_t nonce_len, absl::string_view key); + // Initialize state for 4-pass encryption passes, using the connection ID + // provided in |input|. Returns true if the plaintext is an odd number of + // bytes. |half_len| is half the length of the plaintext, rounded up. + bool InitializeFourPass(const uint8_t* input, uint8_t* left, uint8_t* right, + uint8_t* half_len) const; + // Handles one pass of 4-pass encryption for both encrypt and decrypt. + void EncryptionPass(uint8_t index, uint8_t half_len, bool is_length_odd, + uint8_t* left, uint8_t* right) const; + uint8_t config_id_; uint8_t server_id_len_; uint8_t nonce_len_;
diff --git a/quiche/quic/load_balancer/load_balancer_config_test.cc b/quiche/quic/load_balancer/load_balancer_config_test.cc index 071815e..2c52b98 100644 --- a/quiche/quic/load_balancer/load_balancer_config_test.cc +++ b/quiche/quic/load_balancer/load_balancer_config_test.cc
@@ -4,12 +4,13 @@ #include "quiche/quic/load_balancer/load_balancer_config.h" -#include <array> #include <cstdint> #include <cstring> #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/load_balancer/load_balancer_server_id.h" #include "quiche/quic/platform/api/quic_expect_bug.h" #include "quiche/quic/platform/api/quic_test.h" @@ -86,103 +87,69 @@ EXPECT_TRUE(config2->IsEncrypted()); } -// Compare EncryptionPass() results to the example in -// draft-ietf-quic-load-balancers-15, Section 4.3.2. -TEST_F(LoadBalancerConfigTest, TestEncryptionPassExample) { - auto config = - LoadBalancerConfig::Create(0, 3, 4, absl::string_view(raw_key, 16)); - EXPECT_TRUE(config.has_value()); - EXPECT_TRUE(config->IsEncrypted()); - std::array<uint8_t, 7> bytes = {0x31, 0x44, 0x1a, 0x9c, 0x69, 0xc2, 0x75}; - std::array<uint8_t, 7> pass1 = {0x31, 0x44, 0x1a, 0x9f, 0x1a, 0x5b, 0x6b}; - std::array<uint8_t, 7> pass2 = {0x02, 0x8e, 0x1b, 0x5f, 0x1a, 0x5b, 0x6b}; - std::array<uint8_t, 7> pass3 = {0x02, 0x8e, 0x1b, 0x54, 0x94, 0x97, 0x62}; - std::array<uint8_t, 7> pass4 = {0x8e, 0x9a, 0x91, 0xf4, 0x94, 0x97, 0x62}; - - // Input is too short. - EXPECT_FALSE(config->EncryptionPass(absl::Span<uint8_t>(bytes.data(), 6), 0)); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 1)); - EXPECT_EQ(bytes, pass1); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 2)); - EXPECT_EQ(bytes, pass2); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 3)); - EXPECT_EQ(bytes, pass3); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 4)); - EXPECT_EQ(bytes, pass4); -} - -TEST_F(LoadBalancerConfigTest, EncryptionPassPlaintext) { - auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); - std::array<uint8_t, 7> bytes = {0x31, 0x44, 0x1a, 0x9c, 0x69, 0xc2, 0x75}; - EXPECT_FALSE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 1)); -} - -// Check that the encryption pass code can decode its own ciphertext. Various -// pointer errors could cause the code to overwrite bits that contain -// important information. -TEST_F(LoadBalancerConfigTest, EncryptionPassesAreReversible) { - auto config = - LoadBalancerConfig::Create(0, 3, 4, absl::string_view(raw_key, 16)); - std::array<uint8_t, 7> bytes = { - 0x31, 0x44, 0x1a, 0x9c, 0x69, 0xc2, 0x75, - }; - std::array<uint8_t, 7> orig_bytes; - memcpy(orig_bytes.data(), bytes.data(), bytes.size()); - // Work left->right and right->left passes. - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 1)); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 2)); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 2)); - EXPECT_TRUE(config->EncryptionPass(absl::Span<uint8_t>(bytes), 1)); - EXPECT_EQ(bytes, orig_bytes); -} - -TEST_F(LoadBalancerConfigTest, InvalidBlockEncryption) { - uint8_t pt[kLoadBalancerBlockSize], ct[kLoadBalancerBlockSize]; - auto pt_config = LoadBalancerConfig::CreateUnencrypted(0, 8, 8); - EXPECT_FALSE(pt_config->BlockEncrypt(pt, ct)); - EXPECT_FALSE(pt_config->BlockDecrypt(ct, pt)); - EXPECT_FALSE(pt_config->EncryptionPass(absl::Span<uint8_t>(pt), 0)); - auto small_cid_config = - LoadBalancerConfig::Create(0, 3, 4, absl::string_view(raw_key, 16)); - EXPECT_TRUE(small_cid_config->BlockEncrypt(pt, ct)); - EXPECT_FALSE(small_cid_config->BlockDecrypt(ct, pt)); - auto block_config = - LoadBalancerConfig::Create(0, 8, 8, absl::string_view(raw_key, 16)); - EXPECT_TRUE(block_config->BlockEncrypt(pt, ct)); - EXPECT_TRUE(block_config->BlockDecrypt(ct, pt)); -} - -// Block decrypt test from the Test Vector in -// draft-ietf-quic-load-balancers-15, Appendix B. -TEST_F(LoadBalancerConfigTest, BlockEncryptionExample) { - const uint8_t ptext[] = {0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, - 0xee, 0x08, 0x0d, 0xbf, 0x48, 0xc0, 0xd1, 0xe5}; - const uint8_t ctext[] = {0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, 0xb2, - 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, 0xc3}; - const char key[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, - 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; - uint8_t result[sizeof(ptext)]; - auto config = LoadBalancerConfig::Create(0, 8, 8, absl::string_view(key, 16)); - EXPECT_TRUE(config->BlockEncrypt(ptext, result)); - EXPECT_EQ(memcmp(result, ctext, sizeof(ctext)), 0); - EXPECT_TRUE(config->BlockDecrypt(ctext, result)); - EXPECT_EQ(memcmp(result, ptext, sizeof(ptext)), 0); -} +// Tests for Encrypt() and Decrypt() are in LoadBalancerEncoderTest and +// LoadBalancerDecoderTest, respectively. TEST_F(LoadBalancerConfigTest, ConfigIsCopyable) { - const uint8_t ptext[] = {0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, + const uint8_t ptext[] = {0x00, 0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, 0xee, 0x08, 0x0d, 0xbf, 0x48, 0xc0, 0xd1, 0xe5}; - const uint8_t ctext[] = {0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, 0xb2, - 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, 0xc3}; + uint8_t ctext[] = {0x00, 0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, 0xb2, + 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, 0xc3}; const char key[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; - uint8_t result[sizeof(ptext)]; auto config = LoadBalancerConfig::Create(0, 8, 8, absl::string_view(key, 16)); + ASSERT_TRUE(config.has_value()); auto config2 = config; - EXPECT_TRUE(config->BlockEncrypt(ptext, result)); - EXPECT_EQ(memcmp(result, ctext, sizeof(ctext)), 0); - EXPECT_TRUE(config2->BlockEncrypt(ptext, result)); - EXPECT_EQ(memcmp(result, ctext, sizeof(ctext)), 0); + ASSERT_TRUE(config2.has_value()); + uint8_t temp_ptext[sizeof(ptext)]; // the input will be overwritten, so copy + memcpy(temp_ptext, ptext, sizeof(ptext)); + QuicConnectionId cid1 = + config->Encrypt(absl::Span<uint8_t>(temp_ptext, sizeof(ptext))); + EXPECT_EQ(cid1.length(), sizeof(ctext)); + EXPECT_EQ(memcmp(cid1.data(), ctext, sizeof(ctext)), 0); + memcpy(temp_ptext, ptext, sizeof(ptext)); + QuicConnectionId cid2 = + config2->Encrypt(absl::Span<uint8_t>(temp_ptext, sizeof(ptext))); + EXPECT_EQ(cid2.length(), sizeof(ctext)); + EXPECT_EQ(memcmp(cid2.data(), ctext, sizeof(ctext)), 0); +} + +TEST_F(LoadBalancerConfigTest, OnePassEncryptAndDecryptIgnoreAdditionalBytes) { + uint8_t ptext[] = {0x00, 0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, 0xee, + 0x08, 0x0d, 0xbf, 0x48, 0xc0, 0xd1, 0xe5, 0xda, 0x41}; + uint8_t ctext[] = {0x00, 0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, 0xb2, 0xb9, + 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, 0xc3, 0xda, 0x41}; + const char key[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, + 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; + auto config = LoadBalancerConfig::Create(0, 8, 8, absl::string_view(key, 16)); + ASSERT_TRUE(config.has_value()); + LoadBalancerServerId original_server_id(absl::Span<uint8_t>(&ptext[1], 8)); + QuicConnectionId cid = + config->Encrypt(absl::Span<uint8_t>(ptext, sizeof(ptext))); + EXPECT_EQ(cid.length(), sizeof(ctext)); + EXPECT_EQ(memcmp(cid.data(), ctext, sizeof(ctext)), 0); + LoadBalancerServerId server_id = config->Decrypt(absl::Span<const uint8_t>( + reinterpret_cast<const uint8_t *>(cid.data()), cid.length())); + EXPECT_EQ(server_id, original_server_id); +} + +TEST_F(LoadBalancerConfigTest, FourPassEncryptAndDecryptIgnoreAdditionalBytes) { + uint8_t ptext[] = {0x00, 0xed, 0x79, 0x3a, 0xee, + 0x08, 0x0d, 0xbf, 0xda, 0x41}; + uint8_t ctext[] = {0x00, 0x41, 0x26, 0xee, 0x38, + 0xbf, 0x54, 0x54, 0xda, 0x41}; + const char key[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, + 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; + auto config = LoadBalancerConfig::Create(0, 3, 4, absl::string_view(key, 16)); + ASSERT_TRUE(config.has_value()); + LoadBalancerServerId original_server_id(absl::Span<uint8_t>(&ptext[1], 3)); + QuicConnectionId cid = + config->Encrypt(absl::Span<uint8_t>(ptext, sizeof(ptext))); + EXPECT_EQ(cid.length(), sizeof(ctext)); + EXPECT_EQ(memcmp(cid.data(), ctext, sizeof(ctext)), 0); + LoadBalancerServerId server_id = config->Decrypt(absl::Span<const uint8_t>( + reinterpret_cast<const uint8_t *>(cid.data()), cid.length())); + EXPECT_EQ(server_id, original_server_id); } } // namespace
diff --git a/quiche/quic/load_balancer/load_balancer_decoder.cc b/quiche/quic/load_balancer/load_balancer_decoder.cc index 3cdfd2b..8a564f1 100644 --- a/quiche/quic/load_balancer/load_balancer_decoder.cc +++ b/quiche/quic/load_balancer/load_balancer_decoder.cc
@@ -5,7 +5,6 @@ #include "quiche/quic/load_balancer/load_balancer_decoder.h" #include <cstdint> -#include <cstring> #include <optional> #include "absl/types/span.h" @@ -46,36 +45,9 @@ if (!config.has_value()) { return LoadBalancerServerId(); } - if (connection_id.length() < config->total_len()) { - // Connection ID wasn't long enough - return LoadBalancerServerId(); - } - // The first byte is complete. Finish the rest. - const uint8_t* data = - reinterpret_cast<const uint8_t*>(connection_id.data()) + 1; - if (!config->IsEncrypted()) { // It's a Plaintext CID. - return LoadBalancerServerId( - absl::Span<const uint8_t>(data, config->server_id_len())); - } - uint8_t result[kQuicMaxConnectionIdWithLengthPrefixLength]; - if (config->plaintext_len() == kLoadBalancerKeyLen) { // single pass - if (!config->BlockDecrypt(data, result)) { - return LoadBalancerServerId(); - } - } else { - // Do 3 or 4 passes. Only 3 are necessary if the server_id is short enough - // to fit in the first half of the connection ID (the decoder doesn't need - // to extract the nonce). - memcpy(result, data, config->plaintext_len()); - uint8_t end = (config->server_id_len() > config->nonce_len()) ? 1 : 2; - for (uint8_t i = kNumLoadBalancerCryptoPasses; i >= end; i--) { - if (!config->EncryptionPass(absl::Span<uint8_t>(result), i)) { - return LoadBalancerServerId(); - } - } - } - return LoadBalancerServerId( - absl::Span<const uint8_t>(result, config->server_id_len())); + return config->Decrypt(absl::MakeConstSpan( + reinterpret_cast<const uint8_t*>(connection_id.data()), + connection_id.length())); } std::optional<uint8_t> LoadBalancerDecoder::GetConfigId(
diff --git a/quiche/quic/load_balancer/load_balancer_encoder.cc b/quiche/quic/load_balancer/load_balancer_encoder.cc index e956df2..0f72da3 100644 --- a/quiche/quic/load_balancer/load_balancer_encoder.cc +++ b/quiche/quic/load_balancer/load_balancer_encoder.cc
@@ -7,16 +7,16 @@ #include <cstdint> #include <optional> +#include "absl/cleanup/cleanup.h" #include "absl/numeric/int128.h" +#include "absl/types/span.h" #include "quiche/quic/core/crypto/quic_random.h" #include "quiche/quic/core/quic_connection_id.h" #include "quiche/quic/core/quic_data_writer.h" -#include "quiche/quic/core/quic_utils.h" #include "quiche/quic/core/quic_versions.h" #include "quiche/quic/load_balancer/load_balancer_config.h" #include "quiche/quic/load_balancer/load_balancer_server_id.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" -#include "quiche/quic/platform/api/quic_logging.h" #include "quiche/common/quiche_endian.h" namespace quic { @@ -105,6 +105,11 @@ } QuicConnectionId LoadBalancerEncoder::GenerateConnectionId() { + absl::Cleanup cleanup = [&] { + if (num_nonces_left_ == 0) { + DeleteConfig(); + } + }; uint8_t config_id = config_.has_value() ? config_->config_id() : kLoadBalancerUnroutableConfigId; uint8_t shifted_config_id = config_id << kConnectionIdLengthBits; @@ -125,9 +130,9 @@ if (!config_.has_value()) { return MakeUnroutableConnectionId(first_byte); } - QuicConnectionId id; - id.set_length(length); - QuicDataWriter writer(length, id.mutable_data(), quiche::HOST_BYTE_ORDER); + uint8_t result[kQuicMaxConnectionIdWithLengthPrefixLength]; + QuicDataWriter writer(length, reinterpret_cast<char *>(result), + quiche::HOST_BYTE_ORDER); writer.WriteUInt8(first_byte); absl::uint128 next_nonce = (seed_ + num_nonces_left_--) % NumberOfNonces(config_->nonce_len()); @@ -135,39 +140,7 @@ if (!WriteUint128(next_nonce, config_->nonce_len(), writer)) { return QuicConnectionId(); } - uint8_t *block_start = reinterpret_cast<uint8_t *>(writer.data() + 1); - if (!config_->IsEncrypted()) { - // Fill the nonce field with a hash of the Connection ID to avoid the nonce - // visibly increasing by one. This would allow observers to correlate - // connection IDs as being sequential and likely from the same connection, - // not just the same server. - absl::uint128 nonce_hash = - QuicUtils::FNV1a_128_Hash(absl::string_view(writer.data(), length)); - QuicDataWriter rewriter(config_->nonce_len(), - id.mutable_data() + config_->server_id_len() + 1, - quiche::HOST_BYTE_ORDER); - if (!WriteUint128(nonce_hash, config_->nonce_len(), rewriter)) { - return QuicConnectionId(); - } - } else if (config_->plaintext_len() == kLoadBalancerBlockSize) { - // Use one encryption pass. - if (!config_->BlockEncrypt(block_start, block_start)) { - QUIC_LOG(ERROR) << "Block encryption failed"; - return QuicConnectionId(); - } - } else { - for (uint8_t i = 1; i <= kNumLoadBalancerCryptoPasses; i++) { - if (!config_->EncryptionPass(absl::Span<uint8_t>(block_start, length - 1), - i)) { - QUIC_LOG(ERROR) << "Block encryption failed"; - return QuicConnectionId(); - } - } - } - if (num_nonces_left_ == 0) { - DeleteConfig(); - } - return id; + return config_->Encrypt(absl::Span<uint8_t>(result, config_->total_len())); } std::optional<QuicConnectionId> LoadBalancerEncoder::GenerateNextConnectionId(
diff --git a/quiche/quic/load_balancer/load_balancer_server_id.cc b/quiche/quic/load_balancer/load_balancer_server_id.cc index da95c92..9a4685c 100644 --- a/quiche/quic/load_balancer/load_balancer_server_id.cc +++ b/quiche/quic/load_balancer/load_balancer_server_id.cc
@@ -7,7 +7,6 @@ #include <array> #include <cstdint> #include <cstring> -#include <optional> #include <string> #include "absl/strings/escaping.h" @@ -18,19 +17,29 @@ namespace quic { LoadBalancerServerId::LoadBalancerServerId(absl::string_view data) - : LoadBalancerServerId(absl::MakeSpan( - reinterpret_cast<const uint8_t*>(data.data()), data.length())) {} + : LoadBalancerServerId( + absl::MakeSpan(reinterpret_cast<const uint8_t*>(data.data()), + data.length()), + absl::Span<const uint8_t>()) {} -LoadBalancerServerId::LoadBalancerServerId(absl::Span<const uint8_t> data) { - if (data.length() == 0 || data.length() > kLoadBalancerMaxServerIdLen) { - QUIC_BUG(quic_bug_433312504_01) +LoadBalancerServerId::LoadBalancerServerId(absl::Span<const uint8_t> data) + : LoadBalancerServerId(data, absl::Span<const uint8_t>()) {} + +LoadBalancerServerId::LoadBalancerServerId(absl::Span<const uint8_t> data1, + absl::Span<const uint8_t> data2) + : length_(data1.length() + data2.length()) { + if (length_ == 0 || length_ > kLoadBalancerMaxServerIdLen) { + QUIC_BUG(quic_bug_433312504_02) << "Attempted to create LoadBalancerServerId with length " - << data.length(); + << static_cast<int>(length_); length_ = 0; return; } - length_ = data.length(); - memcpy(data_.data(), data.data(), data.length()); + memcpy(data_.data(), data1.data(), data1.length()); + if (data2.empty()) { + return; + } + memcpy(data_.data() + data1.length(), data2.data(), data2.length()); } std::string LoadBalancerServerId::ToString() const {
diff --git a/quiche/quic/load_balancer/load_balancer_server_id.h b/quiche/quic/load_balancer/load_balancer_server_id.h index a7b1d71..d60d308 100644 --- a/quiche/quic/load_balancer/load_balancer_server_id.h +++ b/quiche/quic/load_balancer/load_balancer_server_id.h
@@ -7,7 +7,6 @@ #include <array> #include <cstdint> -#include <optional> #include <string> #include "absl/strings/string_view.h" @@ -34,6 +33,10 @@ // Copies all the bytes from |data| into a new LoadBalancerServerId. explicit LoadBalancerServerId(absl::Span<const uint8_t> data); explicit LoadBalancerServerId(absl::string_view data); + // Concatenates |data1| and |data2| into a single LoadBalancerServerId. This + // is useful to reduce copying for certain decoder configurations. + explicit LoadBalancerServerId(absl::Span<const uint8_t> data1, + absl::Span<const uint8_t> data2); // Server IDs are opaque bytes, but defining these operators allows us to sort // them into a tree and define ranges.
diff --git a/quiche/quic/load_balancer/load_balancer_server_id_test.cc b/quiche/quic/load_balancer/load_balancer_server_id_test.cc index a882795..08b5e6d 100644 --- a/quiche/quic/load_balancer/load_balancer_server_id_test.cc +++ b/quiche/quic/load_balancer/load_balancer_server_id_test.cc
@@ -30,11 +30,27 @@ absl::Span<const uint8_t>(kRawServerId, 16)) .IsValid()), "Attempted to create LoadBalancerServerId with length 16"); + EXPECT_QUIC_BUG(EXPECT_FALSE(LoadBalancerServerId( + absl::Span<const uint8_t>(kRawServerId, 9), + absl::Span<const uint8_t>(kRawServerId, 7)) + .IsValid()), + "Attempted to create LoadBalancerServerId with length 16"); EXPECT_QUIC_BUG( EXPECT_FALSE(LoadBalancerServerId(absl::Span<const uint8_t>()).IsValid()), "Attempted to create LoadBalancerServerId with length 0"); } +TEST_F(LoadBalancerServerIdTest, TwoPartConstructor) { + LoadBalancerServerId server_id1(absl::Span<const uint8_t>(kRawServerId, 15)); + ASSERT_TRUE(server_id1.IsValid()); + LoadBalancerServerId server_id2( + absl::Span<const uint8_t>(kRawServerId, 8), + absl::Span<const uint8_t>(&kRawServerId[8], 7)); + ASSERT_TRUE(server_id2.IsValid()); + EXPECT_TRUE(server_id1 == server_id2); + ; +} + TEST_F(LoadBalancerServerIdTest, CompareIdenticalExceptLength) { LoadBalancerServerId server_id(absl::Span<const uint8_t>(kRawServerId, 15)); ASSERT_TRUE(server_id.IsValid());