Commit 54679795 by Patrick Stewart Committed by Wenzel Jakob

Add the buffer interface for wrapped STL vectors

Allows use of vectors as python buffers, so for example they can be adopted without a copy by numpy.asarray
Allows faster conversion of buffers to vectors by copying instead of individually casting the elements
parent 16afbcef
...@@ -326,6 +326,49 @@ template <typename Vector, typename Class_> auto vector_if_insertion_operator(Cl ...@@ -326,6 +326,49 @@ template <typename Vector, typename Class_> auto vector_if_insertion_operator(Cl
); );
} }
// Provide the buffer interface for vectors if we have data() and we have a format for it
// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
template <typename Vector, typename = void>
struct vector_has_data_and_format : std::false_type {};
template <typename Vector>
struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(py::format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
// Add the buffer interface to a vector
template <typename Vector, typename Class_, typename... Args>
enable_if_t<detail::any_of<std::is_same<Args, py::buffer_protocol>...>::value>
vector_buffer(Class_& cl) {
using T = typename Vector::value_type;
static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
// numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
py::format_descriptor<T>::format();
cl.def_buffer([](Vector& v) -> py::buffer_info {
return py::buffer_info(v.data(), sizeof(T), py::format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
});
cl.def("__init__", [](Vector& vec, py::buffer buf) {
auto info = buf.request();
if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % sizeof(T))
throw pybind11::type_error("Only valid 1D buffers can be copied to a vector");
if (!detail::compare_buffer_info<T>::compare(info) || sizeof(T) != info.itemsize)
throw pybind11::type_error("Format mismatch (Python: " + info.format + " C++: " + py::format_descriptor<T>::format() + ")");
new (&vec) Vector();
vec.reserve(info.shape[0]);
T *p = static_cast<T*>(info.ptr);
auto step = info.strides[0] / sizeof(T);
T *end = p + info.shape[0] * step;
for (; p < end; p += step)
vec.push_back(*p);
});
return;
}
template <typename Vector, typename Class_, typename... Args>
enable_if_t<!detail::any_of<std::is_same<Args, py::buffer_protocol>...>::value> vector_buffer(Class_&) {}
NAMESPACE_END(detail) NAMESPACE_END(detail)
// //
...@@ -337,6 +380,9 @@ pybind11::class_<Vector, holder_type> bind_vector(pybind11::module &m, std::stri ...@@ -337,6 +380,9 @@ pybind11::class_<Vector, holder_type> bind_vector(pybind11::module &m, std::stri
Class_ cl(m, name.c_str(), std::forward<Args>(args)...); Class_ cl(m, name.c_str(), std::forward<Args>(args)...);
// Declare the buffer interface if a py::buffer_protocol() is passed in
detail::vector_buffer<Vector, Class_, Args...>(cl);
cl.def(pybind11::init<>()); cl.def(pybind11::init<>());
// Register copy constructor (if possible) // Register copy constructor (if possible)
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "pybind11_tests.h" #include "pybind11_tests.h"
#include <pybind11/stl_bind.h> #include <pybind11/stl_bind.h>
#include <pybind11/numpy.h>
#include <map> #include <map>
#include <deque> #include <deque>
#include <unordered_map> #include <unordered_map>
...@@ -58,17 +59,45 @@ template <class Map> Map *times_ten(int n) { ...@@ -58,17 +59,45 @@ template <class Map> Map *times_ten(int n) {
return m; return m;
} }
struct VStruct {
bool w;
uint32_t x;
double y;
bool z;
};
struct VUndeclStruct { //dtype not declared for this version
bool w;
uint32_t x;
double y;
bool z;
};
test_initializer stl_binder_vector([](py::module &m) { test_initializer stl_binder_vector([](py::module &m) {
py::class_<El>(m, "El") py::class_<El>(m, "El")
.def(py::init<int>()); .def(py::init<int>());
py::bind_vector<std::vector<unsigned int>>(m, "VectorInt"); py::bind_vector<std::vector<unsigned char>>(m, "VectorUChar", py::buffer_protocol());
py::bind_vector<std::vector<unsigned int>>(m, "VectorInt", py::buffer_protocol());
py::bind_vector<std::vector<bool>>(m, "VectorBool"); py::bind_vector<std::vector<bool>>(m, "VectorBool");
py::bind_vector<std::vector<El>>(m, "VectorEl"); py::bind_vector<std::vector<El>>(m, "VectorEl");
py::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl"); py::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl");
m.def("create_undeclstruct", [m] () mutable {
py::bind_vector<std::vector<VUndeclStruct>>(m, "VectorUndeclStruct", py::buffer_protocol());
});
try {
py::module::import("numpy");
} catch (...) {
return;
}
PYBIND11_NUMPY_DTYPE(VStruct, w, x, y, z);
py::class_<VStruct>(m, "VStruct").def_readwrite("x", &VStruct::x);
py::bind_vector<std::vector<VStruct>>(m, "VectorStruct", py::buffer_protocol());
m.def("get_vectorstruct", [] {return std::vector<VStruct> {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};});
}); });
test_initializer stl_binder_map([](py::module &m) { test_initializer stl_binder_map([](py::module &m) {
...@@ -97,4 +126,3 @@ test_initializer stl_binder_noncopyable([](py::module &m) { ...@@ -97,4 +126,3 @@ test_initializer stl_binder_noncopyable([](py::module &m) {
py::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC"); py::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC");
m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>, py::return_value_policy::reference); m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>, py::return_value_policy::reference);
}); });
import pytest
import sys
with pytest.suppress(ImportError):
import numpy as np
def test_vector_int(): def test_vector_int():
from pybind11_tests import VectorInt from pybind11_tests import VectorInt
...@@ -26,6 +33,53 @@ def test_vector_int(): ...@@ -26,6 +33,53 @@ def test_vector_int():
assert v_int2 == VectorInt([0, 99, 2, 3]) assert v_int2 == VectorInt([0, 99, 2, 3])
@pytest.unsupported_on_pypy
def test_vector_buffer():
from pybind11_tests import VectorUChar, create_undeclstruct
b = bytearray([1, 2, 3, 4])
v = VectorUChar(b)
assert v[1] == 2
v[2] = 5
m = memoryview(v) # We expose the buffer interface
if sys.version_info.major > 2:
assert m[2] == 5
m[2] = 6
else:
assert m[2] == '\x05'
m[2] = '\x06'
assert v[2] == 6
with pytest.raises(RuntimeError):
create_undeclstruct() # Undeclared struct contents, no buffer interface
@pytest.requires_numpy
def test_vector_buffer_numpy():
from pybind11_tests import VectorInt, get_vectorstruct
a = np.array([1, 2, 3, 4], dtype=np.int32)
with pytest.raises(TypeError):
VectorInt(a)
a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc)
v = VectorInt(a[0, :])
assert len(v) == 4
assert v[2] == 3
m = np.asarray(v)
m[2] = 5
assert v[2] == 5
v = VectorInt(a[:, 1])
assert len(v) == 3
assert v[2] == 10
v = get_vectorstruct()
assert v[0].x == 5
m = np.asarray(v)
m[1]['x'] = 99
assert v[1].x == 99
def test_vector_custom(): def test_vector_custom():
from pybind11_tests import El, VectorEl, VectorVectorEl from pybind11_tests import El, VectorEl, VectorVectorEl
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment