Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11
Commits
01f74095
Commit
01f74095
authored
Jul 23, 2016
by
Ivan Smirnov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Initial implementation of py::dtype
parent
05cb58ad
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
83 deletions
+110
-83
example/example-numpy-dtypes.cpp
+6
-9
include/pybind11/numpy.h
+104
-74
No files found.
example/example-numpy-dtypes.cpp
View file @
01f74095
...
@@ -158,15 +158,12 @@ void print_format_descriptors() {
...
@@ -158,15 +158,12 @@ void print_format_descriptors() {
}
}
void
print_dtypes
()
{
void
print_dtypes
()
{
auto
to_str
=
[](
py
::
object
obj
)
{
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
SimpleStruct
>
().
str
()
<<
std
::
endl
;
return
(
std
::
string
)
(
py
::
str
)
((
py
::
object
)
obj
.
attr
(
"__str__"
))();
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
PackedStruct
>
().
str
()
<<
std
::
endl
;
};
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
NestedStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
SimpleStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
PartialStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
PackedStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
PartialNestedStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
NestedStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
StringStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
PartialStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
PartialNestedStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
StringStruct
>
())
<<
std
::
endl
;
}
}
void
init_ex_numpy_dtypes
(
py
::
module
&
m
)
{
void
init_ex_numpy_dtypes
(
py
::
module
&
m
)
{
...
...
include/pybind11/numpy.h
View file @
01f74095
...
@@ -52,7 +52,12 @@ struct npy_api {
...
@@ -52,7 +52,12 @@ struct npy_api {
return
api
;
return
api
;
}
}
bool
PyArray_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArray_Type_
);
}
bool
PyArray_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArray_Type_
);
}
bool
PyArrayDescr_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArrayDescr_Type_
);
}
PyObject
*
(
*
PyArray_DescrFromType_
)(
int
);
PyObject
*
(
*
PyArray_DescrFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewFromDescr_
)
PyObject
*
(
*
PyArray_NewFromDescr_
)
...
@@ -61,6 +66,7 @@ struct npy_api {
...
@@ -61,6 +66,7 @@ struct npy_api {
PyObject
*
(
*
PyArray_DescrNewFromType_
)(
int
);
PyObject
*
(
*
PyArray_DescrNewFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyTypeObject
*
PyArray_Type_
;
PyTypeObject
*
PyArray_Type_
;
PyTypeObject
*
PyArrayDescr_Type_
;
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
bool
(
*
PyArray_EquivTypes_
)
(
PyObject
*
,
PyObject
*
);
bool
(
*
PyArray_EquivTypes_
)
(
PyObject
*
,
PyObject
*
);
...
@@ -69,6 +75,7 @@ struct npy_api {
...
@@ -69,6 +75,7 @@ struct npy_api {
private
:
private
:
enum
functions
{
enum
functions
{
API_PyArray_Type
=
2
,
API_PyArray_Type
=
2
,
API_PyArrayDescr_Type
=
3
,
API_PyArray_DescrFromType
=
45
,
API_PyArray_DescrFromType
=
45
,
API_PyArray_FromAny
=
69
,
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewCopy
=
85
,
...
@@ -90,6 +97,7 @@ private:
...
@@ -90,6 +97,7 @@ private:
npy_api
api
;
npy_api
api
;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API
(
PyArray_Type
);
DECL_NPY_API
(
PyArray_Type
);
DECL_NPY_API
(
PyArrayDescr_Type
);
DECL_NPY_API
(
PyArray_DescrFromType
);
DECL_NPY_API
(
PyArray_DescrFromType
);
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_NewCopy
);
DECL_NPY_API
(
PyArray_NewCopy
);
...
@@ -104,56 +112,55 @@ private:
...
@@ -104,56 +112,55 @@ private:
};
};
}
}
class
array
:
public
buffer
{
class
dtype
:
public
object
{
public
:
public
:
PYBIND11_OBJECT_DEFAULT
(
array
,
buffer
,
detail
::
npy_api
::
get
().
PyArray_Check_
)
PYBIND11_OBJECT_DEFAULT
(
dtype
,
object
,
detail
::
npy_api
::
get
().
PyArrayDescr_Check_
);
enum
{
dtype
(
const
buffer_info
&
info
)
{
c_style
=
detail
::
npy_api
::
NPY_C_CONTIGUOUS_
,
dtype
descr
(
_dtype_from_pep3118
()(
pybind11
::
str
(
info
.
format
)));
f_style
=
detail
::
npy_api
::
NPY_F_CONTIGUOUS_
,
m_ptr
=
descr
.
strip_padding
().
release
().
ptr
();
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
}
};
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
dtype
(
std
::
string
format
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
m_ptr
=
from_args
(
pybind11
::
str
(
format
)).
release
().
ptr
();
PyObject
*
descr
=
detail
::
npy_format_descriptor
<
Type
>::
dtype
().
release
().
ptr
();
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
object
tmp
=
object
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
,
1
,
&
shape
,
nullptr
,
(
void
*
)
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
}
}
array
(
const
buffer_info
&
info
)
{
static
dtype
from_args
(
object
args
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
// This is essentially the same as calling np.dtype() constructor in Python
PyObject
*
ptr
=
nullptr
;
if
(
!
detail
::
npy_api
::
get
().
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
ptr
)
||
!
ptr
)
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
return
object
(
ptr
,
false
);
}
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
template
<
typename
T
>
static
dtype
of
()
{
auto
numpy_internal
=
module
::
import
(
"numpy.core._internal"
);
return
detail
::
npy_format_descriptor
<
T
>::
dtype
();
auto
dtype_from_fmt
=
(
object
)
numpy_internal
.
attr
(
"_dtype_from_pep3118"
);
}
auto
dtype
=
strip_padding_fields
(
dtype_from_fmt
(
pybind11
::
str
(
info
.
format
)));
object
tmp
(
api
.
PyArray_NewFromDescr_
(
size_t
itemsize
()
const
{
api
.
PyArray_Type_
,
dtype
.
release
().
ptr
(),
(
int
)
info
.
ndim
,
(
Py_intptr_t
*
)
&
info
.
shape
[
0
],
return
(
size_t
)
attr
(
"itemsize"
).
cast
<
int_
>
();
(
Py_intptr_t
*
)
&
info
.
strides
[
0
],
info
.
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
info
.
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
}
}
protected
:
bool
has_fields
()
const
{
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
return
attr
(
"fields"
).
cast
<
object
>
().
ptr
()
!=
Py_None
;
}
std
::
string
kind
()
const
{
return
(
std
::
string
)
attr
(
"kind"
).
cast
<
pybind11
::
str
>
();
}
private
:
static
object
&
_dtype_from_pep3118
()
{
static
object
obj
=
module
::
import
(
"numpy.core._internal"
).
attr
(
"_dtype_from_pep3118"
);
return
obj
;
}
static
object
strip_padding_fields
(
object
dtype
)
{
dtype
strip_padding
(
)
{
// Recursively strip all void fields with empty names that are generated for
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
// padding fields (as of NumPy v1.11).
auto
fields
=
dtype
.
attr
(
"fields"
).
cast
<
object
>
();
auto
fields
=
attr
(
"fields"
).
cast
<
object
>
();
if
(
fields
.
ptr
()
==
Py_None
)
if
(
fields
.
ptr
()
==
Py_None
)
return
dtype
;
return
*
this
;
struct
field_descr
{
pybind11
::
str
name
;
object
format
;
int_
offset
;
};
struct
field_descr
{
pybind11
::
str
name
;
object
format
;
int_
offset
;
};
std
::
vector
<
field_descr
>
field_descriptors
;
std
::
vector
<
field_descr
>
field_descriptors
;
...
@@ -162,11 +169,11 @@ protected:
...
@@ -162,11 +169,11 @@ protected:
for
(
auto
field
:
items
())
{
for
(
auto
field
:
items
())
{
auto
spec
=
object
(
field
,
true
).
cast
<
tuple
>
();
auto
spec
=
object
(
field
,
true
).
cast
<
tuple
>
();
auto
name
=
spec
[
0
].
cast
<
pybind11
::
str
>
();
auto
name
=
spec
[
0
].
cast
<
pybind11
::
str
>
();
auto
format
=
spec
[
1
].
cast
<
tuple
>
()[
0
].
cast
<
object
>
();
auto
format
=
spec
[
1
].
cast
<
tuple
>
()[
0
].
cast
<
dtype
>
();
auto
offset
=
spec
[
1
].
cast
<
tuple
>
()[
1
].
cast
<
int_
>
();
auto
offset
=
spec
[
1
].
cast
<
tuple
>
()[
1
].
cast
<
int_
>
();
if
(
!
len
(
name
)
&&
(
std
::
string
)
dtype
.
attr
(
"kind"
).
cast
<
pybind11
::
str
>
()
==
"V"
)
if
(
!
len
(
name
)
&&
format
.
kind
()
==
"V"
)
continue
;
continue
;
field_descriptors
.
push_back
({
name
,
strip_padding_fields
(
format
),
offset
});
field_descriptors
.
push_back
({
name
,
format
.
strip_padding
(
),
offset
});
}
}
std
::
sort
(
field_descriptors
.
begin
(),
field_descriptors
.
end
(),
std
::
sort
(
field_descriptors
.
begin
(),
field_descriptors
.
end
(),
...
@@ -176,19 +183,57 @@ protected:
...
@@ -176,19 +183,57 @@ protected:
list
names
,
formats
,
offsets
;
list
names
,
formats
,
offsets
;
for
(
auto
&
descr
:
field_descriptors
)
{
for
(
auto
&
descr
:
field_descriptors
)
{
names
.
append
(
descr
.
name
);
names
.
append
(
descr
.
name
);
formats
.
append
(
descr
.
format
);
offsets
.
append
(
descr
.
offset
);
formats
.
append
(
descr
.
format
);
offsets
.
append
(
descr
.
offset
);
}
}
auto
args
=
dict
();
auto
args
=
dict
();
args
[
"names"
]
=
names
;
args
[
"formats"
]
=
formats
;
args
[
"offsets"
]
=
offsets
;
args
[
"names"
]
=
names
;
args
[
"formats"
]
=
formats
;
args
[
"offsets"
]
=
offsets
;
args
[
"itemsize"
]
=
dtype
.
attr
(
"itemsize"
).
cast
<
int_
>
();
args
[
"itemsize"
]
=
(
int_
)
itemsize
();
return
dtype
::
from_args
(
args
);
}
};
PyObject
*
descr
=
nullptr
;
class
array
:
public
buffer
{
if
(
!
detail
::
npy_api
::
get
().
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
descr
)
||
!
descr
)
public
:
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
PYBIND11_OBJECT_DEFAULT
(
array
,
buffer
,
detail
::
npy_api
::
get
().
PyArray_Check_
)
return
object
(
descr
,
false
);
enum
{
c_style
=
detail
::
npy_api
::
NPY_C_CONTIGUOUS_
,
f_style
=
detail
::
npy_api
::
NPY_F_CONTIGUOUS_
,
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
};
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
descr
=
pybind11
::
dtype
::
of
<
Type
>
().
release
().
ptr
();
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
object
tmp
=
object
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
,
1
,
&
shape
,
nullptr
,
(
void
*
)
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
}
array
(
const
buffer_info
&
info
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
descr
=
pybind11
::
dtype
(
info
).
release
().
ptr
();
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
,
(
int
)
info
.
ndim
,
(
Py_intptr_t
*
)
&
info
.
shape
[
0
],
(
Py_intptr_t
*
)
&
info
.
strides
[
0
],
info
.
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
info
.
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
}
}
pybind11
::
dtype
dtype
()
{
return
attr
(
"dtype"
).
cast
<
pybind11
::
dtype
>
();
}
protected
:
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
};
};
template
<
typename
T
,
int
ExtraFlags
=
array
::
forcecast
>
class
array_t
:
public
array
{
template
<
typename
T
,
int
ExtraFlags
=
array
::
forcecast
>
class
array_t
:
public
array
{
...
@@ -201,8 +246,7 @@ public:
...
@@ -201,8 +246,7 @@ public:
if
(
ptr
==
nullptr
)
if
(
ptr
==
nullptr
)
return
nullptr
;
return
nullptr
;
auto
&
api
=
detail
::
npy_api
::
get
();
auto
&
api
=
detail
::
npy_api
::
get
();
PyObject
*
descr
=
detail
::
npy_format_descriptor
<
T
>::
dtype
().
release
().
ptr
();
PyObject
*
result
=
api
.
PyArray_FromAny_
(
ptr
,
pybind11
::
dtype
::
of
<
T
>
().
release
().
ptr
(),
0
,
0
,
PyObject
*
result
=
api
.
PyArray_FromAny_
(
ptr
,
descr
,
0
,
0
,
detail
::
npy_api
::
NPY_ENSURE_ARRAY_
|
ExtraFlags
,
nullptr
);
detail
::
npy_api
::
NPY_ENSURE_ARRAY_
|
ExtraFlags
,
nullptr
);
if
(
!
result
)
if
(
!
result
)
PyErr_Clear
();
PyErr_Clear
();
...
@@ -223,11 +267,6 @@ template <size_t N> struct format_descriptor<std::array<char, N>> {
...
@@ -223,11 +267,6 @@ template <size_t N> struct format_descriptor<std::array<char, N>> {
static
const
char
*
format
()
{
PYBIND11_DESCR
s
=
detail
::
_
<
N
>
()
+
detail
::
_
(
"s"
);
return
s
.
text
();
}
static
const
char
*
format
()
{
PYBIND11_DESCR
s
=
detail
::
_
<
N
>
()
+
detail
::
_
(
"s"
);
return
s
.
text
();
}
};
};
template
<
typename
T
>
object
dtype_of
()
{
return
detail
::
npy_format_descriptor
<
T
>::
dtype
();
}
NAMESPACE_BEGIN
(
detail
)
NAMESPACE_BEGIN
(
detail
)
template
<
typename
T
>
struct
is_std_array
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_std_array
:
std
::
false_type
{
};
template
<
typename
T
,
size_t
N
>
struct
is_std_array
<
std
::
array
<
T
,
N
>>
:
std
::
true_type
{
};
template
<
typename
T
,
size_t
N
>
struct
is_std_array
<
std
::
array
<
T
,
N
>>
:
std
::
true_type
{
};
...
@@ -252,7 +291,7 @@ private:
...
@@ -252,7 +291,7 @@ private:
npy_api
::
NPY_INT_
,
npy_api
::
NPY_UINT_
,
npy_api
::
NPY_LONGLONG_
,
npy_api
::
NPY_ULONGLONG_
};
npy_api
::
NPY_INT_
,
npy_api
::
NPY_UINT_
,
npy_api
::
NPY_LONGLONG_
,
npy_api
::
NPY_ULONGLONG_
};
public
:
public
:
enum
{
value
=
values
[
detail
::
log2
(
sizeof
(
T
))
*
2
+
(
std
::
is_unsigned
<
T
>::
value
?
1
:
0
)]
};
enum
{
value
=
values
[
detail
::
log2
(
sizeof
(
T
))
*
2
+
(
std
::
is_unsigned
<
T
>::
value
?
1
:
0
)]
};
static
object
dtype
()
{
static
pybind11
::
dtype
dtype
()
{
if
(
auto
ptr
=
npy_api
::
get
().
PyArray_DescrFromType_
(
value
))
if
(
auto
ptr
=
npy_api
::
get
().
PyArray_DescrFromType_
(
value
))
return
object
(
ptr
,
true
);
return
object
(
ptr
,
true
);
pybind11_fail
(
"Unsupported buffer format!"
);
pybind11_fail
(
"Unsupported buffer format!"
);
...
@@ -267,7 +306,7 @@ template <typename T> constexpr const int npy_format_descriptor<
...
@@ -267,7 +306,7 @@ template <typename T> constexpr const int npy_format_descriptor<
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
enum { value = npy_api::NumPyName }; \
enum { value = npy_api::NumPyName }; \
static
object
dtype() { \
static
pybind11::dtype
dtype() { \
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
return object(ptr, true); \
return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \
pybind11_fail("Unsupported buffer format!"); \
...
@@ -282,14 +321,9 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
...
@@ -282,14 +321,9 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \
#define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \
static pybind11::dtype dtype() { \
auto& api = npy_api::get(); \
PyObject *descr = nullptr; \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \
return pybind11::dtype::from_args(pybind11::str(fmt.text())); \
if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
pybind11_fail("NumPy: failed to create string dtype"); \
return object(descr, false); \
} \
} \
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
template
<
size_t
N
>
struct
npy_format_descriptor
<
char
[
N
]
>
{
DECL_CHAR_FMT
};
template
<
size_t
N
>
struct
npy_format_descriptor
<
char
[
N
]
>
{
DECL_CHAR_FMT
};
...
@@ -301,14 +335,14 @@ struct field_descriptor {
...
@@ -301,14 +335,14 @@ struct field_descriptor {
size_t
offset
;
size_t
offset
;
size_t
size
;
size_t
size
;
const
char
*
format
;
const
char
*
format
;
object
descr
;
dtype
descr
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
npy_format_descriptor
<
T
,
typename
std
::
enable_if
<
is_pod_struct
<
T
>::
value
>::
type
>
{
struct
npy_format_descriptor
<
T
,
typename
std
::
enable_if
<
is_pod_struct
<
T
>::
value
>::
type
>
{
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
static
object
dtype
()
{
static
pybind11
::
dtype
dtype
()
{
if
(
!
dtype_
())
if
(
!
dtype_
())
pybind11_fail
(
"NumPy: unsupported buffer format!"
);
pybind11_fail
(
"NumPy: unsupported buffer format!"
);
return
object
(
dtype_
(),
true
);
return
object
(
dtype_
(),
true
);
...
@@ -321,7 +355,6 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
...
@@ -321,7 +355,6 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
auto
&
api
=
npy_api
::
get
();
auto
args
=
dict
();
auto
args
=
dict
();
list
names
{
},
offsets
{
},
formats
{
};
list
names
{
},
offsets
{
},
formats
{
};
for
(
auto
field
:
fields
)
{
for
(
auto
field
:
fields
)
{
...
@@ -333,10 +366,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
...
@@ -333,10 +366,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
}
args
[
"names"
]
=
names
;
args
[
"offsets"
]
=
offsets
;
args
[
"formats"
]
=
formats
;
args
[
"names"
]
=
names
;
args
[
"offsets"
]
=
offsets
;
args
[
"formats"
]
=
formats
;
args
[
"itemsize"
]
=
int_
(
sizeof
(
T
));
args
[
"itemsize"
]
=
int_
(
sizeof
(
T
));
// This is essentially the same as calling np.dtype() constructor in Python and passing
dtype_
()
=
pybind11
::
dtype
::
from_args
(
args
).
release
().
ptr
();
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
if
(
!
api
.
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
dtype_
())
||
!
dtype_
())
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// not encoded explicitly into the format string. This will supposedly
...
@@ -366,9 +396,9 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
...
@@ -366,9 +396,9 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
format_
()
=
oss
.
str
();
format_
()
=
oss
.
str
();
// Sanity check: verify that NumPy properly parses our buffer format string
// Sanity check: verify that NumPy properly parses our buffer format string
auto
&
api
=
npy_api
::
get
();
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format
(),
1
,
{
0
},
{
sizeof
(
T
)
}));
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format
(),
1
,
{
0
},
{
sizeof
(
T
)
}));
auto
fixed_dtype
=
array
::
strip_padding_fields
(
object
(
dtype_
(),
true
));
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_
(),
arr
.
dtype
().
ptr
()))
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_
(),
fixed_dtype
.
ptr
()))
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment