Commit 254b3a53 by Justin Bassett Committed by Copybara-Service

Add (unused) validation to absl::MockingBitGen

`absl::Uniform(tag, rng, a, b)` has some restrictions on the values it can produce in that it will always be in the range specified by `a` and `b`, but these restrictions can be violated by `absl::MockingBitGen`. This makes it easier than necessary to introduce a bug in tests using a mock RNG.

We can fix this by making `MockingBitGen` emit a runtime error if the value produced is out of bounds.

Immediately fixing all the internal buggy uses of `MockingBitGen` is currently infeasible, so the plan is this:

 1. Add turned-off validation to `MockingBitGen` to avoid the costs of maintaining unsubmitted code.
 2. Temporarily migrate the internal buggy use cases to keep the current behavior, to be fixed later.
 3. Turn on validation for `MockingBitGen`.
 4. Fix the internal buggy use cases over time.

---

A few of the different categories of errors I found:

 - `Call(tag, rng, a, b) -> a or b`, for open/half-open intervals (i.e. incorrect boundary condition). This case happens quite a lot, e.g. by specifying `absl::Uniform<double>(rng, 0, 1)` to return `1.0`.
 - `Call(tag, rng, 0, 1) -> 42` (i.e. return an arbitrary value). These may be straightforward to fix by just returning an in-range value, or sometimes they are difficult to fix because other data structures depend on those values.

PiperOrigin-RevId: 635503223
Change-Id: I9293ab78e79450e2b7b682dcb05149f238ecc550
parent 93ac3a4f
......@@ -140,9 +140,9 @@ cc_library(
deps = [
":distributions",
":mocking_bit_gen",
"//absl/meta:type_traits",
"//absl/base:config",
"//absl/random/internal:mock_overload_set",
"@com_google_googletest//:gtest",
"//absl/random/internal:mock_validators",
],
)
......@@ -154,15 +154,13 @@ cc_library(
],
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
":distributions",
":random",
"//absl/base:config",
"//absl/base:core_headers",
"//absl/base:fast_type_id",
"//absl/container:flat_hash_map",
"//absl/meta:type_traits",
"//absl/random/internal:distribution_caller",
"//absl/strings",
"//absl/types:span",
"//absl/types:variant",
"//absl/random/internal:mock_helpers",
"//absl/utility",
"@com_google_googletest//:gtest",
],
......@@ -481,6 +479,7 @@ cc_test(
"no_test_wasm",
],
deps = [
":distributions",
":mock_distributions",
":mocking_bit_gen",
":random",
......
......@@ -77,6 +77,7 @@ absl_cc_library(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
absl::config
absl::fast_type_id
absl::optional
)
......@@ -92,6 +93,7 @@ absl_cc_library(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
absl::config
absl::random_mocking_bit_gen
absl::random_internal_mock_helpers
TESTONLY
......@@ -108,17 +110,15 @@ absl_cc_library(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
absl::config
absl::core_headers
absl::flat_hash_map
absl::raw_logging_internal
absl::random_distributions
absl::random_internal_distribution_caller
absl::random_internal_mock_helpers
absl::random_internal_mock_overload_set
absl::random_internal_mock_validators
absl::random_random
absl::strings
absl::span
absl::type_traits
absl::utility
absl::variant
GTest::gmock
GTest::gtest
PUBLIC
......@@ -135,6 +135,7 @@ absl_cc_test(
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
absl::random_distributions
absl::random_mocking_bit_gen
absl::random_random
GTest::gmock
......@@ -1173,6 +1174,26 @@ absl_cc_library(
)
# Internal-only target, do not depend on directly.
absl_cc_library(
NAME
random_internal_mock_validators
HDRS
"internal/mock_validators.h"
COPTS
${ABSL_DEFAULT_COPTS}
LINKOPTS
${ABSL_DEFAULT_LINKOPTS}
DEPS
absl::random_internal_iostream_state_saver
absl::random_internal_uniform_helper
absl::config
absl::raw_logging_internal
absl::strings
absl::string_view
TESTONLY
)
# Internal-only target, do not depend on directly.
absl_cc_test(
NAME
random_internal_uniform_helper_test
......
......@@ -527,6 +527,7 @@ cc_library(
hdrs = ["mock_helpers.h"],
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
"//absl/base:config",
"//absl/base:fast_type_id",
"//absl/types:optional",
],
......@@ -539,6 +540,7 @@ cc_library(
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
":mock_helpers",
"//absl/base:config",
"//absl/random:mocking_bit_gen",
"@com_google_googletest//:gtest",
],
......@@ -712,7 +714,19 @@ cc_library(
":traits",
"//absl/base:config",
"//absl/meta:type_traits",
"//absl/numeric:int128",
],
)
cc_library(
name = "mock_validators",
hdrs = ["mock_validators.h"],
deps = [
":iostream_state_saver",
":uniform_helper",
"//absl/base:config",
"//absl/base:raw_logging_internal",
"//absl/strings",
"//absl/strings:string_view",
],
)
......
......@@ -16,10 +16,9 @@
#ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
#define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/base/config.h"
#include "absl/base/internal/fast_type_id.h"
#include "absl/types/optional.h"
......@@ -27,6 +26,16 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
// A no-op validator meeting the ValidatorT requirements for MockHelpers.
//
// Custom validators should follow a similar structure, passing the type to
// MockHelpers::MockFor<KeyT>(m, CustomValidatorT()).
struct NoOpValidator {
// Default validation: do nothing.
template <typename ResultT, typename... Args>
static void Validate(ResultT, Args&&...) {}
};
// MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and
// BitGenRef to enable the mocking capability for absl distribution functions.
//
......@@ -109,22 +118,39 @@ class MockHelpers {
0, urbg, std::forward<Args>(args)...);
}
// Acquire a mock for the KeyT (may or may not be a signature).
// Acquire a mock for the KeyT (may or may not be a signature), set up to use
// the ValidatorT to verify that the result is in the range of the RNG
// function.
//
// KeyT is used to generate a typeid-based lookup for the mock.
// KeyT is a signature of the form:
// result_type(discriminator_type, std::tuple<args...>)
// The mocked function signature will be composed from KeyT as:
// result_type(args...)
template <typename KeyT, typename MockURBG>
static auto MockFor(MockURBG& m)
// ValidatorT::Validate will be called after the result of the RNG. The
// signature is expected to be of the form:
// ValidatorT::Validate(result, args...)
template <typename KeyT, typename ValidatorT, typename MockURBG>
static auto MockFor(MockURBG& m, ValidatorT)
-> decltype(m.template RegisterMock<
typename KeySignature<KeyT>::result_type,
typename KeySignature<KeyT>::arg_tuple_type>(
m, std::declval<IdType>())) {
m, std::declval<IdType>(), ValidatorT())) {
return m.template RegisterMock<typename KeySignature<KeyT>::result_type,
typename KeySignature<KeyT>::arg_tuple_type>(
m, ::absl::base_internal::FastTypeId<KeyT>());
m, ::absl::base_internal::FastTypeId<KeyT>(), ValidatorT());
}
// Acquire a mock for the KeyT (may or may not be a signature).
//
// KeyT is used to generate a typeid-based lookup for the mock.
// KeyT is a signature of the form:
// result_type(discriminator_type, std::tuple<args...>)
// The mocked function signature will be composed from KeyT as:
// result_type(args...)
template <typename KeyT, typename MockURBG>
static decltype(auto) MockFor(MockURBG& m) {
return MockFor<KeyT>(m, NoOpValidator());
}
};
......
......@@ -16,9 +16,11 @@
#ifndef ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_
#define ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_
#include <tuple>
#include <type_traits>
#include "gmock/gmock.h"
#include "absl/base/config.h"
#include "absl/random/internal/mock_helpers.h"
#include "absl/random/mocking_bit_gen.h"
......@@ -26,7 +28,7 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
template <typename DistrT, typename Fn>
template <typename DistrT, typename ValidatorT, typename Fn>
struct MockSingleOverload;
// MockSingleOverload
......@@ -38,8 +40,8 @@ struct MockSingleOverload;
// arguments to MockingBitGen::Register.
//
// The underlying KeyT must match the KeyT constructed by DistributionCaller.
template <typename DistrT, typename Ret, typename... Args>
struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
template <typename DistrT, typename ValidatorT, typename Ret, typename... Args>
struct MockSingleOverload<DistrT, ValidatorT, Ret(MockingBitGen&, Args...)> {
static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
"Overload signature must have return type matching the "
"distribution result_type.");
......@@ -47,15 +49,21 @@ struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
template <typename MockURBG>
auto gmock_Call(MockURBG& gen, const ::testing::Matcher<Args>&... matchers)
-> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) {
static_assert(std::is_base_of<MockingBitGen, MockURBG>::value,
"Mocking requires an absl::MockingBitGen");
return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...);
-> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT())
.gmock_Call(matchers...)) {
static_assert(
std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value ||
std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value,
"Mocking requires an absl::MockingBitGen");
return MockHelpers::MockFor<KeyT>(gen, ValidatorT())
.gmock_Call(matchers...);
}
};
template <typename DistrT, typename Ret, typename Arg, typename... Args>
struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
template <typename DistrT, typename ValidatorT, typename Ret, typename Arg,
typename... Args>
struct MockSingleOverload<DistrT, ValidatorT,
Ret(Arg, MockingBitGen&, Args...)> {
static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
"Overload signature must have return type matching the "
"distribution result_type.");
......@@ -64,14 +72,44 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
template <typename MockURBG>
auto gmock_Call(const ::testing::Matcher<Arg>& matcher, MockURBG& gen,
const ::testing::Matcher<Args>&... matchers)
-> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher,
matchers...)) {
static_assert(std::is_base_of<MockingBitGen, MockURBG>::value,
"Mocking requires an absl::MockingBitGen");
return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...);
-> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT())
.gmock_Call(matcher, matchers...)) {
static_assert(
std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value ||
std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value,
"Mocking requires an absl::MockingBitGen");
return MockHelpers::MockFor<KeyT>(gen, ValidatorT())
.gmock_Call(matcher, matchers...);
}
};
// MockOverloadSetWithValidator
//
// MockOverloadSetWithValidator is a wrapper around MockOverloadSet which takes
// an additional Validator parameter, allowing for customization of the mock
// behavior.
//
// `ValidatorT::Validate(result, args...)` will be called after the mock
// distribution returns a value in `result`, allowing for validation against the
// args.
template <typename DistrT, typename ValidatorT, typename... Fns>
struct MockOverloadSetWithValidator;
template <typename DistrT, typename ValidatorT, typename Sig>
struct MockOverloadSetWithValidator<DistrT, ValidatorT, Sig>
: public MockSingleOverload<DistrT, ValidatorT, Sig> {
using MockSingleOverload<DistrT, ValidatorT, Sig>::gmock_Call;
};
template <typename DistrT, typename ValidatorT, typename FirstSig,
typename... Rest>
struct MockOverloadSetWithValidator<DistrT, ValidatorT, FirstSig, Rest...>
: public MockSingleOverload<DistrT, ValidatorT, FirstSig>,
public MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...> {
using MockSingleOverload<DistrT, ValidatorT, FirstSig>::gmock_Call;
using MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...>::gmock_Call;
};
// MockOverloadSet
//
// MockOverloadSet takes a distribution and a collection of signatures and
......@@ -79,20 +117,8 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
// `EXPECT_CALL(mock_overload_set, Call(...))` expand and do overload resolution
// correctly.
template <typename DistrT, typename... Signatures>
struct MockOverloadSet;
template <typename DistrT, typename Sig>
struct MockOverloadSet<DistrT, Sig> : public MockSingleOverload<DistrT, Sig> {
using MockSingleOverload<DistrT, Sig>::gmock_Call;
};
template <typename DistrT, typename FirstSig, typename... Rest>
struct MockOverloadSet<DistrT, FirstSig, Rest...>
: public MockSingleOverload<DistrT, FirstSig>,
public MockOverloadSet<DistrT, Rest...> {
using MockSingleOverload<DistrT, FirstSig>::gmock_Call;
using MockOverloadSet<DistrT, Rest...>::gmock_Call;
};
using MockOverloadSet =
MockOverloadSetWithValidator<DistrT, NoOpValidator, Signatures...>;
} // namespace random_internal
ABSL_NAMESPACE_END
......
// Copyright 2024 The Abseil Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
#define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
#include <type_traits>
#include "absl/base/config.h"
#include "absl/base/internal/raw_logging.h"
#include "absl/random/internal/iostream_state_saver.h"
#include "absl/random/internal/uniform_helper.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
template <typename NumType>
class UniformDistributionValidator {
public:
template <typename TagType>
static void Validate(NumType x, TagType tag, NumType lo, NumType hi) {
// For invalid ranges, absl::Uniform() simply returns one of the bounds.
if (x == lo && lo == hi) return;
ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi);
}
static void Validate(NumType x, NumType lo, NumType hi) {
Validate(x, IntervalClosedOpenTag(), lo, hi);
}
template <typename NumType_ = NumType>
static void Validate(NumType) {
// absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any
// value is okay.
static_assert(std::is_integral<NumType_>{},
"Non-integer types may have valid values outside of the full "
"range (e.g. floating point NaN).");
}
private:
static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; }
static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; }
static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; }
static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; }
static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; }
static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; }
static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; }
static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; }
template <typename TagType>
static void ValidateImpl(std::true_type /* is_floating_point */, NumType x,
TagType tag, NumType lo, NumType hi) {
UniformDistributionWrapper<NumType> dist(tag, lo, hi);
NumType lb = dist.a();
NumType ub = dist.b();
// uniform_real_distribution is always closed-open, so the upper bound is
// always non-inclusive.
ABSL_INTERNAL_CHECK(lb <= x && x < ub,
absl::StrCat(x, " is not in ", TagLbBound(tag), lo,
", ", hi, TagUbBound(tag)));
}
template <typename TagType>
static void ValidateImpl(std::false_type /* is_floating_point */, NumType x,
TagType tag, NumType lo, NumType hi) {
using stream_type =
typename random_internal::stream_format_type<NumType>::type;
UniformDistributionWrapper<NumType> dist(tag, lo, hi);
NumType lb = dist.a();
NumType ub = dist.b();
ABSL_INTERNAL_CHECK(
lb <= x && x <= ub,
absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag),
stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag)));
}
};
} // namespace random_internal
ABSL_NAMESPACE_END
} // namespace absl
#endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
......@@ -46,16 +46,18 @@
#ifndef ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_
#define ABSL_RANDOM_MOCK_DISTRIBUTIONS_H_
#include <limits>
#include <type_traits>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/meta/type_traits.h"
#include "absl/base/config.h"
#include "absl/random/bernoulli_distribution.h"
#include "absl/random/beta_distribution.h"
#include "absl/random/distributions.h"
#include "absl/random/exponential_distribution.h"
#include "absl/random/gaussian_distribution.h"
#include "absl/random/internal/mock_overload_set.h"
#include "absl/random/internal/mock_validators.h"
#include "absl/random/log_uniform_int_distribution.h"
#include "absl/random/mocking_bit_gen.h"
#include "absl/random/poisson_distribution.h"
#include "absl/random/zipf_distribution.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
......@@ -80,8 +82,9 @@ ABSL_NAMESPACE_BEGIN
// assert(x == 123456)
//
template <typename R>
using MockUniform = random_internal::MockOverloadSet<
using MockUniform = random_internal::MockOverloadSetWithValidator<
random_internal::UniformDistributionWrapper<R>,
random_internal::UniformDistributionValidator<R>,
R(IntervalClosedOpenTag, MockingBitGen&, R, R),
R(IntervalClosedClosedTag, MockingBitGen&, R, R),
R(IntervalOpenOpenTag, MockingBitGen&, R, R),
......
......@@ -14,7 +14,12 @@
#include "absl/random/mock_distributions.h"
#include <cmath>
#include <limits>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/random/distributions.h"
#include "absl/random/mocking_bit_gen.h"
#include "absl/random/random.h"
......@@ -69,4 +74,205 @@ TEST(MockDistributions, Examples) {
EXPECT_EQ(absl::LogUniform<int>(gen, 0, 1000000, 2), 2040);
}
TEST(MockUniform, OutOfBoundsIsAllowed) {
absl::MockingBitGen gen;
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100)).WillOnce(Return(0));
EXPECT_EQ(absl::Uniform<int>(gen, 1, 100), 0);
}
TEST(ValidatedMockDistributions, UniformDoubleBoundaryCases) {
absl::random_internal::MockingBitGenImpl<true> gen;
EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0))
.WillOnce(Return(
std::nextafter(10.0, -std::numeric_limits<double>::infinity())));
EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0),
std::nextafter(10.0, -std::numeric_limits<double>::infinity()));
EXPECT_CALL(absl::MockUniform<double>(),
Call(absl::IntervalOpen, gen, 1.0, 10.0))
.WillOnce(Return(
std::nextafter(10.0, -std::numeric_limits<double>::infinity())));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
std::nextafter(10.0, -std::numeric_limits<double>::infinity()));
EXPECT_CALL(absl::MockUniform<double>(),
Call(absl::IntervalOpen, gen, 1.0, 10.0))
.WillOnce(
Return(std::nextafter(1.0, std::numeric_limits<double>::infinity())));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
std::nextafter(1.0, std::numeric_limits<double>::infinity()));
}
TEST(ValidatedMockDistributions, UniformDoubleEmptyRangeCases) {
absl::random_internal::MockingBitGenImpl<true> gen;
ON_CALL(absl::MockUniform<double>(), Call(absl::IntervalOpen, gen, 1.0, 1.0))
.WillByDefault(Return(1.0));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 1.0), 1.0);
ON_CALL(absl::MockUniform<double>(),
Call(absl::IntervalOpenClosed, gen, 1.0, 1.0))
.WillByDefault(Return(1.0));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpenClosed, gen, 1.0, 1.0),
1.0);
ON_CALL(absl::MockUniform<double>(),
Call(absl::IntervalClosedOpen, gen, 1.0, 1.0))
.WillByDefault(Return(1.0));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalClosedOpen, gen, 1.0, 1.0),
1.0);
}
TEST(ValidatedMockDistributions, UniformIntEmptyRangeCases) {
absl::random_internal::MockingBitGenImpl<true> gen;
ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpen, gen, 1, 1))
.WillByDefault(Return(1));
EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpen, gen, 1, 1), 1);
ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalOpenClosed, gen, 1, 1))
.WillByDefault(Return(1));
EXPECT_EQ(absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 1), 1);
ON_CALL(absl::MockUniform<int>(), Call(absl::IntervalClosedOpen, gen, 1, 1))
.WillByDefault(Return(1));
EXPECT_EQ(absl::Uniform<int>(absl::IntervalClosedOpen, gen, 1, 1), 1);
}
TEST(ValidatedMockUniformDeathTest, Examples) {
absl::random_internal::MockingBitGenImpl<true> gen;
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100))
.WillOnce(Return(0));
absl::Uniform<int>(gen, 1, 100);
},
" 0 is not in \\[1, 100\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100))
.WillOnce(Return(101));
absl::Uniform<int>(gen, 1, 100);
},
" 101 is not in \\[1, 100\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 100))
.WillOnce(Return(100));
absl::Uniform<int>(gen, 1, 100);
},
" 100 is not in \\[1, 100\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpen, gen, 1, 100))
.WillOnce(Return(1));
absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100);
},
" 1 is not in \\(1, 100\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpen, gen, 1, 100))
.WillOnce(Return(101));
absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100);
},
" 101 is not in \\(1, 100\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpen, gen, 1, 100))
.WillOnce(Return(100));
absl::Uniform<int>(absl::IntervalOpen, gen, 1, 100);
},
" 100 is not in \\(1, 100\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpenClosed, gen, 1, 100))
.WillOnce(Return(1));
absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
},
" 1 is not in \\(1, 100\\]");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpenClosed, gen, 1, 100))
.WillOnce(Return(101));
absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
},
" 101 is not in \\(1, 100\\]");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpenClosed, gen, 1, 100))
.WillOnce(Return(0));
absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
},
" 0 is not in \\(1, 100\\]");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalOpenClosed, gen, 1, 100))
.WillOnce(Return(101));
absl::Uniform<int>(absl::IntervalOpenClosed, gen, 1, 100);
},
" 101 is not in \\(1, 100\\]");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalClosed, gen, 1, 100))
.WillOnce(Return(0));
absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100);
},
" 0 is not in \\[1, 100\\]");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<int>(),
Call(absl::IntervalClosed, gen, 1, 100))
.WillOnce(Return(101));
absl::Uniform<int>(absl::IntervalClosed, gen, 1, 100);
},
" 101 is not in \\[1, 100\\]");
}
TEST(ValidatedMockUniformDeathTest, DoubleBoundaryCases) {
absl::random_internal::MockingBitGenImpl<true> gen;
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<double>(), Call(gen, 1.0, 10.0))
.WillOnce(Return(10.0));
EXPECT_EQ(absl::Uniform<double>(gen, 1.0, 10.0), 10.0);
},
" 10 is not in \\[1, 10\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<double>(),
Call(absl::IntervalOpen, gen, 1.0, 10.0))
.WillOnce(Return(10.0));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
10.0);
},
" 10 is not in \\(1, 10\\)");
EXPECT_DEATH_IF_SUPPORTED(
{
EXPECT_CALL(absl::MockUniform<double>(),
Call(absl::IntervalOpen, gen, 1.0, 10.0))
.WillOnce(Return(1.0));
EXPECT_EQ(absl::Uniform<double>(absl::IntervalOpen, gen, 1.0, 10.0),
1.0);
},
" 1 is not in \\(1, 10\\)");
}
} // namespace
......@@ -28,83 +28,37 @@
#ifndef ABSL_RANDOM_MOCKING_BIT_GEN_H_
#define ABSL_RANDOM_MOCKING_BIT_GEN_H_
#include <iterator>
#include <limits>
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/base/attributes.h"
#include "absl/base/config.h"
#include "absl/base/internal/fast_type_id.h"
#include "absl/container/flat_hash_map.h"
#include "absl/meta/type_traits.h"
#include "absl/random/distributions.h"
#include "absl/random/internal/distribution_caller.h"
#include "absl/random/internal/mock_helpers.h"
#include "absl/random/random.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "absl/utility/utility.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
class BitGenRef;
namespace random_internal {
template <typename>
struct DistributionCaller;
class MockHelpers;
} // namespace random_internal
class BitGenRef;
// MockingBitGen
//
// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class
// which can act in place of an `absl::BitGen` URBG within tests using the
// Googletest testing framework.
//
// Usage:
//
// Use an `absl::MockingBitGen` along with a mock distribution object (within
// mock_distributions.h) inside Googletest constructs such as ON_CALL(),
// EXPECT_TRUE(), etc. to produce deterministic results conforming to the
// distribution's API contract.
//
// Example:
//
// // Mock a call to an `absl::Bernoulli` distribution using Googletest
// absl::MockingBitGen bitgen;
//
// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5))
// .WillByDefault(testing::Return(true));
// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5));
//
// // Mock a call to an `absl::Uniform` distribution within Googletest
// absl::MockingBitGen bitgen;
//
// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_))
// .WillByDefault([] (int low, int high) {
// return low + (high - low) / 2;
// });
//
// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5);
// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35);
//
// At this time, only mock distributions supplied within the Abseil random
// library are officially supported.
//
// EXPECT_CALL and ON_CALL need to be made within the same DLL component as
// the call to absl::Uniform and related methods, otherwise mocking will fail
// since the underlying implementation creates a type-specific pointer which
// will be distinct across different DLL boundaries.
//
class MockingBitGen {
// Implements MockingBitGen with an option to turn on extra validation.
template <bool EnableValidation>
class MockingBitGenImpl {
public:
MockingBitGen() = default;
~MockingBitGen() = default;
MockingBitGenImpl() = default;
~MockingBitGenImpl() = default;
// URBG interface
using result_type = absl::BitGen::result_type;
......@@ -125,15 +79,19 @@ class MockingBitGen {
// NOTE: MockFnCaller is essentially equivalent to the lambda:
// [fn](auto... args) { return fn->Call(std::move(args)...)}
// however that fails to build on some supported platforms.
template <typename MockFnType, typename ResultT, typename Tuple>
template <typename MockFnType, typename ValidatorT, typename ResultT,
typename Tuple>
struct MockFnCaller;
// specialization for std::tuple.
template <typename MockFnType, typename ResultT, typename... Args>
struct MockFnCaller<MockFnType, ResultT, std::tuple<Args...>> {
template <typename MockFnType, typename ValidatorT, typename ResultT,
typename... Args>
struct MockFnCaller<MockFnType, ValidatorT, ResultT, std::tuple<Args...>> {
MockFnType* fn;
inline ResultT operator()(Args... args) {
return fn->Call(std::move(args)...);
ResultT result = fn->Call(args...);
ValidatorT::Validate(result, args...);
return result;
}
};
......@@ -150,16 +108,17 @@ class MockingBitGen {
/*ResultT*/ void* result) = 0;
};
template <typename MockFnType, typename ResultT, typename ArgTupleT>
template <typename MockFnType, typename ValidatorT, typename ResultT,
typename ArgTupleT>
class FunctionHolderImpl final : public FunctionHolder {
public:
void Apply(void* args_tuple, void* result) override {
void Apply(void* args_tuple, void* result) final {
// Requires tuple_args to point to a ArgTupleT, which is a
// std::tuple<Args...> used to invoke the mock function. Requires result
// to point to a ResultT, which is the result of the call.
*static_cast<ResultT*>(result) =
absl::apply(MockFnCaller<MockFnType, ResultT, ArgTupleT>{&mock_fn_},
*static_cast<ArgTupleT*>(args_tuple));
*static_cast<ResultT*>(result) = absl::apply(
MockFnCaller<MockFnType, ValidatorT, ResultT, ArgTupleT>{&mock_fn_},
*static_cast<ArgTupleT*>(args_tuple));
}
MockFnType mock_fn_;
......@@ -175,26 +134,29 @@ class MockingBitGen {
//
// The returned MockFunction<...> type can be used to setup additional
// distribution parameters of the expectation.
template <typename ResultT, typename ArgTupleT, typename SelfT>
auto RegisterMock(SelfT&, base_internal::FastTypeIdType type)
template <typename ResultT, typename ArgTupleT, typename SelfT,
typename ValidatorT>
auto RegisterMock(SelfT&, base_internal::FastTypeIdType type, ValidatorT)
-> decltype(GetMockFnType(std::declval<ResultT>(),
std::declval<ArgTupleT>()))& {
using ActualValidatorT =
std::conditional_t<EnableValidation, ValidatorT, NoOpValidator>;
using MockFnType = decltype(GetMockFnType(std::declval<ResultT>(),
std::declval<ArgTupleT>()));
using WrappedFnType = absl::conditional_t<
std::is_same<SelfT, ::testing::NiceMock<absl::MockingBitGen>>::value,
std::is_same<SelfT, ::testing::NiceMock<MockingBitGenImpl>>::value,
::testing::NiceMock<MockFnType>,
absl::conditional_t<
std::is_same<SelfT,
::testing::NaggyMock<absl::MockingBitGen>>::value,
std::is_same<SelfT, ::testing::NaggyMock<MockingBitGenImpl>>::value,
::testing::NaggyMock<MockFnType>,
absl::conditional_t<
std::is_same<SelfT,
::testing::StrictMock<absl::MockingBitGen>>::value,
::testing::StrictMock<MockingBitGenImpl>>::value,
::testing::StrictMock<MockFnType>, MockFnType>>>;
using ImplT = FunctionHolderImpl<WrappedFnType, ResultT, ArgTupleT>;
using ImplT =
FunctionHolderImpl<WrappedFnType, ActualValidatorT, ResultT, ArgTupleT>;
auto& mock = mocks_[type];
if (!mock) {
mock = absl::make_unique<ImplT>();
......@@ -234,6 +196,58 @@ class MockingBitGen {
// InvokeMock
};
} // namespace random_internal
// MockingBitGen
//
// `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class
// which can act in place of an `absl::BitGen` URBG within tests using the
// Googletest testing framework.
//
// Usage:
//
// Use an `absl::MockingBitGen` along with a mock distribution object (within
// mock_distributions.h) inside Googletest constructs such as ON_CALL(),
// EXPECT_TRUE(), etc. to produce deterministic results conforming to the
// distribution's API contract.
//
// Example:
//
// // Mock a call to an `absl::Bernoulli` distribution using Googletest
// absl::MockingBitGen bitgen;
//
// ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5))
// .WillByDefault(testing::Return(true));
// EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5));
//
// // Mock a call to an `absl::Uniform` distribution within Googletest
// absl::MockingBitGen bitgen;
//
// ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_))
// .WillByDefault([] (int low, int high) {
// return low + (high - low) / 2;
// });
//
// EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5);
// EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35);
//
// At this time, only mock distributions supplied within the Abseil random
// library are officially supported.
//
// EXPECT_CALL and ON_CALL need to be made within the same DLL component as
// the call to absl::Uniform and related methods, otherwise mocking will fail
// since the underlying implementation creates a type-specific pointer which
// will be distinct across different DLL boundaries.
//
using MockingBitGen = random_internal::MockingBitGenImpl<false>;
// UnvalidatedMockingBitGen
//
// UnvalidatedMockingBitGen is a variant of MockingBitGen which does no extra
// validation.
using UnvalidatedMockingBitGen ABSL_DEPRECATED("Use MockingBitGen instead") =
random_internal::MockingBitGenImpl<false>;
ABSL_NAMESPACE_END
} // namespace absl
......
......@@ -16,9 +16,11 @@
#include "absl/random/mocking_bit_gen.h"
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <random>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest-spi.h"
......@@ -246,33 +248,33 @@ TEST(WillOnce, DistinctCounters) {
absl::MockingBitGen gen;
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000))
.Times(3)
.WillRepeatedly(Return(0));
.WillRepeatedly(Return(1));
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1000001, 2000000))
.Times(3)
.WillRepeatedly(Return(1));
EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0);
EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0);
EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 0);
.WillRepeatedly(Return(1000001));
EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1);
EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1);
EXPECT_EQ(absl::Uniform(gen, 1000001, 2000000), 1000001);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 1);
}
TEST(TimesModifier, ModifierSaturatesAndExpires) {
EXPECT_NONFATAL_FAILURE(
[]() {
absl::MockingBitGen gen;
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 1, 1000000))
EXPECT_CALL(absl::MockUniform<int>(), Call(gen, 0, 1000000))
.Times(3)
.WillRepeatedly(Return(15))
.RetiresOnSaturation();
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15);
EXPECT_EQ(absl::Uniform(gen, 1, 1000000), 15);
EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15);
EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15);
EXPECT_EQ(absl::Uniform(gen, 0, 1000000), 15);
// Times(3) has expired - Should get a different value now.
EXPECT_NE(absl::Uniform(gen, 1, 1000000), 15);
EXPECT_NE(absl::Uniform(gen, 0, 1000000), 15);
}(),
"");
}
......@@ -394,7 +396,7 @@ TEST(MockingBitGen, StrictMock_TooMany) {
EXPECT_EQ(absl::Uniform(gen, 1, 1000), 145);
EXPECT_NONFATAL_FAILURE(
[&]() { EXPECT_EQ(absl::Uniform(gen, 10, 1000), 0); }(),
[&]() { EXPECT_EQ(absl::Uniform(gen, 0, 1000), 0); }(),
"over-saturated and active");
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment