QUIC-LB Server Pool: used by load balancers to map server IDs to routing records.
PiperOrigin-RevId: 435134481
diff --git a/quic/load_balancer/load_balancer_server_id_map.h b/quic/load_balancer/load_balancer_server_id_map.h
new file mode 100644
index 0000000..400ac68
--- /dev/null
+++ b/quic/load_balancer/load_balancer_server_id_map.h
@@ -0,0 +1,103 @@
+// Copyright (c) 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
+#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "quic/load_balancer/load_balancer_server_id.h"
+#include "quic/platform/api/quic_bug_tracker.h"
+
+namespace quic {
+
+// This class wraps an absl::flat_hash_map which associates server IDs to an
+// arbitrary type T. It validates that all server ids are of the same fixed
+// length.
+template <typename T>
+class QUIC_EXPORT_PRIVATE LoadBalancerServerIdMap {
+ public:
+ // Returns a newly created pool for server IDs of length |server_id_len|, or
+ // nullptr if |server_id_len| is invalid.
+ static std::shared_ptr<LoadBalancerServerIdMap> Create(
+ const uint8_t server_id_len);
+
+ // Returns the entry associated with |server_id|, if present. For small |T|,
+ // use Lookup. For large |T|, use LookupNoCopy.
+ absl::optional<const T> Lookup(const LoadBalancerServerId server_id) const;
+ const T* LookupNoCopy(const LoadBalancerServerId server_id) const;
+
+ // Updates the table so that |value| is associated with |server_id|. Sets
+ // QUIC_BUG if the length is incorrect for this map.
+ void AddOrReplace(const LoadBalancerServerId server_id, T value);
+
+ // Removes the entry associated with |server_id|.
+ void Erase(const LoadBalancerServerId server_id) {
+ server_id_table_.erase(server_id);
+ }
+
+ uint8_t server_id_len() const { return server_id_len_; }
+
+ private:
+ LoadBalancerServerIdMap(uint8_t server_id_len)
+ : server_id_len_(server_id_len) {}
+
+ const uint8_t server_id_len_; // All server IDs must be of this length.
+ absl::flat_hash_map<LoadBalancerServerId, T> server_id_table_;
+};
+
+template <typename T>
+std::shared_ptr<LoadBalancerServerIdMap<T>> LoadBalancerServerIdMap<T>::Create(
+ const uint8_t server_id_len) {
+ if (server_id_len == 0 || server_id_len > kLoadBalancerMaxServerIdLen) {
+ QUIC_BUG(quic_bug_434893339_01)
+ << "Tried to configure map with server ID length "
+ << static_cast<int>(server_id_len);
+ return nullptr;
+ }
+ return std::make_shared<LoadBalancerServerIdMap<T>>(
+ LoadBalancerServerIdMap(server_id_len));
+}
+
+template <typename T>
+absl::optional<const T> LoadBalancerServerIdMap<T>::Lookup(
+ const LoadBalancerServerId server_id) const {
+ if (server_id.length() != server_id_len_) {
+ QUIC_BUG(quic_bug_434893339_02)
+ << "Lookup with a " << static_cast<int>(server_id.length())
+ << " byte server ID, map requires " << static_cast<int>(server_id_len_);
+ return absl::optional<T>();
+ }
+ auto it = server_id_table_.find(server_id);
+ return (it != server_id_table_.end()) ? it->second
+ : absl::optional<const T>();
+}
+
+template <typename T>
+const T* LoadBalancerServerIdMap<T>::LookupNoCopy(
+ const LoadBalancerServerId server_id) const {
+ if (server_id.length() != server_id_len_) {
+ QUIC_BUG(quic_bug_434893339_02)
+ << "Lookup with a " << static_cast<int>(server_id.length())
+ << " byte server ID, map requires " << static_cast<int>(server_id_len_);
+ return nullptr;
+ }
+ auto it = server_id_table_.find(server_id);
+ return (it != server_id_table_.end()) ? &it->second : nullptr;
+}
+
+template <typename T>
+void LoadBalancerServerIdMap<T>::AddOrReplace(
+ const LoadBalancerServerId server_id, T value) {
+ if (server_id.length() == server_id_len_) {
+ server_id_table_[server_id] = value;
+ } else {
+ QUIC_BUG(quic_bug_434893339_03)
+ << "Server ID of " << static_cast<int>(server_id.length())
+ << " bytes; this map requires " << static_cast<int>(server_id_len_);
+ }
+}
+
+} // namespace quic
+
+#endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
diff --git a/quic/load_balancer/load_balancer_server_id_map_test.cc b/quic/load_balancer/load_balancer_server_id_map_test.cc
new file mode 100644
index 0000000..37f6a06
--- /dev/null
+++ b/quic/load_balancer/load_balancer_server_id_map_test.cc
@@ -0,0 +1,94 @@
+// Copyright (c) 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quic/load_balancer/load_balancer_server_id_map.h"
+
+#include "quic/platform/api/quic_expect_bug.h"
+#include "quic/platform/api/quic_test.h"
+#include "quic/test_tools/quic_test_utils.h"
+
+namespace quic {
+
+namespace test {
+
+namespace {
+
+constexpr uint8_t kServerId[] = {0xed, 0x79, 0x3a, 0x51};
+
+class LoadBalancerServerIdMapTest : public QuicTest {
+ public:
+ const LoadBalancerServerId valid_server_id_ =
+ *LoadBalancerServerId::Create(kServerId);
+ const LoadBalancerServerId invalid_server_id_ =
+ *LoadBalancerServerId::Create(absl::Span<const uint8_t>(kServerId, 3));
+};
+
+TEST_F(LoadBalancerServerIdMapTest, CreateWithBadServerIdLength) {
+ EXPECT_QUIC_BUG(EXPECT_EQ(LoadBalancerServerIdMap<int>::Create(0), nullptr),
+ "Tried to configure map with server ID length 0");
+ EXPECT_QUIC_BUG(EXPECT_EQ(LoadBalancerServerIdMap<int>::Create(16), nullptr),
+ "Tried to configure map with server ID length 16");
+}
+
+TEST_F(LoadBalancerServerIdMapTest, AddOrReplaceWithBadServerIdLength) {
+ int record = 1;
+ auto pool = LoadBalancerServerIdMap<int>::Create(4);
+ EXPECT_NE(pool, nullptr);
+ EXPECT_QUIC_BUG(pool->AddOrReplace(invalid_server_id_, record),
+ "Server ID of 3 bytes; this map requires 4");
+}
+
+TEST_F(LoadBalancerServerIdMapTest, LookupWithBadServerIdLength) {
+ int record = 1;
+ auto pool = LoadBalancerServerIdMap<int>::Create(4);
+ EXPECT_NE(pool, nullptr);
+ pool->AddOrReplace(valid_server_id_, record);
+ EXPECT_QUIC_BUG(EXPECT_FALSE(pool->Lookup(invalid_server_id_).has_value()),
+ "Lookup with a 3 byte server ID, map requires 4");
+ EXPECT_QUIC_BUG(EXPECT_EQ(pool->LookupNoCopy(invalid_server_id_), nullptr),
+ "Lookup with a 3 byte server ID, map requires 4");
+}
+
+TEST_F(LoadBalancerServerIdMapTest, LookupWhenEmpty) {
+ auto pool = LoadBalancerServerIdMap<int>::Create(4);
+ EXPECT_NE(pool, nullptr);
+ EXPECT_EQ(pool->LookupNoCopy(valid_server_id_), nullptr);
+ absl::optional<int> result = pool->Lookup(valid_server_id_);
+ EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(LoadBalancerServerIdMapTest, AddLookup) {
+ int record1 = 1, record2 = 2;
+ auto pool = LoadBalancerServerIdMap<int>::Create(4);
+ EXPECT_NE(pool, nullptr);
+ auto other_server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03, 0x04});
+ EXPECT_TRUE(other_server_id.has_value());
+ pool->AddOrReplace(valid_server_id_, record1);
+ pool->AddOrReplace(*other_server_id, record2);
+ absl::optional<int> result = pool->Lookup(valid_server_id_);
+ EXPECT_TRUE(result.has_value());
+ EXPECT_EQ(*result, record1);
+ auto result_ptr = pool->LookupNoCopy(valid_server_id_);
+ EXPECT_NE(result_ptr, nullptr);
+ EXPECT_EQ(*result_ptr, record1);
+ result = pool->Lookup(*other_server_id);
+ EXPECT_TRUE(result.has_value());
+ EXPECT_EQ(*result, record2);
+}
+
+TEST_F(LoadBalancerServerIdMapTest, AddErase) {
+ int record = 1;
+ auto pool = LoadBalancerServerIdMap<int>::Create(4);
+ EXPECT_NE(pool, nullptr);
+ pool->AddOrReplace(valid_server_id_, record);
+ EXPECT_EQ(*pool->LookupNoCopy(valid_server_id_), record);
+ pool->Erase(valid_server_id_);
+ EXPECT_EQ(pool->LookupNoCopy(valid_server_id_), nullptr);
+}
+
+} // namespace
+
+} // namespace test
+
+} // namespace quic