Commit ac267be5 by Evan Brown Committed by Copybara-Service

Add debug mode checks that element constructors/destructors don't make reentrant…

Add debug mode checks that element constructors/destructors don't make reentrant calls to raw_hash_set member functions.

PiperOrigin-RevId: 660889825
Change-Id: I02e0e364a5215431eddeeabde66531a95aa03f22
parent 9bd9a2d6
...@@ -732,6 +732,7 @@ cc_test( ...@@ -732,6 +732,7 @@ cc_test(
"//absl/memory", "//absl/memory",
"//absl/meta:type_traits", "//absl/meta:type_traits",
"//absl/strings", "//absl/strings",
"//absl/types:optional",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],
......
...@@ -791,6 +791,7 @@ absl_cc_test( ...@@ -791,6 +791,7 @@ absl_cc_test(
absl::log absl::log
absl::memory absl::memory
absl::node_hash_set absl::node_hash_set
absl::optional
absl::prefetch absl::prefetch
absl::raw_hash_set absl::raw_hash_set
absl::strings absl::strings
......
...@@ -536,6 +536,14 @@ static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2), ...@@ -536,6 +536,14 @@ static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2),
// See definition comment for why this is size 32. // See definition comment for why this is size 32.
ABSL_DLL extern const ctrl_t kEmptyGroup[32]; ABSL_DLL extern const ctrl_t kEmptyGroup[32];
// We use these sentinel capacity values in debug mode to indicate different
// classes of bugs.
enum InvalidCapacity : size_t {
kAboveMaxValidCapacity = ~size_t{} - 100,
// Used for reentrancy assertions.
kInvalidReentrance,
};
// Returns a pointer to a control byte group that can be used by empty tables. // Returns a pointer to a control byte group that can be used by empty tables.
inline ctrl_t* EmptyGroup() { inline ctrl_t* EmptyGroup() {
// Const must be cast away here; no uses of this function will actually write // Const must be cast away here; no uses of this function will actually write
...@@ -1376,7 +1384,8 @@ class CommonFields : public CommonFieldsGenerationInfo { ...@@ -1376,7 +1384,8 @@ class CommonFields : public CommonFieldsGenerationInfo {
// The total number of available slots. // The total number of available slots.
size_t capacity() const { return capacity_; } size_t capacity() const { return capacity_; }
void set_capacity(size_t c) { void set_capacity(size_t c) {
assert(c == 0 || IsValidCapacity(c)); // We allow setting above the max valid capacity for debugging purposes.
assert(c == 0 || IsValidCapacity(c) || c > kAboveMaxValidCapacity);
capacity_ = c; capacity_ = c;
} }
...@@ -1444,6 +1453,20 @@ class CommonFields : public CommonFieldsGenerationInfo { ...@@ -1444,6 +1453,20 @@ class CommonFields : public CommonFieldsGenerationInfo {
std::count(control(), control() + capacity(), ctrl_t::kDeleted)); std::count(control(), control() + capacity(), ctrl_t::kDeleted));
} }
// Helper to enable sanitizer mode validation to protect against reentrant
// calls during element constructor/destructor.
template <typename F>
void RunWithReentrancyGuard(F f) {
#ifdef NDEBUG
f();
return;
#endif
const size_t cap = capacity();
set_capacity(kInvalidReentrance);
f();
set_capacity(cap);
}
private: private:
// We store the has_infoz bit in the lowest bit of size_. // We store the has_infoz bit in the lowest bit of size_.
static constexpr size_t HasInfozShift() { return 1; } static constexpr size_t HasInfozShift() { return 1; }
...@@ -2874,6 +2897,7 @@ class raw_hash_set { ...@@ -2874,6 +2897,7 @@ class raw_hash_set {
size_t max_size() const { return (std::numeric_limits<size_t>::max)(); } size_t max_size() const { return (std::numeric_limits<size_t>::max)(); }
ABSL_ATTRIBUTE_REINITIALIZES void clear() { ABSL_ATTRIBUTE_REINITIALIZES void clear() {
AssertValidCapacity();
// Iterating over this container is O(bucket_count()). When bucket_count() // Iterating over this container is O(bucket_count()). When bucket_count()
// is much greater than size(), iteration becomes prohibitively expensive. // is much greater than size(), iteration becomes prohibitively expensive.
// For clear() it is more important to reuse the allocated array when the // For clear() it is more important to reuse the allocated array when the
...@@ -3127,6 +3151,7 @@ class raw_hash_set { ...@@ -3127,6 +3151,7 @@ class raw_hash_set {
// This overload is necessary because otherwise erase<K>(const K&) would be // This overload is necessary because otherwise erase<K>(const K&) would be
// a better match if non-const iterator is passed as an argument. // a better match if non-const iterator is passed as an argument.
void erase(iterator it) { void erase(iterator it) {
AssertValidCapacity();
AssertIsFull(it.control(), it.generation(), it.generation_ptr(), "erase()"); AssertIsFull(it.control(), it.generation(), it.generation_ptr(), "erase()");
destroy(it.slot()); destroy(it.slot());
if (is_soo()) { if (is_soo()) {
...@@ -3138,6 +3163,7 @@ class raw_hash_set { ...@@ -3138,6 +3163,7 @@ class raw_hash_set {
iterator erase(const_iterator first, iterator erase(const_iterator first,
const_iterator last) ABSL_ATTRIBUTE_LIFETIME_BOUND { const_iterator last) ABSL_ATTRIBUTE_LIFETIME_BOUND {
AssertValidCapacity();
// We check for empty first because ClearBackingArray requires that // We check for empty first because ClearBackingArray requires that
// capacity() > 0 as a precondition. // capacity() > 0 as a precondition.
if (empty()) return end(); if (empty()) return end();
...@@ -3193,6 +3219,7 @@ class raw_hash_set { ...@@ -3193,6 +3219,7 @@ class raw_hash_set {
} }
node_type extract(const_iterator position) { node_type extract(const_iterator position) {
AssertValidCapacity();
AssertIsFull(position.control(), position.inner_.generation(), AssertIsFull(position.control(), position.inner_.generation(),
position.inner_.generation_ptr(), "extract()"); position.inner_.generation_ptr(), "extract()");
auto node = CommonAccess::Transfer<node_type>(alloc_ref(), position.slot()); auto node = CommonAccess::Transfer<node_type>(alloc_ref(), position.slot());
...@@ -3325,13 +3352,13 @@ class raw_hash_set { ...@@ -3325,13 +3352,13 @@ class raw_hash_set {
template <class K = key_type> template <class K = key_type>
iterator find(const key_arg<K>& key, iterator find(const key_arg<K>& key,
size_t hash) ABSL_ATTRIBUTE_LIFETIME_BOUND { size_t hash) ABSL_ATTRIBUTE_LIFETIME_BOUND {
AssertHashEqConsistent(key); AssertOnFind(key);
if (is_soo()) return find_soo(key); if (is_soo()) return find_soo(key);
return find_non_soo(key, hash); return find_non_soo(key, hash);
} }
template <class K = key_type> template <class K = key_type>
iterator find(const key_arg<K>& key) ABSL_ATTRIBUTE_LIFETIME_BOUND { iterator find(const key_arg<K>& key) ABSL_ATTRIBUTE_LIFETIME_BOUND {
AssertHashEqConsistent(key); AssertOnFind(key);
if (is_soo()) return find_soo(key); if (is_soo()) return find_soo(key);
prefetch_heap_block(); prefetch_heap_block();
return find_non_soo(key, hash_ref()(key)); return find_non_soo(key, hash_ref()(key));
...@@ -3476,16 +3503,19 @@ class raw_hash_set { ...@@ -3476,16 +3503,19 @@ class raw_hash_set {
slot_type&& slot; slot_type&& slot;
}; };
// TODO(b/303305702): re-enable reentrant validation.
template <typename... Args> template <typename... Args>
inline void construct(slot_type* slot, Args&&... args) { inline void construct(slot_type* slot, Args&&... args) {
common().RunWithReentrancyGuard([&] {
PolicyTraits::construct(&alloc_ref(), slot, std::forward<Args>(args)...); PolicyTraits::construct(&alloc_ref(), slot, std::forward<Args>(args)...);
});
} }
inline void destroy(slot_type* slot) { inline void destroy(slot_type* slot) {
PolicyTraits::destroy(&alloc_ref(), slot); common().RunWithReentrancyGuard(
[&] { PolicyTraits::destroy(&alloc_ref(), slot); });
} }
inline void transfer(slot_type* to, slot_type* from) { inline void transfer(slot_type* to, slot_type* from) {
PolicyTraits::transfer(&alloc_ref(), to, from); common().RunWithReentrancyGuard(
[&] { PolicyTraits::transfer(&alloc_ref(), to, from); });
} }
// TODO(b/289225379): consider having a helper class that has the impls for // TODO(b/289225379): consider having a helper class that has the impls for
...@@ -3690,15 +3720,23 @@ class raw_hash_set { ...@@ -3690,15 +3720,23 @@ class raw_hash_set {
static slot_type* to_slot(void* buf) { return static_cast<slot_type*>(buf); } static slot_type* to_slot(void* buf) { return static_cast<slot_type*>(buf); }
// Requires that lhs does not have a full SOO slot. // Requires that lhs does not have a full SOO slot.
static void move_common(bool that_is_full_soo, allocator_type& rhs_alloc, static void move_common(bool rhs_is_full_soo, allocator_type& rhs_alloc,
CommonFields& lhs, CommonFields&& rhs) { CommonFields& lhs, CommonFields&& rhs) {
if (PolicyTraits::transfer_uses_memcpy() || !that_is_full_soo) { if (PolicyTraits::transfer_uses_memcpy() || !rhs_is_full_soo) {
lhs = std::move(rhs); lhs = std::move(rhs);
} else { } else {
lhs.move_non_heap_or_soo_fields(rhs); lhs.move_non_heap_or_soo_fields(rhs);
// TODO(b/303305702): add reentrancy guard. #ifndef NDEBUG
const size_t rhs_capacity = rhs.capacity();
rhs.set_capacity(kInvalidReentrance);
#endif
lhs.RunWithReentrancyGuard([&] {
PolicyTraits::transfer(&rhs_alloc, to_slot(lhs.soo_data()), PolicyTraits::transfer(&rhs_alloc, to_slot(lhs.soo_data()),
to_slot(rhs.soo_data())); to_slot(rhs.soo_data()));
});
#ifndef NDEBUG
rhs.set_capacity(rhs_capacity);
#endif
} }
} }
...@@ -3831,11 +3869,28 @@ class raw_hash_set { ...@@ -3831,11 +3869,28 @@ class raw_hash_set {
} }
protected: protected:
// Asserts for correctness that we run on find/find_or_prepare_insert.
template <class K>
void AssertOnFind(ABSL_ATTRIBUTE_UNUSED const K& key) {
#ifdef NDEBUG
return;
#endif
AssertHashEqConsistent(key);
AssertValidCapacity();
}
// Asserts that the capacity is not a sentinel invalid value.
// TODO(b/296061262): also add asserts for moved-from and destroyed states.
void AssertValidCapacity() const {
assert(capacity() != kInvalidReentrance &&
"reentrant container access during element construction/destruction "
"is not allowed.");
}
// Asserts that hash and equal functors provided by the user are consistent, // Asserts that hash and equal functors provided by the user are consistent,
// meaning that `eq(k1, k2)` implies `hash(k1)==hash(k2)`. // meaning that `eq(k1, k2)` implies `hash(k1)==hash(k2)`.
template <class K> template <class K>
void AssertHashEqConsistent(ABSL_ATTRIBUTE_UNUSED const K& key) { void AssertHashEqConsistent(const K& key) {
#ifndef NDEBUG
if (empty()) return; if (empty()) return;
const size_t hash_of_arg = hash_ref()(key); const size_t hash_of_arg = hash_ref()(key);
...@@ -3852,13 +3907,13 @@ class raw_hash_set { ...@@ -3852,13 +3907,13 @@ class raw_hash_set {
// In this case, we're going to crash. Do a couple of other checks for // In this case, we're going to crash. Do a couple of other checks for
// idempotence issues. Recalculating hash/eq here is also convenient for // idempotence issues. Recalculating hash/eq here is also convenient for
// debugging with gdb/lldb. // debugging with gdb/lldb.
const size_t once_more_hash_arg = hash_ref()(key); ABSL_ATTRIBUTE_UNUSED const size_t once_more_hash_arg = hash_ref()(key);
assert(hash_of_arg == once_more_hash_arg && "hash is not idempotent."); assert(hash_of_arg == once_more_hash_arg && "hash is not idempotent.");
const size_t once_more_hash_slot = ABSL_ATTRIBUTE_UNUSED const size_t once_more_hash_slot =
PolicyTraits::apply(HashElement{hash_ref()}, element); PolicyTraits::apply(HashElement{hash_ref()}, element);
assert(hash_of_slot == once_more_hash_slot && assert(hash_of_slot == once_more_hash_slot &&
"hash is not idempotent."); "hash is not idempotent.");
const bool once_more_eq = ABSL_ATTRIBUTE_UNUSED const bool once_more_eq =
PolicyTraits::apply(EqualElement<K>{key, eq_ref()}, element); PolicyTraits::apply(EqualElement<K>{key, eq_ref()}, element);
assert(is_key_equal == once_more_eq && "equality is not idempotent."); assert(is_key_equal == once_more_eq && "equality is not idempotent.");
} }
...@@ -3874,7 +3929,6 @@ class raw_hash_set { ...@@ -3874,7 +3929,6 @@ class raw_hash_set {
// We only do validation for small tables so that it's constant time. // We only do validation for small tables so that it's constant time.
if (capacity() > 16) return; if (capacity() > 16) return;
IterateOverFullSlots(common(), slot_array(), assert_consistent); IterateOverFullSlots(common(), slot_array(), assert_consistent);
#endif
} }
// Attempts to find `key` in the table; if it isn't found, returns an iterator // Attempts to find `key` in the table; if it isn't found, returns an iterator
...@@ -3882,7 +3936,7 @@ class raw_hash_set { ...@@ -3882,7 +3936,7 @@ class raw_hash_set {
// `key`'s H2. Returns a bool indicating whether an insertion can take place. // `key`'s H2. Returns a bool indicating whether an insertion can take place.
template <class K> template <class K>
std::pair<iterator, bool> find_or_prepare_insert(const K& key) { std::pair<iterator, bool> find_or_prepare_insert(const K& key) {
AssertHashEqConsistent(key); AssertOnFind(key);
if (is_soo()) return find_or_prepare_insert_soo(key); if (is_soo()) return find_or_prepare_insert_soo(key);
return find_or_prepare_insert_non_soo(key); return find_or_prepare_insert_non_soo(key);
} }
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h"
namespace absl { namespace absl {
ABSL_NAMESPACE_BEGIN ABSL_NAMESPACE_BEGIN
...@@ -3594,6 +3595,72 @@ TEST(Iterator, InconsistentHashEqFunctorsValidation) { ...@@ -3594,6 +3595,72 @@ TEST(Iterator, InconsistentHashEqFunctorsValidation) {
"hash/eq functors are inconsistent."); "hash/eq functors are inconsistent.");
} }
struct ConstructCaller {
explicit ConstructCaller(int v) : val(v) {}
ConstructCaller(int v, absl::FunctionRef<void()> func) : val(v) { func(); }
template <typename H>
friend H AbslHashValue(H h, const ConstructCaller& d) {
return H::combine(std::move(h), d.val);
}
bool operator==(const ConstructCaller& c) const { return val == c.val; }
int val;
};
struct DestroyCaller {
explicit DestroyCaller(int v) : val(v) {}
DestroyCaller(int v, absl::FunctionRef<void()> func)
: val(v), destroy_func(func) {}
DestroyCaller(DestroyCaller&& that)
: val(that.val), destroy_func(std::move(that.destroy_func)) {
that.Deactivate();
}
~DestroyCaller() {
if (destroy_func) (*destroy_func)();
}
void Deactivate() { destroy_func = absl::nullopt; }
template <typename H>
friend H AbslHashValue(H h, const DestroyCaller& d) {
return H::combine(std::move(h), d.val);
}
bool operator==(const DestroyCaller& d) const { return val == d.val; }
int val;
absl::optional<absl::FunctionRef<void()>> destroy_func;
};
TEST(Table, ReentrantCallsFail) {
#ifdef NDEBUG
GTEST_SKIP() << "Reentrant checks only enabled in debug mode.";
#else
{
ValueTable<ConstructCaller> t;
t.insert(ConstructCaller{0});
auto erase_begin = [&] { t.erase(t.begin()); };
EXPECT_DEATH_IF_SUPPORTED(t.emplace(1, erase_begin), "");
}
{
ValueTable<DestroyCaller> t;
t.insert(DestroyCaller{0});
auto find_0 = [&] { t.find(DestroyCaller{0}); };
t.insert(DestroyCaller{1, find_0});
for (int i = 10; i < 20; ++i) t.insert(DestroyCaller{i});
EXPECT_DEATH_IF_SUPPORTED(t.clear(), "");
for (auto& elem : t) elem.Deactivate();
}
{
ValueTable<DestroyCaller> t;
t.insert(DestroyCaller{0});
auto insert_1 = [&] { t.insert(DestroyCaller{1}); };
t.insert(DestroyCaller{1, insert_1});
for (int i = 10; i < 20; ++i) t.insert(DestroyCaller{i});
EXPECT_DEATH_IF_SUPPORTED(t.clear(), "");
for (auto& elem : t) elem.Deactivate();
}
#endif
}
} // namespace } // namespace
} // namespace container_internal } // namespace container_internal
ABSL_NAMESPACE_END ABSL_NAMESPACE_END
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment