Make quic::BitMask parametrized by the type of the enum used.

Currently, we only parametrize individual BitMask methods, which allows the bit mask to be accessed via different enums (that is either a mistake, or something that should be explicitly called out by using static_cast).  Parametrizing it by the enum improves type safety, and also allows us to add methods that would return an enum value (e.g. a max() method that I want to add in a follow-up CL).

This CL also has some minor clean ups, notably making everything constexpr, and turning the constructor into an initializer list (thus making it consistent with, e.g., STL container types).

PiperOrigin-RevId: 538841710
diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_test.h b/quiche/quic/core/batch_writer/quic_batch_writer_test.h
index d42867f..08610f0 100644
--- a/quiche/quic/core/batch_writer/quic_batch_writer_test.h
+++ b/quiche/quic/core/batch_writer/quic_batch_writer_test.h
@@ -233,9 +233,9 @@
       result.control_buffer = {&control_buffer_[0], sizeof(control_buffer_)};
       QuicUdpSocketApi().ReadPacket(
           peer_socket_,
-          quic::BitMask64(QuicUdpPacketInfoBit::V4_SELF_IP,
-                          QuicUdpPacketInfoBit::V6_SELF_IP,
-                          QuicUdpPacketInfoBit::PEER_ADDRESS),
+          quic::QuicUdpPacketInfoBitMask({QuicUdpPacketInfoBit::V4_SELF_IP,
+                                          QuicUdpPacketInfoBit::V6_SELF_IP,
+                                          QuicUdpPacketInfoBit::PEER_ADDRESS}),
           &result);
       ASSERT_TRUE(result.ok);
       ASSERT_TRUE(
diff --git a/quiche/quic/core/quic_packet_reader.cc b/quiche/quic/core/quic_packet_reader.cc
index eaa3441..8dd52a1 100644
--- a/quiche/quic/core/quic_packet_reader.cc
+++ b/quiche/quic/core/quic_packet_reader.cc
@@ -49,13 +49,11 @@
   // arriving at the host and now is considered part of the network delay.
   QuicTime now = clock.Now();
 
-  BitMask64 info_bits{QuicUdpPacketInfoBit::DROPPED_PACKETS,
-                      QuicUdpPacketInfoBit::PEER_ADDRESS,
-                      QuicUdpPacketInfoBit::V4_SELF_IP,
-                      QuicUdpPacketInfoBit::V6_SELF_IP,
-                      QuicUdpPacketInfoBit::RECV_TIMESTAMP,
-                      QuicUdpPacketInfoBit::TTL,
-                      QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER};
+  QuicUdpPacketInfoBitMask info_bits(
+      {QuicUdpPacketInfoBit::DROPPED_PACKETS,
+       QuicUdpPacketInfoBit::PEER_ADDRESS, QuicUdpPacketInfoBit::V4_SELF_IP,
+       QuicUdpPacketInfoBit::V6_SELF_IP, QuicUdpPacketInfoBit::RECV_TIMESTAMP,
+       QuicUdpPacketInfoBit::TTL, QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER});
   if (GetQuicRestartFlag(quic_receive_ecn)) {
     QUIC_RESTART_FLAG_COUNT_N(quic_receive_ecn, 3, 3);
     info_bits.Set(QuicUdpPacketInfoBit::ECN);
diff --git a/quiche/quic/core/quic_udp_socket.cc b/quiche/quic/core/quic_udp_socket.cc
index 1a35b72..8e35a2c 100644
--- a/quiche/quic/core/quic_udp_socket.cc
+++ b/quiche/quic/core/quic_udp_socket.cc
@@ -32,7 +32,7 @@
 
 void PopulatePacketInfoFromControlMessageBase(
     PlatformCmsghdr* cmsg, QuicUdpPacketInfo* packet_info,
-    BitMask64 packet_info_interested) {
+    QuicUdpPacketInfoBitMask packet_info_interested) {
   if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) {
     if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::V6_SELF_IP)) {
       const in6_pktinfo* info = reinterpret_cast<in6_pktinfo*>(CMSG_DATA(cmsg));
diff --git a/quiche/quic/core/quic_udp_socket.h b/quiche/quic/core/quic_udp_socket.h
index 08a2595..2541fb0 100644
--- a/quiche/quic/core/quic_udp_socket.h
+++ b/quiche/quic/core/quic_udp_socket.h
@@ -35,12 +35,15 @@
   TTL,                   // Read & Write
   ECN,                   // Read
   GOOGLE_PACKET_HEADER,  // Read
-  NUM_BITS,
-  IS_GRO,  // Read
+  IS_GRO,                // Read
+
+  // Must be the last value.
+  NUM_BITS
 };
+using QuicUdpPacketInfoBitMask = BitMask<QuicUdpPacketInfoBit>;
 static_assert(static_cast<size_t>(QuicUdpPacketInfoBit::NUM_BITS) <=
-                  BitMask64::NumBits(),
-              "BitMask64 not wide enough to hold all bits.");
+                  QuicUdpPacketInfoBitMask::NumBits(),
+              "QuicUdpPacketInfoBitMask not wide enough to hold all bits.");
 
 // BufferSpan points to an unowned buffer, copying this structure only copies
 // the pointer and length, not the buffer itself.
@@ -60,7 +63,7 @@
 // receiving.
 class QUIC_EXPORT_PRIVATE QuicUdpPacketInfo {
  public:
-  BitMask64 bitmask() const { return bitmask_; }
+  QuicUdpPacketInfoBitMask bitmask() const { return bitmask_; }
 
   void Reset() { bitmask_.ClearAll(); }
 
@@ -159,7 +162,7 @@
   }
 
  private:
-  BitMask64 bitmask_;
+  QuicUdpPacketInfoBitMask bitmask_;
   QuicPacketCount dropped_packets_;
   QuicIpAddress self_v4_ip_;
   QuicIpAddress self_v6_ip_;
@@ -238,7 +241,8 @@
   //
   // If |*result| is reused for subsequent ReadPacket() calls, caller needs to
   // call result->Reset() before each ReadPacket().
-  void ReadPacket(QuicUdpSocketFd fd, BitMask64 packet_info_interested,
+  void ReadPacket(QuicUdpSocketFd fd,
+                  QuicUdpPacketInfoBitMask packet_info_interested,
                   ReadPacketResult* result);
 
   using ReadPacketResults = std::vector<ReadPacketResult>;
@@ -247,7 +251,7 @@
   // Return the number of elements populated into |*results|, note it is
   // possible for some of the populated elements to have ok=false.
   size_t ReadMultiplePackets(QuicUdpSocketFd fd,
-                             BitMask64 packet_info_interested,
+                             QuicUdpPacketInfoBitMask packet_info_interested,
                              ReadPacketResults* results);
 
   // Write a packet to |fd|.
diff --git a/quiche/quic/core/quic_udp_socket_posix.inc b/quiche/quic/core/quic_udp_socket_posix.inc
index deb406c..4d371db 100644
--- a/quiche/quic/core/quic_udp_socket_posix.inc
+++ b/quiche/quic/core/quic_udp_socket_posix.inc
@@ -86,9 +86,9 @@
   memcpy(&pktinfo->ipi6_addr, address_string.c_str(), address_string.length());
 }
 
-void PopulatePacketInfoFromControlMessage(struct cmsghdr* cmsg,
-                                          QuicUdpPacketInfo* packet_info,
-                                          BitMask64 packet_info_interested) {
+void PopulatePacketInfoFromControlMessage(
+    struct cmsghdr* cmsg, QuicUdpPacketInfo* packet_info,
+    QuicUdpPacketInfoBitMask packet_info_interested) {
 #ifdef SOL_UDP
   if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::IS_GRO) &&
       cmsg->cmsg_level == SOL_UDP && cmsg->cmsg_type == UDP_GRO) {
@@ -249,9 +249,9 @@
   return 1 == select(1 + fd, &read_fds, nullptr, nullptr, &select_timeout);
 }
 
-void QuicUdpSocketApi::ReadPacket(QuicUdpSocketFd fd,
-                                  BitMask64 packet_info_interested,
-                                  ReadPacketResult* result) {
+void QuicUdpSocketApi::ReadPacket(
+    QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested,
+    ReadPacketResult* result) {
   result->ok = false;
   BufferSpan& packet_buffer = result->packet_buffer;
   BufferSpan& control_buffer = result->control_buffer;
@@ -318,7 +318,7 @@
   if (hdr.msg_controllen > 0) {
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&hdr); cmsg != nullptr;
          cmsg = CMSG_NXTHDR(&hdr, cmsg)) {
-      BitMask64 prior_bitmask = packet_info->bitmask();
+      QuicUdpPacketInfoBitMask prior_bitmask = packet_info->bitmask();
       PopulatePacketInfoFromControlMessage(cmsg, packet_info,
                                            packet_info_interested);
       if (packet_info->bitmask() == prior_bitmask) {
@@ -331,9 +331,9 @@
   result->ok = true;
 }
 
-size_t QuicUdpSocketApi::ReadMultiplePackets(QuicUdpSocketFd fd,
-                                             BitMask64 packet_info_interested,
-                                             ReadPacketResults* results) {
+size_t QuicUdpSocketApi::ReadMultiplePackets(
+    QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested,
+    ReadPacketResults* results) {
 #if defined(__linux__) && !defined(__ANDROID__)
   if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::IS_GRO)) {
     size_t num_packets = 0;
diff --git a/quiche/quic/core/quic_udp_socket_win.inc b/quiche/quic/core/quic_udp_socket_win.inc
index 429b8c7..3f8f89d 100644
--- a/quiche/quic/core/quic_udp_socket_win.inc
+++ b/quiche/quic/core/quic_udp_socket_win.inc
@@ -111,9 +111,9 @@
   return true;
 }
 
-void QuicUdpSocketApi::ReadPacket(QuicUdpSocketFd fd,
-                                  BitMask64 packet_info_interested,
-                                  ReadPacketResult* result) {
+void QuicUdpSocketApi::ReadPacket(
+    QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested,
+    ReadPacketResult* result) {
   result->ok = false;
 
   // WSARecvMsg is an extension to Windows Socket API that requires us to fetch
@@ -175,7 +175,7 @@
   if (hdr.Control.len > 0) {
     for (WSACMSGHDR* cmsg = WSA_CMSG_FIRSTHDR(&hdr); cmsg != nullptr;
          cmsg = WSA_CMSG_NXTHDR(&hdr, cmsg)) {
-      BitMask64 prior_bitmask = packet_info->bitmask();
+      QuicUdpPacketInfoBitMask prior_bitmask = packet_info->bitmask();
       PopulatePacketInfoFromControlMessageBase(cmsg, packet_info,
                                                packet_info_interested);
       if (packet_info->bitmask() == prior_bitmask) {
@@ -188,9 +188,9 @@
   result->ok = true;
 }
 
-size_t QuicUdpSocketApi::ReadMultiplePackets(QuicUdpSocketFd fd,
-                                             BitMask64 packet_info_interested,
-                                             ReadPacketResults* results) {
+size_t QuicUdpSocketApi::ReadMultiplePackets(
+    QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested,
+    ReadPacketResults* results) {
   size_t num_packets = 0;
   for (ReadPacketResult& result : *results) {
     result.ok = false;
diff --git a/quiche/quic/core/quic_utils.h b/quiche/quic/core/quic_utils.h
index 79212fd..52108aa 100644
--- a/quiche/quic/core/quic_utils.h
+++ b/quiche/quic/core/quic_utils.h
@@ -7,11 +7,12 @@
 
 #include <cstddef>
 #include <cstdint>
-#include <sstream>
+#include <initializer_list>
 #include <string>
 #include <type_traits>
 
 #include "absl/numeric/int128.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "quiche/quic/core/crypto/quic_random.h"
@@ -226,30 +227,31 @@
 // Computes a SHA-256 hash and returns the raw bytes of the hash.
 QUIC_EXPORT_PRIVATE std::string RawSha256(absl::string_view input);
 
-template <typename Mask>
+// BitMask<Index, Mask> is a set of elements of type `Index` represented as a
+// bitmask of an underlying integer type `Mask` (uint64_t by default). The
+// underlying type has to be large enough to fit all possible values of `Index`.
+template <typename Index, typename Mask = uint64_t>
 class QUIC_EXPORT_PRIVATE BitMask {
  public:
-  // explicit to prevent (incorrect) usage like "BitMask bitmask = 0;".
-  template <typename... Bits>
-  explicit BitMask(Bits... bits) {
-    mask_ = MakeMask(bits...);
+  explicit constexpr BitMask(std::initializer_list<Index> bits) {
+    for (Index bit : bits) {
+      mask_ |= MakeMask(bit);
+    }
   }
 
   BitMask() = default;
   BitMask(const BitMask& other) = default;
   BitMask& operator=(const BitMask& other) = default;
 
-  template <typename... Bits>
-  void Set(Bits... bits) {
-    mask_ |= MakeMask(bits...);
+  constexpr void Set(Index bit) { mask_ |= MakeMask(bit); }
+
+  constexpr void Set(std::initializer_list<Index> bits) {
+    mask_ |= BitMask(bits).mask();
   }
 
-  template <typename Bit>
-  bool IsSet(Bit bit) const {
-    return (MakeMask(bit) & mask_) != 0;
-  }
+  constexpr bool IsSet(Index bit) const { return (MakeMask(bit) & mask_) != 0; }
 
-  void ClearAll() { mask_ = 0; }
+  constexpr void ClearAll() { mask_ = 0; }
 
   static constexpr size_t NumBits() { return 8 * sizeof(Mask); }
 
@@ -258,32 +260,35 @@
   }
 
   std::string DebugString() const {
-    std::ostringstream oss;
-    oss << "0x" << std::hex << mask_;
-    return oss.str();
+    return absl::StrCat("0x", absl::Hex(mask_));
   }
 
+  constexpr Mask mask() const { return mask_; }
+
  private:
   template <typename Bit>
-  static std::enable_if_t<std::is_enum<Bit>::value, Mask> MakeMask(Bit bit) {
+  static constexpr std::enable_if_t<std::is_enum_v<Bit>, Mask> MakeMask(
+      Bit bit) {
     using IntType = typename std::underlying_type<Bit>::type;
-    return Mask(1) << static_cast<IntType>(bit);
+    return MakeMask(static_cast<IntType>(bit));
   }
 
   template <typename Bit>
-  static std::enable_if_t<!std::is_enum<Bit>::value, Mask> MakeMask(Bit bit) {
+  static constexpr std::enable_if_t<!std::is_enum_v<Bit>, Mask> MakeMask(
+      Bit bit) {
+    // We can't use QUICHE_DCHECK_LT here, since it doesn't work with constexpr.
+    QUICHE_DCHECK(bit < static_cast<Bit>(NumBits()));
+    if constexpr (std::is_signed_v<Bit>) {
+      QUICHE_DCHECK(bit >= 0);
+    }
     return Mask(1) << bit;
   }
 
-  template <typename Bit, typename... Bits>
-  static Mask MakeMask(Bit first_bit, Bits... other_bits) {
-    return MakeMask(first_bit) | MakeMask(other_bits...);
-  }
-
   Mask mask_ = 0;
 };
 
-using BitMask64 = BitMask<uint64_t>;
+// Ensure that the BitMask constructor can be evaluated as constexpr.
+static_assert(BitMask<int>({1, 2, 3}).mask() == 0x0e);
 
 }  // namespace quic
 
diff --git a/quiche/quic/core/quic_utils_test.cc b/quiche/quic/core/quic_utils_test.cc
index a391cc2..29401fb 100644
--- a/quiche/quic/core/quic_utils_test.cc
+++ b/quiche/quic/core/quic_utils_test.cc
@@ -9,7 +9,6 @@
 #include "absl/base/macros.h"
 #include "absl/numeric/int128.h"
 #include "absl/strings/string_view.h"
-#include "quiche/quic/core/crypto/crypto_protocol.h"
 #include "quiche/quic/core/quic_connection_id.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/platform/api/quic_test.h"
@@ -253,7 +252,8 @@
 };
 
 TEST(QuicBitMaskTest, EnumClass) {
-  BitMask64 mask(TestEnumClassBit::BIT_ZERO, TestEnumClassBit::BIT_TWO);
+  BitMask<TestEnumClassBit> mask(
+      {TestEnumClassBit::BIT_ZERO, TestEnumClassBit::BIT_TWO});
   EXPECT_TRUE(mask.IsSet(TestEnumClassBit::BIT_ZERO));
   EXPECT_FALSE(mask.IsSet(TestEnumClassBit::BIT_ONE));
   EXPECT_TRUE(mask.IsSet(TestEnumClassBit::BIT_TWO));
@@ -265,7 +265,7 @@
 }
 
 TEST(QuicBitMaskTest, Enum) {
-  BitMask64 mask(TEST_BIT_1, TEST_BIT_2);
+  BitMask<TestEnumBit> mask({TEST_BIT_1, TEST_BIT_2});
   EXPECT_FALSE(mask.IsSet(TEST_BIT_0));
   EXPECT_TRUE(mask.IsSet(TEST_BIT_1));
   EXPECT_TRUE(mask.IsSet(TEST_BIT_2));
@@ -277,9 +277,9 @@
 }
 
 TEST(QuicBitMaskTest, Integer) {
-  BitMask64 mask(1, 3);
+  BitMask<int> mask({1, 3});
   mask.Set(3);
-  mask.Set(5, 7, 9);
+  mask.Set({5, 7, 9});
   EXPECT_FALSE(mask.IsSet(0));
   EXPECT_TRUE(mask.IsSet(1));
   EXPECT_FALSE(mask.IsSet(2));
@@ -293,26 +293,26 @@
 }
 
 TEST(QuicBitMaskTest, NumBits) {
-  EXPECT_EQ(64u, BitMask64::NumBits());
-  EXPECT_EQ(32u, BitMask<uint32_t>::NumBits());
+  EXPECT_EQ(64u, BitMask<int>::NumBits());
+  EXPECT_EQ(32u, (BitMask<int, uint32_t>::NumBits()));
 }
 
 TEST(QuicBitMaskTest, Constructor) {
-  BitMask64 empty_mask;
+  BitMask<int> empty_mask;
   for (size_t bit = 0; bit < empty_mask.NumBits(); ++bit) {
     EXPECT_FALSE(empty_mask.IsSet(bit));
   }
 
-  BitMask64 mask(1, 3);
-  BitMask64 mask2 = mask;
-  BitMask64 mask3(mask2);
+  BitMask<int> mask({1, 3});
+  BitMask<int> mask2 = mask;
+  BitMask<int> mask3(mask2);
 
   for (size_t bit = 0; bit < mask.NumBits(); ++bit) {
     EXPECT_EQ(mask.IsSet(bit), mask2.IsSet(bit));
     EXPECT_EQ(mask.IsSet(bit), mask3.IsSet(bit));
   }
 
-  EXPECT_TRUE(std::is_trivially_copyable<BitMask64>::value);
+  EXPECT_TRUE(std::is_trivially_copyable<BitMask<int>>::value);
 }
 
 }  // namespace
diff --git a/quiche/quic/masque/masque_server_session.cc b/quiche/quic/masque/masque_server_session.cc
index 5a5f75b..c3005cf 100644
--- a/quiche/quic/masque/masque_server_session.cc
+++ b/quiche/quic/masque/masque_server_session.cc
@@ -357,7 +357,8 @@
                 << ") stream ID " << it->stream()->id() << " server "
                 << expected_target_server_address;
   QuicUdpSocketApi socket_api;
-  BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS);
+  QuicUdpPacketInfoBitMask packet_info_interested(
+      {QuicUdpPacketInfoBit::PEER_ADDRESS});
   char packet_buffer[1 + kMaxIncomingPacketSize];
   packet_buffer[0] = 0;  // context ID.
   char control_buffer[kDefaultUdpPacketControlBufferSize];