Commit 57401927 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `absl::Span` `std::complex` support in absl_casters.h.

PiperOrigin-RevId: 532576894
parent 64a813b5
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <datetime.h> // Python datetime builtin. #include <datetime.h> // Python datetime builtin.
#include <cmath> #include <cmath>
#include <complex>
#include <cstdint> #include <cstdint>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
...@@ -385,10 +386,38 @@ template <> ...@@ -385,10 +386,38 @@ template <>
struct type_caster<absl::CivilYear> struct type_caster<absl::CivilYear>
: public absl_civil_date_caster<absl::CivilYear> {}; : public absl_civil_date_caster<absl::CivilYear> {};
// Using internal namespace to avoid name collisons in case this code is
// accepted upsteam (pybind11).
namespace internal {
template <typename T>
static constexpr bool is_buffer_interface_compatible_type =
std::is_arithmetic<T>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value;
template <typename T, typename SFINAE = void>
struct format_descriptor_char2 {
static constexpr const char c = '\0';
};
template <typename T>
struct format_descriptor_char2<std::complex<T>> : format_descriptor<T> {};
template <typename T>
inline bool buffer_view_matches_format_descriptor(const char* view_format) {
return view_format[0] == format_descriptor<T>::c ||
(view_format[0] == 'Z' &&
view_format[1] == format_descriptor_char2<T>::c);
}
} // namespace internal
// Returns {true, a span referencing the data contained by src} without copying // Returns {true, a span referencing the data contained by src} without copying
// or converting the data if possible. Otherwise returns {false, an empty span}. // or converting the data if possible. Otherwise returns {false, an empty span}.
template <typename T, typename std::enable_if<std::is_arithmetic<T>::value, template <typename T, typename std::enable_if<
bool>::type = true> internal::is_buffer_interface_compatible_type<T>,
bool>::type = true>
std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) { std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
Py_buffer view; Py_buffer view;
int flags = PyBUF_STRIDES | PyBUF_FORMAT; int flags = PyBUF_STRIDES | PyBUF_FORMAT;
...@@ -396,7 +425,7 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) { ...@@ -396,7 +425,7 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) { if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) {
auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); }); auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); });
if (view.ndim == 1 && view.strides[0] == sizeof(T) && if (view.ndim == 1 && view.strides[0] == sizeof(T) &&
view.format[0] == format_descriptor<T>::c) { internal::buffer_view_matches_format_descriptor<T>(view.format)) {
return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])}; return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])};
} }
} else { } else {
...@@ -405,9 +434,9 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) { ...@@ -405,9 +434,9 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
} }
return {false, absl::Span<T>()}; return {false, absl::Span<T>()};
} }
// If T is not a numeric type, the buffer interface cannot be used. template <typename T, typename std::enable_if<
template <typename T, typename std::enable_if<!std::is_arithmetic<T>::value, !internal::is_buffer_interface_compatible_type<T>,
bool>::type = true> bool>::type = true>
constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) { constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
return {false, absl::Span<T>()}; return {false, absl::Span<T>()};
} }
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
// All rights reserved. Use of this source code is governed by a // All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file. // BSD-style license that can be found in the LICENSE file.
#include <pybind11/complex.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <pybind11/stl_bind.h> #include <pybind11/stl_bind.h>
#include <complex>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
...@@ -263,6 +265,14 @@ void FillSpan(int value, absl::Span<int> output_span) { ...@@ -263,6 +265,14 @@ void FillSpan(int value, absl::Span<int> output_span) {
for (auto& i : output_span) i = value; for (auto& i : output_span) i = value;
} }
template <typename CmplxType, typename NonConstCmplxType =
typename std::remove_const<CmplxType>::type>
NonConstCmplxType SumSpanComplex(absl::Span<CmplxType> input_span) {
NonConstCmplxType sum = 0;
for (auto& i : input_span) sum += i;
return sum;
}
struct ObjectForSpan { struct ObjectForSpan {
explicit ObjectForSpan(int v) : value(v) {} explicit ObjectForSpan(int v) : value(v) {}
int value; int value;
...@@ -382,6 +392,11 @@ PYBIND11_MODULE(absl_example, m) { ...@@ -382,6 +392,11 @@ PYBIND11_MODULE(absl_example, m) {
// Non-const spans can never be converted, so `output_span` could be marked as // Non-const spans can never be converted, so `output_span` could be marked as
// `noconvert`, but that would be redundant (so test that it is not needed). // `noconvert`, but that would be redundant (so test that it is not needed).
m.def("fill_span", &FillSpan, arg("value"), arg("output_span")); m.def("fill_span", &FillSpan, arg("value"), arg("output_span"));
m.def("sum_span_complex64", &SumSpanComplex<std::complex<float>>);
m.def("sum_span_const_complex64", &SumSpanComplex<const std::complex<float>>);
m.def("sum_span_complex128", &SumSpanComplex<std::complex<double>>);
m.def("sum_span_const_complex128",
&SumSpanComplex<const std::complex<double>>, arg("input_span"));
// Span of objects. // Span of objects.
class_<ObjectForSpan>(m, "ObjectForSpan") class_<ObjectForSpan>(m, "ObjectForSpan")
......
...@@ -368,6 +368,16 @@ class AbslNumericSpanTest(parameterized.TestCase): ...@@ -368,6 +368,16 @@ class AbslNumericSpanTest(parameterized.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
absl_example.fill_span(42, values) absl_example.fill_span(42, values)
@parameterized.parameters(
('complex64', absl_example.sum_span_complex64),
('complex64', absl_example.sum_span_const_complex64),
('complex128', absl_example.sum_span_complex128),
('complex128', absl_example.sum_span_const_complex128),
)
def test_complex(self, numpy_type, sum_span_fn):
xs = np.array([x * 1j for x in range(10)]).astype(numpy_type)
self.assertEqual(sum_span_fn(xs), 45j)
def make_native_list_of_objects(): def make_native_list_of_objects():
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)] return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment