Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow all slices of state to be set #104

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust2
Title: Next Generation dust
Version: 0.1.19
Version: 0.1.20
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
20 changes: 10 additions & 10 deletions R/cpp11.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R/dust.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions R/interface-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ dust_likelihood_run <- function(obj, pars, initial = NULL,
## so that we can access `n_groups`
index_state <- check_index(index_state, max = obj$n_state,
unique = TRUE)
if (!is.null(initial)) {
initial <- prepare_state(initial,
NULL, # index_state
NULL, # index_particle
index_group,
obj$n_state,
obj$n_particles,
obj$n_groups,
obj$preserve_particle_dimension,
obj$preserve_group_dimension)
}
obj$methods$run(obj$ptr,
initial,
save_history,
Expand Down
166 changes: 160 additions & 6 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,94 @@ dust_system_state <- function(sys, index_state = NULL, index_particle = NULL,

##' Set system state. Takes a multidimensional array (2- or 3d
##' depending on if the system is grouped or not). Dimensions of
##' length 1 will be recycled as appropriate.
##' length 1 will be recycled as appropriate. For continuous time
##' systems, we will initialise the solver immediately after setting
##' state, which may cause errors if your initial state is invalid for
##' your system. There are many ways that you can use this function
##' to set different fractions of state (a subset of states, particles
##' or parameter groups, recycling over any dimensions that are
##' missing). Please see the Examples section for usage.
##'
##' @title Set system state
##'
##' @inheritParams dust_system_state
##'
##' @param state A matrix or array of state. If ungrouped, the
##' dimension order expected is state x particle. If grouped the
##' order is state x particle x group.
##' order is state x particle x group. If you have a grouped system
##' with 1 particle and `preserve_state_dimension = FALSE` then the
##' state has size state x group. You can omit higher dimensions,
##' so if you pass a vector it will be treated as if all higher
##' dimensions are length 1 (or if you have a grouped system you
##' can provide a matrix and treat it as if the third dimension had
##' length 1). If you provide any `index_` argument then the length
##' of the corresponding state dimension must match the index
##' length.
##'
##' @param index_state An index to control which state variables we
##' set. You can use this to set a subset of state variables.
##'
##' @param index_particle An index to control which particles have
##' their state updated
##'
##' @param index_group An index to control which groups have their
##' state updated.
##'
##' @return Nothing, called for side effects only
##' @export
dust_system_set_state <- function(sys, state) {
##' @examples
##' # Consider a system with 3 particles and 1 group:
##' sir <- dust_example("sir")
##' sys <- dust_system_create(sir(), list(), n_particles = 3)
##' # The state for this system is packed as S, I, R, cases_cumul, cases_inc:
##' dust_unpack_index(sys)
##'
##' # Set all particles to the same state:
##' dust_system_set_state(sys, c(1000, 10, 0, 0, 0))
##' dust_system_state(sys)
##'
##' # We can set everything to different states by passing a vector
##' # with this shape:
##' m <- cbind(c(1000, 10, 0, 0, 0), c(999, 11, 0, 0, 0), c(998, 12, 0, 0, 0))
##' dust_system_set_state(sys, m)
##' dust_system_state(sys)
##'
##' # Or set the state for just one state:
##' dust_system_set_state(sys, 1, index_state = 4)
##' dust_system_state(sys)
##'
##' # If you want to set a different state across particles, you must
##' # provide a *matrix* (a vector always sets the same state into
##' # every particle)
##' dust_system_set_state(sys, rbind(c(1, 2, 3)), index_state = 4)
##' dust_system_state(sys)
##'
##' # This will not work as it can be ambiguous what you are
##' # trying to do:
##' #> dust_system_set_state(sys, c(1, 2, 3), index_state = 4)
##'
##' # State can be set for specific particles:
##' dust_system_set_state(sys, c(900, 100, 0, 0, 0), index_particle = 2)
##' dust_system_state(sys)
##'
##' # And you can combine 'index_particle' with 'index_state' to set
##' # small rectangles of state:
##' dust_system_set_state(sys, matrix(c(1, 2, 3, 4), 2, 2),
##' index_particle = 2:3, index_state = 4:5)
##' dust_system_state(sys)
dust_system_set_state <- function(sys, state, index_state = NULL,
index_particle = NULL, index_group = NULL) {
check_is_dust_system(sys)
## TODO: check rank etc here (mrc-5565), and support
## preserve_particle_dimension
sys$methods$set_state(sys$ptr, state, sys$preserve_group_dimension)
state <- prepare_state(state,
index_state,
index_particle,
index_group,
sys$n_state,
sys$n_particles,
sys$n_groups,
sys$preserve_particle_dimension,
sys$preserve_group_dimension)
sys$methods$set_state(sys$ptr, state)
invisible()
}

Expand Down Expand Up @@ -753,3 +824,86 @@ dust_package_env <- function(env, quiet = FALSE) {
env <- load_temporary_package(env$path, env$name, quiet)
}
}


## This does a bunch of bookkeeping to work out if we can set state
## into a system, and works out what we'll need to recycle when
## setting it.
prepare_state <- function(state,
index_state,
index_particle,
index_group,
n_state,
n_particles,
n_groups,
preserve_particle_dimension,
preserve_group_dimension,
name = deparse(substitute(state)),
call = parent.frame()) {
len_from_index <- function(n, idx, name_index = deparse(substitute(idx))) {
if (is.null(idx)) {
n
} else {
check_index(idx, n, unique = TRUE, name = name_index)
length(idx)
}
}
len_state <- len_from_index(n_state, index_state)
len_particles <- len_from_index(n_particles, index_particle)
len_groups <- len_from_index(n_groups, index_group)

stopifnot(preserve_particle_dimension || n_particles == 1)
stopifnot(preserve_group_dimension || n_groups == 1)

d <- dim2(state)
rank <- length(d)
expected <- c(state = len_state,
particle = if (preserve_particle_dimension) len_particles,
group = if (preserve_group_dimension) len_groups)
rank_expected <- length(expected)
if (rank > rank_expected) {
cli::cli_abort(
paste("Expected 'state' to be a {rank_description(rank_expected)}",
"but was given a {rank_description(rank)}"),
arg = name, call = call)
}
if (rank < rank_expected) {
d <- c(d, rep(1L, rank_expected - rank))
}

ok <- d == expected | c(FALSE, d[-1] == 1)
if (all(ok)) {
## We will access this by position from the C++ code but name it
## here for clarity.
return(list(state = state,
index_state = index_state,
index_particle = index_particle,
index_group = index_group,
recycle_particle = n_particles > 1 && d[[2]] == 1,
recycle_group = n_groups > 1 && last(d) == 1))
}

if (!ok[[1]]) {
if (rank == 1) {
msg <- "Expected '{name}' to have length {len_state}"
} else if (rank == 2) {
msg <- "Expected '{name}' to have {len_state} rows"
} else {
msg <- "Expected dimension 1 of '{name}' to be length {len_state}"
}
} else if (!ok[[2]]) {
expected_str <-
if (expected[[2]] == 1) "1" else sprintf("1 or %d", expected[[2]])
if (rank == 2) {
msg <- "Expected '{name}' to have {expected_str} columns"
} else {
msg <- "Expected dimension 2 of '{name}' to be length {expected_str}"
}
} else {
expected_str <-
if (expected[[3]] == 1) "1" else sprintf("1 or %d", expected[[3]])
msg <- "Expected dimension 3 of '{name}' to be length {expected_str}"
}

cli::cli_abort(msg, arg = name, call = call)
}
18 changes: 18 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,21 @@ drop_last <- function(x, n = 1) {
len <- length(x)
x[if (len < n) integer(0) else seq_len(len - n)]
}


dim2 <- function(x) {
dim(x) %||% length(x)
}


rank_description <- function(rank) {
if (rank == 0) {
"scalar"
} else if (rank == 1) {
"vector"
} else if (rank == 2) {
"matrix"
} else {
sprintf("%d-dimensional array", rank)
}
}
64 changes: 36 additions & 28 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <dust2/continuous/control.hpp>
#include <dust2/continuous/solver.hpp>
#include <dust2/errors.hpp>
#include <dust2/internals.hpp>
#include <dust2/packing.hpp>
#include <dust2/properties.hpp>
#include <dust2/tools.hpp>
Expand Down Expand Up @@ -179,35 +180,19 @@ class dust_continuous {
}

template <typename Iter>
void set_state(Iter iter, bool recycle_particle, bool recycle_group,
const std::vector<size_t>& index_group) {
void set_state(Iter iter,
const std::vector<size_t>& index_state,
const std::vector<size_t>& index_particle,
const std::vector<size_t>& index_group,
bool recycle_particle,
bool recycle_group) {
errors_.reset();
const auto offset_read_group = recycle_group ? 0 :
(n_state_ * (recycle_particle ? 1 : n_particles_));
const auto offset_read_particle = recycle_particle ? 0 : n_state_;

real_type * state_data = state_.data();
#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2)
#endif
for (auto i : index_group) {
for (size_t j = 0; j < n_particles_; ++j) {
const auto k = n_particles_ * i + j;
const auto offset_read =
i * offset_read_group + j * offset_read_particle;
const auto offset_write = k * n_state_;
auto& internal_i = internal_[tools::thread_index() * n_groups_ + i];
real_type* y = state_data + offset_write;
std::copy_n(iter + offset_read, n_state_, y);
try {
solver_.initialise(time_, y, ode_internals_[k],
rhs_(shared_[i], internal_i));
} catch (std::exception const& e) {
errors_.capture(e, k);
}
}
}
errors_.report();
dust2::internals::set_state(state_, iter,
n_state_, n_particles_, n_groups_,
index_state, index_particle, index_group,
recycle_particle, recycle_group,
n_threads_);
initialise_solver_(index_group.empty() ? all_groups_ : index_group);
}

// iter here is an iterator to our *reordering index*, which will be
Expand Down Expand Up @@ -378,6 +363,29 @@ class dust_continuous {
T::rhs(t, y, shared, internal, dydt);
};
}

void initialise_solver_(std::vector<size_t> index_group) {
errors_.reset();
real_type * state_data = state_.data();
#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2)
#endif
for (auto i : index_group) {
for (size_t j = 0; j < n_particles_; ++j) {
const auto k = n_particles_ * i + j;
const auto offset = k * n_state_;
auto& internal_i = internal_[tools::thread_index() * n_groups_ + i];
real_type * y = state_data + offset;
try {
solver_.initialise(time_, y, ode_internals_[k],
rhs_(shared_[i], internal_i));
} catch (std::exception const& e) {
errors_.capture(e, k);
}
}
}
errors_.report();
}
};

}
Loading
Loading