Commit c18f6323 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `StatusNotOk` `__eq__` and pickle support.

PiperOrigin-RevId: 477876254
parent 13d4f99d
# Pybind11 bindings for the Abseil C++ Common Libraries # Pybind11 bindings for the Abseil C++ Common Libraries
# LOAD(pytype_pybind_extension)
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
...@@ -124,11 +125,6 @@ pybind_library( ...@@ -124,11 +125,6 @@ pybind_library(
name = "import_status_module", name = "import_status_module",
srcs = ["import_status_module.cc"], srcs = ["import_status_module.cc"],
hdrs = ["import_status_module.h"], hdrs = ["import_status_module.h"],
deps = [
":check_status_module_imported",
":register_status_bindings",
"@com_google_absl//absl/status",
],
) )
pybind_library( pybind_library(
......
...@@ -2,10 +2,6 @@ ...@@ -2,10 +2,6 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "absl/status/status.h"
#include "pybind11_abseil/check_status_module_imported.h"
#include "pybind11_abseil/register_status_bindings.h"
namespace pybind11 { namespace pybind11 {
namespace google { namespace google {
...@@ -14,13 +10,10 @@ module_ ImportStatusModule(bool bypass_regular_import) { ...@@ -14,13 +10,10 @@ module_ ImportStatusModule(bool bypass_regular_import) {
pybind11_fail("ImportStatusModule() PyGILState_Check() failure."); pybind11_fail("ImportStatusModule() PyGILState_Check() failure.");
} }
if (bypass_regular_import) { if (bypass_regular_import) {
auto m = reinterpret_borrow<module_>(PyImport_AddModule( throw std::runtime_error(
PYBIND11_TOSTRING(PYBIND11_ABSEIL_STATUS_MODULE_PATH))); "ImportStatusModule(bypass_regular_import=true) is no longer supported."
if (!internal::IsStatusModuleImported()) { " Please change the calling code to"
internal::RegisterStatusBindings(m); " call this function without arguments.");
}
// else no-op because bindings are already loaded.
return m;
} }
return module_::import(PYBIND11_TOSTRING(PYBIND11_ABSEIL_STATUS_MODULE_PATH)); return module_::import(PYBIND11_TOSTRING(PYBIND11_ABSEIL_STATUS_MODULE_PATH));
} }
......
...@@ -18,7 +18,9 @@ namespace google { ...@@ -18,7 +18,9 @@ namespace google {
// Imports the bindings for the status types. This is meant to only be called // Imports the bindings for the status types. This is meant to only be called
// from a PYBIND11_MODULE definition. The Python GIL must be held when calling // from a PYBIND11_MODULE definition. The Python GIL must be held when calling
// this function (enforced). // this function (enforced).
module_ ImportStatusModule(bool bypass_regular_import = true); // TODO(b/225205409): Remove bypass_regular_import.
// bypass_regular_import is deprecated and can only be false (enforced).
module_ ImportStatusModule(bool bypass_regular_import = false);
} // namespace google } // namespace google
} // namespace pybind11 } // namespace pybind11
......
...@@ -161,7 +161,8 @@ void RegisterStatusBindings(module m) { ...@@ -161,7 +161,8 @@ void RegisterStatusBindings(module m) {
[](const absl::StatusCode& code) { return static_cast<int>(code); }, [](const absl::StatusCode& code) { return static_cast<int>(code); },
arg("code")); arg("code"));
class_<absl::Status>(m, "Status") class_<absl::Status> py_class_status(m, "Status");
py_class_status.def(init())
.def(init()) .def(init())
.def(init([](InitFromTag init_from_tag, const object& obj) { .def(init([](InitFromTag init_from_tag, const object& obj) {
switch (init_from_tag) { switch (init_from_tag) {
...@@ -230,14 +231,6 @@ void RegisterStatusBindings(module m) { ...@@ -230,14 +231,6 @@ void RegisterStatusBindings(module m) {
[](const absl::Status& s) { [](const absl::Status& s) {
return decode_utf8_replace(s.ToString()); return decode_utf8_replace(s.ToString());
}) })
.def("__str__",
[](const absl::Status& s) {
return decode_utf8_replace(s.ToString());
})
.def("to_string_status_not_ok",
[](const absl::Status& s) {
return decode_utf8_replace(s.ToString());
})
.def_static("OkStatus", DoNotThrowStatus(&absl::OkStatus)) .def_static("OkStatus", DoNotThrowStatus(&absl::OkStatus))
.def("raw_code", &absl::Status::raw_code) .def("raw_code", &absl::Status::raw_code)
.def("CanonicalCode", .def("CanonicalCode",
...@@ -304,6 +297,11 @@ void RegisterStatusBindings(module m) { ...@@ -304,6 +297,11 @@ void RegisterStatusBindings(module m) {
PyCapsule_New(static_cast<void*>(self), "::absl::Status", nullptr)); PyCapsule_New(static_cast<void*>(self), "::absl::Status", nullptr));
}); });
py_class_status.def("__str__",
[](const absl::Status& s) {
return decode_utf8_replace(s.ToString());
});
m.def("is_ok", &IsOk, arg("status_or"), m.def("is_ok", &IsOk, arg("status_or"),
"Returns false only if passed a non-ok status; otherwise returns true. " "Returns false only if passed a non-ok status; otherwise returns true. "
"This can be used on the return value of a function which returns a " "This can be used on the return value of a function which returns a "
...@@ -330,13 +328,22 @@ void RegisterStatusBindings(module m) { ...@@ -330,13 +328,22 @@ void RegisterStatusBindings(module m) {
def_status_factory(m, "unimplemented_error", WrapUnimplementedError); def_status_factory(m, "unimplemented_error", WrapUnimplementedError);
def_status_factory(m, "unknown_error", WrapUnknownError); def_status_factory(m, "unknown_error", WrapUnknownError);
// Having Python code embedded here is a compromise solution:
// * Ideally the code would live in a .py file, where it is accessible to
// tooling such as linters or source code indexing systems.
// * But that makes the dependency management in build systems much more
// complex.
// * The embedded code fragment is small and expected to always stay small,
// having it here is most practical given current technologies (the lesser
// of two evils).
pybind11::exec(R"( pybind11::exec(R"(
class StatusNotOk(Exception): class StatusNotOk(Exception):
def __init__(self, status): def __init__(self, status):
assert status is not None assert status is not None
assert not status.ok() assert not status.ok()
self._status = status self._status = status
Exception.__init__(self, status.to_string_status_not_ok()) Exception.__init__(self, str(self))
@property @property
def status(self): def status(self):
...@@ -351,8 +358,25 @@ void RegisterStatusBindings(module m) { ...@@ -351,8 +358,25 @@ void RegisterStatusBindings(module m) {
@property @property
def message(self): def message(self):
return self._status.message() return self._status.message()
def __str__(self):
return self._status.to_string()
def __eq__(self, other):
if not isinstance(other, StatusNotOk):
return NotImplemented
lhs = Status(InitFromTag.capsule, self._status)
rhs = Status(InitFromTag.capsule, other._status)
return lhs == rhs
# NOTE: The absl::SourceLocation is lost.
# It is impossible to serialize-deserialize.
def __reduce_ex__(self, protocol):
del protocol
return (type(self), (self._status,))
)", )",
m.attr("__dict__"), m.attr("__dict__")); m.attr("__dict__"), m.attr("__dict__"));
static pybind11::object PyStatusNotOk = m.attr("StatusNotOk"); static pybind11::object PyStatusNotOk = m.attr("StatusNotOk");
// Register a custom handler which converts a C++ StatusNotOk to a // Register a custom handler which converts a C++ StatusNotOk to a
......
# Tests and examples for pybind11_abseil. # Tests and examples for pybind11_abseil.
# load("//devtools/python/blaze:strict.bzl", "py_strict_test") # LOAD(pytype_strict_contrib_test)
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
licenses(["notice"]) licenses(["notice"])
......
...@@ -45,6 +45,10 @@ class StatusCodeTest(absltest.TestCase): ...@@ -45,6 +45,10 @@ class StatusCodeTest(absltest.TestCase):
def test_status_code_as_int(self): def test_status_code_as_int(self):
self.assertEqual(status.StatusCodeAsInt(status.StatusCode.UNAVAILABLE), 14) self.assertEqual(status.StatusCodeAsInt(status.StatusCode.UNAVAILABLE), 14)
def test_repr(self):
self.assertEqual(
repr(status.StatusCode.NOT_FOUND), '<StatusCode.NOT_FOUND: 5>')
class StatusTest(parameterized.TestCase): class StatusTest(parameterized.TestCase):
...@@ -70,17 +74,6 @@ class StatusTest(parameterized.TestCase): ...@@ -70,17 +74,6 @@ class StatusTest(parameterized.TestCase):
self.assertEqual(cm.exception.code, int(status.StatusCode.CANCELLED)) self.assertEqual(cm.exception.code, int(status.StatusCode.CANCELLED))
self.assertEqual(cm.exception.message, 'test') self.assertEqual(cm.exception.message, 'test')
def test_build_status_not_ok_enum(self):
e = status.BuildStatusNotOk(status.StatusCode.INVALID_ARGUMENT, 'Msg enum.')
self.assertEqual(e.status.code(), status.StatusCode.INVALID_ARGUMENT)
self.assertEqual(e.code, int(status.StatusCode.INVALID_ARGUMENT))
self.assertEqual(e.message, 'Msg enum.')
def test_build_status_not_ok_int(self):
with self.assertRaises(TypeError) as cm:
status.BuildStatusNotOk(1, 'Msg int.') # pytype: disable=wrong-arg-types
self.assertIn('incompatible function arguments', str(cm.exception))
def test_status_not_ok_status(self): def test_status_not_ok_status(self):
e = status.StatusNotOk(status.Status(status.StatusCode.CANCELLED, 'Cnclld')) e = status.StatusNotOk(status.Status(status.StatusCode.CANCELLED, 'Cnclld'))
self.assertEqual(e.code, int(status.StatusCode.CANCELLED)) self.assertEqual(e.code, int(status.StatusCode.CANCELLED))
...@@ -246,7 +239,7 @@ class StatusTest(parameterized.TestCase): ...@@ -246,7 +239,7 @@ class StatusTest(parameterized.TestCase):
self.assertFalse(st.ErasePayload('UrlNeverExisted')) self.assertFalse(st.ErasePayload('UrlNeverExisted'))
self.assertEqual(st.AllPayloads(), ()) self.assertEqual(st.AllPayloads(), ())
def testDunderEqAndDunderHash(self): def test_eq_and_hash(self):
s0 = status.Status(status.StatusCode.CANCELLED, 'A') s0 = status.Status(status.StatusCode.CANCELLED, 'A')
sb = status.Status(status.StatusCode.CANCELLED, 'A') sb = status.Status(status.StatusCode.CANCELLED, 'A')
sp = status.Status(status.StatusCode.CANCELLED, 'A') sp = status.Status(status.StatusCode.CANCELLED, 'A')
...@@ -476,5 +469,38 @@ class StatusOrTest(absltest.TestCase): ...@@ -476,5 +469,38 @@ class StatusOrTest(absltest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
status_example.call_get_redirect_to_python(int_getter, 100) status_example.call_get_redirect_to_python(int_getter, 100)
class StatusNotOkTest(absltest.TestCase):
def test_build_status_not_ok_enum(self):
e = status.BuildStatusNotOk(status.StatusCode.INVALID_ARGUMENT, 'Msg enum.')
self.assertEqual(e.status.code(), status.StatusCode.INVALID_ARGUMENT)
self.assertEqual(e.code, int(status.StatusCode.INVALID_ARGUMENT))
self.assertEqual(e.message, 'Msg enum.')
def test_build_status_not_ok_int(self):
with self.assertRaises(TypeError) as cm:
status.BuildStatusNotOk(1, 'Msg int.') # pytype: disable=wrong-arg-types
self.assertIn('incompatible function arguments', str(cm.exception))
def test_eq(self):
sa1 = status.BuildStatusNotOk(status.StatusCode.UNKNOWN, 'sa')
sa2 = status.BuildStatusNotOk(status.StatusCode.UNKNOWN, 'sa')
sb = status.BuildStatusNotOk(status.StatusCode.UNKNOWN, 'sb')
self.assertTrue(bool(sa1 == sa1)) # pylint: disable=comparison-with-itself
self.assertTrue(bool(sa1 == sa2))
self.assertFalse(bool(sa1 == sb))
self.assertFalse(bool(sa1 == 'x'))
self.assertFalse(bool('x' == sa1))
def test_pickle(self):
orig = status.BuildStatusNotOk(status.StatusCode.UNKNOWN, 'Cabbage')
ser = pickle.dumps(orig)
deser = pickle.loads(ser)
self.assertEqual(deser.message, 'Cabbage')
self.assertEqual(deser, orig)
self.assertIs(deser.__class__, orig.__class__)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment