Commit 67491a41 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Enable passing `absl::Span<bool>` and `absl::Span<const bool>`

PiperOrigin-RevId: 590290022
parent f37d4455
...@@ -387,7 +387,7 @@ namespace internal { ...@@ -387,7 +387,7 @@ namespace internal {
template <typename T> template <typename T>
static constexpr bool is_buffer_interface_compatible_type = static constexpr bool is_buffer_interface_compatible_type =
detail::is_same_ignoring_cvref<T, PyObject*>::value || detail::is_same_ignoring_cvref<T, PyObject*>::value ||
std::is_arithmetic<T>::value || std::is_arithmetic<std::remove_cv_t<T>>::value ||
std::is_same<T, std::complex<float>>::value || std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value; std::is_same<T, std::complex<double>>::value;
...@@ -405,7 +405,8 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) { ...@@ -405,7 +405,8 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) { if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) {
auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); }); auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); });
if (view.ndim == 1 && view.strides[0] == sizeof(T) && if (view.ndim == 1 && view.strides[0] == sizeof(T) &&
buffer_info(&view, /*ownview=*/false).item_type_is_equivalent_to<T>()) { buffer_info(&view, /*ownview=*/false)
.item_type_is_equivalent_to<std::remove_cv_t<T>>()) {
return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])}; return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])};
} }
} else { } else {
...@@ -421,6 +422,29 @@ constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle /*src*/) { ...@@ -421,6 +422,29 @@ constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle /*src*/) {
return {false, absl::Span<T>()}; return {false, absl::Span<T>()};
} }
template <typename T,
typename std::enable_if<
!std::is_same<std::remove_cv_t<T>, bool>::value, int>::type = 0>
std::tuple<bool, absl::Span<T>> LoadSpanOpaqueVector(handle src) {
// Attempt to unwrap an opaque std::vector.
using value_type = std::remove_cv_t<T>;
type_caster_base<std::vector<value_type>> caster;
if (caster.load(src, false)) {
return {true,
absl::MakeSpan(static_cast<std::vector<value_type>&>(caster))};
}
return {false, absl::Span<T>()};
}
template <typename T,
typename std::enable_if<
std::is_same<std::remove_cv_t<T>, bool>::value, int>::type = 0>
std::tuple<bool, absl::Span<T>> LoadSpanOpaqueVector(handle src) {
// std::vector<bool> is special and cannot directly be converted to a Span
// (see https://en.cppreference.com/w/cpp/container/vector_bool).
return {false, absl::Span<T>()};
}
// Helper to determine whether T is a span. // Helper to determine whether T is a span.
template <typename T> template <typename T>
struct is_absl_span : std::false_type {}; struct is_absl_span : std::false_type {};
...@@ -433,7 +457,7 @@ template <typename T> ...@@ -433,7 +457,7 @@ template <typename T>
struct type_caster<absl::Span<T>> { struct type_caster<absl::Span<T>> {
public: public:
// The type referenced by the span, with const removed. // The type referenced by the span, with const removed.
using value_type = typename std::remove_cv<T>::type; using value_type = std::remove_cv_t<T>;
static_assert(!is_absl_span<value_type>::value, static_assert(!is_absl_span<value_type>::value,
"Nested absl spans are not supported."); "Nested absl spans are not supported.");
...@@ -479,19 +503,17 @@ struct type_caster<absl::Span<T>> { ...@@ -479,19 +503,17 @@ struct type_caster<absl::Span<T>> {
std::tie(loaded, value_) = LoadSpanFromBuffer<T>(src); std::tie(loaded, value_) = LoadSpanFromBuffer<T>(src);
if (loaded) return true; if (loaded) return true;
// Attempt to unwrap an opaque std::vector. std::tie(loaded, value_) = LoadSpanOpaqueVector<T>(src);
type_caster_base<std::vector<value_type>> caster; if (loaded) return true;
if (caster.load(src, false)) {
value_ = get_value(caster);
return true;
}
// Attempt to convert a native sequence. If the is_base_of_v check passes, // Attempt to convert a native sequence. If the is_base_of check passes,
// the elements do not require converting and pointers do not reference a // the elements do not require converting and pointers do not reference a
// temporary object owned by the element caster. Pointers to converted // temporary object owned by the element caster. Pointers to converted
// types are not allowed because they would result a dangling reference // types are not allowed because they would result a dangling reference
// when the element caster is destroyed. // when the element caster is destroyed.
if (convert && std::is_const<T>::value && if (convert && std::is_const<T>::value &&
// See comment for ephemeral_storage_type below.
!std::is_same<T, const bool>::value &&
(!std::is_pointer<T>::value || (!std::is_pointer<T>::value ||
std::is_base_of<type_caster_generic, make_caster<T>>::value)) { std::is_base_of<type_caster_generic, make_caster<T>>::value)) {
list_caster_.emplace(); list_caster_.emplace();
...@@ -512,12 +534,28 @@ struct type_caster<absl::Span<T>> { ...@@ -512,12 +534,28 @@ struct type_caster<absl::Span<T>> {
} }
private: private:
template <typename Caster> // Unfortunately using std::vector as ephemeral_storage_type creates
// complications for std::vector<bool>
// (https://en.cppreference.com/w/cpp/container/vector_bool).
using ephemeral_storage_type = std::vector<value_type>;
template <
typename Caster, typename VT = value_type,
typename std::enable_if<!std::is_same<VT, bool>::value, int>::type = 0>
absl::Span<T> get_value(Caster& caster) { absl::Span<T> get_value(Caster& caster) {
return absl::MakeSpan(static_cast<std::vector<value_type>&>(caster)); return absl::MakeSpan(static_cast<ephemeral_storage_type&>(caster));
}
// This template specialization is needed to avoid compilation errors.
// The conditions in load() make this code unreachable.
template <
typename Caster, typename VT = value_type,
typename std::enable_if<std::is_same<VT, bool>::value, int>::type = 0>
absl::Span<T> get_value(Caster&) {
throw std::runtime_error("Expected to be unreachable.");
} }
using ListCaster = list_caster<std::vector<value_type>, value_type>; using ListCaster = list_caster<ephemeral_storage_type, value_type>;
absl::optional<ListCaster> list_caster_; absl::optional<ListCaster> list_caster_;
absl::Span<T> value_; absl::Span<T> value_;
}; };
......
...@@ -279,6 +279,18 @@ std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) { ...@@ -279,6 +279,18 @@ std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
return result; return result;
} }
std::string PassSpanBool(absl::Span<bool> input_span) {
std::string result;
for (const auto& i : input_span) result += (i ? "t" : "f");
return result;
}
std::string PassSpanConstBool(absl::Span<const bool> input_span) {
std::string result;
for (const auto& i : input_span) result += (i ? "T" : "F");
return result;
}
struct ObjectForSpan { struct ObjectForSpan {
explicit ObjectForSpan(int v) : value(v) {} explicit ObjectForSpan(int v) : value(v) {}
int value; int value;
...@@ -404,6 +416,8 @@ PYBIND11_MODULE(absl_example, m) { ...@@ -404,6 +416,8 @@ PYBIND11_MODULE(absl_example, m) {
m.def("sum_span_const_complex128", m.def("sum_span_const_complex128",
&SumSpanComplex<const std::complex<double>>, arg("input_span")); &SumSpanComplex<const std::complex<double>>, arg("input_span"));
m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span")); m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span"));
m.def("pass_span_bool", &PassSpanBool, arg("span"));
m.def("pass_span_const_bool", &PassSpanConstBool, arg("span"));
// Span of objects. // Span of objects.
class_<ObjectForSpan>(m, "ObjectForSpan") class_<ObjectForSpan>(m, "ObjectForSpan")
......
...@@ -312,7 +312,7 @@ def make_read_only_numpy_array(): ...@@ -312,7 +312,7 @@ def make_read_only_numpy_array():
return values return values
def make_srided_numpy_array(stride): def make_strided_numpy_array(stride):
return np.zeros(10, dtype=np.int32)[::stride] return np.zeros(10, dtype=np.int32)[::stride]
...@@ -373,10 +373,10 @@ class AbslNumericSpanTest(parameterized.TestCase): ...@@ -373,10 +373,10 @@ class AbslNumericSpanTest(parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
('float_numpy', np.zeros(5, dtype=float)), ('float_numpy', np.zeros(5, dtype=float)),
('two_d_numpy', np.zeros( ('two_d_numpy', np.zeros((5, 5), dtype=np.int32)),
(5, 5), dtype=np.int32)), ('read_only', make_read_only_numpy_array()), ('read_only', make_read_only_numpy_array()),
('strided_skip', make_srided_numpy_array(2)), ('strided_skip', make_strided_numpy_array(2)),
('strided_reverse', make_srided_numpy_array(-1)), ('strided_reverse', make_strided_numpy_array(-1)),
('non_supported_type', np.zeros(5, dtype=np.unicode_)), ('non_supported_type', np.zeros(5, dtype=np.unicode_)),
('native_list', [0] * 5)) ('native_list', [0] * 5))
def test_fill_span_fails_from(self, values): def test_fill_span_fails_from(self, values):
...@@ -397,6 +397,28 @@ class AbslNumericSpanTest(parameterized.TestCase): ...@@ -397,6 +397,28 @@ class AbslNumericSpanTest(parameterized.TestCase):
arr = np.array([-3, 'four', 5.0], dtype=object) arr = np.array([-3, 'four', 5.0], dtype=object)
self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0') self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0')
@parameterized.parameters(
([], ''),
([False], 'f'),
([True], 't'),
([False, True, True, False], 'fttf'),
)
def test_pass_span_bool(self, bools, expected):
arr = np.array(bools, dtype=bool)
s = absl_example.pass_span_bool(arr)
self.assertEqual(s, expected)
@parameterized.parameters(
([], ''),
([False], 'F'),
([True], 'T'),
([False, True, True, False], 'FTTF'),
)
def test_pass_span_const_bool(self, bools, expected):
arr = np.array(bools, dtype=bool)
s = absl_example.pass_span_const_bool(arr)
self.assertEqual(s, expected)
def make_native_list_of_objects(): def make_native_list_of_objects():
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)] return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment