Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11_abseil
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11_abseil
Commits
67491a41
Commit
67491a41
authored
Dec 12, 2023
by
Ralf W. Grosse-Kunstleve
Committed by
Copybara-Service
Dec 12, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable passing `absl::Span<bool>` and `absl::Span<const bool>`
PiperOrigin-RevId: 590290022
parent
f37d4455
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
18 deletions
+92
-18
pybind11_abseil/absl_casters.h
+51
-13
pybind11_abseil/tests/absl_example.cc
+14
-0
pybind11_abseil/tests/absl_test.py
+27
-5
No files found.
pybind11_abseil/absl_casters.h
View file @
67491a41
...
...
@@ -387,7 +387,7 @@ namespace internal {
template
<
typename
T
>
static
constexpr
bool
is_buffer_interface_compatible_type
=
detail
::
is_same_ignoring_cvref
<
T
,
PyObject
*>::
value
||
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_arithmetic
<
std
::
remove_cv_t
<
T
>
>::
value
||
std
::
is_same
<
T
,
std
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
std
::
complex
<
double
>>::
value
;
...
...
@@ -405,7 +405,8 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
if
(
PyObject_GetBuffer
(
src
.
ptr
(),
&
view
,
flags
)
==
0
)
{
auto
cleanup
=
absl
::
MakeCleanup
([
&
view
]
{
PyBuffer_Release
(
&
view
);
});
if
(
view
.
ndim
==
1
&&
view
.
strides
[
0
]
==
sizeof
(
T
)
&&
buffer_info
(
&
view
,
/*ownview=*/
false
).
item_type_is_equivalent_to
<
T
>
())
{
buffer_info
(
&
view
,
/*ownview=*/
false
)
.
item_type_is_equivalent_to
<
std
::
remove_cv_t
<
T
>>
())
{
return
{
true
,
absl
::
MakeSpan
(
static_cast
<
T
*>
(
view
.
buf
),
view
.
shape
[
0
])};
}
}
else
{
...
...
@@ -421,6 +422,29 @@ constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle /*src*/) {
return
{
false
,
absl
::
Span
<
T
>
()};
}
template
<
typename
T
,
typename
std
::
enable_if
<
!
std
::
is_same
<
std
::
remove_cv_t
<
T
>
,
bool
>::
value
,
int
>::
type
=
0
>
std
::
tuple
<
bool
,
absl
::
Span
<
T
>>
LoadSpanOpaqueVector
(
handle
src
)
{
// Attempt to unwrap an opaque std::vector.
using
value_type
=
std
::
remove_cv_t
<
T
>
;
type_caster_base
<
std
::
vector
<
value_type
>>
caster
;
if
(
caster
.
load
(
src
,
false
))
{
return
{
true
,
absl
::
MakeSpan
(
static_cast
<
std
::
vector
<
value_type
>&>
(
caster
))};
}
return
{
false
,
absl
::
Span
<
T
>
()};
}
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
std
::
remove_cv_t
<
T
>
,
bool
>::
value
,
int
>::
type
=
0
>
std
::
tuple
<
bool
,
absl
::
Span
<
T
>>
LoadSpanOpaqueVector
(
handle
src
)
{
// std::vector<bool> is special and cannot directly be converted to a Span
// (see https://en.cppreference.com/w/cpp/container/vector_bool).
return
{
false
,
absl
::
Span
<
T
>
()};
}
// Helper to determine whether T is a span.
template
<
typename
T
>
struct
is_absl_span
:
std
::
false_type
{};
...
...
@@ -433,7 +457,7 @@ template <typename T>
struct
type_caster
<
absl
::
Span
<
T
>>
{
public
:
// The type referenced by the span, with const removed.
using
value_type
=
typename
std
::
remove_cv
<
T
>::
type
;
using
value_type
=
std
::
remove_cv_t
<
T
>
;
static_assert
(
!
is_absl_span
<
value_type
>::
value
,
"Nested absl spans are not supported."
);
...
...
@@ -479,19 +503,17 @@ struct type_caster<absl::Span<T>> {
std
::
tie
(
loaded
,
value_
)
=
LoadSpanFromBuffer
<
T
>
(
src
);
if
(
loaded
)
return
true
;
// Attempt to unwrap an opaque std::vector.
type_caster_base
<
std
::
vector
<
value_type
>>
caster
;
if
(
caster
.
load
(
src
,
false
))
{
value_
=
get_value
(
caster
);
return
true
;
}
std
::
tie
(
loaded
,
value_
)
=
LoadSpanOpaqueVector
<
T
>
(
src
);
if
(
loaded
)
return
true
;
// Attempt to convert a native sequence. If the is_base_of
_v
check passes,
// Attempt to convert a native sequence. If the is_base_of check passes,
// the elements do not require converting and pointers do not reference a
// temporary object owned by the element caster. Pointers to converted
// types are not allowed because they would result a dangling reference
// when the element caster is destroyed.
if
(
convert
&&
std
::
is_const
<
T
>::
value
&&
// See comment for ephemeral_storage_type below.
!
std
::
is_same
<
T
,
const
bool
>::
value
&&
(
!
std
::
is_pointer
<
T
>::
value
||
std
::
is_base_of
<
type_caster_generic
,
make_caster
<
T
>>::
value
))
{
list_caster_
.
emplace
();
...
...
@@ -512,12 +534,28 @@ struct type_caster<absl::Span<T>> {
}
private
:
template
<
typename
Caster
>
// Unfortunately using std::vector as ephemeral_storage_type creates
// complications for std::vector<bool>
// (https://en.cppreference.com/w/cpp/container/vector_bool).
using
ephemeral_storage_type
=
std
::
vector
<
value_type
>
;
template
<
typename
Caster
,
typename
VT
=
value_type
,
typename
std
::
enable_if
<!
std
::
is_same
<
VT
,
bool
>::
value
,
int
>::
type
=
0
>
absl
::
Span
<
T
>
get_value
(
Caster
&
caster
)
{
return
absl
::
MakeSpan
(
static_cast
<
std
::
vector
<
value_type
>&>
(
caster
));
return
absl
::
MakeSpan
(
static_cast
<
ephemeral_storage_type
&>
(
caster
));
}
// This template specialization is needed to avoid compilation errors.
// The conditions in load() make this code unreachable.
template
<
typename
Caster
,
typename
VT
=
value_type
,
typename
std
::
enable_if
<
std
::
is_same
<
VT
,
bool
>::
value
,
int
>::
type
=
0
>
absl
::
Span
<
T
>
get_value
(
Caster
&
)
{
throw
std
::
runtime_error
(
"Expected to be unreachable."
);
}
using
ListCaster
=
list_caster
<
std
::
vector
<
value_type
>
,
value_type
>
;
using
ListCaster
=
list_caster
<
ephemeral_storage_type
,
value_type
>
;
absl
::
optional
<
ListCaster
>
list_caster_
;
absl
::
Span
<
T
>
value_
;
};
...
...
pybind11_abseil/tests/absl_example.cc
View file @
67491a41
...
...
@@ -279,6 +279,18 @@ std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
return
result
;
}
std
::
string
PassSpanBool
(
absl
::
Span
<
bool
>
input_span
)
{
std
::
string
result
;
for
(
const
auto
&
i
:
input_span
)
result
+=
(
i
?
"t"
:
"f"
);
return
result
;
}
std
::
string
PassSpanConstBool
(
absl
::
Span
<
const
bool
>
input_span
)
{
std
::
string
result
;
for
(
const
auto
&
i
:
input_span
)
result
+=
(
i
?
"T"
:
"F"
);
return
result
;
}
struct
ObjectForSpan
{
explicit
ObjectForSpan
(
int
v
)
:
value
(
v
)
{}
int
value
;
...
...
@@ -404,6 +416,8 @@ PYBIND11_MODULE(absl_example, m) {
m
.
def
(
"sum_span_const_complex128"
,
&
SumSpanComplex
<
const
std
::
complex
<
double
>>
,
arg
(
"input_span"
));
m
.
def
(
"pass_span_pyobject_ptr"
,
&
PassSpanPyObjectPtr
,
arg
(
"span"
));
m
.
def
(
"pass_span_bool"
,
&
PassSpanBool
,
arg
(
"span"
));
m
.
def
(
"pass_span_const_bool"
,
&
PassSpanConstBool
,
arg
(
"span"
));
// Span of objects.
class_
<
ObjectForSpan
>
(
m
,
"ObjectForSpan"
)
...
...
pybind11_abseil/tests/absl_test.py
View file @
67491a41
...
...
@@ -312,7 +312,7 @@ def make_read_only_numpy_array():
return
values
def
make_srided_numpy_array
(
stride
):
def
make_s
t
rided_numpy_array
(
stride
):
return
np
.
zeros
(
10
,
dtype
=
np
.
int32
)[::
stride
]
...
...
@@ -373,10 +373,10 @@ class AbslNumericSpanTest(parameterized.TestCase):
@parameterized.named_parameters
(
(
'float_numpy'
,
np
.
zeros
(
5
,
dtype
=
float
)),
(
'two_d_numpy'
,
np
.
zeros
(
(
5
,
5
),
dtype
=
np
.
int32
)),
(
'read_only'
,
make_read_only_numpy_array
()),
(
'strided_skip'
,
make_srided_numpy_array
(
2
)),
(
'strided_reverse'
,
make_srided_numpy_array
(
-
1
)),
(
'two_d_numpy'
,
np
.
zeros
(
(
5
,
5
),
dtype
=
np
.
int32
)),
(
'read_only'
,
make_read_only_numpy_array
()),
(
'strided_skip'
,
make_s
t
rided_numpy_array
(
2
)),
(
'strided_reverse'
,
make_s
t
rided_numpy_array
(
-
1
)),
(
'non_supported_type'
,
np
.
zeros
(
5
,
dtype
=
np
.
unicode_
)),
(
'native_list'
,
[
0
]
*
5
))
def
test_fill_span_fails_from
(
self
,
values
):
...
...
@@ -397,6 +397,28 @@ class AbslNumericSpanTest(parameterized.TestCase):
arr
=
np
.
array
([
-
3
,
'four'
,
5.0
],
dtype
=
object
)
self
.
assertEqual
(
absl_example
.
pass_span_pyobject_ptr
(
arr
),
'-3four5.0'
)
@parameterized.parameters
(
([],
''
),
([
False
],
'f'
),
([
True
],
't'
),
([
False
,
True
,
True
,
False
],
'fttf'
),
)
def
test_pass_span_bool
(
self
,
bools
,
expected
):
arr
=
np
.
array
(
bools
,
dtype
=
bool
)
s
=
absl_example
.
pass_span_bool
(
arr
)
self
.
assertEqual
(
s
,
expected
)
@parameterized.parameters
(
([],
''
),
([
False
],
'F'
),
([
True
],
'T'
),
([
False
,
True
,
True
,
False
],
'FTTF'
),
)
def
test_pass_span_const_bool
(
self
,
bools
,
expected
):
arr
=
np
.
array
(
bools
,
dtype
=
bool
)
s
=
absl_example
.
pass_span_const_bool
(
arr
)
self
.
assertEqual
(
s
,
expected
)
def
make_native_list_of_objects
():
return
[
absl_example
.
ObjectForSpan
(
3
),
absl_example
.
ObjectForSpan
(
5
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment