Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
pybind11
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
open
pybind11
Commits
f2a0ad58
Commit
f2a0ad58
authored
Sep 08, 2016
by
Ivan Smirnov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
array: add direct data access and indexing methods
parent
91b3d681
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
33 deletions
+102
-33
include/pybind11/numpy.h
+102
-33
No files found.
include/pybind11/numpy.h
View file @
f2a0ad58
...
...
@@ -26,8 +26,14 @@
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
/* This will be true on all flat address space platforms and allows us to reduce the
whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
upon the library user. */
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
NAMESPACE_BEGIN
(
pybind11
)
namespace
detail
{
NAMESPACE_BEGIN
(
detail
)
template
<
typename
type
,
typename
SFINAE
=
void
>
struct
npy_format_descriptor
{
};
template
<
typename
type
>
struct
is_pod_struct
;
...
...
@@ -141,10 +147,12 @@ private:
return
api
;
}
};
}
NAMESPACE_END
(
detail
)
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
...
...
@@ -250,7 +258,7 @@ public:
};
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
void
*
ptr
=
nullptr
)
{
const
std
::
vector
<
size_t
>&
strides
,
const
void
*
ptr
=
nullptr
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
ndim
=
shape
.
size
();
if
(
shape
.
size
()
!=
strides
.
size
())
...
...
@@ -258,7 +266,7 @@ public:
auto
descr
=
dt
;
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
.
release
().
ptr
(),
(
int
)
ndim
,
(
Py_intptr_t
*
)
shape
.
data
(),
(
Py_intptr_t
*
)
strides
.
data
(),
ptr
,
0
,
nullptr
),
false
);
(
Py_intptr_t
*
)
strides
.
data
(),
const_cast
<
void
*>
(
ptr
)
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
...
...
@@ -266,20 +274,20 @@ public:
m_ptr
=
tmp
.
release
().
ptr
();
}
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
const
void
*
ptr
=
nullptr
)
:
array
(
dt
,
shape
,
default_strides
(
shape
,
dt
.
itemsize
()),
ptr
)
{
}
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
const
void
*
ptr
=
nullptr
)
:
array
(
dt
,
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
T
*
ptr
)
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
)
:
array
(
pybind11
::
dtype
::
of
<
T
>
(),
shape
,
strides
,
(
void
*
)
ptr
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
T
*
ptr
)
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
)
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
)
{
}
template
<
typename
T
>
array
(
size_t
count
,
T
*
ptr
)
template
<
typename
T
>
array
(
size_t
count
,
const
T
*
ptr
)
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
array
(
const
buffer_info
&
info
)
...
...
@@ -312,27 +320,25 @@ public:
/// Dimensions of the array
const
size_t
*
shape
()
const
{
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
}
/// Dimension along a given axis
size_t
shape
(
size_t
dim
)
const
{
if
(
dim
>=
ndim
())
pybind11_fail
(
"NumPy: attempted to index shape beyond ndim
"
);
fail_dim_check
(
dim
,
"invalid axis
"
);
return
shape
()[
dim
];
}
/// Strides of the array
const
size_t
*
strides
()
const
{
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
strides
));
}
/// Stride along a given axis
size_t
strides
(
size_t
dim
)
const
{
if
(
dim
>=
ndim
())
pybind11_fail
(
"NumPy: attempted to index strides beyond ndim
"
);
fail_dim_check
(
dim
,
"invalid axis
"
);
return
strides
()[
dim
];
}
...
...
@@ -346,20 +352,61 @@ public:
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_OWNDATA_
);
}
/// Direct pointer to contained buffer
const
void
*
data
()
const
{
return
reinterpret_cast
<
const
void
*>
(
PyArray_GET_
(
m_ptr
,
data
));
/// Pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
template
<
typename
...
Ix
>
const
void
*
data
(
Ix
&&
...
index
)
const
{
return
static_cast
<
const
void
*>
(
PyArray_GET_
(
m_ptr
,
data
)
+
offset_at
(
index
...));
}
/// Direct mutable pointer to contained buffer (checks writeable flag)
void
*
mutable_data
()
{
if
(
!
writeable
())
pybind11_fail
(
"NumPy: cannot get mutable data of a read-only array"
);
return
reinterpret_cast
<
void
*>
(
PyArray_GET_
(
m_ptr
,
data
));
/// Mutable pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
/// May throw if the array is not writeable.
template
<
typename
...
Ix
>
void
*
mutable_data
(
Ix
&&
...
index
)
{
check_writeable
();
return
static_cast
<
void
*>
(
PyArray_GET_
(
m_ptr
,
data
)
+
offset_at
(
index
...));
}
/// Byte offset from beginning of the array to a given index (full or partial).
/// May throw if the index would lead to out of bounds access.
template
<
typename
...
Ix
>
size_t
offset_at
(
Ix
&&
...
index
)
const
{
if
(
sizeof
...(
index
)
>
ndim
())
fail_dim_check
(
sizeof
...(
index
),
"too many indices for an array"
);
return
get_byte_offset
(
index
...);
}
size_t
offset_at
()
const
{
return
0
;
}
/// Item count from beginning of the array to a given index (full or partial).
/// May throw if the index would lead to out of bounds access.
template
<
typename
...
Ix
>
size_t
index_at
(
Ix
&&
...
index
)
const
{
return
offset_at
(
index
...)
/
itemsize
();
}
protected
:
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
template
<
typename
,
typename
>
friend
struct
detail
::
npy_format_descriptor
;
void
fail_dim_check
(
size_t
dim
,
const
std
::
string
&
msg
)
const
{
throw
index_error
(
msg
+
": "
+
std
::
to_string
(
dim
)
+
" (ndim = "
+
std
::
to_string
(
ndim
())
+
")"
);
}
template
<
typename
...
Ix
>
size_t
get_byte_offset
(
Ix
&&
...
index
)
const
{
const
size_t
idx
[]
=
{
(
size_t
)
index
...
};
if
(
!
std
::
equal
(
idx
+
0
,
idx
+
sizeof
...(
index
),
shape
(),
std
::
less
<
size_t
>
{}))
{
auto
mismatch
=
std
::
mismatch
(
idx
+
0
,
idx
+
sizeof
...(
index
),
shape
(),
std
::
less
<
size_t
>
{});
throw
index_error
(
std
::
string
(
"index "
)
+
std
::
to_string
(
*
mismatch
.
first
)
+
" is out of bounds for axis "
+
std
::
to_string
(
mismatch
.
first
-
idx
)
+
" with size "
+
std
::
to_string
(
*
mismatch
.
second
));
}
return
std
::
inner_product
(
idx
+
0
,
idx
+
sizeof
...(
index
),
strides
(),
(
size_t
)
0
);
}
size_t
get_byte_offset
()
const
{
return
0
;
}
void
check_writeable
()
const
{
if
(
!
writeable
())
throw
std
::
runtime_error
(
"array is not writeable"
);
}
static
std
::
vector
<
size_t
>
default_strides
(
const
std
::
vector
<
size_t
>&
shape
,
size_t
itemsize
)
{
auto
ndim
=
shape
.
size
();
...
...
@@ -382,23 +429,45 @@ public:
array_t
(
const
buffer_info
&
info
)
:
array
(
info
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
=
nullptr
)
:
array
(
shape
,
strides
,
ptr
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
=
nullptr
)
:
array
(
shape
,
ptr
)
{
}
array_t
(
size_t
count
,
T
*
ptr
=
nullptr
)
array_t
(
size_t
count
,
const
T
*
ptr
=
nullptr
)
:
array
(
count
,
ptr
)
{
}
const
T
*
data
()
const
{
return
reinterpret_cast
<
const
T
*>
(
PyArray_GET_
(
m_ptr
,
data
)
);
const
expr
size_t
itemsize
()
const
{
return
sizeof
(
T
);
}
T
*
mutable_data
()
{
if
(
!
writeable
())
pybind11_fail
(
"NumPy: cannot get mutable data of a read-only array"
);
return
reinterpret_cast
<
T
*>
(
PyArray_GET_
(
m_ptr
,
data
));
template
<
typename
...
Ix
>
size_t
index_at
(
Ix
&
...
index
)
const
{
return
offset_at
(
index
...)
/
itemsize
();
}
template
<
typename
...
Ix
>
const
T
*
data
(
Ix
&&
...
index
)
const
{
return
static_cast
<
const
T
*>
(
array
::
data
(
index
...));
}
template
<
typename
...
Ix
>
T
*
mutable_data
(
Ix
&&
...
index
)
{
return
static_cast
<
T
*>
(
array
::
mutable_data
(
index
...));
}
// Reference to element at a given index
template
<
typename
...
Ix
>
const
T
&
at
(
Ix
&&
...
index
)
const
{
if
(
sizeof
...(
index
)
!=
ndim
())
fail_dim_check
(
sizeof
...(
index
),
"index dimension mismatch"
);
// not using offset_at() / index_at() here so as to avoid another dimension check
return
*
(
static_cast
<
const
T
*>
(
array
::
data
())
+
get_byte_offset
(
index
...)
/
itemsize
());
}
// Mutable reference to element at a given index
template
<
typename
...
Ix
>
T
&
mutable_at
(
Ix
&&
...
index
)
{
if
(
sizeof
...(
index
)
!=
ndim
())
fail_dim_check
(
sizeof
...(
index
),
"index dimension mismatch"
);
// not using offset_at() / index_at() here so as to avoid another dimension check
return
*
(
static_cast
<
T
*>
(
array
::
mutable_data
())
+
get_byte_offset
(
index
...)
/
itemsize
());
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment