Commit d9614e4e by pybind11_abseil authors Committed by Copybara-Service

Internal change

PiperOrigin-RevId: 373344235
parent 7cf76381
......@@ -21,6 +21,7 @@
#include <pybind11/pybind11.h>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include "absl/status/status.h"
......@@ -58,6 +59,20 @@ inline module ImportStatusModule() {
namespace detail {
template <typename StatusType>
struct NoThrowStatusType {
// NoThrowStatus should only wrap absl::Status or absl::StatusOr.
using NoThrowAbslStatus = type_caster_base<absl::Status>;
static constexpr auto name = NoThrowAbslStatus::name;
};
template <typename PayloadType>
struct NoThrowStatusType<absl::StatusOr<PayloadType>> {
using NoThrowAbslStatus = type_caster_base<absl::Status>;
static constexpr auto name = _("Union[") + NoThrowAbslStatus::name + _(", ") +
make_caster<PayloadType>::name + _("]");
};
// Convert NoThrowStatus by dispatching to a caster for StatusType with the
// argument throw_exception = false. StatusType should be an absl::Status
// (rvalue, lvalue, reference, or pointer), or an absl::StatusOr value.
......@@ -67,7 +82,7 @@ template <typename StatusType>
struct type_caster<google::NoThrowStatus<StatusType>> {
using InputType = google::NoThrowStatus<StatusType>;
using StatusCaster = make_caster<StatusType>;
static constexpr auto name = StatusCaster::name;
static constexpr auto name = NoThrowStatusType<StatusType>::name;
// Convert C++->Python.
static handle cast(const InputType& src, return_value_policy policy,
......@@ -85,7 +100,8 @@ struct type_caster<google::NoThrowStatus<StatusType>> {
template <>
struct type_caster<absl::Status> : public type_caster_base<absl::Status> {
public:
// Convert C++ -> Python.
static constexpr auto name = _("None");
// Convert C++ -> Python.
static handle cast(const absl::Status* src, return_value_policy policy,
handle parent, bool throw_exception = true) {
if (!src) return none().release();
......@@ -138,8 +154,7 @@ struct type_caster<absl::StatusOr<PayloadType>> {
public:
using PayloadCaster = make_caster<PayloadType>;
using StatusCaster = make_caster<absl::Status>;
static constexpr auto name =
_("Union[") + StatusCaster::name + _(", ") + PayloadCaster::name + _("]");
static constexpr auto name = PayloadCaster::name;
// Convert C++ -> Python.
static handle cast(const absl::StatusOr<PayloadType>* src,
......
......@@ -9,6 +9,11 @@ from pybind11_abseil import status
from pybind11_abseil.tests import status_example
def docstring_signature(f):
"""Returns the first line from a docstring - the signature for a function."""
return f.__doc__.split('\n')[0]
class StatusTest(absltest.TestCase):
def test_pass_status(self):
......@@ -16,6 +21,10 @@ class StatusTest(absltest.TestCase):
self.assertTrue(
status_example.check_status(test_status, status.StatusCode.CANCELLED))
def test_return_status_return_type_from_doc(self):
self.assertEndsWith(
docstring_signature(status_example.return_status), ' -> None')
def test_return_ok(self):
# The return_status function should convert an ok status to None.
self.assertIsNone(status_example.return_status(status.StatusCode.OK))
......@@ -39,6 +48,10 @@ class StatusTest(absltest.TestCase):
with self.assertRaises(Exception):
status_example.return_status(status.StatusCode.CANCELLED, 'test')
def test_make_status_return_type_from_doc(self):
self.assertRegex(
docstring_signature(status_example.make_status), r' -> .*\.Status')
def test_make_ok(self):
# The make_status function has been set up to return a status object
# instead of raising an exception (this is done in status_example.cc).
......@@ -107,6 +120,11 @@ class StatusTest(absltest.TestCase):
class StatusOrTest(absltest.TestCase):
def test_return_value_status_or_return_type_from_doc(self):
self.assertEndsWith(
docstring_signature(status_example.return_value_status_or),
' -> int')
def test_return_value(self):
self.assertEqual(status_example.return_value_status_or(5), 5)
......@@ -115,6 +133,11 @@ class StatusOrTest(absltest.TestCase):
status_example.return_failure_status_or(status.StatusCode.NOT_FOUND)
self.assertEqual(cm.exception.status.code(), status.StatusCode.NOT_FOUND)
def test_make_failure_status_or_return_type_from_doc(self):
self.assertRegex(
docstring_signature(status_example.make_failure_status_or),
r' -> Union\[.*\.Status, int\]')
def test_make_not_ok(self):
self.assertEqual(
status_example.make_failure_status_or(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment