Commit 67491a41 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Enable passing `absl::Span<bool>` and `absl::Span<const bool>`

PiperOrigin-RevId: 590290022
parent f37d4455
......@@ -387,7 +387,7 @@ namespace internal {
template <typename T>
static constexpr bool is_buffer_interface_compatible_type =
detail::is_same_ignoring_cvref<T, PyObject*>::value ||
std::is_arithmetic<T>::value ||
std::is_arithmetic<std::remove_cv_t<T>>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value;
......@@ -405,7 +405,8 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) {
auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); });
if (view.ndim == 1 && view.strides[0] == sizeof(T) &&
buffer_info(&view, /*ownview=*/false).item_type_is_equivalent_to<T>()) {
buffer_info(&view, /*ownview=*/false)
.item_type_is_equivalent_to<std::remove_cv_t<T>>()) {
return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])};
}
} else {
......@@ -421,6 +422,29 @@ constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle /*src*/) {
return {false, absl::Span<T>()};
}
template <typename T,
typename std::enable_if<
!std::is_same<std::remove_cv_t<T>, bool>::value, int>::type = 0>
std::tuple<bool, absl::Span<T>> LoadSpanOpaqueVector(handle src) {
// Attempt to unwrap an opaque std::vector.
using value_type = std::remove_cv_t<T>;
type_caster_base<std::vector<value_type>> caster;
if (caster.load(src, false)) {
return {true,
absl::MakeSpan(static_cast<std::vector<value_type>&>(caster))};
}
return {false, absl::Span<T>()};
}
template <typename T,
typename std::enable_if<
std::is_same<std::remove_cv_t<T>, bool>::value, int>::type = 0>
std::tuple<bool, absl::Span<T>> LoadSpanOpaqueVector(handle src) {
// std::vector<bool> is special and cannot directly be converted to a Span
// (see https://en.cppreference.com/w/cpp/container/vector_bool).
return {false, absl::Span<T>()};
}
// Helper to determine whether T is a span.
template <typename T>
struct is_absl_span : std::false_type {};
......@@ -433,7 +457,7 @@ template <typename T>
struct type_caster<absl::Span<T>> {
public:
// The type referenced by the span, with const removed.
using value_type = typename std::remove_cv<T>::type;
using value_type = std::remove_cv_t<T>;
static_assert(!is_absl_span<value_type>::value,
"Nested absl spans are not supported.");
......@@ -479,19 +503,17 @@ struct type_caster<absl::Span<T>> {
std::tie(loaded, value_) = LoadSpanFromBuffer<T>(src);
if (loaded) return true;
// Attempt to unwrap an opaque std::vector.
type_caster_base<std::vector<value_type>> caster;
if (caster.load(src, false)) {
value_ = get_value(caster);
return true;
}
std::tie(loaded, value_) = LoadSpanOpaqueVector<T>(src);
if (loaded) return true;
// Attempt to convert a native sequence. If the is_base_of_v check passes,
// Attempt to convert a native sequence. If the is_base_of check passes,
// the elements do not require converting and pointers do not reference a
// temporary object owned by the element caster. Pointers to converted
// types are not allowed because they would result a dangling reference
// when the element caster is destroyed.
if (convert && std::is_const<T>::value &&
// See comment for ephemeral_storage_type below.
!std::is_same<T, const bool>::value &&
(!std::is_pointer<T>::value ||
std::is_base_of<type_caster_generic, make_caster<T>>::value)) {
list_caster_.emplace();
......@@ -512,12 +534,28 @@ struct type_caster<absl::Span<T>> {
}
private:
template <typename Caster>
// Unfortunately using std::vector as ephemeral_storage_type creates
// complications for std::vector<bool>
// (https://en.cppreference.com/w/cpp/container/vector_bool).
using ephemeral_storage_type = std::vector<value_type>;
template <
typename Caster, typename VT = value_type,
typename std::enable_if<!std::is_same<VT, bool>::value, int>::type = 0>
absl::Span<T> get_value(Caster& caster) {
return absl::MakeSpan(static_cast<std::vector<value_type>&>(caster));
return absl::MakeSpan(static_cast<ephemeral_storage_type&>(caster));
}
// This template specialization is needed to avoid compilation errors.
// The conditions in load() make this code unreachable.
template <
typename Caster, typename VT = value_type,
typename std::enable_if<std::is_same<VT, bool>::value, int>::type = 0>
absl::Span<T> get_value(Caster&) {
throw std::runtime_error("Expected to be unreachable.");
}
using ListCaster = list_caster<std::vector<value_type>, value_type>;
using ListCaster = list_caster<ephemeral_storage_type, value_type>;
absl::optional<ListCaster> list_caster_;
absl::Span<T> value_;
};
......
......@@ -279,6 +279,18 @@ std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
return result;
}
std::string PassSpanBool(absl::Span<bool> input_span) {
std::string result;
for (const auto& i : input_span) result += (i ? "t" : "f");
return result;
}
std::string PassSpanConstBool(absl::Span<const bool> input_span) {
std::string result;
for (const auto& i : input_span) result += (i ? "T" : "F");
return result;
}
struct ObjectForSpan {
explicit ObjectForSpan(int v) : value(v) {}
int value;
......@@ -404,6 +416,8 @@ PYBIND11_MODULE(absl_example, m) {
m.def("sum_span_const_complex128",
&SumSpanComplex<const std::complex<double>>, arg("input_span"));
m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span"));
m.def("pass_span_bool", &PassSpanBool, arg("span"));
m.def("pass_span_const_bool", &PassSpanConstBool, arg("span"));
// Span of objects.
class_<ObjectForSpan>(m, "ObjectForSpan")
......
......@@ -312,7 +312,7 @@ def make_read_only_numpy_array():
return values
def make_srided_numpy_array(stride):
def make_strided_numpy_array(stride):
return np.zeros(10, dtype=np.int32)[::stride]
......@@ -373,10 +373,10 @@ class AbslNumericSpanTest(parameterized.TestCase):
@parameterized.named_parameters(
('float_numpy', np.zeros(5, dtype=float)),
('two_d_numpy', np.zeros(
(5, 5), dtype=np.int32)), ('read_only', make_read_only_numpy_array()),
('strided_skip', make_srided_numpy_array(2)),
('strided_reverse', make_srided_numpy_array(-1)),
('two_d_numpy', np.zeros((5, 5), dtype=np.int32)),
('read_only', make_read_only_numpy_array()),
('strided_skip', make_strided_numpy_array(2)),
('strided_reverse', make_strided_numpy_array(-1)),
('non_supported_type', np.zeros(5, dtype=np.unicode_)),
('native_list', [0] * 5))
def test_fill_span_fails_from(self, values):
......@@ -397,6 +397,28 @@ class AbslNumericSpanTest(parameterized.TestCase):
arr = np.array([-3, 'four', 5.0], dtype=object)
self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0')
@parameterized.parameters(
([], ''),
([False], 'f'),
([True], 't'),
([False, True, True, False], 'fttf'),
)
def test_pass_span_bool(self, bools, expected):
arr = np.array(bools, dtype=bool)
s = absl_example.pass_span_bool(arr)
self.assertEqual(s, expected)
@parameterized.parameters(
([], ''),
([False], 'F'),
([True], 'T'),
([False, True, True, False], 'FTTF'),
)
def test_pass_span_const_bool(self, bools, expected):
arr = np.array(bools, dtype=bool)
s = absl_example.pass_span_const_bool(arr)
self.assertEqual(s, expected)
def make_native_list_of_objects():
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment