Skip to content

Commit

Permalink
Merge pull request #14 from mrc-ide/mrc-5334
Browse files Browse the repository at this point in the history
Compute likelihood via particle filter
  • Loading branch information
weshinsley authored May 21, 2024
2 parents fb628d3 + 920dc6c commit 49624b8
Show file tree
Hide file tree
Showing 12 changed files with 518 additions and 10 deletions.
16 changes: 16 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, r_initial, grouped) {
.Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, r_initial, grouped)
}

dust2_cpu_sir_filter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_seed) {
.Call(`_dust2_dust2_cpu_sir_filter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_seed)
}

dust2_cpu_sir_filter_run <- function(ptr, r_pars, grouped) {
.Call(`_dust2_dust2_cpu_sir_filter_run`, ptr, r_pars, grouped)
}

dust2_cpu_sir_filter_rng_state <- function(ptr) {
.Call(`_dust2_dust2_cpu_sir_filter_rng_state`, ptr)
}

test_resample_weight <- function(w, u) {
.Call(`_dust2_test_resample_weight`, w, u)
}

dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
.Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic)
}
Expand Down
99 changes: 99 additions & 0 deletions inst/include/dust2/filter.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <dust2/cpu.hpp>
#include <dust2/filter_details.hpp>
#include <mcstate/random/random.hpp>

namespace dust2 {

Expand Down Expand Up @@ -69,4 +71,101 @@ class unfilter {
std::vector<real_type> ll_step_;
};

template <typename T>
class filter {
public:
using real_type = typename T::real_type;
using data_type = typename T::data_type;
using rng_state_type = typename T::rng_state_type;
using rng_int_type = typename rng_state_type::int_type;

dust_cpu<T> model;

filter(dust_cpu<T> model_,
real_type time_start,
std::vector<real_type> time,
std::vector<data_type> data,
const std::vector<rng_int_type>& seed) :
model(model_),
time_start_(time_start),
time_(time),
data_(data),
n_particles_(model.n_particles()),
n_groups_(model.n_groups()),
rng_(n_groups_, seed, false),
ll_(n_groups_ * n_particles_, 0),
ll_step_(n_groups_ * n_particles_, 0) {
// TODO: duplicated with the above, can be done generically though
// it's not a lot of code.
const auto dt = model_.dt();
for (size_t i = 0; i < time_.size(); i++) {
const auto t0 = i == 0 ? time_start_ : time_[i - 1];
const auto t1 = time_[i];
step_.push_back(static_cast<size_t>(std::round((t1 - t0) / dt)));
}
}

void run() {
const auto n_times = step_.size();

model.set_time(time_start_);
model.set_state_initial();
std::fill(ll_.begin(), ll_.end(), 0);

// Just store this here; later once we have state to save we can
// probably use that vector instead.
std::vector<size_t> index(n_particles_ * n_groups_);

auto it_data = data_.begin();
for (size_t i = 0; i < n_times; ++i, it_data += n_groups_) {
model.run_steps(step_[i]);
model.compare_data(it_data, ll_step_.begin());

for (size_t i = 0; i < n_groups_; ++i) {
const auto offset = i * n_particles_;
const auto w = ll_step_.begin() + offset;
ll_[i] += details::scale_log_weights<real_type>(n_particles_, w);
}

// early exit here, once enabled, setting and checking a
// threshhold log likelihood below which we're uninterested;
// this requires some fiddling.

// This can be parallelised across groups
for (size_t i = 0; i < n_groups_; ++i) {
const auto offset = i * n_particles_;
const auto w = ll_step_.begin() + offset;
const auto idx = index.begin() + offset;
const auto u = mcstate::random::random_real<real_type>(rng_.state(i));
details::resample_weight(n_particles_, w, u, idx);
}

model.reorder(index.begin());

// save trajectories (perhaps)
// save snapshots (perhaps)
}
}

template <typename It>
void last_log_likelihood(It it) {
std::copy_n(ll_.begin(), n_groups_, it);
}

auto rng_state() { // TODO: should be const, error in mcstate2
return rng_.export_state();
}

private:
real_type time_start_;
std::vector<real_type> time_;
std::vector<size_t> step_;
std::vector<data_type> data_;
size_t n_particles_;
size_t n_groups_;
mcstate::random::prng<rng_state_type> rng_;
std::vector<real_type> ll_;
std::vector<real_type> ll_step_;
};

}
78 changes: 78 additions & 0 deletions inst/include/dust2/filter_details.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#pragma once

#include <cmath>
#include <limits>
#include <numeric>
#include <vector>

namespace dust2 {
namespace details {

template <typename real_type>
real_type scale_log_weights(size_t n, typename std::vector<real_type>::iterator w) {
if (n == 1) {
return *w;
}
constexpr auto neg_infinity = -std::numeric_limits<real_type>::infinity();
real_type max_w = neg_infinity;
auto wi = w;
for (size_t i = 0; i < n; ++i, ++wi) {
if (std::isnan(*wi)) {
*wi = neg_infinity;
} else {
max_w = std::max(max_w, *wi);
}
}

if (max_w == neg_infinity) {
return max_w;
}

real_type tot = 0.0;
wi = w;
for (size_t i = 0; i < n; ++i, ++wi) {
*wi = std::exp(*wi - max_w);
tot += *wi;
}

return std::log(tot / n) + max_w;
}


// This is a nasty bit of bookkeeping for the "systematic" resample,
// which is one of several options (we may allow configuration of this
// later, which will present its own challenges of course).
template <typename real_type>
void resample_weight(size_t n_particles,
typename std::vector<real_type>::const_iterator w,
const real_type u,
typename std::vector<size_t>::iterator idx) {
const auto tot = std::accumulate(w, w + n_particles,
static_cast<real_type>(0));
const auto uu0 = tot * u / n_particles;
const auto du = tot / n_particles;

real_type ww = 0.0;
size_t j = 0;
for (size_t i = 0; i < n_particles; ++i) {
const real_type uu = uu0 + i * du;
// The second clause (i.e., j < n_particles) should never be hit but
// prevents any invalid read if we have pathalogical 'u' that is
// within floating point eps of 1 - solve this instead by passing
// w as begin/end pair, something that only happens in single precision:
//
// https://github.com/reside-ic/reside-ic.github.io/pull/60/files
// https://github.com/mrc-ide/dust/pull/238
while (ww < uu && j < n_particles) {
ww += *w;
++w;
++j;
}
*idx = j == 0 ? 0 : j - 1;
++idx;
}
}


}
}
12 changes: 3 additions & 9 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars,
const auto time = check_time(r_time, "time");
const auto dt = check_dt(r_dt);

auto n_particles = to_size(r_n_particles, "n_particles");
auto n_groups = to_size(r_n_groups, "n_groups");
const auto n_particles = to_size(r_n_particles, "n_particles");
const auto n_groups = to_size(r_n_groups, "n_groups");

const auto shared = build_shared<T>(r_pars, n_groups);
// Later, we need one of these per thread
Expand Down Expand Up @@ -144,14 +144,8 @@ SEXP dust2_cpu_reorder(cpp11::sexp ptr, cpp11::integers r_index) {

template <typename T>
SEXP dust2_cpu_rng_state(cpp11::sexp ptr) {
using rng_state_type = typename T::rng_state_type;
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();

const auto state = obj->rng_state();
const auto len = sizeof(typename rng_state_type::int_type) * state.size();
cpp11::writable::raws ret(len);
std::memcpy(RAW(ret), state.data(), len);
return ret;
return rng_state_as_raw(obj->rng_state());
}

template <typename T>
Expand Down
125 changes: 124 additions & 1 deletion inst/include/dust2/r/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars,

auto n_particles = to_size(r_n_particles, "n_particles");
auto n_groups = to_size(r_n_groups, "n_groups");
const bool grouped = n_groups > 0;
const auto grouped = n_groups > 0;
const auto time_start = check_time(r_time_start, "time_start");
const auto time = check_time_sequence<real_type>(time_start, r_time, "time");
const auto dt = check_dt(r_dt);
Expand Down Expand Up @@ -78,5 +78,128 @@ cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
return ret;
}

template <typename T>
cpp11::sexp dust2_cpu_filter_alloc(cpp11::list r_pars,
cpp11::sexp r_time_start,
cpp11::sexp r_time,
cpp11::sexp r_dt,
cpp11::list r_data,
cpp11::sexp r_n_particles,
cpp11::sexp r_n_groups,
cpp11::sexp r_seed) {
using real_type = typename T::real_type;
using rng_state_type = typename T::rng_state_type;
using rng_seed_type = std::vector<typename rng_state_type::int_type>;

const auto n_particles = to_size(r_n_particles, "n_particles");
const auto n_groups = to_size(r_n_groups, "n_groups");
const auto grouped = n_groups > 0;
const auto time_start = check_time(r_time_start, "time_start");
const auto time = check_time_sequence<real_type>(time_start, r_time, "time");
const auto dt = check_dt(r_dt);
const auto shared = build_shared<T>(r_pars, n_groups);
const auto internal = build_internal<T>(shared);
const auto data = check_data<T>(r_data, time.size(), n_groups, "data");

// It's possible that we don't want to always really be
// deterministic here? Though nooone can think of a case where
// that's actually the behaviour wanted. For now let's go fully
// deterministic.
auto seed = mcstate::random::r::as_rng_seed<rng_state_type>(r_seed);
const auto deterministic = false;

// Create all the required rng states across the filter and the
// model, in a reasonable way. We need to make this slightly easier
// to do from mcstate really. Expand the state to give all the
// state required by the filter (n_groups streams worth) and the
// model (n_groups * n_particles worth, though the last bit of
// expansion could be done by the model itself instead?)
//
// There are two ways of sorting out the state here:
//
// 1. we could take the first n_groups states for the filter and the
// remaining for the models. This has the nice property that we can
// expand the model state later if we support growing models
// (mrc-5355). However, it has the undesirable consequence that a
// filter with multiple groups will stream differently to a filter
// containing a the first group only.
//
// 2. we take each block of (1+n_particles) states for each group,
// giving the first to the filter and the rest to the model. This
// means that we can change the number of groups without affecting
// the results, though we can't change the number of particles as
// easily.
const auto n_groups_effective = grouped ? n_groups : 1;
const auto n_streams = n_groups_effective * (n_particles + 1);
const auto rng_state = mcstate::random::prng<rng_state_type>(n_streams, seed, deterministic).export_state();
const auto rng_len = rng_state_type::size();
rng_seed_type seed_filter;
rng_seed_type seed_model;
for (size_t i = 0; i < n_groups_effective; ++i) {
const auto it = rng_state.begin() + i * rng_len * (n_particles + 1);
seed_filter.insert(seed_filter.end(),
it, it + rng_len);
seed_model.insert(seed_model.end(),
it + rng_len, it + rng_len * (n_particles + 1));
}

const auto model = dust2::dust_cpu<T>(shared, internal, time_start, dt, n_particles,
seed_model, deterministic);

auto obj = new filter<T>(model, time_start, time, data, seed_filter);
cpp11::external_pointer<filter<T>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->model.n_state());
cpp11::sexp r_group_names = R_NilValue;
if (grouped) {
r_group_names = r_pars.attr("names");
}
cpp11::sexp r_grouped = cpp11::as_sexp(grouped);

return cpp11::writable::list{ptr, r_n_state, r_grouped, r_group_names};
}

template <typename T>
cpp11::sexp dust2_cpu_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
bool grouped) {
auto *obj =
cpp11::as_cpp<cpp11::external_pointer<filter<T>>>(ptr).get();
if (r_pars != R_NilValue) {
update_pars(obj->model, cpp11::as_cpp<cpp11::list>(r_pars), grouped);
}
obj->run();

cpp11::writable::doubles ret(obj->model.n_groups());
obj->last_log_likelihood(REAL(ret));
return ret;
}

template <typename T>
cpp11::sexp dust2_cpu_filter_rng_state(cpp11::sexp ptr) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<filter<T>>>(ptr).get();
using rng_state_type = typename T::rng_state_type;

// Undo the construction as above so that the rng state comes out in
// the same format it goes in, as a single raw vector.
const auto& state_filter = obj->rng_state();
const auto& state_model = obj->model.rng_state();
const auto n_particles = obj->model.n_particles();
const auto n_groups = obj->model.n_groups();
const auto n_state = rng_state_type::size();
const auto n_bytes = sizeof(typename rng_state_type::int_type);
const auto n_bytes_state = n_bytes * n_state;
cpp11::writable::raws ret(n_bytes * (state_filter.size() + state_model.size()));
for (size_t i = 0; i < n_groups; ++i) {
std::memcpy(RAW(ret) + i * n_bytes_state * (n_particles + 1),
state_filter.data() + i * n_state,
n_bytes_state);
std::memcpy(RAW(ret) + i * n_bytes_state * (n_particles + 1) + n_bytes_state,
state_model.data() + i * n_state * n_particles,
n_bytes_state * n_particles);
}

return ret;
}

}
}
9 changes: 9 additions & 0 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <cstring>
#include <numeric>
#include <vector>
#include <dust2/common.hpp>
Expand Down Expand Up @@ -277,5 +278,13 @@ void set_state(dust_cpu<T>& obj, cpp11::sexp r_state, bool grouped) {
obj.set_state(REAL(r_state), recycle_particle, recycle_group);
}

template <typename T>
SEXP rng_state_as_raw(const std::vector<T>& state) {
const auto len = sizeof(T) * state.size();
cpp11::writable::raws ret(len);
std::memcpy(RAW(ret), state.data(), len);
return ret;
}

}
}
Loading

0 comments on commit 49624b8

Please sign in to comment.