Commit 75fe331e by Victor Akabutu Committed by Copybara-Service

Enable StatusOr<T>* returns.

PiperOrigin-RevId: 357400716
parent 014b66ea
...@@ -146,6 +146,19 @@ either the `Status` object or a function returning a `Status` to ...@@ -146,6 +146,19 @@ either the `Status` object or a function returning a `Status` to
`pybind11::google::DoNotThrowStatus` before casting or binding. This works with `pybind11::google::DoNotThrowStatus` before casting or binding. This works with
references and pointers to `absl::Status` objects too. references and pointers to `absl::Status` objects too.
It isn't possible to specify separate return value policies for a `StatusOr`
object and its payload. Since `StatusOr` is processed and not ever actually
represented in Python, the return value policy applies to the payload. Eg, if
you return a StatusOr<MyObject*> (note the * is inside the `StatusOr`) with a
take_ownership return val policy and the status is OK (ie, it has a payload),
Python will take ownership of that payload and free it when it is garbage
collected.
However, if you return a StatusOr<MyObject>* (note the * is outside the
`StatusOr` rather than inside it now) with a take_ownership return val policy,
Python does not take ownership of the `StatusOr` and will not free it (because
again, that policy applies to `MyObject`, not `StatusOr`).
See `status_utils.cc` in this directory for details about what methods are See `status_utils.cc` in this directory for details about what methods are
available in wrapped `absl::Status` objects. available in wrapped `absl::Status` objects.
......
...@@ -121,7 +121,18 @@ struct type_caster<absl::Status> : public type_caster_base<absl::Status> { ...@@ -121,7 +121,18 @@ struct type_caster<absl::Status> : public type_caster_base<absl::Status> {
} }
}; };
// Convert an absl::StatusOr. // Convert absl::StatusOr<T>.
// It isn't possible to specify separate return value policies for the container
// (StatusOr) and the payload. Since StatusOr is processed and not ever actually
// represented in python, the return value policy applies to the payload. Eg, if
// you return a StatusOr<MyObject*> (note the * is inside the StatusOr) with a
// take_ownership return val policy and the status is ok (ie, it has a payload),
// python will take ownership of that payload and free it when it is garbage
// collected.
// However, if you return a StatusOr<MyObject>* (note the * is outside the
// StatusOr rather than inside it now) with a take_ownership return val policy,
// python does not take ownership of the StatusOr and will not free it (because
// again, that policy applies to MyObject, not StatusOr).
template <typename PayloadType> template <typename PayloadType>
struct type_caster<absl::StatusOr<PayloadType>> { struct type_caster<absl::StatusOr<PayloadType>> {
public: public:
...@@ -129,18 +140,38 @@ struct type_caster<absl::StatusOr<PayloadType>> { ...@@ -129,18 +140,38 @@ struct type_caster<absl::StatusOr<PayloadType>> {
using StatusCaster = make_caster<absl::Status>; using StatusCaster = make_caster<absl::Status>;
static constexpr auto name = _("StatusOr[") + PayloadCaster::name + _("]"); static constexpr auto name = _("StatusOr[") + PayloadCaster::name + _("]");
// Conversion part 2 (C++ -> Python). // Convert C++ -> Python.
static handle cast(const absl::StatusOr<PayloadType>* src,
return_value_policy policy, handle parent,
bool throw_exception = true) {
if (!src) return none().release();
return cast_impl(*src, policy, parent, throw_exception);
}
static handle cast(const absl::StatusOr<PayloadType>& src,
return_value_policy policy, handle parent,
bool throw_exception = true) {
return cast_impl(src, policy, parent, throw_exception);
}
static handle cast(absl::StatusOr<PayloadType>&& src, static handle cast(absl::StatusOr<PayloadType>&& src,
return_value_policy policy, handle parent, return_value_policy policy, handle parent,
bool throw_exception = true) { bool throw_exception = true) {
return cast_impl(std::move(src), policy, parent, throw_exception);
}
private:
template <typename CType>
static handle cast_impl(CType&& src, return_value_policy policy,
handle parent, bool throw_exception) {
google::CheckStatusModuleImported(); google::CheckStatusModuleImported();
if (src.ok()) { if (src.ok()) {
// Convert and return the payload. // Convert and return the payload.
return PayloadCaster::cast(std::forward<PayloadType>(*src), policy, return PayloadCaster::cast(std::forward<CType>(src).value(), policy,
parent); parent);
} else { } else {
// Convert and return the error. // Convert and return the error.
return StatusCaster::cast(std::move(src.status()), return StatusCaster::cast(std::forward<CType>(src).status(),
return_value_policy::move, parent, return_value_policy::move, parent,
throw_exception); throw_exception);
} }
......
...@@ -126,6 +126,15 @@ PYBIND11_MODULE(status_example, m) { ...@@ -126,6 +126,15 @@ PYBIND11_MODULE(status_example, m) {
"Return a reference in a status or to a static value.", "Return a reference in a status or to a static value.",
return_value_policy::reference); return_value_policy::reference);
m.def("return_unique_ptr_status_or", &ReturnUniquePtrStatusOr, arg("value")); m.def("return_unique_ptr_status_or", &ReturnUniquePtrStatusOr, arg("value"));
m.def("return_status_or_pointer", []() {
static absl::StatusOr<int>* ptr = new absl::StatusOr<int>(42);
return ptr;
});
m.def("return_failure_status_or_pointer", []() {
static absl::StatusOr<int>* ptr =
new absl::StatusOr<int>(absl::InvalidArgumentError("Uh oh!"));
return ptr;
});
} }
} // namespace test } // namespace test
......
...@@ -138,6 +138,17 @@ class StatusOrTest(absltest.TestCase): ...@@ -138,6 +138,17 @@ class StatusOrTest(absltest.TestCase):
status.StatusCode.CANCELLED) status.StatusCode.CANCELLED)
self.assertFalse(status.is_ok(failure_result)) self.assertFalse(status.is_ok(failure_result))
def test_return_status_or_pointer(self):
expected_result = 42
for _ in range(3):
result = status_example.return_status_or_pointer()
self.assertEqual(result, expected_result)
def test_return_failed_status_or_pointer(self):
for _ in range(3):
with self.assertRaises(status.StatusNotOk):
status_example.return_failure_status_or_pointer()
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment