Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11_abseil
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11_abseil
Commits
b96bc8ee
Commit
b96bc8ee
authored
Sep 17, 2022
by
Ralf W. Grosse-Kunstleve
Committed by
Copybara-Service
Sep 17, 2022
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add Payload Management APIs and pickle support.
PiperOrigin-RevId: 475084993
parent
b863d63b
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
168 additions
and
6 deletions
+168
-6
pybind11_abseil/register_status_bindings.cc
+78
-2
pybind11_abseil/tests/status_test.py
+90
-4
No files found.
pybind11_abseil/register_status_bindings.cc
View file @
b96bc8ee
...
@@ -19,6 +19,8 @@ namespace pybind11 {
...
@@ -19,6 +19,8 @@ namespace pybind11 {
namespace
google
{
namespace
google
{
namespace
{
namespace
{
enum
struct
InitFromTag
{
capsule
,
serialized
};
// Returns false if status_or represents a non-ok status object, and true in all
// Returns false if status_or represents a non-ok status object, and true in all
// other cases (including the case that this is passed a non-status object).
// other cases (including the case that this is passed a non-status object).
bool
IsOk
(
handle
status_or
)
{
bool
IsOk
(
handle
status_or
)
{
...
@@ -110,6 +112,10 @@ str decode_utf8_replace(absl::string_view s) {
...
@@ -110,6 +112,10 @@ str decode_utf8_replace(absl::string_view s) {
namespace
internal
{
namespace
internal
{
void
RegisterStatusBindings
(
module
m
)
{
void
RegisterStatusBindings
(
module
m
)
{
enum_
<
InitFromTag
>
(
m
,
"InitFromTag"
)
.
value
(
"capsule"
,
InitFromTag
::
capsule
)
.
value
(
"serialized"
,
InitFromTag
::
serialized
);
enum_
<
absl
::
StatusCode
>
(
m
,
"StatusCode"
)
enum_
<
absl
::
StatusCode
>
(
m
,
"StatusCode"
)
.
value
(
"OK"
,
absl
::
StatusCode
::
kOk
)
.
value
(
"OK"
,
absl
::
StatusCode
::
kOk
)
.
value
(
"CANCELLED"
,
absl
::
StatusCode
::
kCancelled
)
.
value
(
"CANCELLED"
,
absl
::
StatusCode
::
kCancelled
)
...
@@ -131,7 +137,45 @@ void RegisterStatusBindings(module m) {
...
@@ -131,7 +137,45 @@ void RegisterStatusBindings(module m) {
class_
<
absl
::
Status
>
(
m
,
"Status"
)
class_
<
absl
::
Status
>
(
m
,
"Status"
)
.
def
(
init
())
.
def
(
init
())
.
def
(
init
<
absl
::
StatusCode
,
std
::
string
>
())
.
def
(
init
([](
InitFromTag
init_from_tag
,
const
object
&
obj
)
{
switch
(
init_from_tag
)
{
case
InitFromTag
:
:
capsule
:
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"Implemented in pending child cl/474244219."
);
throw
error_already_set
();
}
case
InitFromTag
:
:
serialized
:
{
auto
state
=
cast
<
tuple
>
(
obj
);
if
(
len
(
state
)
!=
3
)
{
throw
value_error
(
absl
::
StrCat
(
"Unexpected len(state) == "
,
len
(
state
),
" ["
,
__FILE__
,
":"
,
__LINE__
,
"]"
));
}
auto
code
=
cast
<
absl
::
StatusCode
>
(
state
[
0
]);
auto
message
=
cast
<
std
::
string
>
(
state
[
1
]);
auto
all_payloads
=
cast
<
tuple
>
(
state
[
2
]);
auto
status
=
std
::
unique_ptr
<
absl
::
Status
>
{
new
absl
::
Status
{
code
,
message
}};
for
(
auto
ap_item_obj
:
all_payloads
)
{
auto
ap_item_tup
=
cast
<
tuple
>
(
ap_item_obj
);
if
(
len
(
ap_item_tup
)
!=
2
)
{
throw
value_error
(
absl
::
StrCat
(
"Unexpected len(tuple) == "
,
len
(
ap_item_tup
),
" where (type_url, payload) is expected ["
,
__FILE__
,
":"
,
__LINE__
,
"]"
));
}
auto
type_url
=
cast
<
absl
::
string_view
>
(
ap_item_tup
[
0
]);
auto
payload
=
cast
<
absl
::
string_view
>
(
ap_item_tup
[
1
]);
status
->
SetPayload
(
type_url
,
absl
::
Cord
(
payload
));
}
return
status
;
}
}
throw
std
::
runtime_error
(
absl
::
StrCat
(
"Meant to be unreachable ["
,
__FILE__
,
":"
,
__LINE__
,
"]"
));
}),
arg
(
"init_from_tag"
),
arg
(
"obj"
))
.
def
(
init
<
absl
::
StatusCode
,
std
::
string
>
(),
arg
(
"code"
),
arg
(
"msg"
))
.
def
(
"ok"
,
&
absl
::
Status
::
ok
)
.
def
(
"ok"
,
&
absl
::
Status
::
ok
)
.
def
(
"code"
,
&
absl
::
Status
::
code
)
.
def
(
"code"
,
&
absl
::
Status
::
code
)
.
def
(
"code_int"
,
.
def
(
"code_int"
,
...
@@ -171,7 +215,39 @@ void RegisterStatusBindings(module m) {
...
@@ -171,7 +215,39 @@ void RegisterStatusBindings(module m) {
[](
const
absl
::
Status
&
self
)
{
[](
const
absl
::
Status
&
self
)
{
return
decode_utf8_replace
(
self
.
message
());
return
decode_utf8_replace
(
self
.
message
());
})
})
.
def
(
"IgnoreError"
,
&
absl
::
Status
::
IgnoreError
);
.
def
(
"IgnoreError"
,
&
absl
::
Status
::
IgnoreError
)
.
def
(
"SetPayload"
,
[](
absl
::
Status
&
self
,
absl
::
string_view
type_url
,
absl
::
string_view
payload
)
{
self
.
SetPayload
(
type_url
,
absl
::
Cord
(
payload
));
})
.
def
(
"ErasePayload"
,
[](
absl
::
Status
&
self
,
absl
::
string_view
type_url
)
{
return
self
.
ErasePayload
(
type_url
);
})
.
def
(
"AllPayloads"
,
[](
const
absl
::
Status
&
s
)
{
list
key_value_pairs
;
s
.
ForEachPayload
([
&
key_value_pairs
](
absl
::
string_view
key
,
const
absl
::
Cord
&
value
)
{
key_value_pairs
.
append
(
make_tuple
(
bytes
(
std
::
string
(
key
)),
bytes
(
std
::
string
(
value
))));
});
// Make the order deterministic, especially long-term.
key_value_pairs
.
attr
(
"sort"
)();
return
tuple
(
key_value_pairs
);
})
.
def
(
"__reduce_ex__"
,
[](
const
object
&
self
,
int
)
{
return
make_tuple
(
self
.
attr
(
"__class__"
),
make_tuple
(
InitFromTag
::
serialized
,
make_tuple
(
self
.
attr
(
"code"
)(),
self
.
attr
(
"message_bytes"
)(),
self
.
attr
(
"AllPayloads"
)())));
},
arg
(
"protocol"
)
=
-
1
);
m
.
def
(
"is_ok"
,
&
IsOk
,
arg
(
"status_or"
),
m
.
def
(
"is_ok"
,
&
IsOk
,
arg
(
"status_or"
),
"Returns false only if passed a non-ok status; otherwise returns true. "
"Returns false only if passed a non-ok status; otherwise returns true. "
...
...
pybind11_abseil/tests/status_test.py
View file @
b96bc8ee
"""Tests for google3.third_party.pybind11_abseil.status_casters."""
"""Tests for google3.third_party.pybind11_abseil.status_casters."""
from
__future__
import
absolute_import
import
pickle
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
pybind11_abseil
import
status
from
pybind11_abseil
import
status
from
pybind11_abseil.tests
import
status_example
from
pybind11_abseil.tests
import
status_example
...
@@ -14,7 +13,7 @@ def docstring_signature(f):
...
@@ -14,7 +13,7 @@ def docstring_signature(f):
return
f
.
__doc__
.
split
(
'
\n
'
)[
0
]
return
f
.
__doc__
.
split
(
'
\n
'
)[
0
]
class
StatusTest
(
absltest
.
TestCase
):
class
StatusTest
(
parameterized
.
TestCase
):
def
test_pass_status
(
self
):
def
test_pass_status
(
self
):
test_status
=
status
.
Status
(
status
.
StatusCode
.
CANCELLED
,
'test'
)
test_status
=
status
.
Status
(
status
.
StatusCode
.
CANCELLED
,
'test'
)
...
@@ -183,6 +182,93 @@ class StatusTest(absltest.TestCase):
...
@@ -183,6 +182,93 @@ class StatusTest(absltest.TestCase):
self
.
assertEqual
(
st500
.
raw_code
(),
500
)
self
.
assertEqual
(
st500
.
raw_code
(),
500
)
self
.
assertEqual
(
st500
.
code
(),
status
.
StatusCode
.
UNKNOWN
)
self
.
assertEqual
(
st500
.
code
(),
status
.
StatusCode
.
UNKNOWN
)
def
test_payload_management_apis
(
self
):
st
=
status
.
Status
(
status
.
StatusCode
.
CANCELLED
,
''
)
self
.
assertEqual
(
st
.
AllPayloads
(),
())
st
.
SetPayload
(
'Url1'
,
'Payload1'
)
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url1'
,
b
'Payload1'
),))
st
.
SetPayload
(
'Url0'
,
'Payload0'
)
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url0'
,
b
'Payload0'
),
(
b
'Url1'
,
b
'Payload1'
)))
st
.
SetPayload
(
'Url2'
,
'Payload2'
)
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url0'
,
b
'Payload0'
),
(
b
'Url1'
,
b
'Payload1'
),
(
b
'Url2'
,
b
'Payload2'
)))
st
.
SetPayload
(
'Url2'
,
'Payload2B'
)
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url0'
,
b
'Payload0'
),
(
b
'Url1'
,
b
'Payload1'
),
(
b
'Url2'
,
b
'Payload2B'
)))
self
.
assertTrue
(
st
.
ErasePayload
(
'Url1'
))
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url0'
,
b
'Payload0'
),
(
b
'Url2'
,
b
'Payload2B'
)))
self
.
assertFalse
(
st
.
ErasePayload
(
'Url1'
))
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url0'
,
b
'Payload0'
),
(
b
'Url2'
,
b
'Payload2B'
)))
self
.
assertFalse
(
st
.
ErasePayload
(
'UrlNeverExisted'
))
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url0'
,
b
'Payload0'
),
(
b
'Url2'
,
b
'Payload2B'
)))
self
.
assertTrue
(
st
.
ErasePayload
(
'Url0'
))
self
.
assertEqual
(
st
.
AllPayloads
(),
((
b
'Url2'
,
b
'Payload2B'
),))
self
.
assertTrue
(
st
.
ErasePayload
(
'Url2'
))
self
.
assertEqual
(
st
.
AllPayloads
(),
())
self
.
assertFalse
(
st
.
ErasePayload
(
'UrlNeverExisted'
))
self
.
assertEqual
(
st
.
AllPayloads
(),
())
def
assertEqualStatus
(
self
,
a
,
b
):
self
.
assertEqual
(
a
.
code
(),
b
.
code
())
self
.
assertEqual
(
a
.
message_bytes
(),
b
.
message_bytes
())
self
.
assertSequenceEqual
(
sorted
(
a
.
AllPayloads
()),
sorted
(
b
.
AllPayloads
()))
@parameterized.parameters
(
0
,
1
,
2
)
def
test_pickle
(
self
,
payload_size
):
orig
=
status
.
Status
(
status
.
StatusCode
.
CANCELLED
,
'Cucumber.'
)
expected_all_payloads
=
[]
for
i
in
range
(
payload_size
):
type_url
=
f
'Url{i}'
payload
=
f
'Payload{i}'
orig
.
SetPayload
(
type_url
,
payload
)
expected_all_payloads
.
append
((
type_url
.
encode
(),
payload
.
encode
()))
expected_all_payloads
=
tuple
(
expected_all_payloads
)
# Redundant with other tests, but here to reassure that the preconditions
# for the tests below to be meaningful are met.
self
.
assertEqual
(
orig
.
code
(),
status
.
StatusCode
.
CANCELLED
)
self
.
assertEqual
(
orig
.
message_bytes
(),
b
'Cucumber.'
)
self
.
assertEqual
(
orig
.
AllPayloads
(),
expected_all_payloads
)
# Exercises implementation details, but is simple and might be useful to
# narrow down root causes for regressions.
redx
=
orig
.
__reduce_ex__
()
self
.
assertLen
(
redx
,
2
)
self
.
assertIs
(
redx
[
0
],
status
.
Status
)
self
.
assertEqual
(
redx
[
1
],
(
status
.
InitFromTag
.
serialized
,
(
status
.
StatusCode
.
CANCELLED
,
b
'Cucumber.'
,
expected_all_payloads
)))
ser
=
pickle
.
dumps
(
orig
)
deser
=
pickle
.
loads
(
ser
)
self
.
assertEqualStatus
(
deser
,
orig
)
self
.
assertIs
(
deser
.
__class__
,
orig
.
__class__
)
def
test_init_from_serialized_exception_unexpected_len_state
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
r'Unexpected len\(state\) == 4'
r' \[.*register_status_bindings\.cc:[0-9]+\]'
):
status
.
Status
(
status
.
InitFromTag
.
serialized
,
(
0
,
0
,
0
,
0
))
def
test_init_from_serialized_exception_unexpected_len_ap_item_tup
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
r'Unexpected len\(tuple\) == 3 where \(type_url, payload\) is expected'
r' \[.*register_status_bindings\.cc:[0-9]+\]'
):
status
.
Status
(
status
.
InitFromTag
.
serialized
,
(
status
.
StatusCode
.
CANCELLED
,
''
,
((
0
,
0
,
0
),)))
def
test_init_from_capsule_not_implemented_error
(
self
):
with
self
.
assertRaises
(
NotImplementedError
):
status
.
Status
(
status
.
InitFromTag
.
capsule
,
())
class
IntGetter
(
status_example
.
IntGetter
):
class
IntGetter
(
status_example
.
IntGetter
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment