Make quic::BitMask parametrized by the type of the enum used. Currently, we only parametrize individual BitMask methods, which allows the bit mask to be accessed via different enums (that is either a mistake, or something that should be explicitly called out by using static_cast). Parametrizing it by the enum improves type safety, and also allows us to add methods that would return an enum value (e.g. a max() method that I want to add in a follow-up CL). This CL also has some minor clean ups, notably making everything constexpr, and turning the constructor into an initializer list (thus making it consistent with, e.g., STL container types). PiperOrigin-RevId: 538841710
diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_test.h b/quiche/quic/core/batch_writer/quic_batch_writer_test.h index d42867f..08610f0 100644 --- a/quiche/quic/core/batch_writer/quic_batch_writer_test.h +++ b/quiche/quic/core/batch_writer/quic_batch_writer_test.h
@@ -233,9 +233,9 @@ result.control_buffer = {&control_buffer_[0], sizeof(control_buffer_)}; QuicUdpSocketApi().ReadPacket( peer_socket_, - quic::BitMask64(QuicUdpPacketInfoBit::V4_SELF_IP, - QuicUdpPacketInfoBit::V6_SELF_IP, - QuicUdpPacketInfoBit::PEER_ADDRESS), + quic::QuicUdpPacketInfoBitMask({QuicUdpPacketInfoBit::V4_SELF_IP, + QuicUdpPacketInfoBit::V6_SELF_IP, + QuicUdpPacketInfoBit::PEER_ADDRESS}), &result); ASSERT_TRUE(result.ok); ASSERT_TRUE(
diff --git a/quiche/quic/core/quic_packet_reader.cc b/quiche/quic/core/quic_packet_reader.cc index eaa3441..8dd52a1 100644 --- a/quiche/quic/core/quic_packet_reader.cc +++ b/quiche/quic/core/quic_packet_reader.cc
@@ -49,13 +49,11 @@ // arriving at the host and now is considered part of the network delay. QuicTime now = clock.Now(); - BitMask64 info_bits{QuicUdpPacketInfoBit::DROPPED_PACKETS, - QuicUdpPacketInfoBit::PEER_ADDRESS, - QuicUdpPacketInfoBit::V4_SELF_IP, - QuicUdpPacketInfoBit::V6_SELF_IP, - QuicUdpPacketInfoBit::RECV_TIMESTAMP, - QuicUdpPacketInfoBit::TTL, - QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER}; + QuicUdpPacketInfoBitMask info_bits( + {QuicUdpPacketInfoBit::DROPPED_PACKETS, + QuicUdpPacketInfoBit::PEER_ADDRESS, QuicUdpPacketInfoBit::V4_SELF_IP, + QuicUdpPacketInfoBit::V6_SELF_IP, QuicUdpPacketInfoBit::RECV_TIMESTAMP, + QuicUdpPacketInfoBit::TTL, QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER}); if (GetQuicRestartFlag(quic_receive_ecn)) { QUIC_RESTART_FLAG_COUNT_N(quic_receive_ecn, 3, 3); info_bits.Set(QuicUdpPacketInfoBit::ECN);
diff --git a/quiche/quic/core/quic_udp_socket.cc b/quiche/quic/core/quic_udp_socket.cc index 1a35b72..8e35a2c 100644 --- a/quiche/quic/core/quic_udp_socket.cc +++ b/quiche/quic/core/quic_udp_socket.cc
@@ -32,7 +32,7 @@ void PopulatePacketInfoFromControlMessageBase( PlatformCmsghdr* cmsg, QuicUdpPacketInfo* packet_info, - BitMask64 packet_info_interested) { + QuicUdpPacketInfoBitMask packet_info_interested) { if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::V6_SELF_IP)) { const in6_pktinfo* info = reinterpret_cast<in6_pktinfo*>(CMSG_DATA(cmsg));
diff --git a/quiche/quic/core/quic_udp_socket.h b/quiche/quic/core/quic_udp_socket.h index 08a2595..2541fb0 100644 --- a/quiche/quic/core/quic_udp_socket.h +++ b/quiche/quic/core/quic_udp_socket.h
@@ -35,12 +35,15 @@ TTL, // Read & Write ECN, // Read GOOGLE_PACKET_HEADER, // Read - NUM_BITS, - IS_GRO, // Read + IS_GRO, // Read + + // Must be the last value. + NUM_BITS }; +using QuicUdpPacketInfoBitMask = BitMask<QuicUdpPacketInfoBit>; static_assert(static_cast<size_t>(QuicUdpPacketInfoBit::NUM_BITS) <= - BitMask64::NumBits(), - "BitMask64 not wide enough to hold all bits."); + QuicUdpPacketInfoBitMask::NumBits(), + "QuicUdpPacketInfoBitMask not wide enough to hold all bits."); // BufferSpan points to an unowned buffer, copying this structure only copies // the pointer and length, not the buffer itself. @@ -60,7 +63,7 @@ // receiving. class QUIC_EXPORT_PRIVATE QuicUdpPacketInfo { public: - BitMask64 bitmask() const { return bitmask_; } + QuicUdpPacketInfoBitMask bitmask() const { return bitmask_; } void Reset() { bitmask_.ClearAll(); } @@ -159,7 +162,7 @@ } private: - BitMask64 bitmask_; + QuicUdpPacketInfoBitMask bitmask_; QuicPacketCount dropped_packets_; QuicIpAddress self_v4_ip_; QuicIpAddress self_v6_ip_; @@ -238,7 +241,8 @@ // // If |*result| is reused for subsequent ReadPacket() calls, caller needs to // call result->Reset() before each ReadPacket(). - void ReadPacket(QuicUdpSocketFd fd, BitMask64 packet_info_interested, + void ReadPacket(QuicUdpSocketFd fd, + QuicUdpPacketInfoBitMask packet_info_interested, ReadPacketResult* result); using ReadPacketResults = std::vector<ReadPacketResult>; @@ -247,7 +251,7 @@ // Return the number of elements populated into |*results|, note it is // possible for some of the populated elements to have ok=false. size_t ReadMultiplePackets(QuicUdpSocketFd fd, - BitMask64 packet_info_interested, + QuicUdpPacketInfoBitMask packet_info_interested, ReadPacketResults* results); // Write a packet to |fd|.
diff --git a/quiche/quic/core/quic_udp_socket_posix.inc b/quiche/quic/core/quic_udp_socket_posix.inc index deb406c..4d371db 100644 --- a/quiche/quic/core/quic_udp_socket_posix.inc +++ b/quiche/quic/core/quic_udp_socket_posix.inc
@@ -86,9 +86,9 @@ memcpy(&pktinfo->ipi6_addr, address_string.c_str(), address_string.length()); } -void PopulatePacketInfoFromControlMessage(struct cmsghdr* cmsg, - QuicUdpPacketInfo* packet_info, - BitMask64 packet_info_interested) { +void PopulatePacketInfoFromControlMessage( + struct cmsghdr* cmsg, QuicUdpPacketInfo* packet_info, + QuicUdpPacketInfoBitMask packet_info_interested) { #ifdef SOL_UDP if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::IS_GRO) && cmsg->cmsg_level == SOL_UDP && cmsg->cmsg_type == UDP_GRO) { @@ -249,9 +249,9 @@ return 1 == select(1 + fd, &read_fds, nullptr, nullptr, &select_timeout); } -void QuicUdpSocketApi::ReadPacket(QuicUdpSocketFd fd, - BitMask64 packet_info_interested, - ReadPacketResult* result) { +void QuicUdpSocketApi::ReadPacket( + QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested, + ReadPacketResult* result) { result->ok = false; BufferSpan& packet_buffer = result->packet_buffer; BufferSpan& control_buffer = result->control_buffer; @@ -318,7 +318,7 @@ if (hdr.msg_controllen > 0) { for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&hdr); cmsg != nullptr; cmsg = CMSG_NXTHDR(&hdr, cmsg)) { - BitMask64 prior_bitmask = packet_info->bitmask(); + QuicUdpPacketInfoBitMask prior_bitmask = packet_info->bitmask(); PopulatePacketInfoFromControlMessage(cmsg, packet_info, packet_info_interested); if (packet_info->bitmask() == prior_bitmask) { @@ -331,9 +331,9 @@ result->ok = true; } -size_t QuicUdpSocketApi::ReadMultiplePackets(QuicUdpSocketFd fd, - BitMask64 packet_info_interested, - ReadPacketResults* results) { +size_t QuicUdpSocketApi::ReadMultiplePackets( + QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested, + ReadPacketResults* results) { #if defined(__linux__) && !defined(__ANDROID__) if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::IS_GRO)) { size_t num_packets = 0;
diff --git a/quiche/quic/core/quic_udp_socket_win.inc b/quiche/quic/core/quic_udp_socket_win.inc index 429b8c7..3f8f89d 100644 --- a/quiche/quic/core/quic_udp_socket_win.inc +++ b/quiche/quic/core/quic_udp_socket_win.inc
@@ -111,9 +111,9 @@ return true; } -void QuicUdpSocketApi::ReadPacket(QuicUdpSocketFd fd, - BitMask64 packet_info_interested, - ReadPacketResult* result) { +void QuicUdpSocketApi::ReadPacket( + QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested, + ReadPacketResult* result) { result->ok = false; // WSARecvMsg is an extension to Windows Socket API that requires us to fetch @@ -175,7 +175,7 @@ if (hdr.Control.len > 0) { for (WSACMSGHDR* cmsg = WSA_CMSG_FIRSTHDR(&hdr); cmsg != nullptr; cmsg = WSA_CMSG_NXTHDR(&hdr, cmsg)) { - BitMask64 prior_bitmask = packet_info->bitmask(); + QuicUdpPacketInfoBitMask prior_bitmask = packet_info->bitmask(); PopulatePacketInfoFromControlMessageBase(cmsg, packet_info, packet_info_interested); if (packet_info->bitmask() == prior_bitmask) { @@ -188,9 +188,9 @@ result->ok = true; } -size_t QuicUdpSocketApi::ReadMultiplePackets(QuicUdpSocketFd fd, - BitMask64 packet_info_interested, - ReadPacketResults* results) { +size_t QuicUdpSocketApi::ReadMultiplePackets( + QuicUdpSocketFd fd, QuicUdpPacketInfoBitMask packet_info_interested, + ReadPacketResults* results) { size_t num_packets = 0; for (ReadPacketResult& result : *results) { result.ok = false;
diff --git a/quiche/quic/core/quic_utils.h b/quiche/quic/core/quic_utils.h index 79212fd..52108aa 100644 --- a/quiche/quic/core/quic_utils.h +++ b/quiche/quic/core/quic_utils.h
@@ -7,11 +7,12 @@ #include <cstddef> #include <cstdint> -#include <sstream> +#include <initializer_list> #include <string> #include <type_traits> #include "absl/numeric/int128.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "quiche/quic/core/crypto/quic_random.h" @@ -226,30 +227,31 @@ // Computes a SHA-256 hash and returns the raw bytes of the hash. QUIC_EXPORT_PRIVATE std::string RawSha256(absl::string_view input); -template <typename Mask> +// BitMask<Index, Mask> is a set of elements of type `Index` represented as a +// bitmask of an underlying integer type `Mask` (uint64_t by default). The +// underlying type has to be large enough to fit all possible values of `Index`. +template <typename Index, typename Mask = uint64_t> class QUIC_EXPORT_PRIVATE BitMask { public: - // explicit to prevent (incorrect) usage like "BitMask bitmask = 0;". - template <typename... Bits> - explicit BitMask(Bits... bits) { - mask_ = MakeMask(bits...); + explicit constexpr BitMask(std::initializer_list<Index> bits) { + for (Index bit : bits) { + mask_ |= MakeMask(bit); + } } BitMask() = default; BitMask(const BitMask& other) = default; BitMask& operator=(const BitMask& other) = default; - template <typename... Bits> - void Set(Bits... bits) { - mask_ |= MakeMask(bits...); + constexpr void Set(Index bit) { mask_ |= MakeMask(bit); } + + constexpr void Set(std::initializer_list<Index> bits) { + mask_ |= BitMask(bits).mask(); } - template <typename Bit> - bool IsSet(Bit bit) const { - return (MakeMask(bit) & mask_) != 0; - } + constexpr bool IsSet(Index bit) const { return (MakeMask(bit) & mask_) != 0; } - void ClearAll() { mask_ = 0; } + constexpr void ClearAll() { mask_ = 0; } static constexpr size_t NumBits() { return 8 * sizeof(Mask); } @@ -258,32 +260,35 @@ } std::string DebugString() const { - std::ostringstream oss; - oss << "0x" << std::hex << mask_; - return oss.str(); + return absl::StrCat("0x", absl::Hex(mask_)); } + constexpr Mask mask() const { return mask_; } + private: template <typename Bit> - static std::enable_if_t<std::is_enum<Bit>::value, Mask> MakeMask(Bit bit) { + static constexpr std::enable_if_t<std::is_enum_v<Bit>, Mask> MakeMask( + Bit bit) { using IntType = typename std::underlying_type<Bit>::type; - return Mask(1) << static_cast<IntType>(bit); + return MakeMask(static_cast<IntType>(bit)); } template <typename Bit> - static std::enable_if_t<!std::is_enum<Bit>::value, Mask> MakeMask(Bit bit) { + static constexpr std::enable_if_t<!std::is_enum_v<Bit>, Mask> MakeMask( + Bit bit) { + // We can't use QUICHE_DCHECK_LT here, since it doesn't work with constexpr. + QUICHE_DCHECK(bit < static_cast<Bit>(NumBits())); + if constexpr (std::is_signed_v<Bit>) { + QUICHE_DCHECK(bit >= 0); + } return Mask(1) << bit; } - template <typename Bit, typename... Bits> - static Mask MakeMask(Bit first_bit, Bits... other_bits) { - return MakeMask(first_bit) | MakeMask(other_bits...); - } - Mask mask_ = 0; }; -using BitMask64 = BitMask<uint64_t>; +// Ensure that the BitMask constructor can be evaluated as constexpr. +static_assert(BitMask<int>({1, 2, 3}).mask() == 0x0e); } // namespace quic
diff --git a/quiche/quic/core/quic_utils_test.cc b/quiche/quic/core/quic_utils_test.cc index a391cc2..29401fb 100644 --- a/quiche/quic/core/quic_utils_test.cc +++ b/quiche/quic/core/quic_utils_test.cc
@@ -9,7 +9,6 @@ #include "absl/base/macros.h" #include "absl/numeric/int128.h" #include "absl/strings/string_view.h" -#include "quiche/quic/core/crypto/crypto_protocol.h" #include "quiche/quic/core/quic_connection_id.h" #include "quiche/quic/core/quic_types.h" #include "quiche/quic/platform/api/quic_test.h" @@ -253,7 +252,8 @@ }; TEST(QuicBitMaskTest, EnumClass) { - BitMask64 mask(TestEnumClassBit::BIT_ZERO, TestEnumClassBit::BIT_TWO); + BitMask<TestEnumClassBit> mask( + {TestEnumClassBit::BIT_ZERO, TestEnumClassBit::BIT_TWO}); EXPECT_TRUE(mask.IsSet(TestEnumClassBit::BIT_ZERO)); EXPECT_FALSE(mask.IsSet(TestEnumClassBit::BIT_ONE)); EXPECT_TRUE(mask.IsSet(TestEnumClassBit::BIT_TWO)); @@ -265,7 +265,7 @@ } TEST(QuicBitMaskTest, Enum) { - BitMask64 mask(TEST_BIT_1, TEST_BIT_2); + BitMask<TestEnumBit> mask({TEST_BIT_1, TEST_BIT_2}); EXPECT_FALSE(mask.IsSet(TEST_BIT_0)); EXPECT_TRUE(mask.IsSet(TEST_BIT_1)); EXPECT_TRUE(mask.IsSet(TEST_BIT_2)); @@ -277,9 +277,9 @@ } TEST(QuicBitMaskTest, Integer) { - BitMask64 mask(1, 3); + BitMask<int> mask({1, 3}); mask.Set(3); - mask.Set(5, 7, 9); + mask.Set({5, 7, 9}); EXPECT_FALSE(mask.IsSet(0)); EXPECT_TRUE(mask.IsSet(1)); EXPECT_FALSE(mask.IsSet(2)); @@ -293,26 +293,26 @@ } TEST(QuicBitMaskTest, NumBits) { - EXPECT_EQ(64u, BitMask64::NumBits()); - EXPECT_EQ(32u, BitMask<uint32_t>::NumBits()); + EXPECT_EQ(64u, BitMask<int>::NumBits()); + EXPECT_EQ(32u, (BitMask<int, uint32_t>::NumBits())); } TEST(QuicBitMaskTest, Constructor) { - BitMask64 empty_mask; + BitMask<int> empty_mask; for (size_t bit = 0; bit < empty_mask.NumBits(); ++bit) { EXPECT_FALSE(empty_mask.IsSet(bit)); } - BitMask64 mask(1, 3); - BitMask64 mask2 = mask; - BitMask64 mask3(mask2); + BitMask<int> mask({1, 3}); + BitMask<int> mask2 = mask; + BitMask<int> mask3(mask2); for (size_t bit = 0; bit < mask.NumBits(); ++bit) { EXPECT_EQ(mask.IsSet(bit), mask2.IsSet(bit)); EXPECT_EQ(mask.IsSet(bit), mask3.IsSet(bit)); } - EXPECT_TRUE(std::is_trivially_copyable<BitMask64>::value); + EXPECT_TRUE(std::is_trivially_copyable<BitMask<int>>::value); } } // namespace
diff --git a/quiche/quic/masque/masque_server_session.cc b/quiche/quic/masque/masque_server_session.cc index 5a5f75b..c3005cf 100644 --- a/quiche/quic/masque/masque_server_session.cc +++ b/quiche/quic/masque/masque_server_session.cc
@@ -357,7 +357,8 @@ << ") stream ID " << it->stream()->id() << " server " << expected_target_server_address; QuicUdpSocketApi socket_api; - BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS); + QuicUdpPacketInfoBitMask packet_info_interested( + {QuicUdpPacketInfoBit::PEER_ADDRESS}); char packet_buffer[1 + kMaxIncomingPacketSize]; packet_buffer[0] = 0; // context ID. char control_buffer[kDefaultUdpPacketControlBufferSize];