diff --git a/beluga/include/beluga/actions/normalize.hpp b/beluga/include/beluga/actions/normalize.hpp index f47a7ee9d..14c54a52a 100644 --- a/beluga/include/beluga/actions/normalize.hpp +++ b/beluga/include/beluga/actions/normalize.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -34,120 +35,106 @@ namespace beluga::actions { namespace detail { -/// Implementation detail for a normalize range adaptor object. -struct normalize_base_fn { - /// Overload that implements the normalize algorithm. - /** - * \tparam ExecutionPolicy An [execution policy](https://en.cppreference.com/w/cpp/algorithm/execution_policy_tag_t). - * \tparam Range An [input range](https://en.cppreference.com/w/cpp/ranges/input_range). - * \param policy The execution policy to use. - * \param range An existing range to apply this action to. - * \param factor The normalization factor. - */ - template < - class ExecutionPolicy, - class Range, - std::enable_if_t>, int> = 0, - std::enable_if_t, int> = 0> - constexpr auto operator()(ExecutionPolicy&& policy, Range& range, double factor) const -> Range& { - if (std::abs(factor - 1.0) < std::numeric_limits::epsilon()) { - return range; // No change. - } +/// \cond detail + +template +struct normalize_closure { + public: + static_assert(std::is_execution_policy_v); + + constexpr normalize_closure() noexcept : policy_{std::execution::seq} {} + + constexpr explicit normalize_closure(ExecutionPolicy policy) : policy_{std::move(policy)} {} + + constexpr explicit normalize_closure(double factor) noexcept : policy_{std::execution::seq}, factor_{factor} {} - auto weights = [&range]() { + constexpr normalize_closure(ExecutionPolicy policy, double factor) : policy_{std::move(policy)}, factor_{factor} {} + + template + constexpr auto operator()(Range& range) const -> Range& { + static_assert(ranges::forward_range); + + auto weights = std::invoke([&range]() { if constexpr (beluga::is_particle_range_v) { return range | beluga::views::weights | ranges::views::common; } else { return range | ranges::views::common; } - }(); + }); + + const double factor = std::invoke([this, weights]() { + if (factor_.has_value()) { + return factor_.value(); + } + + return ranges::accumulate(weights, 0.0); // The default normalization factor is the total sum of weights. + }); + + if (std::abs(factor - 1.0) < std::numeric_limits::epsilon()) { + return range; // No change. + } std::transform( - policy, // - std::begin(weights), // - std::end(weights), // - std::begin(weights), // + policy_, // + weights.begin(), // + weights.end(), // + weights.begin(), // [factor](const auto w) { return w / factor; }); + return range; } - /// Overload that uses a default normalization factor. - /** - * The default normalization factor is the total sum of weights. - */ + private: + ExecutionPolicy policy_{}; + std::optional factor_; +}; + +struct normalize_fn { template < class ExecutionPolicy, class Range, std::enable_if_t>, int> = 0, std::enable_if_t, int> = 0> - constexpr auto operator()(ExecutionPolicy&& policy, Range& range) const -> Range& { - auto weights = [&range]() { - if constexpr (beluga::is_particle_range_v) { - return range | beluga::views::weights | ranges::views::common; - } else { - return range | ranges::views::common; - } - }(); - - const double total_weight = ranges::accumulate(weights, 0.0); - return (*this)(std::forward(policy), range, total_weight); + constexpr auto operator()(ExecutionPolicy&& policy, Range& range, double factor) const -> Range& { + return normalize_closure{std::forward(policy), factor}(range); } - /// Overload that re-orders arguments from an action closure. template < - class Range, class ExecutionPolicy, - std::enable_if_t, int> = 0, - std::enable_if_t, int> = 0> - constexpr auto operator()(Range&& range, double factor, ExecutionPolicy policy) const -> Range& { - return (*this)(std::move(policy), std::forward(range), factor); - } - - /// Overload that re-orders arguments from an action closure. - template < class Range, - class ExecutionPolicy, - std::enable_if_t, int> = 0, - std::enable_if_t, int> = 0> - constexpr auto operator()(Range&& range, ExecutionPolicy policy) const -> Range& { - return (*this)(std::move(policy), std::forward(range)); - } - - /// Overload that returns an action closure to compose with other actions. - template , int> = 0> - constexpr auto operator()(ExecutionPolicy policy, double factor) const { - return ranges::make_action_closure(ranges::bind_back(normalize_base_fn{}, factor, std::move(policy))); + std::enable_if_t>, int> = 0, + std::enable_if_t, int> = 0> + constexpr auto operator()(ExecutionPolicy&& policy, Range& range) const -> Range& { + return normalize_closure{std::forward(policy)}(range); } - /// Overload that returns an action closure to compose with other actions. - template , int> = 0> - constexpr auto operator()(ExecutionPolicy policy) const { - return ranges::make_action_closure(ranges::bind_back(normalize_base_fn{}, std::move(policy))); + template , int> = 0> + constexpr auto operator()(Range& range, double factor) const -> Range& { + return normalize_closure{factor}(range); } -}; - -/// Implementation detail for a normalize range adaptor object with a default execution policy. -struct normalize_fn : public normalize_base_fn { - using normalize_base_fn::operator(); - /// Overload that defines a default execution policy. template , int> = 0> - constexpr auto operator()(Range&& range, double factor) const -> Range& { - return (*this)(std::execution::seq, std::forward(range), factor); + constexpr auto operator()(Range& range) const -> Range& { + return normalize_closure{}(range); } - /// Overload that defines a default execution policy. - template , int> = 0> - constexpr auto operator()(Range&& range) const -> Range& { - return (*this)(std::execution::seq, std::forward(range)); + template >, int> = 0> + constexpr auto operator()(ExecutionPolicy&& policy, double factor) const { + return ranges::actions::action_closure{normalize_closure{std::forward(policy), factor}}; } - /// Overload that returns an action closure to compose with other actions. - constexpr auto operator()(double factor) const { - return ranges::make_action_closure(ranges::bind_back(normalize_fn{}, factor)); + template >, int> = 0> + constexpr auto operator()(ExecutionPolicy&& policy) const { + return ranges::actions::action_closure{normalize_closure{std::forward(policy)}}; } + + constexpr auto operator()(double factor) const { return ranges::actions::action_closure{normalize_closure{factor}}; } + + constexpr auto operator()() const { return ranges::actions::action_closure{normalize_closure{}}; } }; +/// \endcond + } // namespace detail /// [Range adaptor object](https://en.cppreference.com/w/cpp/named_req/RangeAdaptorObject) that diff --git a/beluga/test/beluga/actions/test_normalize.cpp b/beluga/test/beluga/actions/test_normalize.cpp index 1e529553c..49525db4b 100644 --- a/beluga/test/beluga/actions/test_normalize.cpp +++ b/beluga/test/beluga/actions/test_normalize.cpp @@ -87,6 +87,12 @@ TEST(NormalizeAction, ZeroFactor) { ASSERT_TRUE(std::isinf(beluga::weight(input.front()))); } +TEST(NormalizeAction, OneFactor) { + auto input = std::vector{std::make_tuple(5, beluga::Weight(4.0))}; + input |= beluga::actions::normalize(1.0); + ASSERT_EQ(input.front(), std::make_tuple(5, beluga::Weight(4.0))); +} + TEST(NormalizeAction, NegativeFactor) { auto input = std::vector{std::make_tuple(5, beluga::Weight(4.0))}; input |= beluga::actions::normalize(-2.0);