Commit 4a7c2ec6 by Justin Bassett Committed by Copybara-Service

Forbid absl::Uniform<absl::int128>(gen)

std::is_signed can't be specialized, so this actually lets through non-unsigned types where the types are not language primitives (i.e. it lets absl::int128 through). However, std::numeric_limits can be specialized, and is indeed specialized, so we can use that instead.
PiperOrigin-RevId: 636983590
Change-Id: Ic993518e9cac7c453b08deaf3784b6fba49f15d0
parent 0ef5bc61
...@@ -221,6 +221,8 @@ cc_test( ...@@ -221,6 +221,8 @@ cc_test(
deps = [ deps = [
":distributions", ":distributions",
":random", ":random",
"//absl/meta:type_traits",
"//absl/numeric:int128",
"//absl/random/internal:distribution_test_util", "//absl/random/internal:distribution_test_util",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
......
...@@ -288,6 +288,8 @@ absl_cc_test( ...@@ -288,6 +288,8 @@ absl_cc_test(
DEPS DEPS
absl::random_distributions absl::random_distributions
absl::random_random absl::random_random
absl::type_traits
absl::int128
absl::random_internal_distribution_test_util absl::random_internal_distribution_test_util
GTest::gmock GTest::gmock
GTest::gtest_main GTest::gtest_main
......
...@@ -46,23 +46,23 @@ ...@@ -46,23 +46,23 @@
#ifndef ABSL_RANDOM_DISTRIBUTIONS_H_ #ifndef ABSL_RANDOM_DISTRIBUTIONS_H_
#define ABSL_RANDOM_DISTRIBUTIONS_H_ #define ABSL_RANDOM_DISTRIBUTIONS_H_
#include <algorithm>
#include <cmath>
#include <limits> #include <limits>
#include <random>
#include <type_traits> #include <type_traits>
#include "absl/base/config.h"
#include "absl/base/internal/inline_variable.h" #include "absl/base/internal/inline_variable.h"
#include "absl/meta/type_traits.h"
#include "absl/random/bernoulli_distribution.h" #include "absl/random/bernoulli_distribution.h"
#include "absl/random/beta_distribution.h" #include "absl/random/beta_distribution.h"
#include "absl/random/exponential_distribution.h" #include "absl/random/exponential_distribution.h"
#include "absl/random/gaussian_distribution.h" #include "absl/random/gaussian_distribution.h"
#include "absl/random/internal/distribution_caller.h" // IWYU pragma: export #include "absl/random/internal/distribution_caller.h" // IWYU pragma: export
#include "absl/random/internal/traits.h"
#include "absl/random/internal/uniform_helper.h" // IWYU pragma: export #include "absl/random/internal/uniform_helper.h" // IWYU pragma: export
#include "absl/random/log_uniform_int_distribution.h" #include "absl/random/log_uniform_int_distribution.h"
#include "absl/random/poisson_distribution.h" #include "absl/random/poisson_distribution.h"
#include "absl/random/uniform_int_distribution.h" #include "absl/random/uniform_int_distribution.h" // IWYU pragma: export
#include "absl/random/uniform_real_distribution.h" #include "absl/random/uniform_real_distribution.h" // IWYU pragma: export
#include "absl/random/zipf_distribution.h" #include "absl/random/zipf_distribution.h"
namespace absl { namespace absl {
...@@ -176,7 +176,7 @@ Uniform(TagType tag, ...@@ -176,7 +176,7 @@ Uniform(TagType tag,
return random_internal::DistributionCaller<gen_t>::template Call< return random_internal::DistributionCaller<gen_t>::template Call<
distribution_t>(&urbg, tag, static_cast<return_t>(lo), distribution_t>(&urbg, tag, static_cast<return_t>(lo),
static_cast<return_t>(hi)); static_cast<return_t>(hi));
} }
// absl::Uniform(bitgen, lo, hi) // absl::Uniform(bitgen, lo, hi)
...@@ -200,7 +200,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) ...@@ -200,7 +200,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references)
return random_internal::DistributionCaller<gen_t>::template Call< return random_internal::DistributionCaller<gen_t>::template Call<
distribution_t>(&urbg, static_cast<return_t>(lo), distribution_t>(&urbg, static_cast<return_t>(lo),
static_cast<return_t>(hi)); static_cast<return_t>(hi));
} }
// absl::Uniform<unsigned T>(bitgen) // absl::Uniform<unsigned T>(bitgen)
...@@ -208,7 +208,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) ...@@ -208,7 +208,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references)
// Overload of Uniform() using the minimum and maximum values of a given type // Overload of Uniform() using the minimum and maximum values of a given type
// `T` (which must be unsigned), returning a value of type `unsigned T` // `T` (which must be unsigned), returning a value of type `unsigned T`
template <typename R, typename URBG> template <typename R, typename URBG>
typename absl::enable_if_t<!std::is_signed<R>::value, R> // typename absl::enable_if_t<!std::numeric_limits<R>::is_signed, R> //
Uniform(URBG&& urbg) { // NOLINT(runtime/references) Uniform(URBG&& urbg) { // NOLINT(runtime/references)
using gen_t = absl::decay_t<URBG>; using gen_t = absl::decay_t<URBG>;
using distribution_t = random_internal::UniformDistributionWrapper<R>; using distribution_t = random_internal::UniformDistributionWrapper<R>;
......
...@@ -17,10 +17,14 @@ ...@@ -17,10 +17,14 @@
#include <cfloat> #include <cfloat>
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <random> #include <limits>
#include <type_traits>
#include <utility>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "absl/meta/type_traits.h"
#include "absl/numeric/int128.h"
#include "absl/random/internal/distribution_test_util.h" #include "absl/random/internal/distribution_test_util.h"
#include "absl/random/random.h" #include "absl/random/random.h"
...@@ -30,7 +34,6 @@ constexpr int kSize = 400000; ...@@ -30,7 +34,6 @@ constexpr int kSize = 400000;
class RandomDistributionsTest : public testing::Test {}; class RandomDistributionsTest : public testing::Test {};
struct Invalid {}; struct Invalid {};
template <typename A, typename B> template <typename A, typename B>
...@@ -93,17 +96,18 @@ void CheckArgsInferType() { ...@@ -93,17 +96,18 @@ void CheckArgsInferType() {
} }
template <typename A, typename B, typename ExplicitRet> template <typename A, typename B, typename ExplicitRet>
auto ExplicitUniformReturnT(int) -> decltype( auto ExplicitUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>(
absl::Uniform<ExplicitRet>(*std::declval<absl::InsecureBitGen*>(), std::declval<absl::InsecureBitGen&>(),
std::declval<A>(), std::declval<B>())); std::declval<A>(), std::declval<B>()));
template <typename, typename, typename ExplicitRet> template <typename, typename, typename ExplicitRet>
Invalid ExplicitUniformReturnT(...); Invalid ExplicitUniformReturnT(...);
template <typename TagType, typename A, typename B, typename ExplicitRet> template <typename TagType, typename A, typename B, typename ExplicitRet>
auto ExplicitTaggedUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>( auto ExplicitTaggedUniformReturnT(int)
std::declval<TagType>(), *std::declval<absl::InsecureBitGen*>(), -> decltype(absl::Uniform<ExplicitRet>(
std::declval<A>(), std::declval<B>())); std::declval<TagType>(), std::declval<absl::InsecureBitGen&>(),
std::declval<A>(), std::declval<B>()));
template <typename, typename, typename, typename ExplicitRet> template <typename, typename, typename, typename ExplicitRet>
Invalid ExplicitTaggedUniformReturnT(...); Invalid ExplicitTaggedUniformReturnT(...);
...@@ -135,6 +139,14 @@ void CheckArgsReturnExpectedType() { ...@@ -135,6 +139,14 @@ void CheckArgsReturnExpectedType() {
""); "");
} }
// Takes the type of `absl::Uniform<R>(gen)` if valid or `Invalid` otherwise.
template <typename R>
auto UniformNoBoundsReturnT(int)
-> decltype(absl::Uniform<R>(std::declval<absl::InsecureBitGen&>()));
template <typename>
Invalid UniformNoBoundsReturnT(...);
TEST_F(RandomDistributionsTest, UniformTypeInference) { TEST_F(RandomDistributionsTest, UniformTypeInference) {
// Infers common types. // Infers common types.
CheckArgsInferType<uint16_t, uint16_t, uint16_t>(); CheckArgsInferType<uint16_t, uint16_t, uint16_t>();
...@@ -221,6 +233,38 @@ TEST_F(RandomDistributionsTest, UniformNoBounds) { ...@@ -221,6 +233,38 @@ TEST_F(RandomDistributionsTest, UniformNoBounds) {
absl::Uniform<uint32_t>(gen); absl::Uniform<uint32_t>(gen);
absl::Uniform<uint64_t>(gen); absl::Uniform<uint64_t>(gen);
absl::Uniform<absl::uint128>(gen); absl::Uniform<absl::uint128>(gen);
// Compile-time validity tests.
// Allows unsigned ints.
testing::StaticAssertTypeEq<uint8_t,
decltype(UniformNoBoundsReturnT<uint8_t>(0))>();
testing::StaticAssertTypeEq<uint16_t,
decltype(UniformNoBoundsReturnT<uint16_t>(0))>();
testing::StaticAssertTypeEq<uint32_t,
decltype(UniformNoBoundsReturnT<uint32_t>(0))>();
testing::StaticAssertTypeEq<uint64_t,
decltype(UniformNoBoundsReturnT<uint64_t>(0))>();
testing::StaticAssertTypeEq<
absl::uint128, decltype(UniformNoBoundsReturnT<absl::uint128>(0))>();
// Disallows signed ints.
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int8_t>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int16_t>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int32_t>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int64_t>(0))>();
testing::StaticAssertTypeEq<
Invalid, decltype(UniformNoBoundsReturnT<absl::int128>(0))>();
// Disallows float types.
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<float>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<double>(0))>();
} }
TEST_F(RandomDistributionsTest, UniformNonsenseRanges) { TEST_F(RandomDistributionsTest, UniformNonsenseRanges) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment