Commit 3132b38e by Jean-Baptiste Lespiau Committed by Copybara-Service

Add absl::variant type_caster.

PiperOrigin-RevId: 348646910
parent 22974484
......@@ -38,6 +38,7 @@
#include <stdexcept>
#include <type_traits>
#include <typeinfo>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
......@@ -337,6 +338,15 @@ template <>
struct type_caster<absl::nullopt_t> : public void_caster<absl::nullopt_t> {};
#endif
// This is a simple port of the pybind11 std::variant type_caster, applied to
// absl::variant. See pybind11 stl.h.
#ifndef ABSL_HAVE_STD_VARIANT
template <typename... Ts>
struct type_caster<absl::variant<Ts...>>
: variant_caster<absl::variant<Ts...>> {};
#endif
} // namespace detail
} // namespace pybind11
......
......@@ -196,6 +196,33 @@ void DefineNonConstSpan(module* py_m, absl::string_view type_name) {
&FillNonConstSpan<T>, arg("value"), arg("output_span").noconvert());
}
// absl::variant
struct A {
int a;
};
struct B {
int b;
};
typedef absl::variant<A, B> AOrB;
int VariantToInt(AOrB value) {
if (absl::holds_alternative<A>(value)) {
return absl::get<A>(value).a;
} else if (absl::holds_alternative<B>(value)) {
return absl::get<B>(value).b;
} else {
throw std::exception();
}
}
std::vector<AOrB> IdentityWithCopy(const std::vector<AOrB>& value) {
return value;
}
std::vector<absl::variant<A*, B*>> Identity(
const std::vector<absl::variant<A*, B*>>& value) {
return value;
}
PYBIND11_MODULE(absl_example, m) {
// absl::Time/Duration bindings.
m.def("make_duration", &MakeDuration, arg("secs"));
......@@ -257,6 +284,13 @@ PYBIND11_MODULE(absl_example, m) {
// absl::flat_hash_set bindings
m.def("make_set", &MakeSet, arg("values"));
m.def("check_set", &CheckSet, arg("set"), arg("values"));
// absl::variant
class_<A>(m, "A").def(init<int>()).def_readonly("a", &A::a);
class_<B>(m, "B").def(init<int>()).def_readonly("b", &B::b);
m.def("VariantToInt", &VariantToInt);
m.def("Identity", &Identity);
m.def("IdentityWithCopy", &IdentityWithCopy);
}
} // namespace test
......
......@@ -365,5 +365,26 @@ class AbslOptionalTest(absltest.TestCase):
self.assertIsNone(absl_example.make_optional())
class AbslVariantTest(absltest.TestCase):
def test_variant(self):
assert absl_example.VariantToInt(absl_example.A(3)) == 3
assert absl_example.VariantToInt(absl_example.B(5)) == 5
for identity_f, should_be_equal in [(absl_example.Identity, True),
(absl_example.IdentityWithCopy, False)]:
objs = [absl_example.A(3), absl_example.B(5)]
vector = identity_f(objs)
self.assertLen(vector, 2)
self.assertIsInstance(vector[0], absl_example.A)
self.assertEqual(vector[0].a, 3)
self.assertIsInstance(vector[1], absl_example.B)
self.assertEqual(vector[1].b, 5)
if should_be_equal:
self.assertEqual(objs, vector)
else:
self.assertNotEqual(objs, vector)
if __name__ == '__main__':
absltest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment