Commit 2a18ba75 by Evan Brown Committed by Copybara-Service

Add sanitizer mode checks that element constructors/destructors don't make…

Add sanitizer mode checks that element constructors/destructors don't make reentrant calls to raw_hash_set member functions.

PiperOrigin-RevId: 573897598
Change-Id: If40c23ac3cd9fff315ee18774e27c480cbca3a81
parent d368d3d6
...@@ -686,8 +686,10 @@ cc_test( ...@@ -686,8 +686,10 @@ cc_test(
"//absl/base:config", "//absl/base:config",
"//absl/base:core_headers", "//absl/base:core_headers",
"//absl/base:prefetch", "//absl/base:prefetch",
"//absl/functional:function_ref",
"//absl/log", "//absl/log",
"//absl/strings", "//absl/strings",
"//absl/types:optional",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],
......
...@@ -743,10 +743,12 @@ absl_cc_test( ...@@ -743,10 +743,12 @@ absl_cc_test(
absl::core_headers absl::core_headers
absl::flat_hash_map absl::flat_hash_map
absl::flat_hash_set absl::flat_hash_set
absl::function_ref
absl::hash_function_defaults absl::hash_function_defaults
absl::hash_policy_testing absl::hash_policy_testing
absl::hashtable_debug absl::hashtable_debug
absl::log absl::log
absl::optional
absl::prefetch absl::prefetch
absl::raw_hash_set absl::raw_hash_set
absl::strings absl::strings
......
...@@ -249,6 +249,14 @@ inline void SanitizerUnpoisonObject(const T* object) { ...@@ -249,6 +249,14 @@ inline void SanitizerUnpoisonObject(const T* object) {
SanitizerUnpoisonMemoryRegion(object, sizeof(T)); SanitizerUnpoisonMemoryRegion(object, sizeof(T));
} }
template <typename Container, typename Alloc, typename F>
void RunWithReentrancyGuard(Container& c, Alloc& a, F f) {
SanitizerPoisonObject(&c);
if (!std::is_empty<Alloc>()) SanitizerUnpoisonObject(&a);
f();
SanitizerUnpoisonObject(&c);
}
namespace memory_internal { namespace memory_internal {
// If Pair is a standard-layout type, OffsetOf<Pair>::kFirst and // If Pair is a standard-layout type, OffsetOf<Pair>::kFirst and
......
...@@ -2143,7 +2143,7 @@ class raw_hash_set { ...@@ -2143,7 +2143,7 @@ class raw_hash_set {
alignas(slot_type) unsigned char raw[sizeof(slot_type)]; alignas(slot_type) unsigned char raw[sizeof(slot_type)];
slot_type* slot = reinterpret_cast<slot_type*>(&raw); slot_type* slot = reinterpret_cast<slot_type*>(&raw);
PolicyTraits::construct(&alloc_ref(), slot, std::forward<Args>(args)...); construct(slot, std::forward<Args>(args)...);
const auto& elem = PolicyTraits::element(slot); const auto& elem = PolicyTraits::element(slot);
return PolicyTraits::apply(InsertSlot<true>{*this, std::move(*slot)}, elem); return PolicyTraits::apply(InsertSlot<true>{*this, std::move(*slot)}, elem);
} }
...@@ -2248,7 +2248,7 @@ class raw_hash_set { ...@@ -2248,7 +2248,7 @@ class raw_hash_set {
// a better match if non-const iterator is passed as an argument. // a better match if non-const iterator is passed as an argument.
void erase(iterator it) { void erase(iterator it) {
AssertIsFull(it.ctrl_, it.generation(), it.generation_ptr(), "erase()"); AssertIsFull(it.ctrl_, it.generation(), it.generation_ptr(), "erase()");
PolicyTraits::destroy(&alloc_ref(), it.slot_); destroy(it.slot_);
erase_meta_only(it); erase_meta_only(it);
} }
...@@ -2541,10 +2541,9 @@ class raw_hash_set { ...@@ -2541,10 +2541,9 @@ class raw_hash_set {
std::pair<iterator, bool> operator()(const K& key, Args&&...) && { std::pair<iterator, bool> operator()(const K& key, Args&&...) && {
auto res = s.find_or_prepare_insert(key); auto res = s.find_or_prepare_insert(key);
if (res.second) { if (res.second) {
PolicyTraits::transfer(&s.alloc_ref(), s.slot_array() + res.first, s.transfer(s.slot_array() + res.first, &slot);
&slot);
} else if (do_destroy) { } else if (do_destroy) {
PolicyTraits::destroy(&s.alloc_ref(), &slot); s.destroy(&slot);
} }
return {s.iterator_at(res.first), res.second}; return {s.iterator_at(res.first), res.second};
} }
...@@ -2553,13 +2552,31 @@ class raw_hash_set { ...@@ -2553,13 +2552,31 @@ class raw_hash_set {
slot_type&& slot; slot_type&& slot;
}; };
// Helpers to enable sanitizer mode validation to protect against reentrant
// calls during element constructor/destructor.
template <typename... Args>
inline void construct(slot_type* slot, Args&&... args) {
RunWithReentrancyGuard(*this, alloc_ref(), [&] {
PolicyTraits::construct(&alloc_ref(), slot, std::forward<Args>(args)...);
});
}
inline void destroy(slot_type* slot) {
RunWithReentrancyGuard(*this, alloc_ref(),
[&] { PolicyTraits::destroy(&alloc_ref(), slot); });
}
inline void transfer(slot_type* to, slot_type* from) {
RunWithReentrancyGuard(*this, alloc_ref(), [&] {
PolicyTraits::transfer(&alloc_ref(), to, from);
});
}
inline void destroy_slots() { inline void destroy_slots() {
const size_t cap = capacity(); const size_t cap = capacity();
const ctrl_t* ctrl = control(); const ctrl_t* ctrl = control();
slot_type* slot = slot_array(); slot_type* slot = slot_array();
for (size_t i = 0; i != cap; ++i) { for (size_t i = 0; i != cap; ++i) {
if (IsFull(ctrl[i])) { if (IsFull(ctrl[i])) {
PolicyTraits::destroy(&alloc_ref(), slot + i); destroy(slot + i);
} }
} }
} }
...@@ -2622,7 +2639,7 @@ class raw_hash_set { ...@@ -2622,7 +2639,7 @@ class raw_hash_set {
size_t new_i = target.offset; size_t new_i = target.offset;
total_probe_length += target.probe_length; total_probe_length += target.probe_length;
SetCtrl(common(), new_i, H2(hash), sizeof(slot_type)); SetCtrl(common(), new_i, H2(hash), sizeof(slot_type));
PolicyTraits::transfer(&alloc_ref(), new_slots + new_i, old_slots + i); transfer(new_slots + new_i, old_slots + i);
} }
} }
if (old_capacity) { if (old_capacity) {
...@@ -2725,7 +2742,7 @@ class raw_hash_set { ...@@ -2725,7 +2742,7 @@ class raw_hash_set {
reserve(size); reserve(size);
for (iterator it = that.begin(); it != that.end(); ++it) { for (iterator it = that.begin(); it != that.end(); ++it) {
insert(std::move(PolicyTraits::element(it.slot_))); insert(std::move(PolicyTraits::element(it.slot_)));
PolicyTraits::destroy(&that.alloc_ref(), it.slot_); that.destroy(it.slot_);
} }
that.dealloc(); that.dealloc();
that.common() = CommonFields{}; that.common() = CommonFields{};
...@@ -2816,8 +2833,7 @@ class raw_hash_set { ...@@ -2816,8 +2833,7 @@ class raw_hash_set {
// POSTCONDITION: *m.iterator_at(i) == value_type(forward<Args>(args)...). // POSTCONDITION: *m.iterator_at(i) == value_type(forward<Args>(args)...).
template <class... Args> template <class... Args>
void emplace_at(size_t i, Args&&... args) { void emplace_at(size_t i, Args&&... args) {
PolicyTraits::construct(&alloc_ref(), slot_array() + i, construct(slot_array() + i, std::forward<Args>(args)...);
std::forward<Args>(args)...);
assert(PolicyTraits::apply(FindElement{*this}, *iterator_at(i)) == assert(PolicyTraits::apply(FindElement{*this}, *iterator_at(i)) ==
iterator_at(i) && iterator_at(i) &&
...@@ -2883,8 +2899,7 @@ class raw_hash_set { ...@@ -2883,8 +2899,7 @@ class raw_hash_set {
} }
static void transfer_slot_fn(void* set, void* dst, void* src) { static void transfer_slot_fn(void* set, void* dst, void* src) {
auto* h = static_cast<raw_hash_set*>(set); auto* h = static_cast<raw_hash_set*>(set);
PolicyTraits::transfer(&h->alloc_ref(), static_cast<slot_type*>(dst), h->transfer(static_cast<slot_type*>(dst), static_cast<slot_type*>(src));
static_cast<slot_type*>(src));
} }
// Note: dealloc_fn will only be used if we have a non-standard allocator. // Note: dealloc_fn will only be used if we have a non-standard allocator.
static void dealloc_fn(CommonFields& common, const PolicyFunctions&) { static void dealloc_fn(CommonFields& common, const PolicyFunctions&) {
......
...@@ -49,8 +49,10 @@ ...@@ -49,8 +49,10 @@
#include "absl/container/internal/hash_policy_testing.h" #include "absl/container/internal/hash_policy_testing.h"
#include "absl/container/internal/hashtable_debug.h" #include "absl/container/internal/hashtable_debug.h"
#include "absl/container/internal/test_allocator.h" #include "absl/container/internal/test_allocator.h"
#include "absl/functional/function_ref.h"
#include "absl/log/log.h" #include "absl/log/log.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h"
namespace absl { namespace absl {
ABSL_NAMESPACE_BEGIN ABSL_NAMESPACE_BEGIN
...@@ -409,19 +411,15 @@ struct StringTable ...@@ -409,19 +411,15 @@ struct StringTable
using Base::Base; using Base::Base;
}; };
struct IntTable template <typename T>
: raw_hash_set<IntPolicy, hash_default_hash<int64_t>, struct ValueTable : raw_hash_set<ValuePolicy<T>, hash_default_hash<T>,
std::equal_to<int64_t>, std::allocator<int64_t>> { std::equal_to<T>, std::allocator<T>> {
using Base = typename IntTable::raw_hash_set; using Base = typename ValueTable::raw_hash_set;
using Base::Base; using Base::Base;
}; };
struct Uint8Table using IntTable = ValueTable<int64_t>;
: raw_hash_set<Uint8Policy, hash_default_hash<uint8_t>, using Uint8Table = ValueTable<uint8_t>;
std::equal_to<uint8_t>, std::allocator<uint8_t>> {
using Base = typename Uint8Table::raw_hash_set;
using Base::Base;
};
template <typename T> template <typename T>
struct CustomAlloc : std::allocator<T> { struct CustomAlloc : std::allocator<T> {
...@@ -2489,6 +2487,72 @@ using RawHashSetAlloc = raw_hash_set<IntPolicy, hash_default_hash<int64_t>, ...@@ -2489,6 +2487,72 @@ using RawHashSetAlloc = raw_hash_set<IntPolicy, hash_default_hash<int64_t>,
TEST(Table, AllocatorPropagation) { TestAllocPropagation<RawHashSetAlloc>(); } TEST(Table, AllocatorPropagation) { TestAllocPropagation<RawHashSetAlloc>(); }
struct ConstructCaller {
explicit ConstructCaller(int v) : val(v) {}
ConstructCaller(int v, absl::FunctionRef<void()> func) : val(v) { func(); }
template <typename H>
friend H AbslHashValue(H h, const ConstructCaller& d) {
return H::combine(std::move(h), d.val);
}
bool operator==(const ConstructCaller& c) const { return val == c.val; }
int val;
};
struct DestroyCaller {
explicit DestroyCaller(int v) : val(v) {}
DestroyCaller(int v, absl::FunctionRef<void()> func)
: val(v), destroy_func(func) {}
DestroyCaller(DestroyCaller&& that)
: val(that.val), destroy_func(std::move(that.destroy_func)) {
that.Deactivate();
}
~DestroyCaller() {
if (destroy_func) (*destroy_func)();
}
void Deactivate() { destroy_func = absl::nullopt; }
template <typename H>
friend H AbslHashValue(H h, const DestroyCaller& d) {
return H::combine(std::move(h), d.val);
}
bool operator==(const DestroyCaller& d) const { return val == d.val; }
int val;
absl::optional<absl::FunctionRef<void()>> destroy_func;
};
#if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER)
TEST(Table, ReentrantCallsFail) {
constexpr const char* kDeathMessage =
"use-after-poison|use-of-uninitialized-value";
{
ValueTable<ConstructCaller> t;
t.insert(ConstructCaller{0});
auto erase_begin = [&] { t.erase(t.begin()); };
EXPECT_DEATH_IF_SUPPORTED(t.emplace(1, erase_begin), kDeathMessage);
}
{
ValueTable<DestroyCaller> t;
t.insert(DestroyCaller{0});
auto find_0 = [&] { t.find(DestroyCaller{0}); };
t.insert(DestroyCaller{1, find_0});
for (int i = 10; i < 20; ++i) t.insert(DestroyCaller{i});
EXPECT_DEATH_IF_SUPPORTED(t.clear(), kDeathMessage);
for (auto& elem : t) elem.Deactivate();
}
{
ValueTable<DestroyCaller> t;
t.insert(DestroyCaller{0});
auto insert_1 = [&] { t.insert(DestroyCaller{1}); };
t.insert(DestroyCaller{1, insert_1});
for (int i = 10; i < 20; ++i) t.insert(DestroyCaller{i});
EXPECT_DEATH_IF_SUPPORTED(t.clear(), kDeathMessage);
for (auto& elem : t) elem.Deactivate();
}
}
#endif
} // namespace } // namespace
} // namespace container_internal } // namespace container_internal
ABSL_NAMESPACE_END ABSL_NAMESPACE_END
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment