Commit b959683a by Ralf W. Grosse-Kunstleve Committed by Copybara-Service

Add `Status` `__eq__` and `__hash__` methods.

PiperOrigin-RevId: 476925555
parent 075bf549
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <pybind11/embed.h> #include <pybind11/embed.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <cstddef>
#include <exception> #include <exception>
#include <functional> #include <functional>
#include <string> #include <string>
...@@ -99,6 +100,19 @@ void def_status_factory( ...@@ -99,6 +100,19 @@ void def_status_factory(
arg("message")); arg("message"));
} }
absl::StatusOr<absl::Status*> StatusRawPtrFromCapsule(
const object& obj, bool enable_as_capsule_method = true) {
return pybind11_abseil::raw_ptr_from_capsule::RawPtrFromCapsule<absl::Status>(
obj.ptr(), "::absl::Status",
enable_as_capsule_method ? "as_absl_Status" : nullptr);
}
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
std::size_t boost_hash_combine(std::size_t lhs, std::size_t rhs) {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
} // namespace } // namespace
namespace internal { namespace internal {
...@@ -153,12 +167,10 @@ void RegisterStatusBindings(module m) { ...@@ -153,12 +167,10 @@ void RegisterStatusBindings(module m) {
switch (init_from_tag) { switch (init_from_tag) {
case InitFromTag::capsule: case InitFromTag::capsule:
case InitFromTag::capsule_direct_only: { case InitFromTag::capsule_direct_only: {
bool enable_as_capsule_method =
(init_from_tag == InitFromTag::capsule);
absl::StatusOr<absl::Status*> raw_ptr = absl::StatusOr<absl::Status*> raw_ptr =
pybind11_abseil::raw_ptr_from_capsule::RawPtrFromCapsule< StatusRawPtrFromCapsule(obj, enable_as_capsule_method);
absl::Status>(obj.ptr(), "::absl::Status",
init_from_tag == InitFromTag::capsule
? "as_absl_Status"
: nullptr);
if (!raw_ptr.ok()) { if (!raw_ptr.ok()) {
throw value_error(std::string(raw_ptr.status().message())); throw value_error(std::string(raw_ptr.status().message()));
} }
...@@ -258,6 +270,24 @@ void RegisterStatusBindings(module m) { ...@@ -258,6 +270,24 @@ void RegisterStatusBindings(module m) {
key_value_pairs.attr("sort")(); key_value_pairs.attr("sort")();
return tuple(key_value_pairs); return tuple(key_value_pairs);
}) })
.def("__eq__",
[](const absl::Status& self, const object& rhs) {
absl::StatusOr<absl::Status*> rhs_ptr = StatusRawPtrFromCapsule(
rhs, /*enable_as_capsule_method=*/true);
return rhs_ptr.ok() && *rhs_ptr.value() == self;
})
.def("__hash__",
[](const absl::Status& self) {
// Payload is ignored intentionally to minimize runtime.
return boost_hash_combine(
std::hash<int>{}(self.raw_code()),
#if defined(ABSL_USES_STD_STRING_VIEW)
std::hash<std::string_view>{}(self.message())
#else
std::hash<std::string>{}(std::string(self.message()))
#endif
);
})
.def( .def(
"__reduce_ex__", "__reduce_ex__",
[](const object& self, int) { [](const object& self, int) {
......
...@@ -248,10 +248,41 @@ class StatusTest(parameterized.TestCase): ...@@ -248,10 +248,41 @@ class StatusTest(parameterized.TestCase):
self.assertFalse(st.ErasePayload('UrlNeverExisted')) self.assertFalse(st.ErasePayload('UrlNeverExisted'))
self.assertEqual(st.AllPayloads(), ()) self.assertEqual(st.AllPayloads(), ())
def assertEqualStatus(self, a, b): def testDunderEqAndDunderHash(self):
self.assertEqual(a.code(), b.code()) s0 = status.Status(status.StatusCode.CANCELLED, 'A')
self.assertEqual(a.message_bytes(), b.message_bytes()) sb = status.Status(status.StatusCode.CANCELLED, 'A')
self.assertSequenceEqual(sorted(a.AllPayloads()), sorted(b.AllPayloads())) sp = status.Status(status.StatusCode.CANCELLED, 'A')
sp.SetPayload('Url1p', 'Payload1p')
sc = status.Status(status.StatusCode.UNKNOWN, 'A')
sm = status.Status(status.StatusCode.CANCELLED, 'B')
sx = status.Status(status.StatusCode.UNKNOWN, 'B')
self.assertTrue(bool(s0 == s0)) # pylint: disable=comparison-with-itself
self.assertTrue(bool(s0 == sb))
self.assertFalse(bool(s0 == sp))
self.assertFalse(bool(s0 == sc))
self.assertFalse(bool(s0 == sm))
self.assertFalse(bool(s0 == sx))
self.assertFalse(bool(s0 == 'AnyOtherType'))
self.assertEqual(hash(sb), hash(s0))
self.assertEqual(hash(sp), hash(s0)) # Payload ignored intentionally.
self.assertNotEqual(hash(sc), hash(s0))
self.assertNotEqual(hash(sm), hash(s0))
self.assertNotEqual(hash(sx), hash(s0))
st_set = {s0}
self.assertLen(st_set, 1)
st_set.add(sb)
self.assertLen(st_set, 1)
st_set.add(sp)
self.assertLen(st_set, 2)
st_set.add(sc)
self.assertLen(st_set, 3)
st_set.add(sm)
self.assertLen(st_set, 4)
st_set.add(sx)
self.assertLen(st_set, 5)
@parameterized.parameters(0, 1, 2) @parameterized.parameters(0, 1, 2)
def test_pickle(self, payload_size): def test_pickle(self, payload_size):
...@@ -282,7 +313,7 @@ class StatusTest(parameterized.TestCase): ...@@ -282,7 +313,7 @@ class StatusTest(parameterized.TestCase):
ser = pickle.dumps(orig) ser = pickle.dumps(orig)
deser = pickle.loads(ser) deser = pickle.loads(ser)
self.assertEqualStatus(deser, orig) self.assertEqual(deser, orig)
self.assertIs(deser.__class__, orig.__class__) self.assertIs(deser.__class__, orig.__class__)
def test_init_from_serialized_exception_unexpected_len_state(self): def test_init_from_serialized_exception_unexpected_len_state(self):
...@@ -302,12 +333,12 @@ class StatusTest(parameterized.TestCase): ...@@ -302,12 +333,12 @@ class StatusTest(parameterized.TestCase):
def test_init_from_capsule_direct_ok(self): def test_init_from_capsule_direct_ok(self):
orig = status.Status(status.StatusCode.CANCELLED, 'Direct.') orig = status.Status(status.StatusCode.CANCELLED, 'Direct.')
from_cap = status.Status(status.InitFromTag.capsule, orig.as_absl_Status()) from_cap = status.Status(status.InitFromTag.capsule, orig.as_absl_Status())
self.assertEqualStatus(from_cap, orig) self.assertEqual(from_cap, orig)
def test_init_from_capsule_as_capsule_method_ok(self): def test_init_from_capsule_as_capsule_method_ok(self):
orig = status.Status(status.StatusCode.CANCELLED, 'AsCapsuleMethod.') orig = status.Status(status.StatusCode.CANCELLED, 'AsCapsuleMethod.')
from_cap = status.Status(status.InitFromTag.capsule, orig) from_cap = status.Status(status.InitFromTag.capsule, orig)
self.assertEqualStatus(from_cap, orig) self.assertEqual(from_cap, orig)
@parameterized.parameters((False, 'NULL'), (True, '"NotGood"')) @parameterized.parameters((False, 'NULL'), (True, '"NotGood"'))
def test_init_from_capsule_direct_bad_capsule(self, pass_name, quoted_name): def test_init_from_capsule_direct_bad_capsule(self, pass_name, quoted_name):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment