Skip to content

Commit

Permalink
Simplify projections and custom projection
Browse files Browse the repository at this point in the history
  • Loading branch information
kataklinger committed Mar 12, 2024
1 parent 4a9ad74 commit ecb04d9
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 179 deletions.
18 changes: 10 additions & 8 deletions sample/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void nsga() {
return gal::euclidean_distance(c1, c2);
}})
.prune(prune::none{})
.project(project::factory<project::scale, int_rank_t>{})
.project(project::scale<int_rank_t>{})
.select(select::roulette_scaled{select::unique<4>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
.replace(replace::total{})
Expand Down Expand Up @@ -202,7 +202,7 @@ void nsga_ii() {
.cluster(cluster::none{})
.crowd(crowd::distance{})
.prune(prune::global_worst<int_rank_t>{})
.project(project::factory<project::merge, int_rank_t>{})
.project(project::merge<int_rank_t>{})
.select(
select::tournament_scaled{select::unique<2>, select::rounds<2>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
Expand Down Expand Up @@ -239,7 +239,9 @@ void spea() {
.cluster(cluster::linkage{})
.crowd(crowd::none{})
.prune(prune::cluster_edge{})
.project(project::factory<project::truncate, real_rank_t>{})
.project(project::custom{[](auto const& ind) {
return 1. + gal::get_tag<real_rank_t>(ind).get();
}})
.select(select::roulette_scaled{select::unique<4>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
.replace(replace::append{})
Expand Down Expand Up @@ -275,7 +277,7 @@ void spea_ii() {
.cluster(cluster::none{})
.crowd(crowd::neighbor{})
.prune(prune::global_worst<int_rank_t>{})
.project(project::factory<project::translate, int_rank_t>{})
.project(project::translate<int_rank_t>{})
.select(select::roulette_scaled{select::unique<4>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
.replace(replace::append{})
Expand Down Expand Up @@ -311,7 +313,7 @@ void rdga() {
.cluster(cluster::adaptive_hypergrid<10, 10>{})
.crowd(crowd::cluster{})
.prune(prune::none{})
.project(project::factory<project::alternate, int_rank_t>{})
.project(project::alternate<int_rank_t>{})
.select(select::roulette_scaled{select::unique<4>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
.replace(replace::nondominating_parents_raw{})
Expand Down Expand Up @@ -351,7 +353,7 @@ void pesa() {
.cluster(cluster::hypergrid<std::array<double, 2>, 0.1_dc, 0.1_dc>{})
.crowd(crowd::cluster{})
.prune(prune::cluster_random{rng})
.project(project::factory<project::truncate, crowd_density_t>{})
.project(project::truncate<crowd_density_t>{})
.select(select::roulette_scaled{select::unique<4>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
.replace(replace::append{})
Expand Down Expand Up @@ -387,7 +389,7 @@ void pesa_ii() {
.cluster(cluster::hypergrid<std::array<double, 2>, 0.1_dc, 0.1_dc>{})
.crowd(crowd::none{})
.prune(prune::cluster_random{rng})
.project(project::factory<project::none>{})
.project(project::none{})
.select(select::cluster{
gal::select::shared<cluster_label>, select::unique<4>, rng})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
Expand Down Expand Up @@ -429,7 +431,7 @@ void paes() {
.cluster(cluster::adaptive_hypergrid<10, 10>{})
.crowd(crowd::cluster{})
.prune(prune::cluster_random{rng})
.project(project::factory<project::truncate, crowd_density_t>{})
.project(project::truncate<crowd_density_t>{})
.select(select::lineal_scaled{})
.couple(couple::parametrize<couple::exclusive, 0.8_fc, 0.2_fc, true>(rng))
.replace(replace::append{})
Expand Down
19 changes: 9 additions & 10 deletions src/inc/configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,22 +634,22 @@ struct select_ptype : public details::ptype_base<Built, select_ptype> {
}
};

template<typename Factory, typename Context>
template<typename Projection>
class project_body {
public:
using projection_t = operation_factory_result_t<Factory, Context>;
using projection_t = Projection;

public:
inline constexpr explicit project_body(Factory const& projection)
inline constexpr explicit project_body(projection_t const& projection)
: projection_{projection} {
}

inline auto projection(Context& context) const {
return projection_(context);
inline auto const& projection() const {
return projection_;
}

private:
Factory projection_;
projection_t projection_;
};

template<typename Built>
Expand All @@ -662,10 +662,9 @@ struct project_ptype : public details::ptype_base<Built, project_ptype> {
: details::ptype_base<Built, project_ptype>{current} {
}

template<
projection_factory<population_context_t, pareto_preservance_t> Factory>
inline constexpr auto project(Factory const& projection) const {
return this->next(project_body<Factory, population_context_t>{projection});
template<projection<population_context_t, pareto_preservance_t> Projection>
inline constexpr auto project(Projection const& projection) const {
return this->next(project_body<Projection>{projection});
}
};

Expand Down
13 changes: 5 additions & 8 deletions src/inc/moo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,15 @@ concept algo_config = basic_algo_config<Config> && requires(Config c) {
typename Config::pareto_preservance_t>;
requires pruning<typename Config::pruning_t, typename Config::population_t>;
requires projection<typename Config::projection_t,
typename Config::population_t,
typename Config::population_context_t,
typename Config::pareto_preservance_t>;

{ c.ranking() } -> std::convertible_to<typename Config::ranking_t>;
{ c.elitism() } -> std::convertible_to<typename Config::elitism_t>;
{ c.clustering() } -> std::convertible_to<typename Config::clustering_t>;
{ c.crowding() } -> std::convertible_to<typename Config::crowding_t>;
{ c.pruning() } -> std::convertible_to<typename Config::pruning_t>;

{
c.projection(std::declval<typename Config::population_context_t&>())
} -> std::convertible_to<typename Config::projection_t>;
{ c.projection() } -> std::convertible_to<typename Config::projection_t>;
};

template<algo_config Config>
Expand Down Expand Up @@ -120,7 +117,6 @@ class algo {
config_.evaluator()};

auto coupler = config_.coupling(reproduction);
auto projector = config_.projection(ctx);

auto* statistics = &init();
while (!token.stop_requested() &&
Expand All @@ -139,7 +135,7 @@ class algo {

prune(config_.pruning(), *statistics);

project(projector, fronts, clusters, *statistics);
project(config_.projection(), ctx, fronts, clusters, *statistics);

auto selected = select(*statistics);
stats::count_range(*statistics, selection_count_tag, selected);
Expand Down Expand Up @@ -211,11 +207,12 @@ class algo {

template<typename Projection>
inline void project(Projection const& operation,
population_context_t& ctx,
pareto_t& sets,
cluster_set& clusters,
statistics_t& current) {
[[maybe_unused]] auto timer = stats::start_timer(current, project_time_tag);
return std::invoke(operation, sets, clusters);
return std::invoke(operation, ctx, sets, clusters);
}

inline auto select(statistics_t& current) {
Expand Down
14 changes: 4 additions & 10 deletions src/inc/multiobjective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,18 +421,12 @@ namespace details {
template<typename Population, typename Pruning>
struct pruning_traits : details::pruning_helper<Population, Pruning> {};

template<typename Operation, typename Population, typename Preserved>
template<typename Operation, typename Context, typename Preserved>
concept projection = std::invocable<
Operation,
std::add_lvalue_reference_t<population_pareto_t<Population, Preserved>>,
std::add_lvalue_reference_t<Context>,
std::add_lvalue_reference_t<
population_pareto_t<typename Context::population_t, Preserved>>,
cluster_set const&>;

template<typename Factory, typename Context, typename Preserved>
concept projection_factory =
std::is_invocable_v<std::add_const_t<Factory>,
std::add_lvalue_reference_t<Context>> &&
projection<operation_factory_result_t<Factory, Context>,
typename Context::population_t,
Preserved>;

} // namespace gal
Loading

0 comments on commit ecb04d9

Please sign in to comment.