Skip to content

Commit

Permalink
Allow all slices of state to be set
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Oct 21, 2024
1 parent e4d1e54 commit dfa1e41
Show file tree
Hide file tree
Showing 25 changed files with 738 additions and 156 deletions.
20 changes: 10 additions & 10 deletions R/cpp11.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R/dust.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions R/interface-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ dust_likelihood_run <- function(obj, pars, initial = NULL,
## so that we can access `n_groups`
index_state <- check_index(index_state, max = obj$n_state,
unique = TRUE)
if (!is.null(initial)) {
initial <- prepare_state(initial,
NULL, # index_state
NULL, # index_particle
index_group,
obj$n_state,
obj$n_particles,
obj$n_groups,
obj$preserve_particle_dimension,
obj$preserve_group_dimension)
}
obj$methods$run(obj$ptr,
initial,
save_history,
Expand Down
166 changes: 160 additions & 6 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,94 @@ dust_system_state <- function(sys, index_state = NULL, index_particle = NULL,

##' Set system state. Takes a multidimensional array (2- or 3d
##' depending on if the system is grouped or not). Dimensions of
##' length 1 will be recycled as appropriate.
##' length 1 will be recycled as appropriate. For continuous time
##' systems, we will initialise the solver immediately after setting
##' state, which may cause errors if your initial state is invalid for
##' your system. There are many ways that you can use this function
##' to set different fractions of state (a subset of states, particles
##' or parameter groups, recycling over any dimensions that are
##' missing). Please see the Examples section for usage.
##'
##' @title Set system state
##'
##' @inheritParams dust_system_state
##'
##' @param state A matrix or array of state. If ungrouped, the
##' dimension order expected is state x particle. If grouped the
##' order is state x particle x group.
##' order is state x particle x group. If you have a grouped system
##' with 1 particle and `preserve_state_dimension = FALSE` then the
##' state has size state x group. You can omit higher dimensions,
##' so if you pass a vector it will be treated as if all higher
##' dimensions are length 1 (oif if you have a grouped system you
##' can provide a matrix and treat it as if the third dimension had
##' length 1). If you provide any `index_` argument then the length
##' of the corresponding state dimension must match the index
##' length.
##'
##' @param index_state An index to control which state variables we
##' set. You can use this to set a subset of state variables.
##'
##' @param index_particle An index to control which particles have
##' their state updated
##'
##' @param index_group An index to control which groups have their
##' state updated.
##'
##' @return Nothing, called for side effects only
##' @export
dust_system_set_state <- function(sys, state) {
##' @examples
##' # Consider a system with 3 particles and 1 group:
##' sir <- dust_example("sir")
##' sys <- dust_system_create(sir(), list(), n_particles = 3)
##' # The state for this system is packed as S, I, R, cases_cumul, cases_inc:
##' dust_unpack_index(sys)
##'
##' # Set all particles to the same state:
##' dust_system_set_state(sys, c(1000, 10, 0, 0, 0))
##' dust_system_state(sys)
##'
##' # We can set everything to different states by passing a vector
##' # with this shape:
##' m <- cbind(c(1000, 10, 0, 0, 0), c(999, 11, 0, 0, 0), c(998, 12, 0, 0, 0))
##' dust_system_set_state(sys, m)
##' dust_system_state(sys)
##'
##' # Or set the state for just one state:
##' dust_system_set_state(sys, 1, index_state = 4)
##' dust_system_state(sys)
##'
##' # If you want to set a different state across particles, you must
##' # provide a *matrix* (a vector always sets the same state into
##' # every particle)
##' dust_system_set_state(sys, rbind(c(1, 2, 3)), index_state = 4)
##' dust_system_state(sys)
##'
##' # This will not work as it can it can be ambiguous what you are
##' # trying to do:
##' #> dust_system_set_state(sys, c(1, 2, 3), index_state = 4)
##'
##' # State can be set for specific particles:
##' dust_system_set_state(sys, c(900, 100, 0, 0, 0), index_particle = 2)
##' dust_system_state(sys)
##'
##' # And you can combine 'index_particle' with 'index_state' to set
##' # small rectangles of state:
##' dust_system_set_state(sys, matrix(c(1, 2, 3, 4), 2, 2),
##' index_particle = 2:3, index_state = 4:5)
##' dust_system_state(sys)
dust_system_set_state <- function(sys, state, index_state = NULL,
index_particle = NULL, index_group = NULL) {
check_is_dust_system(sys)
## TODO: check rank etc here (mrc-5565), and support
## preserve_particle_dimension
sys$methods$set_state(sys$ptr, state, sys$preserve_group_dimension)
state <- prepare_state(state,
index_state,
index_particle,
index_group,
sys$n_state,
sys$n_particles,
sys$n_groups,
sys$preserve_particle_dimension,
sys$preserve_group_dimension)
sys$methods$set_state(sys$ptr, state)
invisible()
}

Expand Down Expand Up @@ -753,3 +824,86 @@ dust_package_env <- function(env, quiet = FALSE) {
env <- load_temporary_package(env$path, env$name, quiet)
}
}


## This does a bunch of bookkeeping to work out if we can set state
## into a system, and works out what we'll need to recycle when
## setting it.
prepare_state <- function(state,
index_state,
index_particle,
index_group,
n_state,
n_particles,
n_groups,
preserve_particle_dimension,
preserve_group_dimension,
name = deparse(substitute(state)),
call = parent.frame()) {
len_from_index <- function(n, idx, name_index = deparse(substitute(idx))) {
if (is.null(idx)) {
n
} else {
check_index(idx, n, unique = TRUE, name = name_index)
length(idx)
}
}
len_state <- len_from_index(n_state, index_state)
len_particles <- len_from_index(n_particles, index_particle)
len_groups <- len_from_index(n_groups, index_group)

stopifnot(preserve_particle_dimension || n_particles == 1)
stopifnot(preserve_group_dimension || n_groups == 1)

d <- dim2(state)
rank <- length(d)
expected <- c(state = len_state,
particle = if (preserve_particle_dimension) len_particles,
group = if (preserve_group_dimension) len_groups)
rank_expected <- length(expected)
if (rank > rank_expected) {
cli::cli_abort(
paste("Expected 'state' to be a {rank_description(rank_expected)}",
"but was given a {rank_description(rank)}"),
arg = name, call = call)
}
if (rank < rank_expected) {
d <- c(d, rep(1L, rank_expected - rank))
}

ok <- d == expected | c(FALSE, d[-1] == 1)
if (all(ok)) {
## We will access this by position from the C++ code but name it
## here for clarity.
return(list(state = state,
index_state = index_state,
index_particle = index_particle,
index_group = index_group,
recycle_particle = n_particles > 1 && d[[2]] == 1,
recycle_group = n_groups > 1 && last(d) == 1))
}

if (!ok[[1]]) {
if (rank == 1) {
msg <- "Expected '{name}' to have length {len_state}"
} else if (rank == 2) {
msg <- "Expected '{name}' to have {len_state} rows"
} else {
msg <- "Expected dimension 1 of '{name}' to be length {len_state}"
}
} else if (!ok[[2]]) {
expected_str <-
if (expected[[2]] == 1) "1" else sprintf("1 or %d", expected[[2]])
if (rank == 2) {
msg <- "Expected '{name}' to have {expected_str} columns"
} else {
msg <- "Expected dimension 2 of '{name}' to be length {expected_str}"
}
} else {
expected_str <-
if (expected[[3]] == 1) "1" else sprintf("1 or %d", expected[[3]])
msg <- "Expected dimension 3 of '{name}' to be length {expected_str}"
}

cli::cli_abort(msg, arg = name, call = call)
}
18 changes: 18 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,21 @@ drop_last <- function(x, n = 1) {
len <- length(x)
x[if (len < n) integer(0) else seq_len(len - n)]
}


dim2 <- function(x) {
dim(x) %||% length(x)
}


rank_description <- function(rank) {
if (rank == 0) {
"scalar"
} else if (rank == 1) {
"vector"
} else if (rank == 2) {
"matrix"
} else {
sprintf("%d-dimensional array", rank)
}
}
64 changes: 36 additions & 28 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <dust2/continuous/control.hpp>
#include <dust2/continuous/solver.hpp>
#include <dust2/errors.hpp>
#include <dust2/internals.hpp>
#include <dust2/packing.hpp>
#include <dust2/properties.hpp>
#include <dust2/tools.hpp>
Expand Down Expand Up @@ -179,35 +180,19 @@ class dust_continuous {
}

template <typename Iter>
void set_state(Iter iter, bool recycle_particle, bool recycle_group,
const std::vector<size_t>& index_group) {
void set_state(Iter iter,
const std::vector<size_t>& index_state,
const std::vector<size_t>& index_particle,
const std::vector<size_t>& index_group,
bool recycle_particle,
bool recycle_group) {
errors_.reset();
const auto offset_read_group = recycle_group ? 0 :
(n_state_ * (recycle_particle ? 1 : n_particles_));
const auto offset_read_particle = recycle_particle ? 0 : n_state_;

real_type * state_data = state_.data();
#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2)
#endif
for (auto i : index_group) {
for (size_t j = 0; j < n_particles_; ++j) {
const auto k = n_particles_ * i + j;
const auto offset_read =
i * offset_read_group + j * offset_read_particle;
const auto offset_write = k * n_state_;
auto& internal_i = internal_[tools::thread_index() * n_groups_ + i];
real_type* y = state_data + offset_write;
std::copy_n(iter + offset_read, n_state_, y);
try {
solver_.initialise(time_, y, ode_internals_[k],
rhs_(shared_[i], internal_i));
} catch (std::exception const& e) {
errors_.capture(e, k);
}
}
}
errors_.report();
dust2::internals::set_state(state_, iter,
n_state_, n_particles_, n_groups_,
index_state, index_particle, index_group,
recycle_particle, recycle_group,
n_threads_);
initialise_solver_(index_group.empty() ? all_groups_ : index_group);
}

// iter here is an iterator to our *reordering index*, which will be
Expand Down Expand Up @@ -378,6 +363,29 @@ class dust_continuous {
T::rhs(t, y, shared, internal, dydt);
};
}

void initialise_solver_(std::vector<size_t> index_group) {
errors_.reset();
real_type * state_data = state_.data();
#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2)
#endif
for (auto i : index_group) {
for (size_t j = 0; j < n_particles_; ++j) {
const auto k = n_particles_ * i + j;
const auto offset = k * n_state_;
auto& internal_i = internal_[tools::thread_index() * n_groups_ + i];
real_type * y = state_data + offset;
try {
solver_.initialise(time_, y, ode_internals_[k],
rhs_(shared_[i], internal_i));
} catch (std::exception const& e) {
errors_.capture(e, k);
}
}
}
errors_.report();
}
};

}
Loading

0 comments on commit dfa1e41

Please sign in to comment.