Commit a8fed755 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `absl::Span` `PyObject*` support in absl_casters.h.

PiperOrigin-RevId: 532586991
parent 57401927
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <pybind11/cast.h> #include <pybind11/cast.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <pybind11/type_caster_pyobject_ptr.h>
// Must NOT appear before at least one pybind11 include. // Must NOT appear before at least one pybind11 include.
#include <datetime.h> // Python datetime builtin. #include <datetime.h> // Python datetime builtin.
...@@ -392,11 +393,24 @@ namespace internal { ...@@ -392,11 +393,24 @@ namespace internal {
template <typename T> template <typename T>
static constexpr bool is_buffer_interface_compatible_type = static constexpr bool is_buffer_interface_compatible_type =
detail::is_same_ignoring_cvref<T, PyObject*>::value ||
std::is_arithmetic<T>::value || std::is_arithmetic<T>::value ||
std::is_same<T, std::complex<float>>::value || std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value; std::is_same<T, std::complex<double>>::value;
template <typename T, typename SFINAE = void> template <typename T, typename SFINAE = void>
struct format_descriptor_char1 : format_descriptor<T> {};
template <typename T>
struct format_descriptor_char1<
T,
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject*>::value>> {
static constexpr const char c = 'O';
static constexpr const char value[2] = {c, '\0'};
static std::string format() { return std::string(1, c); }
};
template <typename T, typename SFINAE = void>
struct format_descriptor_char2 { struct format_descriptor_char2 {
static constexpr const char c = '\0'; static constexpr const char c = '\0';
}; };
...@@ -406,7 +420,7 @@ struct format_descriptor_char2<std::complex<T>> : format_descriptor<T> {}; ...@@ -406,7 +420,7 @@ struct format_descriptor_char2<std::complex<T>> : format_descriptor<T> {};
template <typename T> template <typename T>
inline bool buffer_view_matches_format_descriptor(const char* view_format) { inline bool buffer_view_matches_format_descriptor(const char* view_format) {
return view_format[0] == format_descriptor<T>::c || return view_format[0] == format_descriptor_char1<T>::c ||
(view_format[0] == 'Z' && (view_format[0] == 'Z' &&
view_format[1] == format_descriptor_char2<T>::c); view_format[1] == format_descriptor_char2<T>::c);
} }
......
...@@ -273,6 +273,12 @@ NonConstCmplxType SumSpanComplex(absl::Span<CmplxType> input_span) { ...@@ -273,6 +273,12 @@ NonConstCmplxType SumSpanComplex(absl::Span<CmplxType> input_span) {
return sum; return sum;
} }
std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
std::string result;
for (auto& i : input_span) result += str(i);
return result;
}
struct ObjectForSpan { struct ObjectForSpan {
explicit ObjectForSpan(int v) : value(v) {} explicit ObjectForSpan(int v) : value(v) {}
int value; int value;
...@@ -397,6 +403,7 @@ PYBIND11_MODULE(absl_example, m) { ...@@ -397,6 +403,7 @@ PYBIND11_MODULE(absl_example, m) {
m.def("sum_span_complex128", &SumSpanComplex<std::complex<double>>); m.def("sum_span_complex128", &SumSpanComplex<std::complex<double>>);
m.def("sum_span_const_complex128", m.def("sum_span_const_complex128",
&SumSpanComplex<const std::complex<double>>, arg("input_span")); &SumSpanComplex<const std::complex<double>>, arg("input_span"));
m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span"));
// Span of objects. // Span of objects.
class_<ObjectForSpan>(m, "ObjectForSpan") class_<ObjectForSpan>(m, "ObjectForSpan")
......
...@@ -378,6 +378,10 @@ class AbslNumericSpanTest(parameterized.TestCase): ...@@ -378,6 +378,10 @@ class AbslNumericSpanTest(parameterized.TestCase):
xs = np.array([x * 1j for x in range(10)]).astype(numpy_type) xs = np.array([x * 1j for x in range(10)]).astype(numpy_type)
self.assertEqual(sum_span_fn(xs), 45j) self.assertEqual(sum_span_fn(xs), 45j)
def test_pass_span_pyobject_ptr(self):
arr = np.array([-3, 'four', 5.0], dtype=object)
self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0')
def make_native_list_of_objects(): def make_native_list_of_objects():
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)] return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment