Commit f217c041 by Wenzel Jakob Committed by GitHub

Merge pull request #402 from aldanor/feature/numpy-c-api

Add array methods via C API
parents 720136bf aca6bcae
......@@ -83,12 +83,11 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
static constexpr bool isVector = Type::IsVectorAtCompileTime;
bool load(handle src, bool) {
array_t<Scalar> buffer(src, true);
if (!buffer.check())
array_t<Scalar> buf(src, true);
if (!buf.check())
return false;
auto info = buffer.request();
if (info.ndim == 1) {
if (buf.ndim() == 1) {
typedef Eigen::InnerStride<> Strides;
if (!isVector &&
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
......@@ -96,31 +95,32 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
return false;
if (Type::SizeAtCompileTime != Eigen::Dynamic &&
info.shape[0] != (size_t) Type::SizeAtCompileTime)
buf.shape(0) != (size_t) Type::SizeAtCompileTime)
return false;
auto strides = Strides(info.strides[0] / sizeof(Scalar));
Strides::Index n_elts = (Strides::Index) info.shape[0];
Strides::Index n_elts = (Strides::Index) buf.shape(0);
Strides::Index unity = 1;
value = Eigen::Map<Type, 0, Strides>(
(Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides);
} else if (info.ndim == 2) {
buf.mutable_data(),
rowMajor ? unity : n_elts,
rowMajor ? n_elts : unity,
Strides(buf.strides(0) / sizeof(Scalar))
);
} else if (buf.ndim() == 2) {
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) ||
(Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime))
if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
(Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
return false;
auto strides = Strides(
info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
value = Eigen::Map<Type, 0, Strides>(
(Scalar *) info.ptr,
typename Strides::Index(info.shape[0]),
typename Strides::Index(info.shape[1]), strides);
buf.mutable_data(),
typename Strides::Index(buf.shape(0)),
typename Strides::Index(buf.shape(1)),
Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
);
} else {
return false;
}
......@@ -222,28 +222,18 @@ struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::
}
}
auto valuesArray = array_t<Scalar>((object) obj.attr("data"));
auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices"));
auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr"));
auto values = array_t<Scalar>((object) obj.attr("data"));
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
auto nnz = obj.attr("nnz").cast<Index>();
if (!valuesArray.check() || !innerIndicesArray.check() ||
!outerIndicesArray.check())
if (!values.check() || !innerIndices.check() || !outerIndices.check())
return false;
auto outerIndices = outerIndicesArray.request();
auto innerIndices = innerIndicesArray.request();
auto values = valuesArray.request();
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
shape[0].cast<Index>(),
shape[1].cast<Index>(),
nnz,
static_cast<StorageIndex *>(outerIndices.ptr),
static_cast<StorageIndex *>(innerIndices.ptr),
static_cast<Scalar *>(values.ptr)
);
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
return true;
}
......
......@@ -19,6 +19,7 @@ set(PYBIND11_TEST_FILES
test_kwargs_and_defaults.cpp
test_methods_and_attributes.cpp
test_modules.cpp
test_numpy_array.cpp
test_numpy_dtypes.cpp
test_numpy_vectorize.cpp
test_opaque_types.cpp
......
/*
tests/test_numpy_array.cpp -- test core array functionality
Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include "pybind11_tests.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <cstdint>
#include <vector>
using arr = py::array;
using arr_t = py::array_t<uint16_t, 0>;
template<typename... Ix> arr data(const arr& a, Ix&&... index) {
return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
}
template<typename... Ix> arr data_t(const arr_t& a, Ix&&... index) {
return arr(a.size() - a.index_at(index...), a.data(index...));
}
arr& mutate_data(arr& a) {
auto ptr = (uint8_t *) a.mutable_data();
for (size_t i = 0; i < a.nbytes(); i++)
ptr[i] = (uint8_t) (ptr[i] * 2);
return a;
}
arr_t& mutate_data_t(arr_t& a) {
auto ptr = a.mutable_data();
for (size_t i = 0; i < a.size(); i++)
ptr[i]++;
return a;
}
template<typename... Ix> arr& mutate_data(arr& a, Ix&&... index) {
auto ptr = (uint8_t *) a.mutable_data(index...);
for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++)
ptr[i] = (uint8_t) (ptr[i] * 2);
return a;
}
template<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix&&... index) {
auto ptr = a.mutable_data(index...);
for (size_t i = 0; i < a.size() - a.index_at(index...); i++)
ptr[i]++;
return a;
}
template<typename... Ix> size_t index_at(const arr& a, Ix&&... idx) { return a.index_at(idx...); }
template<typename... Ix> size_t index_at_t(const arr_t& a, Ix&&... idx) { return a.index_at(idx...); }
template<typename... Ix> size_t offset_at(const arr& a, Ix&&... idx) { return a.offset_at(idx...); }
template<typename... Ix> size_t offset_at_t(const arr_t& a, Ix&&... idx) { return a.offset_at(idx...); }
template<typename... Ix> size_t at_t(const arr_t& a, Ix&&... idx) { return a.at(idx...); }
template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix&&... idx) { a.mutable_at(idx...)++; return a; }
#define def_index_fn(name, type) \
sm.def(#name, [](type a) { return name(a); }); \
sm.def(#name, [](type a, int i) { return name(a, i); }); \
sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
test_initializer numpy_array([](py::module &m) {
auto sm = m.def_submodule("array");
sm.def("ndim", [](const arr& a) { return a.ndim(); });
sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
sm.def("shape", [](const arr& a, size_t dim) { return a.shape(dim); });
sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); });
sm.def("strides", [](const arr& a, size_t dim) { return a.strides(dim); });
sm.def("writeable", [](const arr& a) { return a.writeable(); });
sm.def("size", [](const arr& a) { return a.size(); });
sm.def("itemsize", [](const arr& a) { return a.itemsize(); });
sm.def("nbytes", [](const arr& a) { return a.nbytes(); });
sm.def("owndata", [](const arr& a) { return a.owndata(); });
def_index_fn(data, const arr&);
def_index_fn(data_t, const arr_t&);
def_index_fn(index_at, const arr&);
def_index_fn(index_at_t, const arr_t&);
def_index_fn(offset_at, const arr&);
def_index_fn(offset_at_t, const arr_t&);
def_index_fn(mutate_data, arr&);
def_index_fn(mutate_data_t, arr_t&);
def_index_fn(at_t, const arr_t&);
def_index_fn(mutate_at_t, arr_t&);
});
import pytest
with pytest.suppress(ImportError):
import numpy as np
@pytest.fixture(scope='function')
def arr():
return np.array([[1, 2, 3], [4, 5, 6]], '<u2')
@pytest.requires_numpy
def test_array_attributes():
from pybind11_tests.array import (
ndim, shape, strides, writeable, size, itemsize, nbytes, owndata
)
a = np.array(0, 'f8')
assert ndim(a) == 0
assert all(shape(a) == [])
assert all(strides(a) == [])
with pytest.raises(IndexError) as excinfo:
shape(a, 0)
assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
with pytest.raises(IndexError) as excinfo:
strides(a, 0)
assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
assert writeable(a)
assert size(a) == 1
assert itemsize(a) == 8
assert nbytes(a) == 8
assert owndata(a)
a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
a.flags.writeable = False
assert ndim(a) == 2
assert all(shape(a) == [2, 3])
assert shape(a, 0) == 2
assert shape(a, 1) == 3
assert all(strides(a) == [6, 2])
assert strides(a, 0) == 6
assert strides(a, 1) == 2
with pytest.raises(IndexError) as excinfo:
shape(a, 2)
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
with pytest.raises(IndexError) as excinfo:
strides(a, 2)
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
assert not writeable(a)
assert size(a) == 6
assert itemsize(a) == 2
assert nbytes(a) == 12
assert not owndata(a)
@pytest.requires_numpy
@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)])
def test_index_offset(arr, args, ret):
from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t
assert index_at(arr, *args) == ret
assert index_at_t(arr, *args) == ret
assert offset_at(arr, *args) == ret * arr.dtype.itemsize
assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize
@pytest.requires_numpy
def test_dim_check_fail(arr):
from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
mutate_data, mutate_data_t)
for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
mutate_data, mutate_data_t):
with pytest.raises(IndexError) as excinfo:
func(arr, 1, 2, 3)
assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)'
@pytest.requires_numpy
@pytest.mark.parametrize('args, ret',
[([], [1, 2, 3, 4, 5, 6]),
([1], [4, 5, 6]),
([0, 1], [2, 3, 4, 5, 6]),
([1, 2], [6])])
def test_data(arr, args, ret):
from pybind11_tests.array import data, data_t
assert all(data_t(arr, *args) == ret)
assert all(data(arr, *args)[::2] == ret)
assert all(data(arr, *args)[1::2] == 0)
@pytest.requires_numpy
def test_mutate_readonly(arr):
from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t
arr.flags.writeable = False
for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)):
with pytest.raises(RuntimeError) as excinfo:
func(arr, *args)
assert str(excinfo.value) == 'array is not writeable'
@pytest.requires_numpy
@pytest.mark.parametrize('dim', [0, 1, 3])
def test_at_fail(arr, dim):
from pybind11_tests.array import at_t, mutate_at_t
for func in at_t, mutate_at_t:
with pytest.raises(IndexError) as excinfo:
func(arr, *([0] * dim))
assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim)
@pytest.requires_numpy
def test_at(arr):
from pybind11_tests.array import at_t, mutate_at_t
assert at_t(arr, 0, 2) == 3
assert at_t(arr, 1, 0) == 4
assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
@pytest.requires_numpy
def test_mutate_data(arr):
from pybind11_tests.array import mutate_data, mutate_data_t
assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
@pytest.requires_numpy
def test_bounds_check(arr):
from pybind11_tests.array import (index_at, index_at_t, data, data_t,
mutate_data, mutate_data_t, at_t, mutate_at_t)
funcs = (index_at, index_at_t, data, data_t,
mutate_data, mutate_data_t, at_t, mutate_at_t)
for func in funcs:
with pytest.raises(IndexError) as excinfo:
index_at(arr, 2, 0)
assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2'
with pytest.raises(IndexError) as excinfo:
index_at(arr, 0, 4)
assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment