Commit d9614e4e by pybind11_abseil authors Committed by Copybara-Service

Internal change

PiperOrigin-RevId: 373344235
parent 7cf76381
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <stdexcept> #include <stdexcept>
#include <type_traits>
#include <utility> #include <utility>
#include "absl/status/status.h" #include "absl/status/status.h"
...@@ -58,6 +59,20 @@ inline module ImportStatusModule() { ...@@ -58,6 +59,20 @@ inline module ImportStatusModule() {
namespace detail { namespace detail {
template <typename StatusType>
struct NoThrowStatusType {
// NoThrowStatus should only wrap absl::Status or absl::StatusOr.
using NoThrowAbslStatus = type_caster_base<absl::Status>;
static constexpr auto name = NoThrowAbslStatus::name;
};
template <typename PayloadType>
struct NoThrowStatusType<absl::StatusOr<PayloadType>> {
using NoThrowAbslStatus = type_caster_base<absl::Status>;
static constexpr auto name = _("Union[") + NoThrowAbslStatus::name + _(", ") +
make_caster<PayloadType>::name + _("]");
};
// Convert NoThrowStatus by dispatching to a caster for StatusType with the // Convert NoThrowStatus by dispatching to a caster for StatusType with the
// argument throw_exception = false. StatusType should be an absl::Status // argument throw_exception = false. StatusType should be an absl::Status
// (rvalue, lvalue, reference, or pointer), or an absl::StatusOr value. // (rvalue, lvalue, reference, or pointer), or an absl::StatusOr value.
...@@ -67,7 +82,7 @@ template <typename StatusType> ...@@ -67,7 +82,7 @@ template <typename StatusType>
struct type_caster<google::NoThrowStatus<StatusType>> { struct type_caster<google::NoThrowStatus<StatusType>> {
using InputType = google::NoThrowStatus<StatusType>; using InputType = google::NoThrowStatus<StatusType>;
using StatusCaster = make_caster<StatusType>; using StatusCaster = make_caster<StatusType>;
static constexpr auto name = StatusCaster::name; static constexpr auto name = NoThrowStatusType<StatusType>::name;
// Convert C++->Python. // Convert C++->Python.
static handle cast(const InputType& src, return_value_policy policy, static handle cast(const InputType& src, return_value_policy policy,
...@@ -85,7 +100,8 @@ struct type_caster<google::NoThrowStatus<StatusType>> { ...@@ -85,7 +100,8 @@ struct type_caster<google::NoThrowStatus<StatusType>> {
template <> template <>
struct type_caster<absl::Status> : public type_caster_base<absl::Status> { struct type_caster<absl::Status> : public type_caster_base<absl::Status> {
public: public:
// Convert C++ -> Python. static constexpr auto name = _("None");
// Convert C++ -> Python.
static handle cast(const absl::Status* src, return_value_policy policy, static handle cast(const absl::Status* src, return_value_policy policy,
handle parent, bool throw_exception = true) { handle parent, bool throw_exception = true) {
if (!src) return none().release(); if (!src) return none().release();
...@@ -138,8 +154,7 @@ struct type_caster<absl::StatusOr<PayloadType>> { ...@@ -138,8 +154,7 @@ struct type_caster<absl::StatusOr<PayloadType>> {
public: public:
using PayloadCaster = make_caster<PayloadType>; using PayloadCaster = make_caster<PayloadType>;
using StatusCaster = make_caster<absl::Status>; using StatusCaster = make_caster<absl::Status>;
static constexpr auto name = static constexpr auto name = PayloadCaster::name;
_("Union[") + StatusCaster::name + _(", ") + PayloadCaster::name + _("]");
// Convert C++ -> Python. // Convert C++ -> Python.
static handle cast(const absl::StatusOr<PayloadType>* src, static handle cast(const absl::StatusOr<PayloadType>* src,
......
...@@ -9,6 +9,11 @@ from pybind11_abseil import status ...@@ -9,6 +9,11 @@ from pybind11_abseil import status
from pybind11_abseil.tests import status_example from pybind11_abseil.tests import status_example
def docstring_signature(f):
"""Returns the first line from a docstring - the signature for a function."""
return f.__doc__.split('\n')[0]
class StatusTest(absltest.TestCase): class StatusTest(absltest.TestCase):
def test_pass_status(self): def test_pass_status(self):
...@@ -16,6 +21,10 @@ class StatusTest(absltest.TestCase): ...@@ -16,6 +21,10 @@ class StatusTest(absltest.TestCase):
self.assertTrue( self.assertTrue(
status_example.check_status(test_status, status.StatusCode.CANCELLED)) status_example.check_status(test_status, status.StatusCode.CANCELLED))
def test_return_status_return_type_from_doc(self):
self.assertEndsWith(
docstring_signature(status_example.return_status), ' -> None')
def test_return_ok(self): def test_return_ok(self):
# The return_status function should convert an ok status to None. # The return_status function should convert an ok status to None.
self.assertIsNone(status_example.return_status(status.StatusCode.OK)) self.assertIsNone(status_example.return_status(status.StatusCode.OK))
...@@ -39,6 +48,10 @@ class StatusTest(absltest.TestCase): ...@@ -39,6 +48,10 @@ class StatusTest(absltest.TestCase):
with self.assertRaises(Exception): with self.assertRaises(Exception):
status_example.return_status(status.StatusCode.CANCELLED, 'test') status_example.return_status(status.StatusCode.CANCELLED, 'test')
def test_make_status_return_type_from_doc(self):
self.assertRegex(
docstring_signature(status_example.make_status), r' -> .*\.Status')
def test_make_ok(self): def test_make_ok(self):
# The make_status function has been set up to return a status object # The make_status function has been set up to return a status object
# instead of raising an exception (this is done in status_example.cc). # instead of raising an exception (this is done in status_example.cc).
...@@ -107,6 +120,11 @@ class StatusTest(absltest.TestCase): ...@@ -107,6 +120,11 @@ class StatusTest(absltest.TestCase):
class StatusOrTest(absltest.TestCase): class StatusOrTest(absltest.TestCase):
def test_return_value_status_or_return_type_from_doc(self):
self.assertEndsWith(
docstring_signature(status_example.return_value_status_or),
' -> int')
def test_return_value(self): def test_return_value(self):
self.assertEqual(status_example.return_value_status_or(5), 5) self.assertEqual(status_example.return_value_status_or(5), 5)
...@@ -115,6 +133,11 @@ class StatusOrTest(absltest.TestCase): ...@@ -115,6 +133,11 @@ class StatusOrTest(absltest.TestCase):
status_example.return_failure_status_or(status.StatusCode.NOT_FOUND) status_example.return_failure_status_or(status.StatusCode.NOT_FOUND)
self.assertEqual(cm.exception.status.code(), status.StatusCode.NOT_FOUND) self.assertEqual(cm.exception.status.code(), status.StatusCode.NOT_FOUND)
def test_make_failure_status_or_return_type_from_doc(self):
self.assertRegex(
docstring_signature(status_example.make_failure_status_or),
r' -> Union\[.*\.Status, int\]')
def test_make_not_ok(self): def test_make_not_ok(self):
self.assertEqual( self.assertEqual(
status_example.make_failure_status_or( status_example.make_failure_status_or(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment