Skip to content

Commit

Permalink
Fix sample view (#288)
Browse files Browse the repository at this point in the history
Make sure to set weight equal to 1 after sampling from a particle set.

Signed-off-by: Nahuel Espinosa <nespinosa@ekumenlabs.com>
Co-authored-by: Michel Hidalgo <michel@ekumenlabs.com>
  • Loading branch information
nahueespinosa and hidmic authored Jan 15, 2024
1 parent a177517 commit da5b462
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 29 deletions.
14 changes: 3 additions & 11 deletions beluga/include/beluga/tuple_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,11 @@ class TupleContainer<InternalContainer, std::tuple<Types...>> {
}

[[nodiscard]] constexpr auto all() const {
return std::apply(
[](auto&... containers) { //
return beluga::views::zip(containers...) | ranges::views::const_;
},
sequences_);
return std::apply([](auto&... containers) { return beluga::views::zip(containers...); }, sequences_);
}

[[nodiscard]] constexpr auto all() {
return std::apply(
[](auto&... containers) { //
return beluga::views::zip(containers...);
},
sequences_);
return std::apply([](auto&... containers) { return beluga::views::zip(containers...); }, sequences_);
}
};

Expand All @@ -235,7 +227,7 @@ class TupleVector : public TupleContainer<Vector, T> {

/// Deduction guide to construct from iterators.
template <class I, class S, typename = std::enable_if_t<ranges::input_iterator<I> && ranges::input_iterator<S>>>
TupleVector(I, S) -> TupleVector<std_tuple_decay_t<ranges::iter_value_t<I>>>;
TupleVector(I, S) -> TupleVector<decay_tuple_like_t<ranges::iter_value_t<I>>>;

} // namespace beluga

Expand Down
17 changes: 14 additions & 3 deletions beluga/include/beluga/type_traits/particle_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define BELUGA_TYPE_TRAITS_PARTICLE_TRAITS_HPP

#include <beluga/primitives.hpp>
#include <beluga/type_traits/tuple_traits.hpp>
#include <range/v3/range/traits.hpp>

/**
Expand Down Expand Up @@ -79,15 +80,25 @@ inline constexpr bool is_particle_range_v = is_particle_v<ranges::range_value_t<
/// Returns a new particle from the given state.
/**
* \tparam Particle The particle type to be used.
* \tparam T The particle state type.
* T must be convertible to `state_t<Particle>`.
* \tparam T The particle state type. T must be convertible to `state_t<Particle>`.
* \param value The state to make the particle from.
* \return The new particle, created from the given state.
*
* The new particle will have a weight equal to 1.
*/
template <class Particle, class T = state_t<Particle>>
constexpr auto make_from_state(T value) {
static_assert(is_particle_v<Particle>);
auto particle = Particle{};
auto particle = []() {
if constexpr (is_tuple_like_v<Particle>) {
// Support for zipped ranges composed with views that don't
// propagate the tuple value type of the original range
// (ranges::views::const_).
return decay_tuple_like_t<Particle>{};
} else {
return Particle{};
}
}();
beluga::state(particle) = std::move(value);
beluga::weight(particle) = 1.0;
return particle;
Expand Down
14 changes: 7 additions & 7 deletions beluga/include/beluga/type_traits/tuple_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,20 @@ constexpr decltype(auto) element(TupleLike&& tuple) noexcept {
return std::get<kIndex>(std::forward<TupleLike>(tuple));
}

/// Meta-function that returns a std::tuple with decayed types given a tuple-like type.
/// Meta-function that decays a tuple like type and its members.
template <class T, class = void>
struct std_tuple_decay;
struct decay_tuple_like;

/// `std_tuple_decay` specialization for tuples.
/// `decay_tuple_like` specialization for tuples.
template <template <class...> class TupleLike, class... Args>
struct std_tuple_decay<TupleLike<Args...>, std::enable_if_t<is_tuple_like_v<std::decay_t<TupleLike<Args...>>>>> {
struct decay_tuple_like<TupleLike<Args...>, std::enable_if_t<is_tuple_like_v<std::decay_t<TupleLike<Args...>>>>> {
/// Return type.
using type = std::tuple<std::decay_t<Args>...>;
using type = std::decay_t<TupleLike<std::decay_t<Args>...>>;
};

/// Convenience template type alias for `std_tuple_decay`.
/// Convenience template type alias for `decay_tuple_like`.
template <class T>
using std_tuple_decay_t = typename std_tuple_decay<T>::type;
using decay_tuple_like_t = typename decay_tuple_like<T>::type;

} // namespace beluga

Expand Down
12 changes: 11 additions & 1 deletion beluga/include/beluga/views/sample.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,19 @@ struct sample_fn {
}

/// Overload that handles particle ranges.
/**
* The new particles will all have a weight equal to 1, since, after resampling, the probability
* will be represented by the number of particles rather than their individual weight.
*/
template <
class Range,
class URNG = typename ranges::detail::default_random_engine,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<is_particle_range_v<Range>, int> = 0,
std::enable_if_t<!ranges::range<URNG>, int> = 0>
constexpr auto operator()(Range&& range, URNG& engine = ranges::detail::get_random_engine()) const {
return (*this)(ranges::views::all(range), beluga::views::weights(range), engine);
return (*this)(beluga::views::states(range), beluga::views::weights(range), engine) |
ranges::views::transform(beluga::make_from_state<ranges::range_value_t<Range>>);
}

/// Overload that unwraps the engine reference from a view closure.
Expand All @@ -183,6 +188,11 @@ struct sample_fn {
* To use this, the input range must model the
* [random_access_range](https://en.cppreference.com/w/cpp/ranges/random_access_range)
* and [sized_range](https://en.cppreference.com/w/cpp/ranges/sized_range) concepts.
*
* This view implements multinomial resampling for a given range of particles.
* The core idea is to draw random indices / iterators to the input particle range
* from a [multinomial distribution](https://en.wikipedia.org/wiki/Multinomial_distribution)
* parameterized after particle weights (and assumed uniform for non-weighted particle ranges).
*/
inline constexpr ranges::views::view_closure<detail::sample_fn> sample;

Expand Down
17 changes: 17 additions & 0 deletions beluga/test/beluga/test_tuple_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ TEST(TupleVectorTest, ConceptChecks) {

TEST(TupleVectorTest, TraitConsistency) {
using Container = beluga::TupleVector<std::tuple<float, int>>;
using ConstContainer = decltype(std::declval<const Container&>());
using ConstView = decltype(std::declval<const Container&>() | ranges::views::const_);
using Iterator = decltype(ranges::begin(std::declval<Container&>()));

// Expected types
Expand All @@ -177,6 +179,21 @@ TEST(TupleVectorTest, TraitConsistency) {
static_assert(std::is_same_v<
ranges::range_rvalue_reference_t<Container>, //
ranges::common_tuple<float&&, int&&>>);
static_assert(std::is_same_v<
ranges::range_value_t<ConstContainer>, //
std::tuple<float, int>>);
static_assert(std::is_same_v<
ranges::range_reference_t<ConstContainer>, //
ranges::common_tuple<const float&, const int&>>);
static_assert(std::is_same_v<
ranges::range_rvalue_reference_t<ConstContainer>, //
ranges::common_tuple<const float&&, const int&&>>);

// Expected value type of a const view would be the same as the value type of the
// adapted container (std::tuple<float, int>)... This is not the case. :(
static_assert(std::is_same_v<
ranges::range_value_t<ConstView>, //
ranges::common_tuple<const float&, const int&>>);

// Consistency
static_assert(std::is_same_v<
Expand Down
25 changes: 18 additions & 7 deletions beluga/test/beluga/views/test_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

#include "beluga/views/sample.hpp"

#include <beluga/tuple_vector.hpp>

#include <range/v3/algorithm/count.hpp>
#include <range/v3/algorithm/find.hpp>
#include <range/v3/view/const.hpp>
#include <range/v3/view/take_exactly.hpp>

namespace {
Expand Down Expand Up @@ -66,9 +69,10 @@ TEST(SampleView, DiscreteDistributionSingleElement) {
}

TEST(SampleView, DiscreteDistributionSingleElementFromParticleRange) {
auto input = std::array{std::make_tuple(5, beluga::Weight(1.0))};
auto output = input | beluga::views::sample | ranges::views::take_exactly(20) | beluga::views::states;
ASSERT_EQ(ranges::count(output, 5), 20);
auto input = std::array{std::make_tuple(5, beluga::Weight(5.0))};
auto output = input | beluga::views::sample | ranges::views::take_exactly(20);
ASSERT_EQ(ranges::count(output | beluga::views::states, 5), 20);
ASSERT_EQ(ranges::count(output | beluga::views::weights, beluga::Weight(1.0)), 20);
}

TEST(SampleView, DiscreteDistributionWeightZero) {
Expand Down Expand Up @@ -122,14 +126,21 @@ TEST(SampleView, EngineArgument) {
TEST(SampleView, DiscreteDistributionProbability) {
const auto size = 100'000;

auto input = std::array{1, 2, 3, 4};
auto weights = std::array{0.3, 0.1, 0.4, 0.2};
const auto input = beluga::TupleVector<std::tuple<int, beluga::Weight>>{
std::make_tuple(1, beluga::Weight(0.3)), //
std::make_tuple(2, beluga::Weight(0.1)), //
std::make_tuple(3, beluga::Weight(0.4)), //
std::make_tuple(4, beluga::Weight(0.2))};

auto output = beluga::views::sample(input, weights) | ranges::views::take_exactly(size);
auto output = input | //
ranges::views::const_ | //
beluga::views::sample | //
ranges::views::take_exactly(size);

std::unordered_map<int, std::size_t> buckets;
for (auto value : output) {
for (auto [value, weight] : output) {
++buckets[value];
ASSERT_EQ(weight, 1.0);
}

ASSERT_EQ(ranges::size(buckets), 4);
Expand Down

0 comments on commit da5b462

Please sign in to comment.