Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save history when running unfilter and filter #22

Merged
merged 11 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(dim,dust_model)
S3method(print,dust_model)
S3method(print,dust_model_generator)
export(dust_filter_create)
export(dust_filter_last_history)
export(dust_filter_rng_state)
export(dust_filter_run)
export(dust_model_compare_data)
Expand All @@ -20,5 +21,6 @@ export(dust_model_state)
export(dust_model_time)
export(dust_model_update_pars)
export(dust_unfilter_create)
export(dust_unfilter_last_history)
export(dust_unfilter_run)
useDynLib(dust2, .registration = TRUE)
24 changes: 16 additions & 8 deletions R/cpp11.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 59 additions & 10 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ dust_model <- function(name, env = parent.env(parent.frame())) {
has_compare = !is.null(methods$compare_data))

if (properties$has_compare) {
methods_unfilter <- c("alloc", "run")
methods_unfilter <- c("alloc", "run", "last_history")
methods$unfilter <-
get_methods(methods_unfilter, sprintf("%s_unfilter", name))
methods_filter <- c("alloc", "run", "rng_state")
methods_filter <- c("alloc", "run", "last_history", "rng_state")
methods$filter <-
get_methods(methods_filter, sprintf("%s_filter", name))
}
Expand Down Expand Up @@ -376,7 +376,7 @@ dust_model_compare_data <- function(model, data) {
##' @export
dust_unfilter_create <- function(generator, pars, time_start, time, data,
n_particles = 1, n_groups = 0,
dt = 1) {
dt = 1, index = NULL) {
check_is_dust_model_generator(generator)
if (!generator$properties$has_compare) {
## This moves into something general soon?
Expand All @@ -386,12 +386,13 @@ dust_unfilter_create <- function(generator, pars, time_start, time, data,
arg = "generator")
}
res <- generator$methods$unfilter$alloc(pars, time_start, time, dt, data,
n_particles, n_groups)
n_particles, n_groups, index)
res$name <- generator$name
res$n_particles <- as.integer(n_particles)
res$n_groups <- as.integer(max(n_groups), 1)
res$deterministic <- TRUE
res$methods <- generator$methods$unfilter
res$index <- index
class(res) <- "dust_unfilter"
res
}
Expand All @@ -401,6 +402,8 @@ dust_unfilter_create <- function(generator, pars, time_start, time, data,
##'
##' @title Run unfilter
##'
##' @inheritParams dust_filter_run
##'
##' @param unfilter A `dust_unfilter` object, created by
##' [dust_unfilter_create]
##'
Expand All @@ -410,9 +413,28 @@ dust_unfilter_create <- function(generator, pars, time_start, time, data,
##' there are groups.
##'
##' @export
dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL) {
dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL,
save_history = FALSE) {
check_is_dust_unfilter(unfilter)
unfilter$methods$run(unfilter$ptr, pars, initial, save_history,
unfilter$grouped)
}


##' Fetch the last history created by running an unfilter. This
##' errors if the last call to [dust_unfilter_run] did not use
##' `save_history = TRUE`.
##'
##' @title Fetch last unfilter history
##'
##' @inheritParams dust_unfilter_run
##'
##' @return An array
##'
##' @export
dust_unfilter_last_history <- function(unfilter) {
check_is_dust_unfilter(unfilter)
unfilter$methods$run(unfilter$ptr, pars, initial, unfilter$grouped)
unfilter$methods$last_history(unfilter$ptr, unfilter$grouped)
}


Expand Down Expand Up @@ -451,14 +473,15 @@ dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL) {
##' slowly.
##'
##' @inheritParams dust_model_create
##' @inheritParams dust_model_simulate
##'
##' @return A `dust_unfilter` object, which can be used with
##' [dust_unfilter_run]
##'
##' @export
dust_filter_create <- function(generator, pars, time_start, time, data,
n_particles, n_groups = 0, dt = 1,
seed = NULL) {
index = NULL, seed = NULL) {
check_is_dust_model_generator(generator)
if (!generator$properties$has_compare) {
## This moves into something general soon?
Expand All @@ -468,7 +491,7 @@ dust_filter_create <- function(generator, pars, time_start, time, data,
arg = "generator")
}
res <- generator$methods$filter$alloc(pars, time_start, time, dt, data,
n_particles, n_groups, seed)
n_particles, n_groups, index, seed)
res$name <- generator$name
res$n_particles <- as.integer(n_particles)
res$n_groups <- as.integer(max(n_groups), 1)
Expand All @@ -493,13 +516,39 @@ dust_filter_create <- function(generator, pars, time_start, time, data,
##' particle) or 3d array (state x particle x group). If not
##' provided, the model initial conditions are used.
##'
##' @param save_history Logical, indicating if the simulation history
##' should be saved while the simulation runs; this has a small
##' overhead in runtime and in memory. History (particle
##' trajectories) will be saved at each time in the filter. If the
##' filter was constructed using a non-`NULL` `index` parameter,
##' the history is restricted to these states.
##'
##' @return A vector of likelihood values, with as many elements as
##' there are groups.
##'
##' @export
dust_filter_run <- function(filter, pars = NULL, initial = NULL) {
dust_filter_run <- function(filter, pars = NULL, initial = NULL,
save_history = FALSE) {
check_is_dust_filter(filter)
filter$methods$run(filter$ptr, pars, initial, save_history,
filter$grouped)
}


##' Fetch the last history created by running a filter. This
##' errors if the last call to [dust_filter_run] did not use
##' `save_history = TRUE`.
##'
##' @title Fetch last filter history
##'
##' @inheritParams dust_filter_run
##'
##' @return An array
##'
##' @export
dust_filter_last_history <- function(filter) {
check_is_dust_filter(filter)
filter$methods$run(filter$ptr, pars, initial, filter$grouped)
filter$methods$last_history(filter$ptr, filter$grouped)
}


Expand Down
2 changes: 2 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ reference:
contents:
- dust_unfilter_create
- dust_unfilter_run
- dust_unfilter_last_history
- subtitle: Particle filter
contents:
- dust_filter_create
- dust_filter_run
- dust_filter_last_history
- dust_filter_rng_state


Loading
Loading