Skip to content

Commit

Permalink
Merge pull request #108 from mrc-ide/mrc-5497
Browse files Browse the repository at this point in the history
Add support for output
  • Loading branch information
weshinsley authored Oct 25, 2024
2 parents dc14152 + 27d11bf commit e5aed27
Show file tree
Hide file tree
Showing 17 changed files with 326 additions and 58 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust2
Title: Next Generation dust
Version: 0.2.1
Version: 0.2.2
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
2 changes: 1 addition & 1 deletion R/dust.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions inst/examples/logistic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ class logistic {
using rng_state_type = monty::random::generator<real_type>;

static dust2::packing packing_state(const shared_state& shared) {
return dust2::packing{{"x", {shared.n}}};
return dust2::packing{{"x", {shared.n}}, {"total", {}}};
}

static size_t size_output(const shared_state& shared) {
// Based on packing_state; the *last* 'n' elements here are output.
// We might change this later but this is quite convenient from a
// data layout view and has the advantage that this can be computed
// without using any reference to shared.
static size_t size_output() {
return 1;
}

Expand All @@ -51,13 +55,11 @@ class logistic {
}

static void output(real_type time,
const real_type * state,
real_type * state,
const shared_state& shared,
internal_state& internal,
real_type * output) {
// We will change this to use a delay (e.g., growth over last
// period) to give the history a good workout right away.
output[0] = std::accumulate(state, state + shared.n, 0);
internal_state& internal) {
state[shared.n] = std::accumulate(state, state + shared.n,
static_cast<real_type>(0));
}

static shared_state build_shared(cpp11::list pars) {
Expand Down
79 changes: 72 additions & 7 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

namespace dust2 {

namespace {

}

template <typename T>
class dust_continuous {
public:
Expand All @@ -39,6 +43,8 @@ class dust_continuous {
packing_state_(T::packing_state(shared[0])),
packing_gradient_(do_packing_gradient<T>(shared[0])),
n_state_(packing_state_.size()),
n_state_output_(do_n_state_output<T>(packing_state_)),
n_state_ode_(n_state_ - n_state_output_),
n_particles_(n_particles),
n_groups_(shared.size()),
n_particles_total_(n_particles_ * n_groups_),
Expand All @@ -47,11 +53,11 @@ class dust_continuous {
control_(control),

state_(n_state_ * n_particles_total_),
ode_internals_(n_particles_total_, n_state_),
ode_internals_(n_particles_total_, n_state_ode_),

// For reordering to work:
state_other_(n_state_ * n_particles_total_),
ode_internals_other_(n_particles_total_, n_state_),
ode_internals_other_(n_particles_total_, n_state_ode_),

shared_(shared),
internal_(internal),
Expand All @@ -61,8 +67,9 @@ class dust_continuous {
zero_every_(zero_every_vec<T>(shared_)),
errors_(n_particles_total_),
rng_(n_particles_total_, seed, deterministic),
solver_(n_state_, control_),
n_threads_(n_threads) {
solver_(n_state_ode_, control_),
n_threads_(n_threads),
output_is_current_(n_groups_) {
// TODO: above, filter rng states need adding here too, or
// somewhere at least (we might move the filter elsewhere though,
// in which case that particular bit of weirdness goes away).
Expand Down Expand Up @@ -99,6 +106,7 @@ class dust_continuous {
}
errors_.report();
time_ = time;
update_output_is_current(index_group, false);
}

template <typename mixed_time = typename dust2::properties<T>::is_mixed_time>
Expand Down Expand Up @@ -130,7 +138,7 @@ class dust_continuous {
solver_.run(t0, t1, y, zero_every_[i],
ode_internals_[k],
rhs_(shared_[i], internal_i));
std::copy_n(y, n_state_, y_other);
std::copy_n(y, n_state_ode_, y_other);
T::update(t1, dt_, y, shared_[i], internal_i, rng_state, y_other);
std::swap(y, y_other);
solver_.initialise(time_, y, ode_internals_[k],
Expand All @@ -146,6 +154,7 @@ class dust_continuous {
std::swap(state_, state_other_);
}
time_ = time;
update_output_is_current(index_group, false);
}

void run_to_time(real_type time,
Expand Down Expand Up @@ -177,6 +186,8 @@ class dust_continuous {
}
}
errors_.report();
// Assume not current, because most models would want to call output here()
update_output_is_current(index_group, false);
}

template <typename Iter>
Expand All @@ -193,6 +204,13 @@ class dust_continuous {
recycle_particle, recycle_group,
n_threads_);
initialise_solver_(index_group.empty() ? all_groups_ : index_group);
// I'm not sure what is best here; this (and to a degree
// T::initial) are the two places where we might end up with
// inconsistent output (e.g., the user has set a state that
// includes output but the output is wrong). This approach gives
// them some flexibility at least.
bool update_is_current = tools::is_trivial_index(index_state, n_state_);
update_output_is_current(index_group, update_is_current);
}

// iter here is an iterator to our *reordering index*, which will be
Expand All @@ -217,8 +235,10 @@ class dust_continuous {
for (size_t j = 0; j < n_particles_; ++j) {
const auto k_to = n_particles_ * i + j;
const auto k_from = n_particles_ * i + *(iter + k_to);
const auto n_state_copy =
output_is_current_[i] ? n_state_ : n_state_ode_;
std::copy_n(state_.begin() + k_from * n_state_,
n_state_,
n_state_copy,
state_other_.begin() + k_to * n_state_);
ode_internals_other_[k_to] = ode_internals_[k_from];
}
Expand All @@ -227,7 +247,8 @@ class dust_continuous {
std::swap(ode_internals_, ode_internals_other_);
}

auto& state() const {
const auto& state() {
update_output();
return state_;
}

Expand Down Expand Up @@ -272,6 +293,10 @@ class dust_continuous {

void set_time(real_type time) {
time_ = time;
// We should set output_is_current here but I will wait until
// updating time_ to make it vary by group. Practically the next
// thing anyone does after this is to update initial conditions so
// this is fine for now.
}

auto rng_state() const {
Expand Down Expand Up @@ -334,6 +359,8 @@ class dust_continuous {
dust2::packing packing_state_;
dust2::packing packing_gradient_;
size_t n_state_;
size_t n_state_output_;
size_t n_state_ode_;
size_t n_particles_;
size_t n_groups_;
size_t n_particles_total_;
Expand All @@ -357,6 +384,7 @@ class dust_continuous {
monty::random::prng<rng_state_type> rng_;
ode::solver<real_type> solver_;
size_t n_threads_;
std::vector<bool> output_is_current_;

static auto rhs_(const shared_state& shared, internal_state& internal) {
return [&](real_type t, const real_type* y, real_type* dydt) {
Expand Down Expand Up @@ -386,6 +414,43 @@ class dust_continuous {
}
errors_.report();
}

// Default implementation does nothing
template <typename has_output = typename dust2::properties<T>::has_output>
typename std::enable_if<!has_output::value, void>::type
update_output() {
}

template <typename has_output = typename dust2::properties<T>::has_output>
typename std::enable_if<has_output::value, void>::type
update_output() {
real_type * state_data = state_.data();
#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2)
#endif
for (size_t i = 0; i < n_groups_; ++i) {
for (size_t j = 0; j < n_particles_; ++j) {
const auto k = n_particles_ * i + j;
const auto offset = k * n_state_;
auto& internal_i = internal_[tools::thread_index() * n_groups_ + i];
real_type * y = state_data + offset;
if (!output_is_current_[i]) {
T::output(time_, y, shared_[i], internal_i);
}
}
}
update_output_is_current({}, true);
}

void update_output_is_current(std::vector<size_t> index_group, bool value) {
if (index_group.empty()) {
std::fill(output_is_current_.begin(), output_is_current_.end(), value);
} else {
for (auto i : index_group) {
output_is_current_[i] = value;
}
}
}
};

}
13 changes: 4 additions & 9 deletions inst/include/dust2/packing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ class packing {
size_ = std::accumulate(len_.begin(), len_.end(), 0);
}

// We can support an incremental interface easily enough like this
// if the initializer list version proves too annoying:
//
// void add(mapping_type x) {
// data_.push_back(x);
// len_.push_back(tools::prod(data_.back().second));
// size_ += len_.back();
// }

auto size() const {
return size_;
}

auto& len() const {
return len_;
}

auto& data() const {
return data_;
}
Expand Down
22 changes: 22 additions & 0 deletions inst/include/dust2/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ struct test_has_rhs: std::false_type {};
template <class T>
struct test_has_rhs<T, std::void_t<decltype(T::rhs)>>: std::true_type {};

template <class T, class = void>
struct test_has_output: std::false_type {};
template <class T>
struct test_has_output<T, std::void_t<decltype(T::output)>>: std::true_type {};

}

template<typename T>
Expand All @@ -41,6 +46,7 @@ struct properties {
using has_packing_gradient = internals::test_has_packing_gradient<T>;
using has_zero_every = internals::test_has_zero_every<T>;
using is_mixed_time = typename std::conditional<internals::test_has_rhs<T>::value && internals::test_has_update<T>::value, std::true_type, std::false_type>::type;
using has_output = typename std::conditional<internals::test_has_rhs<T>::value && internals::test_has_output<T>::value, std::true_type, std::false_type>::type;
};

// wrappers around some uses of member functions that may or may not
Expand All @@ -65,6 +71,19 @@ dust2::packing do_packing_gradient(const typename T::shared_state &shared) {
return dust2::packing{};
}


template <typename T, typename std::enable_if<properties<T>::has_output::value, T>::type* = nullptr>
size_t do_n_state_output(const dust2::packing& packing) {
return std::accumulate(packing.len().begin() + T::size_output(),
packing.len().end(),
0);
}

template <typename T, typename std::enable_if<!properties<T>::has_output::value, T>::type* = nullptr>
size_t do_n_state_output(const dust2::packing& packing) {
return 0;
}

template <typename T, typename std::enable_if<properties<T>::has_zero_every::value, T>::type* = nullptr>
auto zero_every_vec(const std::vector<typename T::shared_state>& shared) {
using real_type = typename T::real_type;
Expand All @@ -82,4 +101,7 @@ auto zero_every_vec(const std::vector<typename T::shared_state>& shared) {
return std::vector<zero_every_type<real_type>>(shared.size(), dust2::zero_every_type<real_type>());
}




}
2 changes: 1 addition & 1 deletion inst/include/dust2/tools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ T prod(const std::vector<T>& x) {
}

inline bool is_trivial_index(const std::vector<size_t>& index, size_t n) {
if (index.empty() == 0) {
if (index.empty()) {
return true;
}
if (index.size() != n) {
Expand Down
4 changes: 2 additions & 2 deletions inst/include/dust2/trajectories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class trajectories {
const std::vector<size_t>& index_group) {
index_state_ = index_state;
index_group_ = index_group;
use_index_state_ = tools::is_trivial_index(index_state_, n_state_total_);
use_index_group_ = tools::is_trivial_index(index_group_, n_groups_total_);
use_index_state_ = !tools::is_trivial_index(index_state_, n_state_total_);
use_index_group_ = !tools::is_trivial_index(index_group_, n_groups_total_);
n_state_ = use_index_state_ ? index_state.size() : n_state_total_;
n_groups_ = use_index_group_ ? index_group.size() : n_groups_total_;
len_state_ = n_state_ * n_particles_ * n_groups_;
Expand Down
20 changes: 11 additions & 9 deletions src/logistic.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/malaria.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/sir.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/sirode.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/walk.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e5aed27

Please sign in to comment.