Commit 1fac1b9f by Dean Moldovan Committed by Wenzel Jakob

Make py::iterator compatible with std algorithms

The added type aliases are required by `std::iterator_traits`.
Python iterators satisfy the `InputIterator` concept in C++.
parent f7685826
...@@ -602,6 +602,12 @@ NAMESPACE_END(detail) ...@@ -602,6 +602,12 @@ NAMESPACE_END(detail)
\endrst */ \endrst */
class iterator : public object { class iterator : public object {
public: public:
using iterator_category = std::input_iterator_tag;
using difference_type = ssize_t;
using value_type = handle;
using reference = const handle;
using pointer = const handle *;
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
iterator& operator++() { iterator& operator++() {
...@@ -615,7 +621,7 @@ public: ...@@ -615,7 +621,7 @@ public:
return rv; return rv;
} }
handle operator*() const { reference operator*() const {
if (m_ptr && !value.ptr()) { if (m_ptr && !value.ptr()) {
auto& self = const_cast<iterator &>(*this); auto& self = const_cast<iterator &>(*this);
self.advance(); self.advance();
...@@ -623,7 +629,7 @@ public: ...@@ -623,7 +629,7 @@ public:
return value; return value;
} }
const handle *operator->() const { operator*(); return &value; } pointer operator->() const { operator*(); return &value; }
/** \rst /** \rst
The value which marks the end of the iteration. ``it == iterator::sentinel()`` The value which marks the end of the iteration. ``it == iterator::sentinel()``
......
...@@ -290,4 +290,14 @@ test_initializer sequences_and_iterators([](py::module &pm) { ...@@ -290,4 +290,14 @@ test_initializer sequences_and_iterators([](py::module &pm) {
} }
return l; return l;
}); });
// Make sure that py::iterator works with std algorithms
m.def("count_none", [](py::object o) {
return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
});
m.def("find_none", [](py::object o) {
auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
return it->is_none();
});
}); });
...@@ -113,3 +113,7 @@ def test_python_iterator_in_cpp(): ...@@ -113,3 +113,7 @@ def test_python_iterator_in_cpp():
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
m.iterator_to_list(iter(bad_next_call, None)) m.iterator_to_list(iter(bad_next_call, None))
assert str(excinfo.value) == "py::iterator::advance() should propagate errors" assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
l = [1, None, 0, None]
assert m.count_none(l) == 2
assert m.find_none(l) is True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment