Commit a8fed755 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `absl::Span` `PyObject*` support in absl_casters.h.

PiperOrigin-RevId: 532586991
parent 57401927
......@@ -33,6 +33,7 @@
#include <pybind11/cast.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/type_caster_pyobject_ptr.h>
// Must NOT appear before at least one pybind11 include.
#include <datetime.h> // Python datetime builtin.
......@@ -392,11 +393,24 @@ namespace internal {
template <typename T>
static constexpr bool is_buffer_interface_compatible_type =
detail::is_same_ignoring_cvref<T, PyObject*>::value ||
std::is_arithmetic<T>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value;
template <typename T, typename SFINAE = void>
struct format_descriptor_char1 : format_descriptor<T> {};
template <typename T>
struct format_descriptor_char1<
T,
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject*>::value>> {
static constexpr const char c = 'O';
static constexpr const char value[2] = {c, '\0'};
static std::string format() { return std::string(1, c); }
};
template <typename T, typename SFINAE = void>
struct format_descriptor_char2 {
static constexpr const char c = '\0';
};
......@@ -406,7 +420,7 @@ struct format_descriptor_char2<std::complex<T>> : format_descriptor<T> {};
template <typename T>
inline bool buffer_view_matches_format_descriptor(const char* view_format) {
return view_format[0] == format_descriptor<T>::c ||
return view_format[0] == format_descriptor_char1<T>::c ||
(view_format[0] == 'Z' &&
view_format[1] == format_descriptor_char2<T>::c);
}
......
......@@ -273,6 +273,12 @@ NonConstCmplxType SumSpanComplex(absl::Span<CmplxType> input_span) {
return sum;
}
std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
std::string result;
for (auto& i : input_span) result += str(i);
return result;
}
struct ObjectForSpan {
explicit ObjectForSpan(int v) : value(v) {}
int value;
......@@ -397,6 +403,7 @@ PYBIND11_MODULE(absl_example, m) {
m.def("sum_span_complex128", &SumSpanComplex<std::complex<double>>);
m.def("sum_span_const_complex128",
&SumSpanComplex<const std::complex<double>>, arg("input_span"));
m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span"));
// Span of objects.
class_<ObjectForSpan>(m, "ObjectForSpan")
......
......@@ -378,6 +378,10 @@ class AbslNumericSpanTest(parameterized.TestCase):
xs = np.array([x * 1j for x in range(10)]).astype(numpy_type)
self.assertEqual(sum_span_fn(xs), 45j)
def test_pass_span_pyobject_ptr(self):
arr = np.array([-3, 'four', 5.0], dtype=object)
self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0')
def make_native_list_of_objects():
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment