QUIC-LB Server Pool: used by load balancers to map server IDs to routing records.

PiperOrigin-RevId: 435134481
diff --git a/quic/load_balancer/load_balancer_server_id_map.h b/quic/load_balancer/load_balancer_server_id_map.h
new file mode 100644
index 0000000..400ac68
--- /dev/null
+++ b/quic/load_balancer/load_balancer_server_id_map.h
@@ -0,0 +1,103 @@
+// Copyright (c) 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
+#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "quic/load_balancer/load_balancer_server_id.h"
+#include "quic/platform/api/quic_bug_tracker.h"
+
+namespace quic {
+
+// This class wraps an absl::flat_hash_map which associates server IDs to an
+// arbitrary type T. It validates that all server ids are of the same fixed
+// length.
+template <typename T>
+class QUIC_EXPORT_PRIVATE LoadBalancerServerIdMap {
+ public:
+  // Returns a newly created pool for server IDs of length |server_id_len|, or
+  // nullptr if |server_id_len| is invalid.
+  static std::shared_ptr<LoadBalancerServerIdMap> Create(
+      const uint8_t server_id_len);
+
+  // Returns the entry associated with |server_id|, if present. For small |T|,
+  // use Lookup. For large |T|, use LookupNoCopy.
+  absl::optional<const T> Lookup(const LoadBalancerServerId server_id) const;
+  const T* LookupNoCopy(const LoadBalancerServerId server_id) const;
+
+  // Updates the table so that |value| is associated with |server_id|. Sets
+  // QUIC_BUG if the length is incorrect for this map.
+  void AddOrReplace(const LoadBalancerServerId server_id, T value);
+
+  // Removes the entry associated with |server_id|.
+  void Erase(const LoadBalancerServerId server_id) {
+    server_id_table_.erase(server_id);
+  }
+
+  uint8_t server_id_len() const { return server_id_len_; }
+
+ private:
+  LoadBalancerServerIdMap(uint8_t server_id_len)
+      : server_id_len_(server_id_len) {}
+
+  const uint8_t server_id_len_;  // All server IDs must be of this length.
+  absl::flat_hash_map<LoadBalancerServerId, T> server_id_table_;
+};
+
+template <typename T>
+std::shared_ptr<LoadBalancerServerIdMap<T>> LoadBalancerServerIdMap<T>::Create(
+    const uint8_t server_id_len) {
+  if (server_id_len == 0 || server_id_len > kLoadBalancerMaxServerIdLen) {
+    QUIC_BUG(quic_bug_434893339_01)
+        << "Tried to configure map with server ID length "
+        << static_cast<int>(server_id_len);
+    return nullptr;
+  }
+  return std::make_shared<LoadBalancerServerIdMap<T>>(
+      LoadBalancerServerIdMap(server_id_len));
+}
+
+template <typename T>
+absl::optional<const T> LoadBalancerServerIdMap<T>::Lookup(
+    const LoadBalancerServerId server_id) const {
+  if (server_id.length() != server_id_len_) {
+    QUIC_BUG(quic_bug_434893339_02)
+        << "Lookup with a " << static_cast<int>(server_id.length())
+        << " byte server ID, map requires " << static_cast<int>(server_id_len_);
+    return absl::optional<T>();
+  }
+  auto it = server_id_table_.find(server_id);
+  return (it != server_id_table_.end()) ? it->second
+                                        : absl::optional<const T>();
+}
+
+template <typename T>
+const T* LoadBalancerServerIdMap<T>::LookupNoCopy(
+    const LoadBalancerServerId server_id) const {
+  if (server_id.length() != server_id_len_) {
+    QUIC_BUG(quic_bug_434893339_02)
+        << "Lookup with a " << static_cast<int>(server_id.length())
+        << " byte server ID, map requires " << static_cast<int>(server_id_len_);
+    return nullptr;
+  }
+  auto it = server_id_table_.find(server_id);
+  return (it != server_id_table_.end()) ? &it->second : nullptr;
+}
+
+template <typename T>
+void LoadBalancerServerIdMap<T>::AddOrReplace(
+    const LoadBalancerServerId server_id, T value) {
+  if (server_id.length() == server_id_len_) {
+    server_id_table_[server_id] = value;
+  } else {
+    QUIC_BUG(quic_bug_434893339_03)
+        << "Server ID of " << static_cast<int>(server_id.length())
+        << " bytes; this map requires " << static_cast<int>(server_id_len_);
+  }
+}
+
+}  // namespace quic
+
+#endif  // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
diff --git a/quic/load_balancer/load_balancer_server_id_map_test.cc b/quic/load_balancer/load_balancer_server_id_map_test.cc
new file mode 100644
index 0000000..37f6a06
--- /dev/null
+++ b/quic/load_balancer/load_balancer_server_id_map_test.cc
@@ -0,0 +1,94 @@
+// Copyright (c) 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quic/load_balancer/load_balancer_server_id_map.h"
+
+#include "quic/platform/api/quic_expect_bug.h"
+#include "quic/platform/api/quic_test.h"
+#include "quic/test_tools/quic_test_utils.h"
+
+namespace quic {
+
+namespace test {
+
+namespace {
+
+constexpr uint8_t kServerId[] = {0xed, 0x79, 0x3a, 0x51};
+
+class LoadBalancerServerIdMapTest : public QuicTest {
+ public:
+  const LoadBalancerServerId valid_server_id_ =
+      *LoadBalancerServerId::Create(kServerId);
+  const LoadBalancerServerId invalid_server_id_ =
+      *LoadBalancerServerId::Create(absl::Span<const uint8_t>(kServerId, 3));
+};
+
+TEST_F(LoadBalancerServerIdMapTest, CreateWithBadServerIdLength) {
+  EXPECT_QUIC_BUG(EXPECT_EQ(LoadBalancerServerIdMap<int>::Create(0), nullptr),
+                  "Tried to configure map with server ID length 0");
+  EXPECT_QUIC_BUG(EXPECT_EQ(LoadBalancerServerIdMap<int>::Create(16), nullptr),
+                  "Tried to configure map with server ID length 16");
+}
+
+TEST_F(LoadBalancerServerIdMapTest, AddOrReplaceWithBadServerIdLength) {
+  int record = 1;
+  auto pool = LoadBalancerServerIdMap<int>::Create(4);
+  EXPECT_NE(pool, nullptr);
+  EXPECT_QUIC_BUG(pool->AddOrReplace(invalid_server_id_, record),
+                  "Server ID of 3 bytes; this map requires 4");
+}
+
+TEST_F(LoadBalancerServerIdMapTest, LookupWithBadServerIdLength) {
+  int record = 1;
+  auto pool = LoadBalancerServerIdMap<int>::Create(4);
+  EXPECT_NE(pool, nullptr);
+  pool->AddOrReplace(valid_server_id_, record);
+  EXPECT_QUIC_BUG(EXPECT_FALSE(pool->Lookup(invalid_server_id_).has_value()),
+                  "Lookup with a 3 byte server ID, map requires 4");
+  EXPECT_QUIC_BUG(EXPECT_EQ(pool->LookupNoCopy(invalid_server_id_), nullptr),
+                  "Lookup with a 3 byte server ID, map requires 4");
+}
+
+TEST_F(LoadBalancerServerIdMapTest, LookupWhenEmpty) {
+  auto pool = LoadBalancerServerIdMap<int>::Create(4);
+  EXPECT_NE(pool, nullptr);
+  EXPECT_EQ(pool->LookupNoCopy(valid_server_id_), nullptr);
+  absl::optional<int> result = pool->Lookup(valid_server_id_);
+  EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(LoadBalancerServerIdMapTest, AddLookup) {
+  int record1 = 1, record2 = 2;
+  auto pool = LoadBalancerServerIdMap<int>::Create(4);
+  EXPECT_NE(pool, nullptr);
+  auto other_server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03, 0x04});
+  EXPECT_TRUE(other_server_id.has_value());
+  pool->AddOrReplace(valid_server_id_, record1);
+  pool->AddOrReplace(*other_server_id, record2);
+  absl::optional<int> result = pool->Lookup(valid_server_id_);
+  EXPECT_TRUE(result.has_value());
+  EXPECT_EQ(*result, record1);
+  auto result_ptr = pool->LookupNoCopy(valid_server_id_);
+  EXPECT_NE(result_ptr, nullptr);
+  EXPECT_EQ(*result_ptr, record1);
+  result = pool->Lookup(*other_server_id);
+  EXPECT_TRUE(result.has_value());
+  EXPECT_EQ(*result, record2);
+}
+
+TEST_F(LoadBalancerServerIdMapTest, AddErase) {
+  int record = 1;
+  auto pool = LoadBalancerServerIdMap<int>::Create(4);
+  EXPECT_NE(pool, nullptr);
+  pool->AddOrReplace(valid_server_id_, record);
+  EXPECT_EQ(*pool->LookupNoCopy(valid_server_id_), record);
+  pool->Erase(valid_server_id_);
+  EXPECT_EQ(pool->LookupNoCopy(valid_server_id_), nullptr);
+}
+
+}  // namespace
+
+}  // namespace test
+
+}  // namespace quic