Commit c4baef68 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `.as_absl_Status()` raw pointer capsule feature to `type_caster<absl::Status>::load()`

For interoperability with other Python binding systems (e.g. SWIG).

PiperOrigin-RevId: 487000705
parent 0bb93688
...@@ -50,6 +50,7 @@ pybind_library( ...@@ -50,6 +50,7 @@ pybind_library(
deps = [ deps = [
":check_status_module_imported", ":check_status_module_imported",
":no_throw_status", ":no_throw_status",
":raw_ptr_from_capsule",
":status_not_ok_exception", ":status_not_ok_exception",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "absl/status/status.h" #include "absl/status/status.h"
#include "pybind11_abseil/check_status_module_imported.h" #include "pybind11_abseil/check_status_module_imported.h"
#include "pybind11_abseil/no_throw_status.h" #include "pybind11_abseil/no_throw_status.h"
#include "pybind11_abseil/raw_ptr_from_capsule.h"
#include "pybind11_abseil/status_not_ok_exception.h" #include "pybind11_abseil/status_not_ok_exception.h"
namespace pybind11 { namespace pybind11 {
...@@ -68,6 +69,22 @@ struct type_caster<absl::Status> : public type_caster_base<absl::Status> { ...@@ -68,6 +69,22 @@ struct type_caster<absl::Status> : public type_caster_base<absl::Status> {
return cast_impl(std::move(src), policy, parent, throw_exception); return cast_impl(std::move(src), policy, parent, throw_exception);
} }
bool load(handle src, bool convert) {
if (type_caster_base<absl::Status>::load(src, convert)) {
return true;
}
if (convert) {
absl::StatusOr<void*> raw_ptr =
pybind11_abseil::raw_ptr_from_capsule::RawPtrFromCapsule<void>(
src.ptr(), "::absl::Status", "as_absl_Status");
if (raw_ptr.ok()) {
value = raw_ptr.value();
return true;
}
}
return false;
}
private: private:
template <typename CType> template <typename CType>
static handle cast_impl(CType&& src, return_value_policy policy, static handle cast_impl(CType&& src, return_value_policy policy,
......
...@@ -105,6 +105,20 @@ PYBIND11_MODULE(status_example, m) { ...@@ -105,6 +105,20 @@ PYBIND11_MODULE(status_example, m) {
auto status_module = pybind11::google::ImportStatusModule(); auto status_module = pybind11::google::ImportStatusModule();
m.attr("StatusNotOk") = status_module.attr("StatusNotOk"); m.attr("StatusNotOk") = status_module.attr("StatusNotOk");
m.def("make_absl_status_capsule", [](bool return_ok_status) {
static absl::Status ok_status;
static absl::Status not_ok_status(absl::StatusCode::kAlreadyExists,
"Made by make_absl_status_capsule.");
if (return_ok_status) {
return capsule(static_cast<void*>(&ok_status), "::absl::Status");
}
return capsule(static_cast<void*>(&not_ok_status), "::absl::Status");
});
m.def("extract_code_message", [](const absl::Status& status) {
return pybind11::make_tuple(status.code(), std::string(status.message()));
});
m.def("make_bad_capsule", [](bool pass_name) { m.def("make_bad_capsule", [](bool pass_name) {
// https://docs.python.org/3/c-api/capsule.html: // https://docs.python.org/3/c-api/capsule.html:
// The pointer argument may not be NULL. // The pointer argument may not be NULL.
......
...@@ -13,6 +13,15 @@ def docstring_signature(f): ...@@ -13,6 +13,15 @@ def docstring_signature(f):
return f.__doc__.split('\n')[0] return f.__doc__.split('\n')[0]
class AbslStatusCapsule:
def __init__(self, return_ok_status):
self.return_ok_status = return_ok_status
def as_absl_Status(self): # pylint: disable=invalid-name
return status_example.make_absl_status_capsule(self.return_ok_status)
class BadCapsule: class BadCapsule:
def __init__(self, pass_name): def __init__(self, pass_name):
...@@ -378,6 +387,34 @@ class StatusTest(parameterized.TestCase): ...@@ -378,6 +387,34 @@ class StatusTest(parameterized.TestCase):
f'NotACapsule.as_absl_Status() returned an object ({nm})' f'NotACapsule.as_absl_Status() returned an object ({nm})'
f' that is not a capsule.') f' that is not a capsule.')
@parameterized.parameters(False, True)
def test_status_caster_load_as_absl_status_success(self, return_ok_status):
code, msg = status_example.extract_code_message(
AbslStatusCapsule(return_ok_status))
if return_ok_status:
self.assertEqual(code, status.StatusCode.OK)
self.assertEqual(msg, '')
else:
self.assertEqual(code, status.StatusCode.ALREADY_EXISTS)
self.assertEqual(msg, 'Made by make_absl_status_capsule.')
@parameterized.parameters(False, True)
def test_status_caster_load_as_absl_status_bad_capsule(self, pass_name):
cap = BadCapsule(pass_name)
with self.assertRaises(TypeError):
status_example.extract_code_message(cap)
@parameterized.parameters(None, '', 0)
def test_status_caster_load_as_absl_status_not_a_capsule(self, not_a_capsule):
cap = NotACapsule(not_a_capsule)
with self.assertRaises(TypeError):
status_example.extract_code_message(cap)
@parameterized.parameters(None, '', 0)
def test_status_caster_load_no_as_absl_status(self, something_random):
with self.assertRaises(TypeError):
status_example.extract_code_message(something_random)
class IntGetter(status_example.IntGetter): class IntGetter(status_example.IntGetter):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment