Commit 64e85ee3 by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `StatusCodeFromInt()`, `Status(InitFromTag.capsule, obj)`, `Status.as_absl_Status()`.

This enables universal interoperability with other Python-C/C++ binding systems, including the raw Python C API.

PiperOrigin-RevId: 475633368
parent b96bc8ee
...@@ -67,6 +67,34 @@ pybind_library( ...@@ -67,6 +67,34 @@ pybind_library(
], ],
) )
cc_library(
name = "init_from_tag",
hdrs = ["init_from_tag.h"],
visibility = ["//visibility:private"],
)
cc_library(
name = "raw_ptr_from_capsule",
hdrs = ["raw_ptr_from_capsule.h"],
visibility = ["//visibility:private"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@local_config_python//:python_headers", # buildcleaner: keep
],
)
pybind_library(
name = "utils_pybind11_absl",
srcs = ["utils_pybind11_absl.cc"],
hdrs = ["utils_pybind11_absl.h"],
visibility = ["//visibility:private"],
deps = [
"@com_google_absl//absl/strings",
],
)
pybind_library( pybind_library(
name = "register_status_bindings", name = "register_status_bindings",
srcs = ["register_status_bindings.cc"], srcs = ["register_status_bindings.cc"],
...@@ -74,10 +102,14 @@ pybind_library( ...@@ -74,10 +102,14 @@ pybind_library(
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
":absl_casters", ":absl_casters",
":init_from_tag",
":no_throw_status", ":no_throw_status",
":raw_ptr_from_capsule",
":status_caster", ":status_caster",
":status_not_ok_exception", ":status_not_ok_exception",
":utils_pybind11_absl",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
......
#ifndef PYBIND11_ABSEIL_INIT_FROM_TAG_H_
#define PYBIND11_ABSEIL_INIT_FROM_TAG_H_
namespace pybind11 {
namespace google {
enum struct InitFromTag { capsule, capsule_direct_only, serialized };
} // namespace google
} // namespace pybind11
#endif // PYBIND11_ABSEIL_INIT_FROM_TAG_H_
#ifndef PYBIND11_ABSEIL_RAW_PTR_FROM_CAPSULE_H_
#define PYBIND11_ABSEIL_RAW_PTR_FROM_CAPSULE_H_
// Must be first include (https://docs.python.org/3/c-api/intro.html).
#include <Python.h>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
namespace pybind11_abseil {
namespace raw_ptr_from_capsule {
// Copied from pybind11/pytypes.h, to decouple from exception handling
// requirement.
// Equivalent to obj.__class__.__name__ (or obj.__name__ if obj is a class).
inline const char* obj_class_name(PyObject* obj) {
if (Py_TYPE(obj) == &PyType_Type) {
return reinterpret_cast<PyTypeObject*>(obj)->tp_name;
}
return Py_TYPE(obj)->tp_name;
}
inline std::string quoted_name_or_null_indicator(
const char* name, const char* quote = "\"",
const char* null_indicator = "NULL") {
return (name == nullptr ? null_indicator : absl::StrCat(quote, name, quote));
}
template <typename T>
absl::StatusOr<T*> RawPtrFromCapsule(PyObject* py_obj, const char* name,
const char* as_capsule_method_name) {
// Note: https://docs.python.org/3/c-api/capsule.html:
// The pointer argument may not be NULL.
if (PyCapsule_CheckExact(py_obj)) {
void* void_ptr = PyCapsule_GetPointer(py_obj, name);
if (PyErr_Occurred()) {
PyErr_Clear();
return absl::InvalidArgumentError(absl::StrCat(
"obj is a capsule with name ",
quoted_name_or_null_indicator(PyCapsule_GetName(py_obj)), " but ",
quoted_name_or_null_indicator(name), " is expected."));
}
return static_cast<T*>(void_ptr);
}
if (as_capsule_method_name == nullptr) {
return absl::InvalidArgumentError(
absl::StrCat(obj_class_name(py_obj), " object is not a capsule."));
}
PyObject* from_method =
PyObject_CallMethod(py_obj, as_capsule_method_name, nullptr);
if (from_method == nullptr) {
PyObject *ptype = nullptr, *pvalue = nullptr, *ptraceback = nullptr;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
PyObject* err_msg_str = PyObject_Str(pvalue);
std::string err_msg;
if (err_msg_str == nullptr) {
PyErr_Clear();
err_msg = "<message unavailable>";
} else {
PyObject* err_msg_bytes =
PyUnicode_AsEncodedString(err_msg_str, "UTF-8", "replace");
Py_DECREF(err_msg_str);
if (err_msg_bytes == nullptr) {
PyErr_Clear();
err_msg = "<message unavailable>";
} else {
const char* err_msg_char_ptr = PyBytes_AsString(err_msg_bytes);
if (err_msg_char_ptr == nullptr) {
PyErr_Clear();
err_msg = "<message unavailable>";
} else {
err_msg = err_msg_char_ptr;
}
Py_DECREF(err_msg_bytes);
}
}
return absl::InvalidArgumentError(
absl::StrCat(obj_class_name(py_obj), ".", as_capsule_method_name,
"() call failed: ", obj_class_name(ptype), ": ", err_msg));
}
if (!PyCapsule_CheckExact(from_method)) {
std::string returned_obj_type = obj_class_name(from_method);
Py_DECREF(from_method);
return absl::InvalidArgumentError(
absl::StrCat(obj_class_name(py_obj), ".", as_capsule_method_name,
"() returned an object (", returned_obj_type,
") that is not a capsule."));
}
void* void_ptr = PyCapsule_GetPointer(from_method, name);
if (!PyErr_Occurred()) {
Py_DECREF(from_method);
return static_cast<T*>(void_ptr);
}
PyErr_Clear();
std::string capsule_name =
quoted_name_or_null_indicator(PyCapsule_GetName(from_method));
Py_DECREF(from_method);
return absl::InvalidArgumentError(
absl::StrCat(obj_class_name(py_obj), ".", as_capsule_method_name,
"() returned a capsule with name ", capsule_name, " but ",
quoted_name_or_null_indicator(name), " is expected."));
}
} // namespace raw_ptr_from_capsule
} // namespace pybind11_abseil
#endif // PYBIND11_ABSEIL_RAW_PTR_FROM_CAPSULE_H_
...@@ -9,18 +9,21 @@ ...@@ -9,18 +9,21 @@
#include <utility> #include <utility>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "pybind11_abseil/absl_casters.h" #include "pybind11_abseil/absl_casters.h"
#include "pybind11_abseil/init_from_tag.h"
#include "pybind11_abseil/no_throw_status.h" #include "pybind11_abseil/no_throw_status.h"
#include "pybind11_abseil/raw_ptr_from_capsule.h"
#include "pybind11_abseil/status_caster.h" #include "pybind11_abseil/status_caster.h"
#include "pybind11_abseil/status_not_ok_exception.h" #include "pybind11_abseil/status_not_ok_exception.h"
#include "pybind11_abseil/utils_pybind11_absl.h"
namespace pybind11 { namespace pybind11 {
namespace google { namespace google {
namespace { namespace {
enum struct InitFromTag { capsule, serialized };
// Returns false if status_or represents a non-ok status object, and true in all // Returns false if status_or represents a non-ok status object, and true in all
// other cases (including the case that this is passed a non-status object). // other cases (including the case that this is passed a non-status object).
bool IsOk(handle status_or) { bool IsOk(handle status_or) {
...@@ -96,17 +99,6 @@ void def_status_factory( ...@@ -96,17 +99,6 @@ void def_status_factory(
arg("message")); arg("message"));
} }
// TODO(b/225205409): Move to utility library.
// To avoid clobbering potentially critical error messages with
// `UnicodeDecodeError`.
str decode_utf8_replace(absl::string_view s) {
PyObject* u = PyUnicode_DecodeUTF8(s.data(), s.size(), "replace");
if (u == nullptr) {
throw error_already_set();
}
return reinterpret_steal<str>(u);
}
} // namespace } // namespace
namespace internal { namespace internal {
...@@ -114,6 +106,7 @@ namespace internal { ...@@ -114,6 +106,7 @@ namespace internal {
void RegisterStatusBindings(module m) { void RegisterStatusBindings(module m) {
enum_<InitFromTag>(m, "InitFromTag") enum_<InitFromTag>(m, "InitFromTag")
.value("capsule", InitFromTag::capsule) .value("capsule", InitFromTag::capsule)
.value("capsule_direct_only", InitFromTag::capsule_direct_only)
.value("serialized", InitFromTag::serialized); .value("serialized", InitFromTag::serialized);
enum_<absl::StatusCode>(m, "StatusCode") enum_<absl::StatusCode>(m, "StatusCode")
...@@ -135,14 +128,37 @@ void RegisterStatusBindings(module m) { ...@@ -135,14 +128,37 @@ void RegisterStatusBindings(module m) {
.value("DATA_LOSS", absl::StatusCode::kDataLoss) .value("DATA_LOSS", absl::StatusCode::kDataLoss)
.value("UNAUTHENTICATED", absl::StatusCode::kUnauthenticated); .value("UNAUTHENTICATED", absl::StatusCode::kUnauthenticated);
m.def(
"StatusCodeFromInt",
[](int code_int) {
auto code = absl::StatusCode(code_int);
if (absl::StatusCodeToString(code).empty()) {
// StatusCodeToString() seems to be the best available method for
// this purpose.
throw value_error(absl::StrCat("code_int=", code_int,
" is not a valid absl::StatusCode"));
}
return code;
},
arg("code_int"));
class_<absl::Status>(m, "Status") class_<absl::Status>(m, "Status")
.def(init()) .def(init())
.def(init([](InitFromTag init_from_tag, const object& obj) { .def(init([](InitFromTag init_from_tag, const object& obj) {
switch (init_from_tag) { switch (init_from_tag) {
case InitFromTag::capsule: { case InitFromTag::capsule:
PyErr_SetString(PyExc_NotImplementedError, case InitFromTag::capsule_direct_only: {
"Implemented in pending child cl/474244219."); absl::StatusOr<absl::Status*> raw_ptr =
throw error_already_set(); pybind11_abseil::raw_ptr_from_capsule::RawPtrFromCapsule<
absl::Status>(obj.ptr(), "::absl::Status",
init_from_tag == InitFromTag::capsule
? "as_absl_Status"
: nullptr);
if (!raw_ptr.ok()) {
throw value_error(std::string(raw_ptr.status().message()));
}
return std::unique_ptr<absl::Status>{
new absl::Status{*raw_ptr.value()}};
} }
case InitFromTag::serialized: { case InitFromTag::serialized: {
auto state = cast<tuple>(obj); auto state = cast<tuple>(obj);
...@@ -247,7 +263,11 @@ void RegisterStatusBindings(module m) { ...@@ -247,7 +263,11 @@ void RegisterStatusBindings(module m) {
self.attr("message_bytes")(), self.attr("message_bytes")(),
self.attr("AllPayloads")()))); self.attr("AllPayloads")())));
}, },
arg("protocol") = -1); arg("protocol") = -1)
.def("as_absl_Status", [](absl::Status* self) -> object {
return reinterpret_steal<object>(
PyCapsule_New(static_cast<void*>(self), "::absl::Status", nullptr));
});
m.def("is_ok", &IsOk, arg("status_or"), m.def("is_ok", &IsOk, arg("status_or"),
"Returns false only if passed a non-ok status; otherwise returns true. " "Returns false only if passed a non-ok status; otherwise returns true. "
......
...@@ -105,6 +105,13 @@ PYBIND11_MODULE(status_example, m) { ...@@ -105,6 +105,13 @@ PYBIND11_MODULE(status_example, m) {
auto status_module = pybind11::google::ImportStatusModule(); auto status_module = pybind11::google::ImportStatusModule();
m.attr("StatusNotOk") = status_module.attr("StatusNotOk"); m.attr("StatusNotOk") = status_module.attr("StatusNotOk");
m.def("make_bad_capsule", [](bool pass_name) {
// https://docs.python.org/3/c-api/capsule.html:
// The pointer argument may not be NULL.
return capsule(static_cast<void*>(static_cast<int*>(nullptr) + 1),
pass_name ? "NotGood" : nullptr);
});
class_<IntValue>(m, "IntValue").def_readonly("value", &IntValue::value); class_<IntValue>(m, "IntValue").def_readonly("value", &IntValue::value);
class_<TestClass>(m, "TestClass") class_<TestClass>(m, "TestClass")
......
...@@ -13,6 +13,36 @@ def docstring_signature(f): ...@@ -13,6 +13,36 @@ def docstring_signature(f):
return f.__doc__.split('\n')[0] return f.__doc__.split('\n')[0]
class BadCapsule:
def __init__(self, pass_name):
self.pass_name = pass_name
def as_absl_Status(self): # pylint: disable=invalid-name
return status_example.make_bad_capsule(self.pass_name)
class NotACapsule:
def __init__(self, not_a_capsule):
self.not_a_capsule = not_a_capsule
def as_absl_Status(self): # pylint: disable=invalid-name
return self.not_a_capsule
class StatusCodeTest(absltest.TestCase):
def test_status_code_from_int_valid(self):
self.assertEqual(status.StatusCodeFromInt(13), status.StatusCode.INTERNAL)
def test_status_code_from_int_invalid(self):
with self.assertRaises(ValueError) as ctx:
status.StatusCodeFromInt(9876)
self.assertEqual(
str(ctx.exception), 'code_int=9876 is not a valid absl::StatusCode')
class StatusTest(parameterized.TestCase): class StatusTest(parameterized.TestCase):
def test_pass_status(self): def test_pass_status(self):
...@@ -265,9 +295,62 @@ class StatusTest(parameterized.TestCase): ...@@ -265,9 +295,62 @@ class StatusTest(parameterized.TestCase):
status.Status(status.InitFromTag.serialized, status.Status(status.InitFromTag.serialized,
(status.StatusCode.CANCELLED, '', ((0, 0, 0),))) (status.StatusCode.CANCELLED, '', ((0, 0, 0),)))
def test_init_from_capsule_not_implemented_error(self): def test_init_from_capsule_direct_ok(self):
with self.assertRaises(NotImplementedError): orig = status.Status(status.StatusCode.CANCELLED, 'Direct.')
status.Status(status.InitFromTag.capsule, ()) from_cap = status.Status(status.InitFromTag.capsule, orig.as_absl_Status())
self.assertEqualStatus(from_cap, orig)
def test_init_from_capsule_as_capsule_method_ok(self):
orig = status.Status(status.StatusCode.CANCELLED, 'AsCapsuleMethod.')
from_cap = status.Status(status.InitFromTag.capsule, orig)
self.assertEqualStatus(from_cap, orig)
@parameterized.parameters((False, 'NULL'), (True, '"NotGood"'))
def test_init_from_capsule_direct_bad_capsule(self, pass_name, quoted_name):
with self.assertRaises(ValueError) as ctx:
status.Status(status.InitFromTag.capsule,
status_example.make_bad_capsule(pass_name))
self.assertEqual(
str(ctx.exception),
f'obj is a capsule with name {quoted_name} but "::absl::Status"'
f' is expected.')
@parameterized.parameters((False, 'NULL'), (True, '"NotGood"'))
def test_init_from_capsule_correct_method_bad_capsule(self, pass_name,
quoted_name):
with self.assertRaises(ValueError) as ctx:
status.Status(status.InitFromTag.capsule, BadCapsule(pass_name))
self.assertEqual(
str(ctx.exception),
f'BadCapsule.as_absl_Status() returned a capsule with name'
f' {quoted_name} but "::absl::Status" is expected.')
@parameterized.parameters(None, '', 0)
def test_init_from_capsule_direct_only_not_a_capsule(self, not_a_capsule):
with self.assertRaises(ValueError) as ctx:
status.Status(status.InitFromTag.capsule_direct_only, not_a_capsule)
nm = not_a_capsule.__class__.__name__
self.assertEqual(str(ctx.exception), f'{nm} object is not a capsule.')
@parameterized.parameters(None, '', 0)
def test_init_from_capsule_direct_not_a_capsule(self, not_a_capsule):
with self.assertRaises(ValueError) as ctx:
status.Status(status.InitFromTag.capsule, not_a_capsule)
nm = not_a_capsule.__class__.__name__
self.assertEqual(
str(ctx.exception),
f"{nm}.as_absl_Status() call failed: AttributeError: '{nm}' object"
f" has no attribute 'as_absl_Status'")
@parameterized.parameters(None, '', 0)
def test_init_from_capsule_correct_method_not_a_capsule(self, not_a_capsule):
with self.assertRaises(ValueError) as ctx:
status.Status(status.InitFromTag.capsule, NotACapsule(not_a_capsule))
nm = not_a_capsule.__class__.__name__
self.assertEqual(
str(ctx.exception),
f'NotACapsule.as_absl_Status() returned an object ({nm})'
f' that is not a capsule.')
class IntGetter(status_example.IntGetter): class IntGetter(status_example.IntGetter):
......
#include "pybind11_abseil/utils_pybind11_absl.h"
#include <Python.h>
#include <pybind11/pybind11.h>
#include "absl/strings/string_view.h"
namespace pybind11 {
namespace google {
str decode_utf8_replace(absl::string_view s) {
PyObject* u = PyUnicode_DecodeUTF8(s.data(), s.size(), "replace");
if (u == nullptr) {
throw error_already_set();
}
return reinterpret_steal<str>(u);
}
} // namespace google
} // namespace pybind11
#ifndef PYBIND11_ABSEIL_UTILS_PYBIND11_ABSL_H_
#define PYBIND11_ABSEIL_UTILS_PYBIND11_ABSL_H_
// Note: This is meant to only depend on pybind11 and absl headers.
// DO NOT add other dependencies.
#include <pybind11/pybind11.h>
#include "absl/strings/string_view.h"
namespace pybind11 {
namespace google {
// To avoid clobbering potentially critical error messages with
// `UnicodeDecodeError`.
str decode_utf8_replace(absl::string_view s);
} // namespace google
} // namespace pybind11
#endif // PYBIND11_ABSEIL_UTILS_PYBIND11_ABSL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment