Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11
Commits
43398a85
Commit
43398a85
authored
Jul 28, 2015
by
Wenzel Jakob
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
complex number support
parent
d4258baf
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
28 deletions
+88
-28
example/example10.cpp
+6
-1
example/example10.py
+3
-0
include/pybind/cast.h
+17
-0
include/pybind/common.h
+6
-4
include/pybind/numpy.h
+39
-14
include/pybind/pybind.h
+17
-9
No files found.
example/example10.cpp
View file @
43398a85
...
...
@@ -15,14 +15,19 @@ double my_func(int x, float y, double z) {
return
x
*
y
*
z
;
}
std
::
complex
<
double
>
my_func3
(
std
::
complex
<
double
>
c
)
{
return
c
*
std
::
complex
<
double
>
(
2.
f
);
}
void
init_ex10
(
py
::
module
&
m
)
{
// Vectorize all arguments (though non-vector arguments are also allowed)
m
.
def
(
"vectorized_func"
,
py
::
vectorize
(
my_func
));
// Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
m
.
def
(
"vectorized_func2"
,
[](
py
::
array_dtype
<
int
>
x
,
py
::
array_dtype
<
float
>
y
,
float
z
)
{
return
py
::
vectorize
([
z
](
int
x
,
float
y
)
{
return
my_func
(
x
,
y
,
z
);
})(
x
,
y
);
}
);
// Vectorize all arguments (complex numbers)
m
.
def
(
"vectorized_func3"
,
py
::
vectorize
(
my_func3
));
}
example/example10.py
View file @
43398a85
...
...
@@ -7,6 +7,9 @@ import numpy as np
from
example
import
vectorized_func
from
example
import
vectorized_func2
from
example
import
vectorized_func3
print
(
vectorized_func3
(
np
.
array
(
3
+
7
j
)))
for
f
in
[
vectorized_func
,
vectorized_func2
]:
print
(
f
(
1
,
2
,
3
))
...
...
include/pybind/cast.h
View file @
43398a85
...
...
@@ -192,6 +192,23 @@ public:
PYBIND_TYPE_CASTER
(
bool
,
"bool"
);
};
template
<
typename
T
>
class
type_caster
<
std
::
complex
<
T
>>
{
public
:
bool
load
(
PyObject
*
src
,
bool
)
{
Py_complex
result
=
PyComplex_AsCComplex
(
src
);
if
(
result
.
real
==
-
1
.
0
&&
PyErr_Occurred
())
{
PyErr_Clear
();
return
false
;
}
value
=
std
::
complex
<
T
>
((
T
)
result
.
real
,
(
T
)
result
.
imag
);
return
true
;
}
static
PyObject
*
cast
(
const
std
::
complex
<
T
>
&
src
,
return_value_policy
/* policy */
,
PyObject
*
/* parent */
)
{
return
PyComplex_FromDoubles
((
double
)
src
.
real
(),
(
double
)
src
.
imag
());
}
PYBIND_TYPE_CASTER
(
std
::
complex
<
T
>
,
"complex"
);
};
template
<>
class
type_caster
<
std
::
string
>
{
public
:
bool
load
(
PyObject
*
src
,
bool
)
{
...
...
include/pybind/common.h
View file @
43398a85
...
...
@@ -33,7 +33,7 @@
#include <unordered_map>
#include <iostream>
#include <memory>
#include <
functional
>
#include <
complex
>
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
...
...
@@ -82,7 +82,8 @@ template <typename type> struct format_descriptor { };
#define DECL_FMT(t, n) template<> struct format_descriptor<t> { static std::string value() { return n; }; };
DECL_FMT
(
int8_t
,
"b"
);
DECL_FMT
(
uint8_t
,
"B"
);
DECL_FMT
(
int16_t
,
"h"
);
DECL_FMT
(
uint16_t
,
"H"
);
DECL_FMT
(
int32_t
,
"i"
);
DECL_FMT
(
uint32_t
,
"I"
);
DECL_FMT
(
int64_t
,
"q"
);
DECL_FMT
(
uint64_t
,
"Q"
);
DECL_FMT
(
float
,
"f"
);
DECL_FMT
(
double
,
"d"
);
DECL_FMT
(
float
,
"f"
);
DECL_FMT
(
double
,
"d"
);
DECL_FMT
(
bool
,
"?"
);
DECL_FMT
(
std
::
complex
<
float
>
,
"Zf"
);
DECL_FMT
(
std
::
complex
<
double
>
,
"Zd"
);
#undef DECL_FMT
/// Information record describing a Python buffer object
...
...
@@ -126,11 +127,12 @@ struct type_info {
PyTypeObject
*
type
;
size_t
type_size
;
void
(
*
init_holder
)(
PyObject
*
);
std
::
function
<
buffer_info
*
(
PyObject
*
)
>
get_buffer
;
std
::
vector
<
PyObject
*
(
*
)(
PyObject
*
,
PyTypeObject
*
)
>
implicit_conversions
;
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
)
=
nullptr
;
void
*
get_buffer_data
=
nullptr
;
};
/// Internal data struture used to track registered instances and types
/// Internal data struture used to track registered instances and types
struct
internals
{
std
::
unordered_map
<
std
::
string
,
type_info
>
registered_types
;
std
::
unordered_map
<
void
*
,
PyObject
*>
registered_instances
;
...
...
include/pybind/numpy.h
View file @
43398a85
...
...
@@ -17,8 +17,10 @@
NAMESPACE_BEGIN
(
pybind
)
template
<
typename
type
>
struct
npy_format_descriptor
{
};
class
array
:
public
buffer
{
p
rotected
:
p
ublic
:
struct
API
{
enum
Entries
{
API_PyArray_Type
=
2
,
...
...
@@ -26,10 +28,18 @@ protected:
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
API_NPY_C_CONTIGUOUS
=
0x0001
,
API_NPY_F_CONTIGUOUS
=
0x0002
,
API_NPY_NPY_ARRAY_FORCECAST
=
0x0010
,
API_NPY_ENSURE_ARRAY
=
0x0040
NPY_C_CONTIGUOUS
=
0x0001
,
NPY_F_CONTIGUOUS
=
0x0002
,
NPY_NPY_ARRAY_FORCECAST
=
0x0010
,
NPY_ENSURE_ARRAY
=
0x0040
,
NPY_BOOL
=
0
,
NPY_BYTE
,
NPY_UBYTE
,
NPY_SHORT
,
NPY_USHORT
,
NPY_INT
,
NPY_UINT
,
NPY_LONG
,
NPY_ULONG
,
NPY_LONGLONG
,
NPY_ULONGLONG
,
NPY_FLOAT
,
NPY_DOUBLE
,
NPY_LONGDOUBLE
,
NPY_CFLOAT
,
NPY_CDOUBLE
,
NPY_CLONGDOUBLE
};
static
API
lookup
()
{
...
...
@@ -59,13 +69,12 @@ protected:
PyTypeObject
*
PyArray_Type
;
PyObject
*
(
*
PyArray_FromAny
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
};
public
:
PYBIND_OBJECT_DEFAULT
(
array
,
buffer
,
lookup_api
().
PyArray_Check
)
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
API
&
api
=
lookup_api
();
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
(
int
)
format_descriptor
<
Type
>::
value
()[
0
]);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
npy_format_descriptor
<
Type
>::
value
);
if
(
descr
==
nullptr
)
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format!"
);
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
...
...
@@ -83,7 +92,12 @@ public:
API
&
api
=
lookup_api
();
if
(
info
.
format
.
size
()
!=
1
)
throw
std
::
runtime_error
(
"Unsupported buffer format!"
);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
info
.
format
[
0
]);
int
fmt
=
(
int
)
info
.
format
[
0
];
if
(
info
.
format
==
"Zd"
)
fmt
=
API
::
NPY_CDOUBLE
;
else
if
(
info
.
format
==
"Zf"
)
fmt
=
API
::
NPY_CFLOAT
;
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
fmt
);
if
(
descr
==
nullptr
)
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format '"
+
info
.
format
+
"'!"
);
PyObject
*
tmp
=
api
.
PyArray_NewFromDescr
(
...
...
@@ -109,12 +123,12 @@ public:
PYBIND_OBJECT_CVT
(
array_dtype
,
array
,
is_non_null
,
m_ptr
=
ensure
(
m_ptr
));
array_dtype
()
:
array
()
{
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
static
PyObject
*
ensure
(
PyObject
*
ptr
)
{
PyObject
*
ensure
(
PyObject
*
ptr
)
{
API
&
api
=
lookup_api
();
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
format_descriptor
<
T
>::
value
()[
0
]
);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
npy_format_descriptor
<
T
>::
value
);
return
api
.
PyArray_FromAny
(
ptr
,
descr
,
0
,
0
,
API
::
API_NPY_C_CONTIGUOUS
|
API
::
API_
NPY_ENSURE_ARRAY
|
API
::
API_
NPY_NPY_ARRAY_FORCECAST
,
nullptr
);
API
::
NPY_C_CONTIGUOUS
|
API
::
NPY_ENSURE_ARRAY
|
API
::
NPY_NPY_ARRAY_FORCECAST
,
nullptr
);
}
};
...
...
@@ -125,8 +139,19 @@ PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
float
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
double
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
std
::
complex
<
float
>>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
std
::
complex
<
double
>>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
bool
>
)
NAMESPACE_END
(
detail
)
#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
DECL_FMT
(
int8_t
,
NPY_BYTE
);
DECL_FMT
(
uint8_t
,
NPY_UBYTE
);
DECL_FMT
(
int16_t
,
NPY_SHORT
);
DECL_FMT
(
uint16_t
,
NPY_USHORT
);
DECL_FMT
(
int32_t
,
NPY_INT
);
DECL_FMT
(
uint32_t
,
NPY_UINT
);
DECL_FMT
(
int64_t
,
NPY_LONGLONG
);
DECL_FMT
(
uint64_t
,
NPY_ULONGLONG
);
DECL_FMT
(
float
,
NPY_FLOAT
);
DECL_FMT
(
double
,
NPY_DOUBLE
);
DECL_FMT
(
bool
,
NPY_BOOL
);
DECL_FMT
(
std
::
complex
<
float
>
,
NPY_CFLOAT
);
DECL_FMT
(
std
::
complex
<
double
>
,
NPY_CDOUBLE
);
#undef DECL_FMT
template
<
typename
func_type
,
typename
return_type
,
typename
...
args_type
,
size_t
...
Index
>
std
::
function
<
object
(
array_dtype
<
args_type
>
...)
>
vectorize
(
func_type
&&
f
,
return_type
(
*
)
(
args_type
...),
...
...
@@ -171,7 +196,7 @@ template <typename func_type, typename return_type, typename... args_type, size_
return
cast
(
result
[
0
]);
/* Return the result */
return
array
(
buffer_info
(
result
.
data
(),
sizeof
(
return_type
),
return
array
(
buffer_info
(
result
.
data
(),
sizeof
(
return_type
),
format_descriptor
<
return_type
>::
value
(),
ndim
,
shape
,
strides
));
};
...
...
include/pybind/pybind.h
View file @
43398a85
...
...
@@ -393,22 +393,27 @@ protected:
Py_TYPE
(
self
)
->
tp_free
((
PyObject
*
)
self
);
}
void
install_buffer_funcs
(
const
std
::
function
<
buffer_info
*
(
PyObject
*
)
>
&
func
)
{
void
install_buffer_funcs
(
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
),
void
*
get_buffer_data
)
{
PyHeapTypeObject
*
type
=
(
PyHeapTypeObject
*
)
m_ptr
;
type
->
ht_type
.
tp_as_buffer
=
&
type
->
as_buffer
;
type
->
as_buffer
.
bf_getbuffer
=
getbuffer
;
type
->
as_buffer
.
bf_releasebuffer
=
releasebuffer
;
((
detail
::
type_info
*
)
capsule
(
attr
(
"__pybind__"
)))
->
get_buffer
=
func
;
auto
info
=
((
detail
::
type_info
*
)
capsule
(
attr
(
"__pybind__"
)));
info
->
get_buffer
=
get_buffer
;
info
->
get_buffer_data
=
get_buffer_data
;
}
static
int
getbuffer
(
PyObject
*
obj
,
Py_buffer
*
view
,
int
flags
)
{
auto
const
&
info_func
=
((
detail
::
type_info
*
)
capsule
(
handle
(
obj
).
attr
(
"__pybind__"
)))
->
get_buffer
;
if
(
view
==
nullptr
||
obj
==
nullptr
||
!
info_func
)
{
auto
const
&
typeinfo
=
((
detail
::
type_info
*
)
capsule
(
handle
(
obj
).
attr
(
"__pybind__"
)));
if
(
view
==
nullptr
||
obj
==
nullptr
||
!
typeinfo
||
!
typeinfo
->
get_buffer
)
{
PyErr_SetString
(
PyExc_BufferError
,
"Internal error"
);
return
-
1
;
}
memset
(
view
,
0
,
sizeof
(
Py_buffer
));
buffer_info
*
info
=
info_func
(
obj
);
buffer_info
*
info
=
typeinfo
->
get_buffer
(
obj
,
typeinfo
->
get_buffer_data
);
view
->
obj
=
obj
;
view
->
ndim
=
1
;
view
->
internal
=
info
;
...
...
@@ -483,13 +488,16 @@ public:
return
*
this
;
}
class_
&
def_buffer
(
const
std
::
function
<
buffer_info
(
type
&
)
>
&
func
)
{
install_buffer_funcs
([
func
](
PyObject
*
obj
)
->
buffer_info
*
{
template
<
typename
Func
>
class_
&
def_buffer
(
Func
&&
func
)
{
struct
capture
{
Func
func
;
};
capture
*
ptr
=
new
capture
{
std
::
forward
<
Func
>
(
func
)
};
install_buffer_funcs
([](
PyObject
*
obj
,
void
*
ptr
)
->
buffer_info
*
{
detail
::
type_caster
<
type
>
caster
;
if
(
!
caster
.
load
(
obj
,
false
))
return
nullptr
;
return
new
buffer_info
(
func
(
caster
));
});
return
new
buffer_info
(
((
capture
*
)
ptr
)
->
func
(
caster
));
}
,
ptr
);
return
*
this
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment