diff --git a/DESCRIPTION b/DESCRIPTION index 2a48ebda..dade1817 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: dust2 Title: Next Generation dust -Version: 0.1.19 +Version: 0.1.20 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Imperial College of Science, Technology and Medicine", diff --git a/R/cpp11.R b/R/cpp11.R index 58a2e2cb..1f4be9c6 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -24,8 +24,8 @@ dust2_system_logistic_set_state_initial <- function(ptr) { .Call(`_dust2_dust2_system_logistic_set_state_initial`, ptr) } -dust2_system_logistic_set_state <- function(ptr, r_state, preserve_group_dimension) { - .Call(`_dust2_dust2_system_logistic_set_state`, ptr, r_state, preserve_group_dimension) +dust2_system_logistic_set_state <- function(ptr, r_state) { + .Call(`_dust2_dust2_system_logistic_set_state`, ptr, r_state) } dust2_system_logistic_reorder <- function(ptr, r_index) { @@ -76,8 +76,8 @@ dust2_system_malaria_set_state_initial <- function(ptr) { .Call(`_dust2_dust2_system_malaria_set_state_initial`, ptr) } -dust2_system_malaria_set_state <- function(ptr, r_state, preserve_group_dimension) { - .Call(`_dust2_dust2_system_malaria_set_state`, ptr, r_state, preserve_group_dimension) +dust2_system_malaria_set_state <- function(ptr, r_state) { + .Call(`_dust2_dust2_system_malaria_set_state`, ptr, r_state) } dust2_system_malaria_reorder <- function(ptr, r_index) { @@ -180,8 +180,8 @@ dust2_system_sir_set_state_initial <- function(ptr) { .Call(`_dust2_dust2_system_sir_set_state_initial`, ptr) } -dust2_system_sir_set_state <- function(ptr, r_state, preserve_group_dimension) { - .Call(`_dust2_dust2_system_sir_set_state`, ptr, r_state, preserve_group_dimension) +dust2_system_sir_set_state <- function(ptr, r_state) { + .Call(`_dust2_dust2_system_sir_set_state`, ptr, r_state) } dust2_system_sir_reorder <- function(ptr, r_index) { @@ -288,8 +288,8 @@ dust2_system_sirode_set_state_initial <- function(ptr) { .Call(`_dust2_dust2_system_sirode_set_state_initial`, ptr) } -dust2_system_sirode_set_state <- function(ptr, r_state, preserve_group_dimension) { - .Call(`_dust2_dust2_system_sirode_set_state`, ptr, r_state, preserve_group_dimension) +dust2_system_sirode_set_state <- function(ptr, r_state) { + .Call(`_dust2_dust2_system_sirode_set_state`, ptr, r_state) } dust2_system_sirode_reorder <- function(ptr, r_index) { @@ -444,8 +444,8 @@ dust2_system_walk_set_state_initial <- function(ptr) { .Call(`_dust2_dust2_system_walk_set_state_initial`, ptr) } -dust2_system_walk_set_state <- function(ptr, r_state, preserve_group_dimension) { - .Call(`_dust2_dust2_system_walk_set_state`, ptr, r_state, preserve_group_dimension) +dust2_system_walk_set_state <- function(ptr, r_state) { + .Call(`_dust2_dust2_system_walk_set_state`, ptr, r_state) } dust2_system_walk_reorder <- function(ptr, r_index) { diff --git a/R/dust.R b/R/dust.R index 46c5e1ee..86324829 100644 --- a/R/dust.R +++ b/R/dust.R @@ -1,4 +1,4 @@ -## Generated by dust2 (version 0.1.18) - do not edit +## Generated by dust2 (version 0.1.19) - do not edit logistic <- function() { dust_system_generator("logistic", "continuous", NULL) } diff --git a/R/interface-likelihood.R b/R/interface-likelihood.R index cb77711c..67536d3f 100644 --- a/R/interface-likelihood.R +++ b/R/interface-likelihood.R @@ -80,6 +80,17 @@ dust_likelihood_run <- function(obj, pars, initial = NULL, ## so that we can access `n_groups` index_state <- check_index(index_state, max = obj$n_state, unique = TRUE) + if (!is.null(initial)) { + initial <- prepare_state(initial, + NULL, # index_state + NULL, # index_particle + index_group, + obj$n_state, + obj$n_particles, + obj$n_groups, + obj$preserve_particle_dimension, + obj$preserve_group_dimension) + } obj$methods$run(obj$ptr, initial, save_history, diff --git a/R/interface.R b/R/interface.R index d802976d..9c2b5ba6 100644 --- a/R/interface.R +++ b/R/interface.R @@ -226,7 +226,13 @@ dust_system_state <- function(sys, index_state = NULL, index_particle = NULL, ##' Set system state. Takes a multidimensional array (2- or 3d ##' depending on if the system is grouped or not). Dimensions of -##' length 1 will be recycled as appropriate. +##' length 1 will be recycled as appropriate. For continuous time +##' systems, we will initialise the solver immediately after setting +##' state, which may cause errors if your initial state is invalid for +##' your system. There are many ways that you can use this function +##' to set different fractions of state (a subset of states, particles +##' or parameter groups, recycling over any dimensions that are +##' missing). Please see the Examples section for usage. ##' ##' @title Set system state ##' @@ -234,15 +240,80 @@ dust_system_state <- function(sys, index_state = NULL, index_particle = NULL, ##' ##' @param state A matrix or array of state. If ungrouped, the ##' dimension order expected is state x particle. If grouped the -##' order is state x particle x group. +##' order is state x particle x group. If you have a grouped system +##' with 1 particle and `preserve_state_dimension = FALSE` then the +##' state has size state x group. You can omit higher dimensions, +##' so if you pass a vector it will be treated as if all higher +##' dimensions are length 1 (or if you have a grouped system you +##' can provide a matrix and treat it as if the third dimension had +##' length 1). If you provide any `index_` argument then the length +##' of the corresponding state dimension must match the index +##' length. +##' +##' @param index_state An index to control which state variables we +##' set. You can use this to set a subset of state variables. +##' +##' @param index_particle An index to control which particles have +##' their state updated +##' +##' @param index_group An index to control which groups have their +##' state updated. ##' ##' @return Nothing, called for side effects only ##' @export -dust_system_set_state <- function(sys, state) { +##' @examples +##' # Consider a system with 3 particles and 1 group: +##' sir <- dust_example("sir") +##' sys <- dust_system_create(sir(), list(), n_particles = 3) +##' # The state for this system is packed as S, I, R, cases_cumul, cases_inc: +##' dust_unpack_index(sys) +##' +##' # Set all particles to the same state: +##' dust_system_set_state(sys, c(1000, 10, 0, 0, 0)) +##' dust_system_state(sys) +##' +##' # We can set everything to different states by passing a vector +##' # with this shape: +##' m <- cbind(c(1000, 10, 0, 0, 0), c(999, 11, 0, 0, 0), c(998, 12, 0, 0, 0)) +##' dust_system_set_state(sys, m) +##' dust_system_state(sys) +##' +##' # Or set the state for just one state: +##' dust_system_set_state(sys, 1, index_state = 4) +##' dust_system_state(sys) +##' +##' # If you want to set a different state across particles, you must +##' # provide a *matrix* (a vector always sets the same state into +##' # every particle) +##' dust_system_set_state(sys, rbind(c(1, 2, 3)), index_state = 4) +##' dust_system_state(sys) +##' +##' # This will not work as it can be ambiguous what you are +##' # trying to do: +##' #> dust_system_set_state(sys, c(1, 2, 3), index_state = 4) +##' +##' # State can be set for specific particles: +##' dust_system_set_state(sys, c(900, 100, 0, 0, 0), index_particle = 2) +##' dust_system_state(sys) +##' +##' # And you can combine 'index_particle' with 'index_state' to set +##' # small rectangles of state: +##' dust_system_set_state(sys, matrix(c(1, 2, 3, 4), 2, 2), +##' index_particle = 2:3, index_state = 4:5) +##' dust_system_state(sys) +dust_system_set_state <- function(sys, state, index_state = NULL, + index_particle = NULL, index_group = NULL) { check_is_dust_system(sys) - ## TODO: check rank etc here (mrc-5565), and support - ## preserve_particle_dimension - sys$methods$set_state(sys$ptr, state, sys$preserve_group_dimension) + state <- prepare_state(state, + index_state, + index_particle, + index_group, + sys$n_state, + sys$n_particles, + sys$n_groups, + sys$preserve_particle_dimension, + sys$preserve_group_dimension) + sys$methods$set_state(sys$ptr, state) invisible() } @@ -753,3 +824,86 @@ dust_package_env <- function(env, quiet = FALSE) { env <- load_temporary_package(env$path, env$name, quiet) } } + + +## This does a bunch of bookkeeping to work out if we can set state +## into a system, and works out what we'll need to recycle when +## setting it. +prepare_state <- function(state, + index_state, + index_particle, + index_group, + n_state, + n_particles, + n_groups, + preserve_particle_dimension, + preserve_group_dimension, + name = deparse(substitute(state)), + call = parent.frame()) { + len_from_index <- function(n, idx, name_index = deparse(substitute(idx))) { + if (is.null(idx)) { + n + } else { + check_index(idx, n, unique = TRUE, name = name_index) + length(idx) + } + } + len_state <- len_from_index(n_state, index_state) + len_particles <- len_from_index(n_particles, index_particle) + len_groups <- len_from_index(n_groups, index_group) + + stopifnot(preserve_particle_dimension || n_particles == 1) + stopifnot(preserve_group_dimension || n_groups == 1) + + d <- dim2(state) + rank <- length(d) + expected <- c(state = len_state, + particle = if (preserve_particle_dimension) len_particles, + group = if (preserve_group_dimension) len_groups) + rank_expected <- length(expected) + if (rank > rank_expected) { + cli::cli_abort( + paste("Expected 'state' to be a {rank_description(rank_expected)}", + "but was given a {rank_description(rank)}"), + arg = name, call = call) + } + if (rank < rank_expected) { + d <- c(d, rep(1L, rank_expected - rank)) + } + + ok <- d == expected | c(FALSE, d[-1] == 1) + if (all(ok)) { + ## We will access this by position from the C++ code but name it + ## here for clarity. + return(list(state = state, + index_state = index_state, + index_particle = index_particle, + index_group = index_group, + recycle_particle = n_particles > 1 && d[[2]] == 1, + recycle_group = n_groups > 1 && last(d) == 1)) + } + + if (!ok[[1]]) { + if (rank == 1) { + msg <- "Expected '{name}' to have length {len_state}" + } else if (rank == 2) { + msg <- "Expected '{name}' to have {len_state} rows" + } else { + msg <- "Expected dimension 1 of '{name}' to be length {len_state}" + } + } else if (!ok[[2]]) { + expected_str <- + if (expected[[2]] == 1) "1" else sprintf("1 or %d", expected[[2]]) + if (rank == 2) { + msg <- "Expected '{name}' to have {expected_str} columns" + } else { + msg <- "Expected dimension 2 of '{name}' to be length {expected_str}" + } + } else { + expected_str <- + if (expected[[3]] == 1) "1" else sprintf("1 or %d", expected[[3]]) + msg <- "Expected dimension 3 of '{name}' to be length {expected_str}" + } + + cli::cli_abort(msg, arg = name, call = call) +} diff --git a/R/util.R b/R/util.R index eb5c85c5..5148dbe6 100644 --- a/R/util.R +++ b/R/util.R @@ -130,3 +130,21 @@ drop_last <- function(x, n = 1) { len <- length(x) x[if (len < n) integer(0) else seq_len(len - n)] } + + +dim2 <- function(x) { + dim(x) %||% length(x) +} + + +rank_description <- function(rank) { + if (rank == 0) { + "scalar" + } else if (rank == 1) { + "vector" + } else if (rank == 2) { + "matrix" + } else { + sprintf("%d-dimensional array", rank) + } +} diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index 06511ad7..a6661d33 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -179,35 +180,19 @@ class dust_continuous { } template - void set_state(Iter iter, bool recycle_particle, bool recycle_group, - const std::vector& index_group) { + void set_state(Iter iter, + const std::vector& index_state, + const std::vector& index_particle, + const std::vector& index_group, + bool recycle_particle, + bool recycle_group) { errors_.reset(); - const auto offset_read_group = recycle_group ? 0 : - (n_state_ * (recycle_particle ? 1 : n_particles_)); - const auto offset_read_particle = recycle_particle ? 0 : n_state_; - - real_type * state_data = state_.data(); -#ifdef _OPENMP -#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2) -#endif - for (auto i : index_group) { - for (size_t j = 0; j < n_particles_; ++j) { - const auto k = n_particles_ * i + j; - const auto offset_read = - i * offset_read_group + j * offset_read_particle; - const auto offset_write = k * n_state_; - auto& internal_i = internal_[tools::thread_index() * n_groups_ + i]; - real_type* y = state_data + offset_write; - std::copy_n(iter + offset_read, n_state_, y); - try { - solver_.initialise(time_, y, ode_internals_[k], - rhs_(shared_[i], internal_i)); - } catch (std::exception const& e) { - errors_.capture(e, k); - } - } - } - errors_.report(); + dust2::internals::set_state(state_, iter, + n_state_, n_particles_, n_groups_, + index_state, index_particle, index_group, + recycle_particle, recycle_group, + n_threads_); + initialise_solver_(index_group.empty() ? all_groups_ : index_group); } // iter here is an iterator to our *reordering index*, which will be @@ -378,6 +363,29 @@ class dust_continuous { T::rhs(t, y, shared, internal, dydt); }; } + + void initialise_solver_(std::vector index_group) { + errors_.reset(); + real_type * state_data = state_.data(); +#ifdef _OPENMP +#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2) +#endif + for (auto i : index_group) { + for (size_t j = 0; j < n_particles_; ++j) { + const auto k = n_particles_ * i + j; + const auto offset = k * n_state_; + auto& internal_i = internal_[tools::thread_index() * n_groups_ + i]; + real_type * y = state_data + offset; + try { + solver_.initialise(time_, y, ode_internals_[k], + rhs_(shared_[i], internal_i)); + } catch (std::exception const& e) { + errors_.capture(e, k); + } + } + } + errors_.report(); + } }; } diff --git a/inst/include/dust2/discrete/system.hpp b/inst/include/dust2/discrete/system.hpp index b611b00a..c17d3273 100644 --- a/inst/include/dust2/discrete/system.hpp +++ b/inst/include/dust2/discrete/system.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -144,27 +145,18 @@ class dust_discrete { } template - void set_state(Iter iter, bool recycle_particle, bool recycle_group, - const std::vector& index_group) { + void set_state(Iter iter, + const std::vector& index_state, + const std::vector& index_particle, + const std::vector& index_group, + bool recycle_particle, + bool recycle_group) { errors_.reset(); - const auto offset_read_group = recycle_group ? 0 : - (n_state_ * (recycle_particle ? 1 : n_particles_)); - const auto offset_read_particle = recycle_particle ? 0 : n_state_; - - real_type * state_data = state_.data(); -#ifdef _OPENMP -#pragma omp parallel for schedule(static) num_threads(n_threads_) collapse(2) -#endif - for (auto i : index_group) { - for (size_t j = 0; j < n_particles_; ++j) { - const auto offset_read = - i * offset_read_group + j * offset_read_particle; - const auto offset_write = (n_particles_ * i + j) * n_state_; - std::copy_n(iter + offset_read, - n_state_, - state_data + offset_write); - } - } + dust2::internals::set_state(state_, iter, + n_state_, n_particles_, n_groups_, + index_state, index_particle, index_group, + recycle_particle, recycle_group, + n_threads_); } template diff --git a/inst/include/dust2/internals.hpp b/inst/include/dust2/internals.hpp new file mode 100644 index 00000000..4e028ded --- /dev/null +++ b/inst/include/dust2/internals.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +namespace dust2 { +namespace internals { + +template +void set_state(std::vector& state, + Iter iter, + const size_t n_state, + const size_t n_particles, + const size_t n_groups, + const std::vector& index_state, + const std::vector& index_particle, + const std::vector& index_group, + const bool recycle_particle, + const bool recycle_group, + const size_t n_threads) { + const bool use_index_state = !index_state.empty(); + const bool use_index_particle = !index_particle.empty(); + const bool use_index_group = !index_group.empty(); + + bool do_simple_copy = + !use_index_state && !use_index_particle && !use_index_group && + !recycle_particle && !recycle_group; + if (do_simple_copy) { + std::copy_n(iter, state.size(), state.begin()); + } else { + const auto n_state_in = + use_index_state ? index_state.size() : n_state; + const auto n_particles_in = + use_index_particle ? index_particle.size() : n_particles; + const auto n_groups_in = + use_index_group ? index_group.size() : n_groups; + const auto offset_src_particle = recycle_particle ? 0 : n_state_in; + const auto offset_src_group = recycle_group ? 0 : + (n_state_in * (recycle_particle ? 1 : n_particles_in)); +#ifdef _OPENMP +#pragma omp parallel for schedule(static) num_threads(n_threads) collapse(2) +#endif + for (size_t i = 0; i < n_groups_in; ++i) { + for (size_t j = 0; j < n_particles_in; ++j) { + const auto offset_src = + i * offset_src_group + j * offset_src_particle; + const auto i_dst = use_index_group ? index_group[i] : i; + const auto j_dst = use_index_particle ? index_particle[j] : j; + const auto offset_dst = (n_particles * i_dst + j_dst) * n_state; + const auto iter_src = iter + offset_src; + auto iter_dst = state.begin() + offset_dst; + if (use_index_state) { + for (size_t k = 0; k < n_state_in; ++k) { + *(iter_dst + index_state[k]) = *(iter_src + k); + } + } else { + std::copy_n(iter_src, n_state, iter_dst); + } + } + } + } +} + +} +} diff --git a/inst/include/dust2/r/filter.hpp b/inst/include/dust2/r/filter.hpp index 5f69a912..309138df 100644 --- a/inst/include/dust2/r/filter.hpp +++ b/inst/include/dust2/r/filter.hpp @@ -34,7 +34,7 @@ cpp11::sexp dust2_filter_run(cpp11::sexp ptr, cpp11::sexp r_initial, check_index(r_index_group, obj->sys.n_groups(), "index_group"); if (r_initial != R_NilValue) { - set_state(obj->sys, r_initial, preserve_group_dimension, index_group); + set_state(obj->sys, cpp11::as_cpp(r_initial)); } obj->run(r_initial == R_NilValue, save_history, index_state, index_group); diff --git a/inst/include/dust2/r/helpers.hpp b/inst/include/dust2/r/helpers.hpp index 04dbdcc7..ed312359 100644 --- a/inst/include/dust2/r/helpers.hpp +++ b/inst/include/dust2/r/helpers.hpp @@ -397,43 +397,21 @@ std::vector check_data(cpp11::list r_data, } template -void set_state(T& obj, cpp11::sexp r_state, bool preserve_group_dimension, - const std::vector& index_group) { - // Suppose that we have a n_state x n_particles x n_groups grouped - // system, we then require that we have a state array with rank 3; - // for an ungrouped system this will be rank 2 array. - // - // TODO: these checks would be nicer in R, just do it there and here - // we can just accept what we are given? (mrc-5565) - auto dim = cpp11::as_cpp(r_state.attr("dim")); - const auto rank = dim.size(); - const auto rank_expected = preserve_group_dimension ? 3 : 2; - if (rank != rank_expected) { - cpp11::stop("Expected 'state' to be a %dd array", rank_expected); - } - const int n_state = obj.n_state(); - const int n_particles = - preserve_group_dimension ? obj.n_particles() : - obj.n_particles() * obj.n_groups(); - const int n_groups = index_group.size(); - if (dim[0] != n_state) { - cpp11::stop("Expected the first dimension of 'state' to have size %d", - n_state); - } - const auto recycle_particle = n_particles > 1 && dim[1] == 1; - if (dim[1] != n_particles && dim[1] != 1) { - cpp11::stop("Expected the second dimension of 'state' to have size %d or 1", - n_particles); - } - - const auto recycle_group = - !preserve_group_dimension || (n_groups > 1 && dim[2] == 1); - if (preserve_group_dimension && dim[2] != n_groups && dim[2] != 1) { - cpp11::stop("Expected the third dimension of 'state' to have size %d or 1", - n_groups); - } - obj.set_state(REAL(r_state), recycle_particle, recycle_group, - index_group); +void set_state(T& obj, cpp11::list r_state) { + cpp11::doubles r_value = cpp11::as_doubles(r_state[0]); + cpp11::sexp r_index_state = r_state[1]; + cpp11::sexp r_index_particle = r_state[2]; + cpp11::sexp r_index_group = r_state[3]; + bool recycle_particle = cpp11::as_cpp(r_state[4]); + bool recycle_group = cpp11::as_cpp(r_state[5]); + const auto index_state = + check_index(r_index_state, obj.n_state(), "index_state"); + const auto index_particle = + check_index(r_index_particle, obj.n_particles(), "index_particle"); + const auto index_group = + check_index(r_index_group, obj.n_groups(), "index_group"); + obj.set_state(REAL(r_value), index_state, index_particle, index_group, + recycle_particle, recycle_group); } template diff --git a/inst/include/dust2/r/system.hpp b/inst/include/dust2/r/system.hpp index a7df5a58..aeba5cc2 100644 --- a/inst/include/dust2/r/system.hpp +++ b/inst/include/dust2/r/system.hpp @@ -130,10 +130,10 @@ SEXP dust2_system_set_state_initial(cpp11::sexp ptr) { } template -SEXP dust2_system_set_state(cpp11::sexp ptr, cpp11::sexp r_state, - bool preserve_group_dimension) { +SEXP dust2_system_set_state(cpp11::sexp ptr, + cpp11::list r_state) { auto *obj = cpp11::as_cpp>(ptr).get(); - set_state(*obj, r_state, preserve_group_dimension, obj->all_groups()); + set_state(*obj, r_state); return R_NilValue; } diff --git a/inst/include/dust2/r/unfilter.hpp b/inst/include/dust2/r/unfilter.hpp index 1b692a7b..60dad25b 100644 --- a/inst/include/dust2/r/unfilter.hpp +++ b/inst/include/dust2/r/unfilter.hpp @@ -34,7 +34,7 @@ cpp11::sexp dust2_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_initial, check_index(r_index_group, obj->sys.n_groups(), "index_group"); if (r_initial != R_NilValue) { - set_state(obj->sys, r_initial, preserve_group_dimension, index_group); + set_state(obj->sys, cpp11::as_cpp(r_initial)); } if (adjoint) { obj->run_adjoint(r_initial == R_NilValue, save_history, index_state, diff --git a/inst/template/system.cpp b/inst/template/system.cpp index c4bc0b2b..67d560be 100644 --- a/inst/template/system.cpp +++ b/inst/template/system.cpp @@ -19,8 +19,8 @@ SEXP dust2_system_{{name}}_set_state_initial(cpp11::sexp ptr) { } [[cpp11::register]] -SEXP dust2_system_{{name}}_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension) { - return dust2::r::dust2_system_set_state>(ptr, r_state, preserve_group_dimension); +SEXP dust2_system_{{name}}_set_state(cpp11::sexp ptr, cpp11::list r_state) { + return dust2::r::dust2_system_set_state>(ptr, r_state); } [[cpp11::register]] diff --git a/man/dust_system_set_state.Rd b/man/dust_system_set_state.Rd index f2579fee..3d27e901 100644 --- a/man/dust_system_set_state.Rd +++ b/man/dust_system_set_state.Rd @@ -4,14 +4,37 @@ \alias{dust_system_set_state} \title{Set system state} \usage{ -dust_system_set_state(sys, state) +dust_system_set_state( + sys, + state, + index_state = NULL, + index_particle = NULL, + index_group = NULL +) } \arguments{ \item{sys}{A \code{dust_system} object} \item{state}{A matrix or array of state. If ungrouped, the dimension order expected is state x particle. If grouped the -order is state x particle x group.} +order is state x particle x group. If you have a grouped system +with 1 particle and \code{preserve_state_dimension = FALSE} then the +state has size state x group. You can omit higher dimensions, +so if you pass a vector it will be treated as if all higher +dimensions are length 1 (or if you have a grouped system you +can provide a matrix and treat it as if the third dimension had +length 1). If you provide any \code{index_} argument then the length +of the corresponding state dimension must match the index +length.} + +\item{index_state}{An index to control which state variables we +set. You can use this to set a subset of state variables.} + +\item{index_particle}{An index to control which particles have +their state updated} + +\item{index_group}{An index to control which groups have their +state updated.} } \value{ Nothing, called for side effects only @@ -19,5 +42,52 @@ Nothing, called for side effects only \description{ Set system state. Takes a multidimensional array (2- or 3d depending on if the system is grouped or not). Dimensions of -length 1 will be recycled as appropriate. +length 1 will be recycled as appropriate. For continuous time +systems, we will initialise the solver immediately after setting +state, which may cause errors if your initial state is invalid for +your system. There are many ways that you can use this function +to set different fractions of state (a subset of states, particles +or parameter groups, recycling over any dimensions that are +missing). Please see the Examples section for usage. +} +\examples{ +# Consider a system with 3 particles and 1 group: +sir <- dust_example("sir") +sys <- dust_system_create(sir(), list(), n_particles = 3) +# The state for this system is packed as S, I, R, cases_cumul, cases_inc: +dust_unpack_index(sys) + +# Set all particles to the same state: +dust_system_set_state(sys, c(1000, 10, 0, 0, 0)) +dust_system_state(sys) + +# We can set everything to different states by passing a vector +# with this shape: +m <- cbind(c(1000, 10, 0, 0, 0), c(999, 11, 0, 0, 0), c(998, 12, 0, 0, 0)) +dust_system_set_state(sys, m) +dust_system_state(sys) + +# Or set the state for just one state: +dust_system_set_state(sys, 1, index_state = 4) +dust_system_state(sys) + +# If you want to set a different state across particles, you must +# provide a *matrix* (a vector always sets the same state into +# every particle) +dust_system_set_state(sys, rbind(c(1, 2, 3)), index_state = 4) +dust_system_state(sys) + +# This will not work as it can be ambiguous what you are +# trying to do: +#> dust_system_set_state(sys, c(1, 2, 3), index_state = 4) + +# State can be set for specific particles: +dust_system_set_state(sys, c(900, 100, 0, 0, 0), index_particle = 2) +dust_system_state(sys) + +# And you can combine 'index_particle' with 'index_state' to set +# small rectangles of state: +dust_system_set_state(sys, matrix(c(1, 2, 3, 4), 2, 2), + index_particle = 2:3, index_state = 4:5) +dust_system_state(sys) } diff --git a/scripts/update_example b/scripts/update_example index 407ce902..60c958dd 100755 --- a/scripts/update_example +++ b/scripts/update_example @@ -4,6 +4,7 @@ examples <- c("logistic.cpp", "malaria.cpp", "sir.cpp", "sirode.cpp", setwd(here::here()) unlink(file.path("src", basename(examples))) +unlink("R/dust.R") unlink("inst/dust", recursive = TRUE) dir.create("inst/dust", FALSE) diff --git a/src/cpp11.cpp b/src/cpp11.cpp index e5915361..9f7eb080 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -48,10 +48,10 @@ extern "C" SEXP _dust2_dust2_system_logistic_set_state_initial(SEXP ptr) { END_CPP11 } // logistic.cpp -SEXP dust2_system_logistic_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension); -extern "C" SEXP _dust2_dust2_system_logistic_set_state(SEXP ptr, SEXP r_state, SEXP preserve_group_dimension) { +SEXP dust2_system_logistic_set_state(cpp11::sexp ptr, cpp11::list r_state); +extern "C" SEXP _dust2_dust2_system_logistic_set_state(SEXP ptr, SEXP r_state) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_logistic_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state), cpp11::as_cpp>(preserve_group_dimension))); + return cpp11::as_sexp(dust2_system_logistic_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state))); END_CPP11 } // logistic.cpp @@ -139,10 +139,10 @@ extern "C" SEXP _dust2_dust2_system_malaria_set_state_initial(SEXP ptr) { END_CPP11 } // malaria.cpp -SEXP dust2_system_malaria_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension); -extern "C" SEXP _dust2_dust2_system_malaria_set_state(SEXP ptr, SEXP r_state, SEXP preserve_group_dimension) { +SEXP dust2_system_malaria_set_state(cpp11::sexp ptr, cpp11::list r_state); +extern "C" SEXP _dust2_dust2_system_malaria_set_state(SEXP ptr, SEXP r_state) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_malaria_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state), cpp11::as_cpp>(preserve_group_dimension))); + return cpp11::as_sexp(dust2_system_malaria_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state))); END_CPP11 } // malaria.cpp @@ -321,10 +321,10 @@ extern "C" SEXP _dust2_dust2_system_sir_set_state_initial(SEXP ptr) { END_CPP11 } // sir.cpp -SEXP dust2_system_sir_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension); -extern "C" SEXP _dust2_dust2_system_sir_set_state(SEXP ptr, SEXP r_state, SEXP preserve_group_dimension) { +SEXP dust2_system_sir_set_state(cpp11::sexp ptr, cpp11::list r_state); +extern "C" SEXP _dust2_dust2_system_sir_set_state(SEXP ptr, SEXP r_state) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_sir_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state), cpp11::as_cpp>(preserve_group_dimension))); + return cpp11::as_sexp(dust2_system_sir_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state))); END_CPP11 } // sir.cpp @@ -510,10 +510,10 @@ extern "C" SEXP _dust2_dust2_system_sirode_set_state_initial(SEXP ptr) { END_CPP11 } // sirode.cpp -SEXP dust2_system_sirode_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension); -extern "C" SEXP _dust2_dust2_system_sirode_set_state(SEXP ptr, SEXP r_state, SEXP preserve_group_dimension) { +SEXP dust2_system_sirode_set_state(cpp11::sexp ptr, cpp11::list r_state); +extern "C" SEXP _dust2_dust2_system_sirode_set_state(SEXP ptr, SEXP r_state) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_sirode_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state), cpp11::as_cpp>(preserve_group_dimension))); + return cpp11::as_sexp(dust2_system_sirode_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state))); END_CPP11 } // sirode.cpp @@ -783,10 +783,10 @@ extern "C" SEXP _dust2_dust2_system_walk_set_state_initial(SEXP ptr) { END_CPP11 } // walk.cpp -SEXP dust2_system_walk_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension); -extern "C" SEXP _dust2_dust2_system_walk_set_state(SEXP ptr, SEXP r_state, SEXP preserve_group_dimension) { +SEXP dust2_system_walk_set_state(cpp11::sexp ptr, cpp11::list r_state); +extern "C" SEXP _dust2_dust2_system_walk_set_state(SEXP ptr, SEXP r_state) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_walk_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state), cpp11::as_cpp>(preserve_group_dimension))); + return cpp11::as_sexp(dust2_system_walk_set_state(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_state))); END_CPP11 } // walk.cpp @@ -862,7 +862,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_logistic_rng_state", (DL_FUNC) &_dust2_dust2_system_logistic_rng_state, 1}, {"_dust2_dust2_system_logistic_run_to_time", (DL_FUNC) &_dust2_dust2_system_logistic_run_to_time, 2}, {"_dust2_dust2_system_logistic_set_rng_state", (DL_FUNC) &_dust2_dust2_system_logistic_set_rng_state, 2}, - {"_dust2_dust2_system_logistic_set_state", (DL_FUNC) &_dust2_dust2_system_logistic_set_state, 3}, + {"_dust2_dust2_system_logistic_set_state", (DL_FUNC) &_dust2_dust2_system_logistic_set_state, 2}, {"_dust2_dust2_system_logistic_set_state_initial", (DL_FUNC) &_dust2_dust2_system_logistic_set_state_initial, 1}, {"_dust2_dust2_system_logistic_set_time", (DL_FUNC) &_dust2_dust2_system_logistic_set_time, 2}, {"_dust2_dust2_system_logistic_simulate", (DL_FUNC) &_dust2_dust2_system_logistic_simulate, 5}, @@ -876,7 +876,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_malaria_rng_state", (DL_FUNC) &_dust2_dust2_system_malaria_rng_state, 1}, {"_dust2_dust2_system_malaria_run_to_time", (DL_FUNC) &_dust2_dust2_system_malaria_run_to_time, 2}, {"_dust2_dust2_system_malaria_set_rng_state", (DL_FUNC) &_dust2_dust2_system_malaria_set_rng_state, 2}, - {"_dust2_dust2_system_malaria_set_state", (DL_FUNC) &_dust2_dust2_system_malaria_set_state, 3}, + {"_dust2_dust2_system_malaria_set_state", (DL_FUNC) &_dust2_dust2_system_malaria_set_state, 2}, {"_dust2_dust2_system_malaria_set_state_initial", (DL_FUNC) &_dust2_dust2_system_malaria_set_state_initial, 1}, {"_dust2_dust2_system_malaria_set_time", (DL_FUNC) &_dust2_dust2_system_malaria_set_time, 2}, {"_dust2_dust2_system_malaria_simulate", (DL_FUNC) &_dust2_dust2_system_malaria_simulate, 5}, @@ -889,7 +889,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_sir_rng_state", (DL_FUNC) &_dust2_dust2_system_sir_rng_state, 1}, {"_dust2_dust2_system_sir_run_to_time", (DL_FUNC) &_dust2_dust2_system_sir_run_to_time, 2}, {"_dust2_dust2_system_sir_set_rng_state", (DL_FUNC) &_dust2_dust2_system_sir_set_rng_state, 2}, - {"_dust2_dust2_system_sir_set_state", (DL_FUNC) &_dust2_dust2_system_sir_set_state, 3}, + {"_dust2_dust2_system_sir_set_state", (DL_FUNC) &_dust2_dust2_system_sir_set_state, 2}, {"_dust2_dust2_system_sir_set_state_initial", (DL_FUNC) &_dust2_dust2_system_sir_set_state_initial, 1}, {"_dust2_dust2_system_sir_set_time", (DL_FUNC) &_dust2_dust2_system_sir_set_time, 2}, {"_dust2_dust2_system_sir_simulate", (DL_FUNC) &_dust2_dust2_system_sir_simulate, 5}, @@ -903,7 +903,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_sirode_rng_state", (DL_FUNC) &_dust2_dust2_system_sirode_rng_state, 1}, {"_dust2_dust2_system_sirode_run_to_time", (DL_FUNC) &_dust2_dust2_system_sirode_run_to_time, 2}, {"_dust2_dust2_system_sirode_set_rng_state", (DL_FUNC) &_dust2_dust2_system_sirode_set_rng_state, 2}, - {"_dust2_dust2_system_sirode_set_state", (DL_FUNC) &_dust2_dust2_system_sirode_set_state, 3}, + {"_dust2_dust2_system_sirode_set_state", (DL_FUNC) &_dust2_dust2_system_sirode_set_state, 2}, {"_dust2_dust2_system_sirode_set_state_initial", (DL_FUNC) &_dust2_dust2_system_sirode_set_state_initial, 1}, {"_dust2_dust2_system_sirode_set_time", (DL_FUNC) &_dust2_dust2_system_sirode_set_time, 2}, {"_dust2_dust2_system_sirode_simulate", (DL_FUNC) &_dust2_dust2_system_sirode_simulate, 5}, @@ -915,7 +915,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_walk_rng_state", (DL_FUNC) &_dust2_dust2_system_walk_rng_state, 1}, {"_dust2_dust2_system_walk_run_to_time", (DL_FUNC) &_dust2_dust2_system_walk_run_to_time, 2}, {"_dust2_dust2_system_walk_set_rng_state", (DL_FUNC) &_dust2_dust2_system_walk_set_rng_state, 2}, - {"_dust2_dust2_system_walk_set_state", (DL_FUNC) &_dust2_dust2_system_walk_set_state, 3}, + {"_dust2_dust2_system_walk_set_state", (DL_FUNC) &_dust2_dust2_system_walk_set_state, 2}, {"_dust2_dust2_system_walk_set_state_initial", (DL_FUNC) &_dust2_dust2_system_walk_set_state_initial, 1}, {"_dust2_dust2_system_walk_set_time", (DL_FUNC) &_dust2_dust2_system_walk_set_time, 2}, {"_dust2_dust2_system_walk_simulate", (DL_FUNC) &_dust2_dust2_system_walk_simulate, 5}, diff --git a/src/logistic.cpp b/src/logistic.cpp index f9adf5bb..657282de 100644 --- a/src/logistic.cpp +++ b/src/logistic.cpp @@ -1,4 +1,4 @@ -// Generated by dust2 (version 0.1.18) - do not edit +// Generated by dust2 (version 0.1.19) - do not edit #include #include @@ -112,8 +112,8 @@ SEXP dust2_system_logistic_set_state_initial(cpp11::sexp ptr) { } [[cpp11::register]] -SEXP dust2_system_logistic_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension) { - return dust2::r::dust2_system_set_state>(ptr, r_state, preserve_group_dimension); +SEXP dust2_system_logistic_set_state(cpp11::sexp ptr, cpp11::list r_state) { + return dust2::r::dust2_system_set_state>(ptr, r_state); } [[cpp11::register]] diff --git a/src/malaria.cpp b/src/malaria.cpp index 852123ae..14d130e6 100644 --- a/src/malaria.cpp +++ b/src/malaria.cpp @@ -1,4 +1,4 @@ -// Generated by dust2 (version 0.1.18) - do not edit +// Generated by dust2 (version 0.1.19) - do not edit #include @@ -173,8 +173,8 @@ SEXP dust2_system_malaria_set_state_initial(cpp11::sexp ptr) { } [[cpp11::register]] -SEXP dust2_system_malaria_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension) { - return dust2::r::dust2_system_set_state>(ptr, r_state, preserve_group_dimension); +SEXP dust2_system_malaria_set_state(cpp11::sexp ptr, cpp11::list r_state) { + return dust2::r::dust2_system_set_state>(ptr, r_state); } [[cpp11::register]] diff --git a/src/sir.cpp b/src/sir.cpp index b3149e1b..9df48157 100644 --- a/src/sir.cpp +++ b/src/sir.cpp @@ -1,4 +1,4 @@ -// Generated by dust2 (version 0.1.18) - do not edit +// Generated by dust2 (version 0.1.19) - do not edit #include @@ -229,8 +229,8 @@ SEXP dust2_system_sir_set_state_initial(cpp11::sexp ptr) { } [[cpp11::register]] -SEXP dust2_system_sir_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension) { - return dust2::r::dust2_system_set_state>(ptr, r_state, preserve_group_dimension); +SEXP dust2_system_sir_set_state(cpp11::sexp ptr, cpp11::list r_state) { + return dust2::r::dust2_system_set_state>(ptr, r_state); } [[cpp11::register]] diff --git a/src/sirode.cpp b/src/sirode.cpp index 4f601674..af70d059 100644 --- a/src/sirode.cpp +++ b/src/sirode.cpp @@ -1,4 +1,4 @@ -// Generated by dust2 (version 0.1.18) - do not edit +// Generated by dust2 (version 0.1.19) - do not edit #include @@ -142,8 +142,8 @@ SEXP dust2_system_sirode_set_state_initial(cpp11::sexp ptr) { } [[cpp11::register]] -SEXP dust2_system_sirode_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension) { - return dust2::r::dust2_system_set_state>(ptr, r_state, preserve_group_dimension); +SEXP dust2_system_sirode_set_state(cpp11::sexp ptr, cpp11::list r_state) { + return dust2::r::dust2_system_set_state>(ptr, r_state); } [[cpp11::register]] diff --git a/src/walk.cpp b/src/walk.cpp index aeb962f3..d037ff4b 100644 --- a/src/walk.cpp +++ b/src/walk.cpp @@ -1,4 +1,4 @@ -// Generated by dust2 (version 0.1.18) - do not edit +// Generated by dust2 (version 0.1.19) - do not edit #include @@ -116,8 +116,8 @@ SEXP dust2_system_walk_set_state_initial(cpp11::sexp ptr) { } [[cpp11::register]] -SEXP dust2_system_walk_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool preserve_group_dimension) { - return dust2::r::dust2_system_set_state>(ptr, r_state, preserve_group_dimension); +SEXP dust2_system_walk_set_state(cpp11::sexp ptr, cpp11::list r_state) { + return dust2::r::dust2_system_set_state>(ptr, r_state); } [[cpp11::register]] diff --git a/tests/testthat/test-interface-state.R b/tests/testthat/test-interface-state.R new file mode 100644 index 00000000..15be93e2 --- /dev/null +++ b/tests/testthat/test-interface-state.R @@ -0,0 +1,275 @@ +## There are many tests for checking that state conforms to a system, +## so they're all here so that we can have a nice set of labelled +## tests. This is quite repetitive (as is the implementation) but the +## aim is to convey some hint to the user about what they have done +## wrong. + +test_that("can provide a vector to set into a vector system", { + expect_equal( + prepare_state(1:3, NULL, NULL, NULL, 3, 1, 1, FALSE, FALSE), + list(state = 1:3, + index_state = NULL, + index_particle = NULL, + index_group = NULL, + recycle_particle = FALSE, + recycle_group = FALSE)) +}) + +test_that("can provide a vector with index to set into vector system", { + expect_equal( + prepare_state(1:3, 3:5, NULL, NULL, 5, 1, 1, FALSE, FALSE), + list(state = 1:3, + index_state = 3:5, + index_particle = NULL, + index_group = NULL, + recycle_particle = FALSE, + recycle_group = FALSE)) +}) + + +test_that("error if incorrect data provided to vector system", { + expect_error( + prepare_state(cbind(1:3), NULL, NULL, NULL, 3, 1, 1, FALSE, FALSE), + "Expected 'state' to be a vector but was given a matrix") + expect_error( + prepare_state(array(1:3, c(3, 1, 1)), NULL, NULL, NULL, 3, 1, 1, + FALSE, FALSE), + "Expected 'state' to be a vector but was given a 3-dimensional array") +}) + + +test_that("validate that the index provided for state is reasonable", { + expect_error( + prepare_state(1:4, 3:6, NULL, NULL, 5, 1, 1, FALSE, FALSE, name = "state"), + "All elements of 'index_state' must be at most 5") +}) + + +## TODO: hint here about if we have provided an index, and what we +## did recieve. +test_that("error if incorrect length data provided to vector system", { + expect_error( + prepare_state(1:3, NULL, NULL, NULL, 5, 1, 1, FALSE, FALSE, name = "state"), + "Expected 'state' to have length 5") + expect_error( + prepare_state(1:3, 3:6, NULL, NULL, 6, 1, 1, FALSE, FALSE, name = "state"), + "Expected 'state' to have length 4") +}) + + +test_that("can provide a vector to set into matrix (s x p) system", { + expect_equal( + prepare_state(1:3, NULL, NULL, NULL, 3, 5, 1, TRUE, FALSE), + list(state = 1:3, + index_state = NULL, + index_particle = NULL, + index_group = NULL, + recycle_particle = TRUE, + recycle_group = FALSE)) +}) + + +test_that("can provide a matrix to set into a matrix (s x p) sytem", { + expect_equal( + prepare_state(cbind(1:3), NULL, NULL, NULL, 3, 5, 1, TRUE, FALSE), + list(state = cbind(1:3), + index_state = NULL, + index_particle = NULL, + index_group = NULL, + recycle_particle = TRUE, + recycle_group = FALSE)) + state <- matrix(1:15, 3) + expect_equal( + prepare_state(matrix(1:15, 3), NULL, NULL, NULL, 3, 5, 1, TRUE, FALSE), + list(state = state, + index_state = NULL, + index_particle = NULL, + index_group = NULL, + recycle_particle = FALSE, + recycle_group = FALSE)) +}) + + +test_that("can provide state with state index into matrix system", { + expect_equal( + prepare_state(1:3, 3:5, NULL, NULL, 7, 5, 1, TRUE, FALSE), + list(state = 1:3, + index_state = 3:5, + index_particle = NULL, + index_group = NULL, + recycle_particle = TRUE, + recycle_group = FALSE)) + state <- matrix(1:15, 3:5) + expect_equal( + prepare_state(state, 3:5, NULL, NULL, 7, 5, 1, TRUE, FALSE), + list(state = state, + index_state = 3:5, + index_particle = NULL, + index_group = NULL, + recycle_particle = FALSE, + recycle_group = FALSE)) +}) + + +test_that("can provide state with particle index into matrix system", { + expect_equal( + prepare_state(1:3, NULL, 2:3, NULL, 3, 5, 1, TRUE, FALSE), + list(state = 1:3, + index_state = NULL, + index_particle = 2:3, + index_group = NULL, + recycle_particle = TRUE, + recycle_group = FALSE)) + state <- matrix(1:6, 3) + expect_equal( + prepare_state(state, NULL, 2:3, NULL, 3, 5, 1, TRUE, FALSE), + list(state = state, + index_state = NULL, + index_particle = 2:3, + index_group = NULL, + recycle_particle = FALSE, + recycle_group = FALSE)) +}) + + +test_that("validate that the index provided for particle is reasonable", { + expect_error( + prepare_state(1:4, NULL, 5:6, NULL, 4, 2, 1, TRUE, FALSE), + "All elements of 'index_particle' must be at most 2") + expect_error( + prepare_state(1:4, NULL, c(-1, 0), NULL, 4, 2, 1, TRUE, FALSE), + "All elements of 'index_particle' must be at least 1") + expect_error( + prepare_state(1:4, NULL, c(1, 1, 2), NULL, 4, 2, 1, TRUE, FALSE), + "All elements of 'index_particle' must be distinct") +}) + + +test_that("can set state from vector", { + sys <- dust_system_create(sir(), list(), n_particles = 3) + dust_system_set_state(sys, as.numeric(1:5)) + expect_equal(dust_system_state(sys), matrix(1:5, 5, 3)) +}) + + +test_that("can set state from matrix", { + sys <- dust_system_create(sir(), list(), n_particles = 3) + m <- matrix(runif(15), 5, 3) + dust_system_set_state(sys, m) + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set a fraction of states", { + sys <- dust_system_create(sir(), list(), n_particles = 3) + m <- matrix(runif(15), 5, 3) + dust_system_set_state(sys, m) + m2 <- matrix(seq(16, length.out = 6), 2, 3) + dust_system_set_state(sys, m2, index_state = c(2, 4)) + m[c(2, 4), ] <- m2 + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set a fraction of states from a vector", { + sys <- dust_system_create(sir(), list(), n_particles = 3) + m <- matrix(runif(15), 5, 3) + dust_system_set_state(sys, m) + m2 <- c(16, 17) + dust_system_set_state(sys, m2, index_state = c(2, 4)) + m[c(2, 4), ] <- m2 + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set a fraction of states from a scalar", { + sys <- dust_system_create(sir(), list(), n_particles = 3) + m <- matrix(runif(15), 5, 3) + dust_system_set_state(sys, m) + m2 <- 16 + dust_system_set_state(sys, m2, index_state = 2) + m[2, ] <- m2 + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set a fraction of particles", { + sys <- dust_system_create(sir(), list(), n_particles = 6) + m <- matrix(as.numeric(1:30), 5, 6) + dust_system_set_state(sys, m) + m2 <- matrix(seq(31, length.out = 15), 5, 3) + dust_system_set_state(sys, m2, index_particle = c(2, 4, 6)) + m[, c(2, 4, 6)] <- m2 + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set a fraction of particles from a vector", { + sys <- dust_system_create(sir(), list(), n_particles = 6) + m <- matrix(as.numeric(1:30), 5, 6) + dust_system_set_state(sys, m) + m2 <- seq(31, length.out = 5) + dust_system_set_state(sys, m2, index_particle = c(2, 4, 6)) + m[, c(2, 4, 6)] <- m2 + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set a fraction of both particles and states", { + sys <- dust_system_create(sir(), list(), n_particles = 6) + m <- matrix(as.numeric(1:30), 5, 6) + dust_system_set_state(sys, m) + m2 <- matrix(seq(31, length.out = 6), 2, 3) + dust_system_set_state(sys, m2, + index_state = c(2, 4), + index_particle = c(2, 4, 6)) + m[c(2, 4), c(2, 4, 6)] <- m2 + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set vector into grouped sytem", { + sys <- dust_system_create(walk(), rep(list(list(len = 5, sd = 1)), 3), + n_particles = 4, + n_groups = 3) + dust_system_set_state(sys, 1:5) + expect_equal(dust_system_state(sys), array(1:5, c(5, 4, 3))) +}) + + +test_that("can set matrix into grouped sytem", { + sys <- dust_system_create(walk(), rep(list(list(len = 5, sd = 1)), 3), + n_particles = 4, + n_groups = 3) + m <- matrix(1:20, 5, 4) + dust_system_set_state(sys, m) + expect_equal(dust_system_state(sys), array(m, c(5, 4, 3))) +}) + + +test_that("can set array into grouped sytem", { + sys <- dust_system_create(walk(), rep(list(list(len = 5, sd = 1)), 3), + n_particles = 4, + n_groups = 3) + m <- array(1:60, c(5, 4, 3)) + dust_system_set_state(sys, m) + expect_equal(dust_system_state(sys), m) +}) + + +test_that("can set subset into grouped system", { + sys <- dust_system_create(walk(), rep(list(list(len = 5, sd = 1)), 3), + n_particles = 4, + n_groups = 3) + m1 <- array(1:60, c(5, 4, 3)) + dust_system_set_state(sys, m1) + + i <- c(2, 4) + j <- c(1, 3) + k <- c(2, 3) + m2 <- array(100 + seq_len(8), c(2, 2, 2)) + dust_system_set_state(sys, m2, index_state = i, index_particle = j, + index_group = k) + m1[i, j, k] <- m2 + expect_equal(dust_system_state(sys), m1) +}) diff --git a/tests/testthat/test-sir.R b/tests/testthat/test-sir.R index 32353e96..654de6e3 100644 --- a/tests/testthat/test-sir.R +++ b/tests/testthat/test-sir.R @@ -296,3 +296,13 @@ test_that("can unpack state", { expect_equal(obj$packer_state$unpack(m), list(S = 990, I = 10, R = 0, cases_cumul = 0, cases_inc = 0)) }) + + +test_that("can set state from a vector", { + set.seed(1) + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + sys <- dust_system_create(sir(), pars, n_particles = 10) + s <- c(1000, 10, 0, 0, 0) + dust_system_set_state(sys, s) + expect_equal(dust_system_state(sys), matrix(s, 5, 10)) +}) diff --git a/tests/testthat/test-unfilter.R b/tests/testthat/test-unfilter.R index 0253e80b..eb7428a0 100644 --- a/tests/testthat/test-unfilter.R +++ b/tests/testthat/test-unfilter.R @@ -95,7 +95,7 @@ test_that("can get partial unfilter history", { test_that("can run an unfilter with manually set state", { pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) - state <- matrix(c(1000 - 17, 17, 0, 0, 0), ncol = 1) + state <- c(1000 - 17, 17, 0, 0, 0) time_start <- 0 data <- data.frame(time = c(4, 8, 12, 16), incidence = 1:4) diff --git a/tests/testthat/test-walk.R b/tests/testthat/test-walk.R index 8a8be44b..b98b17af 100644 --- a/tests/testthat/test-walk.R +++ b/tests/testthat/test-walk.R @@ -309,13 +309,13 @@ test_that("can set state where n_state > 1", { ## Appropriate errors: expect_error( dust_system_set_state(obj, matrix(0, 2, 10)), - "Expected the first dimension of 'state' to have size 3") + "Expected 'state' to have 3 rows") expect_error( dust_system_set_state(obj, matrix(0, 3, 15)), - "Expected the second dimension of 'state' to have size 10 or 1") + "Expected 'state' to have 1 or 10 columns") expect_error( dust_system_set_state(obj, array(0, c(3, 10, 1))), - "Expected 'state' to be a 2d array") + "Expected 'state' to be a matrix") }) @@ -351,16 +351,16 @@ test_that("can set state where n_state > 1 and groups are present", { ## Appropriate errors: expect_error( dust_system_set_state(obj, array(0, c(2, 10, 4))), - "Expected the first dimension of 'state' to have size 3") + "Expected dimension 1 of 'state' to be length 3") expect_error( dust_system_set_state(obj, array(0, c(3, 15, 4))), - "Expected the second dimension of 'state' to have size 10 or 1") + "Expected dimension 2 of 'state' to be length 1 or 10") expect_error( dust_system_set_state(obj, array(0, c(3, 10, 2))), - "Expected the third dimension of 'state' to have size 4 or 1") + "Expected dimension 3 of 'state' to be length 1 or 4") expect_error( - dust_system_set_state(obj, array(0, c(3, 10))), - "Expected 'state' to be a 3d array") + dust_system_set_state(obj, array(0, c(3, 10, 2, 1))), + "Expected 'state' to be a 3-dimensional array") })