Skip to content

Commit

Permalink
Merge pull request #16 from mrc-ide/mrc-5389
Browse files Browse the repository at this point in the history
Proof-of-concept interface
  • Loading branch information
richfitz authored May 22, 2024
2 parents 49624b8 + cb8fe62 commit a214e63
Show file tree
Hide file tree
Showing 15 changed files with 595 additions and 199 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
\.*gcov$
^.*\.Rproj$
^\.Rproj\.user$
^_pkgdown\.yml$
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Language: en-GB
Config/testthat/edition: 3
URL: https://github.com/mrc-ide/dust2
URL: https://github.com/mrc-ide/dust2, https://mrc-ide.github.io/dust2
BugReports: https://github.com/mrc-ide/dust2/issues
Imports:
mcstate2
cli,
mcstate2,
rlang
LinkingTo:
cpp11,
mcstate2
Expand Down
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Generated by roxygen2: do not edit by hand

S3method(dim,dust_model)
S3method(print,dust_model)
S3method(print,dust_model_generator)
export(dust_model_create)
export(dust_model_set_state)
export(dust_model_set_state_initial)
export(dust_model_state)
useDynLib(dust2, .registration = TRUE)
8 changes: 8 additions & 0 deletions R/dust.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
sir <- function() {
dust_model("sir")
}


walk <- function() {
dust_model("walk")
}
202 changes: 202 additions & 0 deletions R/interface.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
dust_model <- function(name, env = parent.env(parent.frame())) {
prefix <- sprintf("dust2_cpu_%s", name)
## I don't love that this requires running through sprintf() each
## time we create a model, but using a function for the model (see
## sir()), rather than an object, means that it's easier to think
## about the dependencies among packages. This is also essentially
## how DBI works.

methods_nms <- c("alloc",
"state", "set_state", "set_state_initial")

methods <- lapply(sprintf("dust2_cpu_%s_%s", name, methods_nms),
function(x) env[[x]])
names(methods) <- methods_nms
ret <- list(name = name,
methods = methods)
## TODO: check that alloc exists, then go through and add
## properties.
class(ret) <- "dust_model_generator"
ret
}


##' Create a dust object from a model generator. This allocates a
##' model and sets an initial set of parameters. Once created you can
##' use other dust functions to interact with it.
##'
##' @title Create a dust object
##'
##' @param generator A model generator object, with class
##' `dust_model_generator`
##'
##' @param pars A list of parameters. The format of this will depend
##' on the model. If `n_groups` is 1 or more, then this must be a
##' list of length `n_groups` where each element
##' is a list of parameters for your model.
##'
##' @param n_particles The number of particles to create.
##'
##' @param n_groups Optionally, the number of parameter groups
##'
##' @param time The initial time, defaults to 0
##'
##' @param dt The time step for the model, defaults to 1
##'
##' @param seed Optionally, a seed. Otherwise we respond to R's RNG
##' seed on initialisation.
##'
##' @param deterministic Logical, indicating if the model should be
##' allocated in deterministic mode.
##'
##' @return A `dust_model` object, with opaque format.
##'
##' @export
dust_model_create <- function(generator, pars, n_particles, n_groups = 0,
time = 0, dt = 1,
seed = NULL, deterministic = FALSE) {
check_is_dust_model_generator(generator, substitute(generator))
res <- generator$methods$alloc(pars, time, dt, n_particles, n_groups,
seed, deterministic)
## Here, we augment things slightly
res$name <- generator$name
res$n_particles <- as.integer(n_particles)
res$n_groups <- as.integer(max(n_groups), 1)
res$deterministic <- deterministic
res$methods <- generator$methods
res$properties <- generator$properties
class(res) <- "dust_model"
res
}


##' Extract model state
##'
##' @title Extract model state
##'
##' @param model A `dust_model` object
##'
##' @return An array of model state. If your model is ungrouped, then
##' this has two dimensions (state, particle). If grouped, this has
##' three dimensions (state, particle, group)
##'
##' @seealso [dust_model_set_state()] for setting state and
##' [dust_model_set_state_initial()] for setting state to the
##' model-specific initial conditions.
##'
##' @export
dust_model_state <- function(model) {
check_is_dust_model(model)
model$methods$state(model$ptr, model$grouped)
}


##' Set model state. Takes a multidimensional array (2- or 3d
##' depending on if the model is grouped or not). Dimensions of
##' length 1 will be recycled as appropriate.
##'
##' @title Set model state
##'
##' @inheritParams dust_model_state
##'
##' @param state A matrix or array of state. If ungrouped, the
##' dimension order expected is state x particle. If grouped the
##' order is state x particle x group.
##'
##' @return Nothing, called for side effects only
##' @export
dust_model_set_state <- function(model, state) {
check_is_dust_model(model)
model$methods$set_state(model$ptr, state, model$grouped)
invisible()
}


##' Set model state from a model's initial conditions. This may depend
##' on the current time.
##'
##' @title Set model state to initial conditions
##'
##' @inheritParams dust_model_state
##'
##' @return Nothing, called for side effects only
##' @export
dust_model_set_state_initial <- function(model) {
check_is_dust_model(model)
model$methods$set_state_initial(model$ptr)
invisible()
}


##' @export
print.dust_model_generator <- function(x, ...) {
cli::cli_h1("<dust_model_generator: {x$name}>")
## Later, we might print some additional capabilities of the model
## here, such as if it can be used with a filter, a summary of its
## parameters (once we know how to access that), etc.
cli::cli_alert_info(
"Use 'dust2::dust_model_create()' to create a model with this generator")
invisible(x)
}


##' @export
print.dust_model <- function(x, ...) {
cli::cli_h1("<dust_model: {x$name}>")
if (x$grouped) {
cli::cli_alert_info(paste(
"{x$n_state} state x {x$n_particles} particle{?s} x",
"{x$n_groups} group{?s}"))
} else {
cli::cli_alert_info("{x$n_state} state x {x$n_particles} particle{?s}")
}
if (x$deterministic) {
cli::cli_bullets(c(
i = "This model is deterministic"))
}
## Later, we might print some additional capabilities of the model
## here, such as if it can be used with a filter, a summary of its
## parameters (once we know how to access that), etc.
invisible(x)
}


##' @export
dim.dust_model <- function(x, ...) {
c(x$n_state, x$n_particles, if (x$grouped) x$n_groups)
}


check_is_dust_model_generator <- function(generator, called_as,
call = parent.frame()) {
if (!inherits(generator, "dust_model_generator")) {
hint <- NULL
if (is_uncalled_generator(generator) && is.symbol(called_as)) {
hint <- c(
i = "Did you mean '{deparse(called_as)}()' (i.e., with parentheses)")
}
cli::cli_abort(
c("Expected 'generator' to be a 'dust_model_generator' object",
hint),
arg = "generator")
}
}


check_is_dust_model <- function(model, call = parent.frame()) {
if (!inherits(model, "dust_model")) {
cli::cli_abort("Expected 'model' to be a 'dust_model' object",
arg = "model", call = call)
}
}


is_uncalled_generator <- function(model) {
if (!is.function(model)) {
return(FALSE)
}
code <- body(model)
rlang::is_call(code, "{") &&
length(code) == 2 &&
rlang::is_call(code[[2]], "dust_model")
}
12 changes: 12 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
url: https://mrc-ide.github.io/dust2

template:
bootstrap: 5

reference:
- title: Models
contents:
- dust_model_create
- dust_model_state
- dust_model_set_state
- dust_model_set_state_initial
7 changes: 6 additions & 1 deletion inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars,
cpp11::sexp r_n_state = cpp11::as_sexp(obj->n_state());
cpp11::sexp r_grouped = cpp11::as_sexp(grouped);

return cpp11::writable::list{ptr, r_n_state, r_grouped, r_group_names};
using namespace cpp11::literals;
return cpp11::writable::list{"ptr"_nm = ptr,
"n_state"_nm = r_n_state,
"grouped"_nm = r_grouped,
"group_names"_nm = r_group_names
};
}

template <typename T>
Expand Down
48 changes: 48 additions & 0 deletions man/dust_model_create.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/dust_model_set_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/dust_model_set_state_initial.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/dust_model_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a214e63

Please sign in to comment.