Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11
Commits
2184f6d4
Commit
2184f6d4
authored
Oct 31, 2016
by
Ivan Smirnov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
NumPy dtypes are now shared across extensions
parent
a743ead4
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
27 deletions
+54
-27
include/pybind11/common.h
+1
-0
include/pybind11/numpy.h
+52
-26
tests/test_numpy_dtypes.py
+1
-1
No files found.
include/pybind11/common.h
View file @
2184f6d4
...
@@ -323,6 +323,7 @@ struct internals {
...
@@ -323,6 +323,7 @@ struct internals {
std
::
unordered_set
<
std
::
pair
<
const
PyObject
*
,
const
char
*>
,
overload_hash
>
inactive_overload_cache
;
std
::
unordered_set
<
std
::
pair
<
const
PyObject
*
,
const
char
*>
,
overload_hash
>
inactive_overload_cache
;
std
::
unordered_map
<
std
::
type_index
,
std
::
vector
<
bool
(
*
)(
PyObject
*
,
void
*&
)
>>
direct_conversions
;
std
::
unordered_map
<
std
::
type_index
,
std
::
vector
<
bool
(
*
)(
PyObject
*
,
void
*&
)
>>
direct_conversions
;
std
::
forward_list
<
void
(
*
)
(
std
::
exception_ptr
)
>
registered_exception_translators
;
std
::
forward_list
<
void
(
*
)
(
std
::
exception_ptr
)
>
registered_exception_translators
;
std
::
unordered_map
<
std
::
string
,
void
*>
shared_data
;
#if defined(WITH_THREAD)
#if defined(WITH_THREAD)
decltype
(
PyThread_create_key
())
tstate
=
0
;
// Usually an int but a long on Cygwin64 with Python 3.x
decltype
(
PyThread_create_key
())
tstate
=
0
;
// Usually an int but a long on Cygwin64 with Python 3.x
PyInterpreterState
*
istate
=
nullptr
;
PyInterpreterState
*
istate
=
nullptr
;
...
...
include/pybind11/numpy.h
View file @
2184f6d4
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <initializer_list>
#include <initializer_list>
#include <functional>
#include <functional>
#include <utility>
#include <utility>
#include <typeindex>
#if defined(_MSC_VER)
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(push)
...
@@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
...
@@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
PyObject
*
base
;
PyObject
*
base
;
};
};
struct
numpy_type_info
{
PyObject
*
dtype_ptr
;
std
::
string
format_str
;
};
struct
numpy_internals
{
std
::
unordered_map
<
std
::
type_index
,
numpy_type_info
>
registered_dtypes
;
template
<
typename
T
>
numpy_type_info
*
get_type_info
(
bool
throw_if_missing
=
true
)
{
auto
it
=
registered_dtypes
.
find
(
std
::
type_index
(
typeid
(
T
)));
if
(
it
!=
registered_dtypes
.
end
())
return
&
(
it
->
second
);
if
(
throw_if_missing
)
pybind11_fail
(
std
::
string
(
"NumPy type info missing for "
)
+
typeid
(
T
).
name
());
return
nullptr
;
}
};
inline
PYBIND11_NOINLINE
numpy_internals
*
load_numpy_internals
()
{
auto
&
shared_data
=
detail
::
get_internals
().
shared_data
;
auto
it
=
shared_data
.
find
(
"numpy_internals"
);
if
(
it
!=
shared_data
.
end
())
return
(
numpy_internals
*
)
it
->
second
;
auto
ptr
=
new
numpy_internals
();
shared_data
[
"numpy_internals"
]
=
ptr
;
return
ptr
;
}
inline
numpy_internals
&
get_numpy_internals
()
{
static
numpy_internals
*
ptr
=
load_numpy_internals
();
return
*
ptr
;
}
struct
npy_api
{
struct
npy_api
{
enum
constants
{
enum
constants
{
NPY_C_CONTIGUOUS_
=
0x0001
,
NPY_C_CONTIGUOUS_
=
0x0001
,
...
@@ -661,30 +695,29 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
...
@@ -661,30 +695,29 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
static
pybind11
::
dtype
dtype
()
{
static
pybind11
::
dtype
dtype
()
{
if
(
!
dtype_ptr
)
return
object
(
dtype_ptr
(),
true
);
pybind11_fail
(
"NumPy: unsupported buffer format!"
);
return
object
(
dtype_ptr
,
true
);
}
}
static
std
::
string
format
()
{
static
std
::
string
format
()
{
if
(
!
dtype_ptr
)
static
auto
format_str
=
get_numpy_internals
().
get_type_info
<
T
>
(
true
)
->
format_str
;
pybind11_fail
(
"NumPy: unsupported buffer format!"
);
return
format_str
;
return
format_str
;
}
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
if
(
dtype_ptr
)
auto
&
numpy_internals
=
get_numpy_internals
();
if
(
numpy_internals
.
get_type_info
<
T
>
(
false
))
pybind11_fail
(
"NumPy: dtype is already registered"
);
pybind11_fail
(
"NumPy: dtype is already registered"
);
list
names
,
formats
,
offsets
;
list
names
,
formats
,
offsets
;
for
(
auto
field
:
fields
)
{
for
(
auto
field
:
fields
)
{
if
(
!
field
.
descr
)
if
(
!
field
.
descr
)
pybind11_fail
(
"NumPy: unsupported field dtype"
);
pybind11_fail
(
std
::
string
(
"NumPy: unsupported field dtype: `"
)
+
field
.
name
+
"` @ "
+
typeid
(
T
).
name
());
names
.
append
(
PYBIND11_STR_TYPE
(
field
.
name
));
names
.
append
(
PYBIND11_STR_TYPE
(
field
.
name
));
formats
.
append
(
field
.
descr
);
formats
.
append
(
field
.
descr
);
offsets
.
append
(
pybind11
::
int_
(
field
.
offset
));
offsets
.
append
(
pybind11
::
int_
(
field
.
offset
));
}
}
dtype_ptr
=
pybind11
::
dtype
(
names
,
formats
,
offsets
,
sizeof
(
T
)).
release
().
ptr
();
auto
dtype_ptr
=
pybind11
::
dtype
(
names
,
formats
,
offsets
,
sizeof
(
T
)).
release
().
ptr
();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// not encoded explicitly into the format string. This will supposedly
...
@@ -695,9 +728,7 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
...
@@ -695,9 +728,7 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
// strings and will just do it ourselves.
// strings and will just do it ourselves.
std
::
vector
<
field_descriptor
>
ordered_fields
(
fields
);
std
::
vector
<
field_descriptor
>
ordered_fields
(
fields
);
std
::
sort
(
ordered_fields
.
begin
(),
ordered_fields
.
end
(),
std
::
sort
(
ordered_fields
.
begin
(),
ordered_fields
.
end
(),
[](
const
field_descriptor
&
a
,
const
field_descriptor
&
b
)
{
[](
const
field_descriptor
&
a
,
const
field_descriptor
&
b
)
{
return
a
.
offset
<
b
.
offset
;
});
return
a
.
offset
<
b
.
offset
;
});
size_t
offset
=
0
;
size_t
offset
=
0
;
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
oss
<<
"T{"
;
oss
<<
"T{"
;
...
@@ -711,44 +742,39 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
...
@@ -711,44 +742,39 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
if
(
sizeof
(
T
)
>
offset
)
if
(
sizeof
(
T
)
>
offset
)
oss
<<
(
sizeof
(
T
)
-
offset
)
<<
'x'
;
oss
<<
(
sizeof
(
T
)
-
offset
)
<<
'x'
;
oss
<<
'}'
;
oss
<<
'}'
;
format_str
=
oss
.
str
();
auto
format_str
=
oss
.
str
();
// Sanity check: verify that NumPy properly parses our buffer format string
// Sanity check: verify that NumPy properly parses our buffer format string
auto
&
api
=
npy_api
::
get
();
auto
&
api
=
npy_api
::
get
();
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format
()
,
1
));
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format
_str
,
1
));
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
arr
.
dtype
().
ptr
()))
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
arr
.
dtype
().
ptr
()))
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
register_direct_converter
();
auto
tindex
=
std
::
type_index
(
typeid
(
T
));
numpy_internals
.
registered_dtypes
[
tindex
]
=
{
dtype_ptr
,
format_str
};
get_internals
().
direct_conversions
[
tindex
].
push_back
(
direct_converter
);
}
}
private
:
private
:
static
std
::
string
format_str
;
static
PyObject
*
dtype_ptr
()
{
static
PyObject
*
dtype_ptr
;
static
PyObject
*
ptr
=
get_numpy_internals
().
get_type_info
<
T
>
(
true
)
->
dtype_ptr
;
return
ptr
;
}
static
bool
direct_converter
(
PyObject
*
obj
,
void
*&
value
)
{
static
bool
direct_converter
(
PyObject
*
obj
,
void
*&
value
)
{
auto
&
api
=
npy_api
::
get
();
auto
&
api
=
npy_api
::
get
();
if
(
!
PyObject_TypeCheck
(
obj
,
api
.
PyVoidArrType_Type_
))
if
(
!
PyObject_TypeCheck
(
obj
,
api
.
PyVoidArrType_Type_
))
return
false
;
return
false
;
if
(
auto
descr
=
object
(
api
.
PyArray_DescrFromScalar_
(
obj
),
false
))
{
if
(
auto
descr
=
object
(
api
.
PyArray_DescrFromScalar_
(
obj
),
false
))
{
if
(
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
descr
.
ptr
()))
{
if
(
api
.
PyArray_EquivTypes_
(
dtype_ptr
()
,
descr
.
ptr
()))
{
value
=
((
PyVoidScalarObject_Proxy
*
)
obj
)
->
obval
;
value
=
((
PyVoidScalarObject_Proxy
*
)
obj
)
->
obval
;
return
true
;
return
true
;
}
}
}
}
return
false
;
return
false
;
}
}
static
void
register_direct_converter
()
{
get_internals
().
direct_conversions
[
std
::
type_index
(
typeid
(
T
))].
push_back
(
direct_converter
);
}
};
};
template
<
typename
T
>
std
::
string
npy_format_descriptor
<
T
,
enable_if_t
<
is_pod_struct
<
T
>::
value
>>::
format_str
;
template
<
typename
T
>
PyObject
*
npy_format_descriptor
<
T
,
enable_if_t
<
is_pod_struct
<
T
>::
value
>>::
dtype_ptr
=
nullptr
;
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
::pybind11::detail::field_descriptor { \
::pybind11::detail::field_descriptor { \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
...
...
tests/test_numpy_dtypes.py
View file @
2184f6d4
...
@@ -18,7 +18,7 @@ def test_format_descriptors():
...
@@ -18,7 +18,7 @@ def test_format_descriptors():
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
get_format_unbound
()
get_format_unbound
()
assert
'unsupported buffer format'
in
str
(
excinfo
.
value
)
assert
re
.
match
(
'^NumPy type info missing for .*UnboundStruct.*$'
,
str
(
excinfo
.
value
)
)
assert
print_format_descriptors
()
==
[
assert
print_format_descriptors
()
==
[
"T{=?:x:3x=I:y:=f:z:}"
,
"T{=?:x:3x=I:y:=f:z:}"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment