Commit 2812af91 by Abseil Team Committed by Copybara-Service

Avoid extra `& msbs` on every iteration over the mask for GroupPortableImpl.

PiperOrigin-RevId: 602974812
Change-Id: Ic35b41e321b9456a8ddd83470ee2eb07c51e3180
parent 0aefaf7f
...@@ -374,6 +374,9 @@ uint32_t TrailingZeros(T x) { ...@@ -374,6 +374,9 @@ uint32_t TrailingZeros(T x) {
return static_cast<uint32_t>(countr_zero(x)); return static_cast<uint32_t>(countr_zero(x));
} }
// 8 bytes bitmask with most significant bit set for every byte.
constexpr uint64_t kMsbs8Bytes = 0x8080808080808080ULL;
// An abstract bitmask, such as that emitted by a SIMD instruction. // An abstract bitmask, such as that emitted by a SIMD instruction.
// //
// Specifically, this type implements a simple bitset whose representation is // Specifically, this type implements a simple bitset whose representation is
...@@ -423,27 +426,35 @@ class NonIterableBitMask { ...@@ -423,27 +426,35 @@ class NonIterableBitMask {
// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When // an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as // `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask. // the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
// If NullifyBitsOnIteration is true (only allowed for Shift == 3),
// non zero abstract bit is allowed to have additional bits
// (e.g., `0xff`, `0x83` and `0x9c` are ok, but `0x6f` is not).
// //
// For example: // For example:
// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2 // for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3 // for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
template <class T, int SignificantBits, int Shift = 0> template <class T, int SignificantBits, int Shift = 0,
bool NullifyBitsOnIteration = false>
class BitMask : public NonIterableBitMask<T, SignificantBits, Shift> { class BitMask : public NonIterableBitMask<T, SignificantBits, Shift> {
using Base = NonIterableBitMask<T, SignificantBits, Shift>; using Base = NonIterableBitMask<T, SignificantBits, Shift>;
static_assert(std::is_unsigned<T>::value, ""); static_assert(std::is_unsigned<T>::value, "");
static_assert(Shift == 0 || Shift == 3, ""); static_assert(Shift == 0 || Shift == 3, "");
static_assert(!NullifyBitsOnIteration || Shift == 3, "");
public: public:
explicit BitMask(T mask) : Base(mask) {} explicit BitMask(T mask) : Base(mask) {
if (Shift == 3 && !NullifyBitsOnIteration) {
assert(this->mask_ == (this->mask_ & kMsbs8Bytes));
}
}
// BitMask is an iterator over the indices of its abstract bits. // BitMask is an iterator over the indices of its abstract bits.
using value_type = int; using value_type = int;
using iterator = BitMask; using iterator = BitMask;
using const_iterator = BitMask; using const_iterator = BitMask;
BitMask& operator++() { BitMask& operator++() {
if (Shift == 3) { if (Shift == 3 && NullifyBitsOnIteration) {
constexpr uint64_t msbs = 0x8080808080808080ULL; this->mask_ &= kMsbs8Bytes;
this->mask_ &= msbs;
} }
this->mask_ &= (this->mask_ - 1); this->mask_ &= (this->mask_ - 1);
return *this; return *this;
...@@ -685,10 +696,11 @@ struct GroupAArch64Impl { ...@@ -685,10 +696,11 @@ struct GroupAArch64Impl {
ctrl = vld1_u8(reinterpret_cast<const uint8_t*>(pos)); ctrl = vld1_u8(reinterpret_cast<const uint8_t*>(pos));
} }
BitMask<uint64_t, kWidth, 3> Match(h2_t hash) const { auto Match(h2_t hash) const {
uint8x8_t dup = vdup_n_u8(hash); uint8x8_t dup = vdup_n_u8(hash);
auto mask = vceq_u8(ctrl, dup); auto mask = vceq_u8(ctrl, dup);
return BitMask<uint64_t, kWidth, 3>( return BitMask<uint64_t, kWidth, /*Shift=*/3,
/*NullifyBitsOnIteration=*/true>(
vget_lane_u64(vreinterpret_u64_u8(mask), 0)); vget_lane_u64(vreinterpret_u64_u8(mask), 0));
} }
...@@ -704,12 +716,13 @@ struct GroupAArch64Impl { ...@@ -704,12 +716,13 @@ struct GroupAArch64Impl {
// Returns a bitmask representing the positions of full slots. // Returns a bitmask representing the positions of full slots.
// Note: for `is_small()` tables group may contain the "same" slot twice: // Note: for `is_small()` tables group may contain the "same" slot twice:
// original and mirrored. // original and mirrored.
BitMask<uint64_t, kWidth, 3> MaskFull() const { auto MaskFull() const {
uint64_t mask = vget_lane_u64( uint64_t mask = vget_lane_u64(
vreinterpret_u64_u8(vcge_s8(vreinterpret_s8_u8(ctrl), vreinterpret_u64_u8(vcge_s8(vreinterpret_s8_u8(ctrl),
vdup_n_s8(static_cast<int8_t>(0)))), vdup_n_s8(static_cast<int8_t>(0)))),
0); 0);
return BitMask<uint64_t, kWidth, 3>(mask); return BitMask<uint64_t, kWidth, /*Shift=*/3,
/*NullifyBitsOnIteration=*/true>(mask);
} }
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const { NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
...@@ -736,11 +749,10 @@ struct GroupAArch64Impl { ...@@ -736,11 +749,10 @@ struct GroupAArch64Impl {
void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const { void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const {
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0); uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
constexpr uint64_t msbs = 0x8080808080808080ULL;
constexpr uint64_t slsbs = 0x0202020202020202ULL; constexpr uint64_t slsbs = 0x0202020202020202ULL;
constexpr uint64_t midbs = 0x7e7e7e7e7e7e7e7eULL; constexpr uint64_t midbs = 0x7e7e7e7e7e7e7e7eULL;
auto x = slsbs & (mask >> 6); auto x = slsbs & (mask >> 6);
auto res = (x + midbs) | msbs; auto res = (x + midbs) | kMsbs8Bytes;
little_endian::Store64(dst, res); little_endian::Store64(dst, res);
} }
...@@ -768,30 +780,26 @@ struct GroupPortableImpl { ...@@ -768,30 +780,26 @@ struct GroupPortableImpl {
// v = 0x1716151413121110 // v = 0x1716151413121110
// hash = 0x12 // hash = 0x12
// retval = (v - lsbs) & ~v & msbs = 0x0000000080800000 // retval = (v - lsbs) & ~v & msbs = 0x0000000080800000
constexpr uint64_t msbs = 0x8080808080808080ULL;
constexpr uint64_t lsbs = 0x0101010101010101ULL; constexpr uint64_t lsbs = 0x0101010101010101ULL;
auto x = ctrl ^ (lsbs * hash); auto x = ctrl ^ (lsbs * hash);
return BitMask<uint64_t, kWidth, 3>((x - lsbs) & ~x & msbs); return BitMask<uint64_t, kWidth, 3>((x - lsbs) & ~x & kMsbs8Bytes);
} }
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const { NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & ~(ctrl << 6)) & return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & ~(ctrl << 6)) &
msbs); kMsbs8Bytes);
} }
// Returns a bitmask representing the positions of full slots. // Returns a bitmask representing the positions of full slots.
// Note: for `is_small()` tables group may contain the "same" slot twice: // Note: for `is_small()` tables group may contain the "same" slot twice:
// original and mirrored. // original and mirrored.
BitMask<uint64_t, kWidth, 3> MaskFull() const { BitMask<uint64_t, kWidth, 3> MaskFull() const {
constexpr uint64_t msbs = 0x8080808080808080ULL; return BitMask<uint64_t, kWidth, 3>((ctrl ^ kMsbs8Bytes) & kMsbs8Bytes);
return BitMask<uint64_t, kWidth, 3>((ctrl ^ msbs) & msbs);
} }
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const { NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & ~(ctrl << 7)) & return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & ~(ctrl << 7)) &
msbs); kMsbs8Bytes);
} }
uint32_t CountLeadingEmptyOrDeleted() const { uint32_t CountLeadingEmptyOrDeleted() const {
...@@ -803,9 +811,8 @@ struct GroupPortableImpl { ...@@ -803,9 +811,8 @@ struct GroupPortableImpl {
} }
void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const { void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const {
constexpr uint64_t msbs = 0x8080808080808080ULL;
constexpr uint64_t lsbs = 0x0101010101010101ULL; constexpr uint64_t lsbs = 0x0101010101010101ULL;
auto x = ctrl & msbs; auto x = ctrl & kMsbs8Bytes;
auto res = (~x + (x >> 7)) & ~lsbs; auto res = (~x + (x >> 7)) & ~lsbs;
little_endian::Store64(dst, res); little_endian::Store64(dst, res);
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "absl/container/internal/raw_hash_set.h" #include "absl/container/internal/raw_hash_set.h"
#include <algorithm> #include <algorithm>
#include <array>
#include <atomic> #include <atomic>
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
...@@ -75,6 +76,7 @@ struct RawHashSetTestOnlyAccess { ...@@ -75,6 +76,7 @@ struct RawHashSetTestOnlyAccess {
namespace { namespace {
using ::testing::ElementsAre; using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::Eq; using ::testing::Eq;
using ::testing::Ge; using ::testing::Ge;
using ::testing::Lt; using ::testing::Lt;
...@@ -156,20 +158,66 @@ TEST(BitMask, Smoke) { ...@@ -156,20 +158,66 @@ TEST(BitMask, Smoke) {
EXPECT_THAT((BitMask<uint8_t, 8>(0xAA)), ElementsAre(1, 3, 5, 7)); EXPECT_THAT((BitMask<uint8_t, 8>(0xAA)), ElementsAre(1, 3, 5, 7));
} }
TEST(BitMask, WithShift) { TEST(BitMask, WithShift_MatchPortable) {
// See the non-SSE version of Group for details on what this math is for. // See the non-SSE version of Group for details on what this math is for.
uint64_t ctrl = 0x1716151413121110; uint64_t ctrl = 0x1716151413121110;
uint64_t hash = 0x12; uint64_t hash = 0x12;
constexpr uint64_t msbs = 0x8080808080808080ULL;
constexpr uint64_t lsbs = 0x0101010101010101ULL; constexpr uint64_t lsbs = 0x0101010101010101ULL;
auto x = ctrl ^ (lsbs * hash); auto x = ctrl ^ (lsbs * hash);
uint64_t mask = (x - lsbs) & ~x & msbs; uint64_t mask = (x - lsbs) & ~x & kMsbs8Bytes;
EXPECT_EQ(0x0000000080800000, mask); EXPECT_EQ(0x0000000080800000, mask);
BitMask<uint64_t, 8, 3> b(mask); BitMask<uint64_t, 8, 3> b(mask);
EXPECT_EQ(*b, 2); EXPECT_EQ(*b, 2);
} }
constexpr uint64_t kSome8BytesMask = /* */ 0x8000808080008000ULL;
constexpr uint64_t kSome8BytesMaskAllOnes = 0xff00ffffff00ff00ULL;
constexpr auto kSome8BytesMaskBits = std::array<int, 5>{1, 3, 4, 5, 7};
TEST(BitMask, WithShift_FullMask) {
EXPECT_THAT((BitMask<uint64_t, 8, 3>(kMsbs8Bytes)),
ElementsAre(0, 1, 2, 3, 4, 5, 6, 7));
EXPECT_THAT(
(BitMask<uint64_t, 8, 3, /*NullifyBitsOnIteration=*/true>(kMsbs8Bytes)),
ElementsAre(0, 1, 2, 3, 4, 5, 6, 7));
EXPECT_THAT(
(BitMask<uint64_t, 8, 3, /*NullifyBitsOnIteration=*/true>(~uint64_t{0})),
ElementsAre(0, 1, 2, 3, 4, 5, 6, 7));
}
TEST(BitMask, WithShift_EmptyMask) {
EXPECT_THAT((BitMask<uint64_t, 8, 3>(0)), ElementsAre());
EXPECT_THAT((BitMask<uint64_t, 8, 3, /*NullifyBitsOnIteration=*/true>(0)),
ElementsAre());
}
TEST(BitMask, WithShift_SomeMask) {
EXPECT_THAT((BitMask<uint64_t, 8, 3>(kSome8BytesMask)),
ElementsAreArray(kSome8BytesMaskBits));
EXPECT_THAT((BitMask<uint64_t, 8, 3, /*NullifyBitsOnIteration=*/true>(
kSome8BytesMask)),
ElementsAreArray(kSome8BytesMaskBits));
EXPECT_THAT((BitMask<uint64_t, 8, 3, /*NullifyBitsOnIteration=*/true>(
kSome8BytesMaskAllOnes)),
ElementsAreArray(kSome8BytesMaskBits));
}
TEST(BitMask, WithShift_SomeMaskExtraBitsForNullify) {
// Verify that adding extra bits into non zero bytes is fine.
uint64_t extra_bits = 77;
for (int i = 0; i < 100; ++i) {
// Add extra bits, but keep zero bytes untouched.
uint64_t extra_mask = extra_bits & kSome8BytesMaskAllOnes;
EXPECT_THAT((BitMask<uint64_t, 8, 3, /*NullifyBitsOnIteration=*/true>(
kSome8BytesMask | extra_mask)),
ElementsAreArray(kSome8BytesMaskBits))
<< i << " " << extra_mask;
extra_bits = (extra_bits + 1) * 3;
}
}
TEST(BitMask, LeadingTrailing) { TEST(BitMask, LeadingTrailing) {
EXPECT_EQ((BitMask<uint32_t, 16>(0x00001a40).LeadingZeros()), 3); EXPECT_EQ((BitMask<uint32_t, 16>(0x00001a40).LeadingZeros()), 3);
EXPECT_EQ((BitMask<uint32_t, 16>(0x00001a40).TrailingZeros()), 6); EXPECT_EQ((BitMask<uint32_t, 16>(0x00001a40).TrailingZeros()), 6);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment