Commit b96bc8ee by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add Payload Management APIs and pickle support.

PiperOrigin-RevId: 475084993
parent b863d63b
...@@ -19,6 +19,8 @@ namespace pybind11 { ...@@ -19,6 +19,8 @@ namespace pybind11 {
namespace google { namespace google {
namespace { namespace {
enum struct InitFromTag { capsule, serialized };
// Returns false if status_or represents a non-ok status object, and true in all // Returns false if status_or represents a non-ok status object, and true in all
// other cases (including the case that this is passed a non-status object). // other cases (including the case that this is passed a non-status object).
bool IsOk(handle status_or) { bool IsOk(handle status_or) {
...@@ -110,6 +112,10 @@ str decode_utf8_replace(absl::string_view s) { ...@@ -110,6 +112,10 @@ str decode_utf8_replace(absl::string_view s) {
namespace internal { namespace internal {
void RegisterStatusBindings(module m) { void RegisterStatusBindings(module m) {
enum_<InitFromTag>(m, "InitFromTag")
.value("capsule", InitFromTag::capsule)
.value("serialized", InitFromTag::serialized);
enum_<absl::StatusCode>(m, "StatusCode") enum_<absl::StatusCode>(m, "StatusCode")
.value("OK", absl::StatusCode::kOk) .value("OK", absl::StatusCode::kOk)
.value("CANCELLED", absl::StatusCode::kCancelled) .value("CANCELLED", absl::StatusCode::kCancelled)
...@@ -131,7 +137,45 @@ void RegisterStatusBindings(module m) { ...@@ -131,7 +137,45 @@ void RegisterStatusBindings(module m) {
class_<absl::Status>(m, "Status") class_<absl::Status>(m, "Status")
.def(init()) .def(init())
.def(init<absl::StatusCode, std::string>()) .def(init([](InitFromTag init_from_tag, const object& obj) {
switch (init_from_tag) {
case InitFromTag::capsule: {
PyErr_SetString(PyExc_NotImplementedError,
"Implemented in pending child cl/474244219.");
throw error_already_set();
}
case InitFromTag::serialized: {
auto state = cast<tuple>(obj);
if (len(state) != 3) {
throw value_error(
absl::StrCat("Unexpected len(state) == ", len(state),
" [", __FILE__, ":", __LINE__, "]"));
}
auto code = cast<absl::StatusCode>(state[0]);
auto message = cast<std::string>(state[1]);
auto all_payloads = cast<tuple>(state[2]);
auto status = std::unique_ptr<absl::Status>{
new absl::Status{code, message}};
for (auto ap_item_obj : all_payloads) {
auto ap_item_tup = cast<tuple>(ap_item_obj);
if (len(ap_item_tup) != 2) {
throw value_error(absl::StrCat(
"Unexpected len(tuple) == ", len(ap_item_tup),
" where (type_url, payload) is expected [", __FILE__,
":", __LINE__, "]"));
}
auto type_url = cast<absl::string_view>(ap_item_tup[0]);
auto payload = cast<absl::string_view>(ap_item_tup[1]);
status->SetPayload(type_url, absl::Cord(payload));
}
return status;
}
}
throw std::runtime_error(absl::StrCat(
"Meant to be unreachable [", __FILE__, ":", __LINE__, "]"));
}),
arg("init_from_tag"), arg("obj"))
.def(init<absl::StatusCode, std::string>(), arg("code"), arg("msg"))
.def("ok", &absl::Status::ok) .def("ok", &absl::Status::ok)
.def("code", &absl::Status::code) .def("code", &absl::Status::code)
.def("code_int", .def("code_int",
...@@ -171,7 +215,39 @@ void RegisterStatusBindings(module m) { ...@@ -171,7 +215,39 @@ void RegisterStatusBindings(module m) {
[](const absl::Status& self) { [](const absl::Status& self) {
return decode_utf8_replace(self.message()); return decode_utf8_replace(self.message());
}) })
.def("IgnoreError", &absl::Status::IgnoreError); .def("IgnoreError", &absl::Status::IgnoreError)
.def("SetPayload",
[](absl::Status& self, absl::string_view type_url,
absl::string_view payload) {
self.SetPayload(type_url, absl::Cord(payload));
})
.def("ErasePayload",
[](absl::Status& self, absl::string_view type_url) {
return self.ErasePayload(type_url);
})
.def("AllPayloads",
[](const absl::Status& s) {
list key_value_pairs;
s.ForEachPayload([&key_value_pairs](absl::string_view key,
const absl::Cord& value) {
key_value_pairs.append(make_tuple(bytes(std::string(key)),
bytes(std::string(value))));
});
// Make the order deterministic, especially long-term.
key_value_pairs.attr("sort")();
return tuple(key_value_pairs);
})
.def(
"__reduce_ex__",
[](const object& self, int) {
return make_tuple(
self.attr("__class__"),
make_tuple(InitFromTag::serialized,
make_tuple(self.attr("code")(),
self.attr("message_bytes")(),
self.attr("AllPayloads")())));
},
arg("protocol") = -1);
m.def("is_ok", &IsOk, arg("status_or"), m.def("is_ok", &IsOk, arg("status_or"),
"Returns false only if passed a non-ok status; otherwise returns true. " "Returns false only if passed a non-ok status; otherwise returns true. "
......
"""Tests for google3.third_party.pybind11_abseil.status_casters.""" """Tests for google3.third_party.pybind11_abseil.status_casters."""
from __future__ import absolute_import import pickle
from __future__ import division
from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized
from pybind11_abseil import status from pybind11_abseil import status
from pybind11_abseil.tests import status_example from pybind11_abseil.tests import status_example
...@@ -14,7 +13,7 @@ def docstring_signature(f): ...@@ -14,7 +13,7 @@ def docstring_signature(f):
return f.__doc__.split('\n')[0] return f.__doc__.split('\n')[0]
class StatusTest(absltest.TestCase): class StatusTest(parameterized.TestCase):
def test_pass_status(self): def test_pass_status(self):
test_status = status.Status(status.StatusCode.CANCELLED, 'test') test_status = status.Status(status.StatusCode.CANCELLED, 'test')
...@@ -183,6 +182,93 @@ class StatusTest(absltest.TestCase): ...@@ -183,6 +182,93 @@ class StatusTest(absltest.TestCase):
self.assertEqual(st500.raw_code(), 500) self.assertEqual(st500.raw_code(), 500)
self.assertEqual(st500.code(), status.StatusCode.UNKNOWN) self.assertEqual(st500.code(), status.StatusCode.UNKNOWN)
def test_payload_management_apis(self):
st = status.Status(status.StatusCode.CANCELLED, '')
self.assertEqual(st.AllPayloads(), ())
st.SetPayload('Url1', 'Payload1')
self.assertEqual(st.AllPayloads(), ((b'Url1', b'Payload1'),))
st.SetPayload('Url0', 'Payload0')
self.assertEqual(st.AllPayloads(),
((b'Url0', b'Payload0'), (b'Url1', b'Payload1')))
st.SetPayload('Url2', 'Payload2')
self.assertEqual(st.AllPayloads(),
((b'Url0', b'Payload0'), (b'Url1', b'Payload1'),
(b'Url2', b'Payload2')))
st.SetPayload('Url2', 'Payload2B')
self.assertEqual(st.AllPayloads(),
((b'Url0', b'Payload0'), (b'Url1', b'Payload1'),
(b'Url2', b'Payload2B')))
self.assertTrue(st.ErasePayload('Url1'))
self.assertEqual(st.AllPayloads(),
((b'Url0', b'Payload0'), (b'Url2', b'Payload2B')))
self.assertFalse(st.ErasePayload('Url1'))
self.assertEqual(st.AllPayloads(),
((b'Url0', b'Payload0'), (b'Url2', b'Payload2B')))
self.assertFalse(st.ErasePayload('UrlNeverExisted'))
self.assertEqual(st.AllPayloads(),
((b'Url0', b'Payload0'), (b'Url2', b'Payload2B')))
self.assertTrue(st.ErasePayload('Url0'))
self.assertEqual(st.AllPayloads(), ((b'Url2', b'Payload2B'),))
self.assertTrue(st.ErasePayload('Url2'))
self.assertEqual(st.AllPayloads(), ())
self.assertFalse(st.ErasePayload('UrlNeverExisted'))
self.assertEqual(st.AllPayloads(), ())
def assertEqualStatus(self, a, b):
self.assertEqual(a.code(), b.code())
self.assertEqual(a.message_bytes(), b.message_bytes())
self.assertSequenceEqual(sorted(a.AllPayloads()), sorted(b.AllPayloads()))
@parameterized.parameters(0, 1, 2)
def test_pickle(self, payload_size):
orig = status.Status(status.StatusCode.CANCELLED, 'Cucumber.')
expected_all_payloads = []
for i in range(payload_size):
type_url = f'Url{i}'
payload = f'Payload{i}'
orig.SetPayload(type_url, payload)
expected_all_payloads.append((type_url.encode(), payload.encode()))
expected_all_payloads = tuple(expected_all_payloads)
# Redundant with other tests, but here to reassure that the preconditions
# for the tests below to be meaningful are met.
self.assertEqual(orig.code(), status.StatusCode.CANCELLED)
self.assertEqual(orig.message_bytes(), b'Cucumber.')
self.assertEqual(orig.AllPayloads(), expected_all_payloads)
# Exercises implementation details, but is simple and might be useful to
# narrow down root causes for regressions.
redx = orig.__reduce_ex__()
self.assertLen(redx, 2)
self.assertIs(redx[0], status.Status)
self.assertEqual(
redx[1],
(status.InitFromTag.serialized,
(status.StatusCode.CANCELLED, b'Cucumber.', expected_all_payloads)))
ser = pickle.dumps(orig)
deser = pickle.loads(ser)
self.assertEqualStatus(deser, orig)
self.assertIs(deser.__class__, orig.__class__)
def test_init_from_serialized_exception_unexpected_len_state(self):
with self.assertRaisesRegex(
ValueError, r'Unexpected len\(state\) == 4'
r' \[.*register_status_bindings\.cc:[0-9]+\]'):
status.Status(status.InitFromTag.serialized, (0, 0, 0, 0))
def test_init_from_serialized_exception_unexpected_len_ap_item_tup(self):
with self.assertRaisesRegex(
ValueError,
r'Unexpected len\(tuple\) == 3 where \(type_url, payload\) is expected'
r' \[.*register_status_bindings\.cc:[0-9]+\]'):
status.Status(status.InitFromTag.serialized,
(status.StatusCode.CANCELLED, '', ((0, 0, 0),)))
def test_init_from_capsule_not_implemented_error(self):
with self.assertRaises(NotImplementedError):
status.Status(status.InitFromTag.capsule, ())
class IntGetter(status_example.IntGetter): class IntGetter(status_example.IntGetter):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment