Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11
Commits
00488a3e
Commit
00488a3e
authored
Oct 13, 2016
by
Wenzel Jakob
Committed by
GitHub
Oct 13, 2016
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #440 from wjakob/master
Permit creation of NumPy arrays with a "base" object that owns the data
parents
43f6aa68
fac7c094
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
28 deletions
+185
-28
include/pybind11/common.h
+1
-0
include/pybind11/numpy.h
+62
-23
include/pybind11/pybind11.h
+17
-5
tests/test_numpy_array.cpp
+25
-0
tests/test_numpy_array.py
+80
-0
No files found.
include/pybind11/common.h
View file @
00488a3e
...
@@ -455,6 +455,7 @@ PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
...
@@ -455,6 +455,7 @@ PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
PYBIND11_RUNTIME_EXCEPTION
(
index_error
,
PyExc_IndexError
)
PYBIND11_RUNTIME_EXCEPTION
(
index_error
,
PyExc_IndexError
)
PYBIND11_RUNTIME_EXCEPTION
(
key_error
,
PyExc_KeyError
)
PYBIND11_RUNTIME_EXCEPTION
(
key_error
,
PyExc_KeyError
)
PYBIND11_RUNTIME_EXCEPTION
(
value_error
,
PyExc_ValueError
)
PYBIND11_RUNTIME_EXCEPTION
(
value_error
,
PyExc_ValueError
)
PYBIND11_RUNTIME_EXCEPTION
(
import_error
,
PyExc_ImportError
)
PYBIND11_RUNTIME_EXCEPTION
(
type_error
,
PyExc_TypeError
)
PYBIND11_RUNTIME_EXCEPTION
(
type_error
,
PyExc_TypeError
)
PYBIND11_RUNTIME_EXCEPTION
(
cast_error
,
PyExc_RuntimeError
)
/// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION
(
cast_error
,
PyExc_RuntimeError
)
/// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION
(
reference_cast_error
,
PyExc_RuntimeError
)
/// Used internally
PYBIND11_RUNTIME_EXCEPTION
(
reference_cast_error
,
PyExc_RuntimeError
)
/// Used internally
...
...
include/pybind11/numpy.h
View file @
00488a3e
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
#include <functional>
#include <functional>
#if defined(_MSC_VER)
#if defined(_MSC_VER)
#pragma warning(push)
#
pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#
pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
#endif
/* This will be true on all flat address space platforms and allows us to reduce the
/* This will be true on all flat address space platforms and allows us to reduce the
...
@@ -156,8 +156,10 @@ NAMESPACE_END(detail)
...
@@ -156,8 +156,10 @@ NAMESPACE_END(detail)
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
#define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_FLAGS_(ptr) \
PyArray_GET_(ptr, flags)
#define PyArray_CHKFLAGS_(ptr, flag) \
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (
reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags
& flag))
(flag == (
PyArray_FLAGS_(ptr)
& flag))
class
dtype
:
public
object
{
class
dtype
:
public
object
{
public
:
public
:
...
@@ -258,38 +260,62 @@ public:
...
@@ -258,38 +260,62 @@ public:
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
};
};
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
void
*
ptr
=
nullptr
)
{
const
std
::
vector
<
size_t
>
&
strides
,
const
void
*
ptr
=
nullptr
,
handle
base
=
handle
())
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
&
api
=
detail
::
npy_api
::
get
();
auto
ndim
=
shape
.
size
();
auto
ndim
=
shape
.
size
();
if
(
shape
.
size
()
!=
strides
.
size
())
if
(
shape
.
size
()
!=
strides
.
size
())
pybind11_fail
(
"NumPy: shape ndim doesn't match strides ndim"
);
pybind11_fail
(
"NumPy: shape ndim doesn't match strides ndim"
);
auto
descr
=
dt
;
auto
descr
=
dt
;
int
flags
=
0
;
if
(
base
&&
ptr
)
{
array
base_array
(
base
,
true
);
if
(
base_array
.
check
())
/* Copy flags from base (except baseship bit) */
flags
=
base_array
.
flags
()
&
~
detail
::
npy_api
::
NPY_ARRAY_OWNDATA_
;
else
/* Writable by default, easy to downgrade later on if needed */
flags
=
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
;
}
object
tmp
(
api
.
PyArray_NewFromDescr_
(
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
.
release
().
ptr
(),
(
int
)
ndim
,
(
Py_intptr_t
*
)
shape
.
data
(),
api
.
PyArray_Type_
,
descr
.
release
().
ptr
(),
(
int
)
ndim
,
(
Py_intptr_t
*
)
shape
.
data
(),
(
Py_intptr_t
*
)
strides
.
data
(),
const_cast
<
void
*>
(
ptr
),
0
,
nullptr
),
false
);
(
Py_intptr_t
*
)
strides
.
data
(),
const_cast
<
void
*>
(
ptr
),
flags
,
nullptr
),
false
);
if
(
!
tmp
)
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
if
(
ptr
)
{
if
(
base
)
{
PyArray_GET_
(
tmp
.
ptr
(),
base
)
=
base
.
inc_ref
().
ptr
();
}
else
{
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
}
}
m_ptr
=
tmp
.
release
().
ptr
();
m_ptr
=
tmp
.
release
().
ptr
();
}
}
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
const
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>
&
shape
,
:
array
(
dt
,
shape
,
default_strides
(
shape
,
dt
.
itemsize
()),
ptr
)
{
}
const
void
*
ptr
=
nullptr
,
handle
base
=
handle
())
:
array
(
dt
,
shape
,
default_strides
(
shape
,
dt
.
itemsize
()),
ptr
,
base
)
{
}
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
const
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
const
void
*
ptr
=
nullptr
,
:
array
(
dt
,
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
handle
base
=
handle
())
:
array
(
dt
,
std
::
vector
<
size_t
>
{
count
},
ptr
,
base
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
)
const
std
::
vector
<
size_t
>&
strides
,
:
array
(
pybind11
::
dtype
::
of
<
T
>
(),
shape
,
strides
,
(
void
*
)
ptr
)
{
}
const
T
*
ptr
,
handle
base
=
handle
())
:
array
(
pybind11
::
dtype
::
of
<
T
>
(),
shape
,
strides
,
(
void
*
)
ptr
,
base
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
)
template
<
typename
T
>
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
)
{
}
array
(
const
std
::
vector
<
size_t
>
&
shape
,
const
T
*
ptr
,
handle
base
=
handle
())
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
,
base
)
{
}
template
<
typename
T
>
array
(
size_t
count
,
const
T
*
ptr
)
template
<
typename
T
>
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
array
(
size_t
count
,
const
T
*
ptr
,
handle
base
=
handle
())
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
,
base
)
{
}
array
(
const
buffer_info
&
info
)
array
(
const
buffer_info
&
info
)
:
array
(
pybind11
::
dtype
(
info
),
info
.
shape
,
info
.
strides
,
info
.
ptr
)
{
}
:
array
(
pybind11
::
dtype
(
info
),
info
.
shape
,
info
.
strides
,
info
.
ptr
)
{
}
...
@@ -319,6 +345,11 @@ public:
...
@@ -319,6 +345,11 @@ public:
return
(
size_t
)
PyArray_GET_
(
m_ptr
,
nd
);
return
(
size_t
)
PyArray_GET_
(
m_ptr
,
nd
);
}
}
/// Base object
object
base
()
const
{
return
object
(
PyArray_GET_
(
m_ptr
,
base
),
true
);
}
/// Dimensions of the array
/// Dimensions of the array
const
size_t
*
shape
()
const
{
const
size_t
*
shape
()
const
{
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
...
@@ -343,6 +374,11 @@ public:
...
@@ -343,6 +374,11 @@ public:
return
strides
()[
dim
];
return
strides
()[
dim
];
}
}
/// Return the NumPy array flags
int
flags
()
const
{
return
PyArray_FLAGS_
(
m_ptr
);
}
/// If set, the array is writeable (otherwise the buffer is read-only)
/// If set, the array is writeable (otherwise the buffer is read-only)
bool
writeable
()
const
{
bool
writeable
()
const
{
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
);
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
);
...
@@ -436,14 +472,17 @@ public:
...
@@ -436,14 +472,17 @@ public:
array_t
(
const
buffer_info
&
info
)
:
array
(
info
)
{
}
array_t
(
const
buffer_info
&
info
)
:
array
(
info
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>
&
shape
,
:
array
(
shape
,
strides
,
ptr
)
{
}
const
std
::
vector
<
size_t
>
&
strides
,
const
T
*
ptr
=
nullptr
,
handle
base
=
handle
())
:
array
(
shape
,
strides
,
ptr
,
base
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>
&
shape
,
const
T
*
ptr
=
nullptr
,
:
array
(
shape
,
ptr
)
{
}
handle
base
=
handle
())
:
array
(
shape
,
ptr
,
base
)
{
}
array_t
(
size_t
count
,
const
T
*
ptr
=
nullptr
)
array_t
(
size_t
count
,
const
T
*
ptr
=
nullptr
,
handle
base
=
handle
()
)
:
array
(
count
,
ptr
)
{
}
:
array
(
count
,
ptr
,
base
)
{
}
constexpr
size_t
itemsize
()
const
{
constexpr
size_t
itemsize
()
const
{
return
sizeof
(
T
);
return
sizeof
(
T
);
...
...
include/pybind11/pybind11.h
View file @
00488a3e
...
@@ -567,7 +567,7 @@ public:
...
@@ -567,7 +567,7 @@ public:
static
module
import
(
const
char
*
name
)
{
static
module
import
(
const
char
*
name
)
{
PyObject
*
obj
=
PyImport_ImportModule
(
name
);
PyObject
*
obj
=
PyImport_ImportModule
(
name
);
if
(
!
obj
)
if
(
!
obj
)
pybind11_fail
(
"Module
\"
"
+
std
::
string
(
name
)
+
"
\"
not found!"
);
throw
import_error
(
"Module
\"
"
+
std
::
string
(
name
)
+
"
\"
not found!"
);
return
module
(
obj
,
false
);
return
module
(
obj
,
false
);
}
}
};
};
...
@@ -1344,15 +1344,27 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
...
@@ -1344,15 +1344,27 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
auto
sep
=
kwargs
.
contains
(
"sep"
)
?
kwargs
[
"sep"
]
:
cast
(
" "
);
auto
sep
=
kwargs
.
contains
(
"sep"
)
?
kwargs
[
"sep"
]
:
cast
(
" "
);
auto
line
=
sep
.
attr
(
"join"
)(
strings
);
auto
line
=
sep
.
attr
(
"join"
)(
strings
);
auto
file
=
kwargs
.
contains
(
"file"
)
?
kwargs
[
"file"
].
cast
<
object
>
()
object
file
;
:
module
::
import
(
"sys"
).
attr
(
"stdout"
);
if
(
kwargs
.
contains
(
"file"
))
{
file
=
kwargs
[
"file"
].
cast
<
object
>
();
}
else
{
try
{
file
=
module
::
import
(
"sys"
).
attr
(
"stdout"
);
}
catch
(
const
import_error
&
)
{
/* If print() is called from code that is executed as
part of garbage collection during interpreter shutdown,
importing 'sys' can fail. Give up rather than crashing the
interpreter in this case. */
return
;
}
}
auto
write
=
file
.
attr
(
"write"
);
auto
write
=
file
.
attr
(
"write"
);
write
(
line
);
write
(
line
);
write
(
kwargs
.
contains
(
"end"
)
?
kwargs
[
"end"
]
:
cast
(
"
\n
"
));
write
(
kwargs
.
contains
(
"end"
)
?
kwargs
[
"end"
]
:
cast
(
"
\n
"
));
if
(
kwargs
.
contains
(
"flush"
)
&&
kwargs
[
"flush"
].
cast
<
bool
>
())
{
if
(
kwargs
.
contains
(
"flush"
)
&&
kwargs
[
"flush"
].
cast
<
bool
>
())
file
.
attr
(
"flush"
)();
file
.
attr
(
"flush"
)();
}
}
}
NAMESPACE_END
(
detail
)
NAMESPACE_END
(
detail
)
...
...
tests/test_numpy_array.cpp
View file @
00488a3e
...
@@ -99,4 +99,29 @@ test_initializer numpy_array([](py::module &m) {
...
@@ -99,4 +99,29 @@ test_initializer numpy_array([](py::module &m) {
sm
.
def
(
"make_c_array"
,
[]
{
sm
.
def
(
"make_c_array"
,
[]
{
return
py
::
array_t
<
float
>
({
2
,
2
},
{
8
,
4
});
return
py
::
array_t
<
float
>
({
2
,
2
},
{
8
,
4
});
});
});
sm
.
def
(
"wrap"
,
[](
py
::
array
a
)
{
return
py
::
array
(
a
.
dtype
(),
std
::
vector
<
size_t
>
(
a
.
shape
(),
a
.
shape
()
+
a
.
ndim
()),
std
::
vector
<
size_t
>
(
a
.
strides
(),
a
.
strides
()
+
a
.
ndim
()),
a
.
data
(),
a
);
});
struct
ArrayClass
{
int
data
[
2
]
=
{
1
,
2
};
ArrayClass
()
{
py
::
print
(
"ArrayClass()"
);
}
~
ArrayClass
()
{
py
::
print
(
"~ArrayClass()"
);
}
};
py
::
class_
<
ArrayClass
>
(
sm
,
"ArrayClass"
)
.
def
(
py
::
init
<>
())
.
def
(
"numpy_view"
,
[](
py
::
object
&
obj
)
{
py
::
print
(
"ArrayClass::numpy_view()"
);
ArrayClass
&
a
=
obj
.
cast
<
ArrayClass
&>
();
return
py
::
array_t
<
int
>
({
2
},
{
4
},
a
.
data
,
obj
);
}
);
});
});
tests/test_numpy_array.py
View file @
00488a3e
import
pytest
import
pytest
import
gc
with
pytest
.
suppress
(
ImportError
):
with
pytest
.
suppress
(
ImportError
):
import
numpy
as
np
import
numpy
as
np
...
@@ -149,6 +150,7 @@ def test_bounds_check(arr):
...
@@ -149,6 +150,7 @@ def test_bounds_check(arr):
index_at
(
arr
,
0
,
4
)
index_at
(
arr
,
0
,
4
)
assert
str
(
excinfo
.
value
)
==
'index 4 is out of bounds for axis 1 with size 3'
assert
str
(
excinfo
.
value
)
==
'index 4 is out of bounds for axis 1 with size 3'
@pytest.requires_numpy
@pytest.requires_numpy
def
test_make_c_f_array
():
def
test_make_c_f_array
():
from
pybind11_tests.array
import
(
from
pybind11_tests.array
import
(
...
@@ -158,3 +160,81 @@ def test_make_c_f_array():
...
@@ -158,3 +160,81 @@ def test_make_c_f_array():
assert
not
make_c_array
()
.
flags
.
f_contiguous
assert
not
make_c_array
()
.
flags
.
f_contiguous
assert
make_f_array
()
.
flags
.
f_contiguous
assert
make_f_array
()
.
flags
.
f_contiguous
assert
not
make_f_array
()
.
flags
.
c_contiguous
assert
not
make_f_array
()
.
flags
.
c_contiguous
@pytest.requires_numpy
def
test_wrap
():
from
pybind11_tests.array
import
wrap
def
assert_references
(
A
,
B
):
assert
A
is
not
B
assert
A
.
__array_interface__
[
'data'
][
0
]
==
\
B
.
__array_interface__
[
'data'
][
0
]
assert
A
.
shape
==
B
.
shape
assert
A
.
strides
==
B
.
strides
assert
A
.
flags
.
c_contiguous
==
B
.
flags
.
c_contiguous
assert
A
.
flags
.
f_contiguous
==
B
.
flags
.
f_contiguous
assert
A
.
flags
.
writeable
==
B
.
flags
.
writeable
assert
A
.
flags
.
aligned
==
B
.
flags
.
aligned
assert
A
.
flags
.
updateifcopy
==
B
.
flags
.
updateifcopy
assert
np
.
all
(
A
==
B
)
assert
not
B
.
flags
.
owndata
assert
B
.
base
is
A
if
A
.
flags
.
writeable
and
A
.
ndim
==
2
:
A
[
0
,
0
]
=
1234
assert
B
[
0
,
0
]
==
1234
A1
=
np
.
array
([
1
,
2
],
dtype
=
np
.
int16
)
assert
A1
.
flags
.
owndata
and
A1
.
base
is
None
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
np
.
float32
,
order
=
'F'
)
assert
A1
.
flags
.
owndata
and
A1
.
base
is
None
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
np
.
float32
,
order
=
'C'
)
A1
.
flags
.
writeable
=
False
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
np
.
random
.
random
((
4
,
4
,
4
))
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
A1
.
transpose
()
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
A1
.
diagonal
()
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
@pytest.requires_numpy
def
test_numpy_view
(
capture
):
from
pybind11_tests.array
import
ArrayClass
with
capture
:
ac
=
ArrayClass
()
ac_view_1
=
ac
.
numpy_view
()
ac_view_2
=
ac
.
numpy_view
()
assert
np
.
all
(
ac_view_1
==
np
.
array
([
1
,
2
],
dtype
=
np
.
int32
))
del
ac
gc
.
collect
()
assert
capture
==
"""
ArrayClass()
ArrayClass::numpy_view()
ArrayClass::numpy_view()
"""
ac_view_1
[
0
]
=
4
ac_view_1
[
1
]
=
3
assert
ac_view_2
[
0
]
==
4
assert
ac_view_2
[
1
]
==
3
with
capture
:
del
ac_view_1
del
ac_view_2
gc
.
collect
()
assert
capture
==
"""
~ArrayClass()
"""
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment