Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite normalize action #441

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 67 additions & 80 deletions beluga/include/beluga/actions/normalize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <algorithm>
#include <execution>
#include <optional>

#include <range/v3/action/action.hpp>
#include <range/v3/numeric/accumulate.hpp>
Expand All @@ -34,120 +35,106 @@ namespace beluga::actions {

namespace detail {

/// Implementation detail for a normalize range adaptor object.
struct normalize_base_fn {
/// Overload that implements the normalize algorithm.
/**
* \tparam ExecutionPolicy An [execution policy](https://en.cppreference.com/w/cpp/algorithm/execution_policy_tag_t).
* \tparam Range An [input range](https://en.cppreference.com/w/cpp/ranges/input_range).
* \param policy The execution policy to use.
* \param range An existing range to apply this action to.
* \param factor The normalization factor.
*/
template <
class ExecutionPolicy,
class Range,
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, double factor) const -> Range& {
if (std::abs(factor - 1.0) < std::numeric_limits<double>::epsilon()) {
return range; // No change.
}
/// \cond detail

template <class ExecutionPolicy = std::execution::sequenced_policy>
struct normalize_closure {
public:
static_assert(std::is_execution_policy_v<ExecutionPolicy>);

constexpr normalize_closure() noexcept : policy_{std::execution::seq} {}

constexpr explicit normalize_closure(ExecutionPolicy policy) : policy_{std::move(policy)} {}

constexpr explicit normalize_closure(double factor) noexcept : policy_{std::execution::seq}, factor_{factor} {}

auto weights = [&range]() {
constexpr normalize_closure(ExecutionPolicy policy, double factor) : policy_{std::move(policy)}, factor_{factor} {}

template <class Range>
constexpr auto operator()(Range& range) const -> Range& {
static_assert(ranges::forward_range<Range>);

auto weights = std::invoke([&range]() {
if constexpr (beluga::is_particle_range_v<Range>) {
return range | beluga::views::weights | ranges::views::common;
} else {
return range | ranges::views::common;
}
}();
});

const double factor = std::invoke([this, weights]() {
if (factor_.has_value()) {
return factor_.value();
}

return ranges::accumulate(weights, 0.0); // The default normalization factor is the total sum of weights.
});

if (std::abs(factor - 1.0) < std::numeric_limits<double>::epsilon()) {
return range; // No change.
}

std::transform(
policy, //
std::begin(weights), //
std::end(weights), //
std::begin(weights), //
policy_, //
weights.begin(), //
weights.end(), //
weights.begin(), //
[factor](const auto w) { return w / factor; });

return range;
}

/// Overload that uses a default normalization factor.
/**
* The default normalization factor is the total sum of weights.
*/
private:
ExecutionPolicy policy_{};
std::optional<double> factor_;
};

struct normalize_fn {
template <
class ExecutionPolicy,
class Range,
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range) const -> Range& {
auto weights = [&range]() {
if constexpr (beluga::is_particle_range_v<Range>) {
return range | beluga::views::weights | ranges::views::common;
} else {
return range | ranges::views::common;
}
}();

const double total_weight = ranges::accumulate(weights, 0.0);
return (*this)(std::forward<ExecutionPolicy>(policy), range, total_weight);
constexpr auto operator()(ExecutionPolicy&& policy, Range& range, double factor) const -> Range& {
return normalize_closure{std::forward<ExecutionPolicy>(policy), factor}(range);
}

/// Overload that re-orders arguments from an action closure.
template <
class Range,
class ExecutionPolicy,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(Range&& range, double factor, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range), factor);
}

/// Overload that re-orders arguments from an action closure.
template <
class Range,
class ExecutionPolicy,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(Range&& range, ExecutionPolicy policy) const -> Range& {
return (*this)(std::move(policy), std::forward<Range>(range));
}

/// Overload that returns an action closure to compose with other actions.
template <class ExecutionPolicy, std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(ExecutionPolicy policy, double factor) const {
return ranges::make_action_closure(ranges::bind_back(normalize_base_fn{}, factor, std::move(policy)));
std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, Range& range) const -> Range& {
return normalize_closure{std::forward<ExecutionPolicy>(policy)}(range);
}

/// Overload that returns an action closure to compose with other actions.
template <class ExecutionPolicy, std::enable_if_t<std::is_execution_policy_v<ExecutionPolicy>, int> = 0>
constexpr auto operator()(ExecutionPolicy policy) const {
return ranges::make_action_closure(ranges::bind_back(normalize_base_fn{}, std::move(policy)));
template <class Range, std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range& range, double factor) const -> Range& {
return normalize_closure{factor}(range);
}
};

/// Implementation detail for a normalize range adaptor object with a default execution policy.
struct normalize_fn : public normalize_base_fn {
using normalize_base_fn::operator();

/// Overload that defines a default execution policy.
template <class Range, std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range, double factor) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range), factor);
constexpr auto operator()(Range& range) const -> Range& {
return normalize_closure{}(range);
}

/// Overload that defines a default execution policy.
template <class Range, std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range) const -> Range& {
return (*this)(std::execution::seq, std::forward<Range>(range));
template <class ExecutionPolicy, std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy, double factor) const {
return ranges::actions::action_closure{normalize_closure{std::forward<ExecutionPolicy>(policy), factor}};
}

/// Overload that returns an action closure to compose with other actions.
constexpr auto operator()(double factor) const {
return ranges::make_action_closure(ranges::bind_back(normalize_fn{}, factor));
template <class ExecutionPolicy, std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0>
constexpr auto operator()(ExecutionPolicy&& policy) const {
return ranges::actions::action_closure{normalize_closure{std::forward<ExecutionPolicy>(policy)}};
}

constexpr auto operator()(double factor) const { return ranges::actions::action_closure{normalize_closure{factor}}; }

constexpr auto operator()() const { return ranges::actions::action_closure{normalize_closure{}}; }
};

/// \endcond

} // namespace detail

/// [Range adaptor object](https://en.cppreference.com/w/cpp/named_req/RangeAdaptorObject) that
Expand Down
6 changes: 6 additions & 0 deletions beluga/test/beluga/actions/test_normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ TEST(NormalizeAction, ZeroFactor) {
ASSERT_TRUE(std::isinf(beluga::weight(input.front())));
}

TEST(NormalizeAction, OneFactor) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(4.0))};
input |= beluga::actions::normalize(1.0);
ASSERT_EQ(input.front(), std::make_tuple(5, beluga::Weight(4.0)));
}

TEST(NormalizeAction, NegativeFactor) {
auto input = std::vector{std::make_tuple(5, beluga::Weight(4.0))};
input |= beluga::actions::normalize(-2.0);
Expand Down
Loading