Skip to content

Commit

Permalink
Merge pull request #4 from mrc-ide/mrc-5327
Browse files Browse the repository at this point in the history
Allow fractional time
  • Loading branch information
richfitz authored May 10, 2024
2 parents 7b6c142 + c690ed8 commit 8db8a6a
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 13 deletions.
4 changes: 4 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ dust2_cpu_walk_set_state <- function(ptr, r_state) {
dust2_cpu_walk_rng_state <- function(ptr) {
.Call(`_dust2_dust2_cpu_walk_rng_state`, ptr)
}

dust2_cpu_walk_set_time <- function(ptr, r_time) {
.Call(`_dust2_dust2_cpu_walk_set_time`, ptr, r_time)
}
12 changes: 5 additions & 7 deletions inst/include/dust2/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class dust_cpu {
// We don't check that the size is the same across all states;
// this should be done by the caller (similarly, we don't check
// that shared and internal have the same size).
if (dt != 1) {
throw std::runtime_error("Requiring dt = 1 for now");
}
}

auto run_steps(size_t n_steps) {
Expand Down Expand Up @@ -97,13 +94,14 @@ class dust_cpu {
return state_;
}

// Fairly useless getter/setter - we might be better exposing time
// directly as a field. However, for the MPI and GPU version this
// will almost certainly do something.
auto time() const {
return time_;
}

void set_time(real_type time) {
// TODO: some validation needs to be done here to deal with
// offsets relative to dt.
time_ = time;
}

Expand Down Expand Up @@ -133,11 +131,11 @@ class dust_cpu {
real_type * state, rng_state_type& rng_state,
real_type * state_next) {
for (size_t i = 0; i < n_steps; ++i) {
T::update(time, dt, state, shared, internal, rng_state, state_next);
T::update(time + i * dt, dt, state, shared, internal, rng_state,
state_next);
std::swap(state, state_next);
}
}
};


}
13 changes: 11 additions & 2 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars,
using internal_state = typename T::internal_state;
using rng_state_type = typename T::rng_state_type;

auto time = to_double(r_time, "time");
auto dt = to_double(r_dt, "r_dt");
const auto time = check_time(r_time);
const auto dt = check_dt(r_dt);

auto n_particles = to_size(r_n_particles, "n_particles");
auto n_groups = to_size(r_n_groups, "n_groups");

Expand Down Expand Up @@ -114,5 +115,13 @@ SEXP dust2_cpu_rng_state(cpp11::sexp ptr) {
return ret;
}

template <typename T>
SEXP dust2_cpu_set_time(cpp11::sexp ptr, cpp11::sexp r_time) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
const auto time = check_time(r_time);
obj->set_time(time);
return R_NilValue;
}

}
}
31 changes: 31 additions & 0 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ inline bool to_bool(cpp11::sexp x, const char * name) {
cpp11::stop("'%s' must be scalar logical", name);
}

template <typename T>
bool is_integer_like(T x, T eps) {
return std::abs(x - round(x)) <= eps;
}

inline double check_time(cpp11::sexp r_time) {
const auto time = to_double(r_time, "time");
const auto eps = 1e-8;
// We can relax this later and carefully align time onto a grid
if (!is_integer_like(time, eps)) {
throw std::runtime_error("Expected 'time' to be integer-like");
}
return time;
}

inline double check_dt(cpp11::sexp r_dt) {
const auto dt = to_double(r_dt, "dt");
const auto eps = 1e-8;
if (dt <= 0) {
cpp11::stop("Expected 'dt' to be greater than 0");
}
if (dt > 1) {
cpp11::stop("Expected 'dt' to be at most 1");
}
const auto inv_dt = 1 / dt;
if (!is_integer_like(inv_dt, eps)) {
throw std::runtime_error("Expected 'dt' to be the inverse of an integer");
}
return dt;
}

// template <typename real_type>
// inline cpp11::sexp to_matrix(std::vector<real_type> x, size_t nr, size_t nc) {
// cpp11::writable::integers dim{static_cast<int>(nr), static_cast<int>(nc)};
Expand Down
8 changes: 8 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ extern "C" SEXP _dust2_dust2_cpu_walk_rng_state(SEXP ptr) {
return cpp11::as_sexp(dust2_cpu_walk_rng_state(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr)));
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_set_time(cpp11::sexp ptr, cpp11::sexp r_time);
extern "C" SEXP _dust2_dust2_cpu_walk_set_time(SEXP ptr, SEXP r_time) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_walk_set_time(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_time)));
END_CPP11
}

extern "C" {
static const R_CallMethodDef CallEntries[] = {
Expand All @@ -62,6 +69,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_walk_run_steps", (DL_FUNC) &_dust2_dust2_cpu_walk_run_steps, 2},
{"_dust2_dust2_cpu_walk_set_state", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state, 2},
{"_dust2_dust2_cpu_walk_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state_initial, 1},
{"_dust2_dust2_cpu_walk_set_time", (DL_FUNC) &_dust2_dust2_cpu_walk_set_time, 2},
{"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 1},
{"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1},
{NULL, NULL, 0}
Expand Down
5 changes: 5 additions & 0 deletions src/walk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ SEXP dust2_cpu_walk_set_state(cpp11::sexp ptr, cpp11::sexp r_state) {
SEXP dust2_cpu_walk_rng_state(cpp11::sexp ptr) {
return dust2::r::dust2_cpu_rng_state<walk>(ptr);
}

[[cpp11::register]]
SEXP dust2_cpu_walk_set_time(cpp11::sexp ptr, cpp11::sexp r_time) {
return dust2::r::dust2_cpu_set_time<walk>(ptr, r_time);
}
50 changes: 46 additions & 4 deletions tests/testthat/test-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,40 @@ test_that("can run deterministically", {
})


test_that("require that dt is 1 for now", {
test_that("Allow fractional dt", {
pars <- list(sd = 1, random_initial = TRUE)
obj <- dust2_cpu_walk_alloc(pars, 0, 0.5, 10, 0, 42, FALSE)
ptr <- obj[[1]]
expect_equal(dust2_cpu_walk_time(ptr), 0)
expect_null(dust2_cpu_walk_run_steps(ptr, 3))
expect_equal(dust2_cpu_walk_time(ptr), 1.5)
expect_null(dust2_cpu_walk_run_steps(ptr, 5))
expect_equal(dust2_cpu_walk_time(ptr), 4)
})


test_that("provided dt is reasonable", {
pars <- list(sd = 1, random_initial = TRUE)
expect_error(
dust2_cpu_walk_alloc(pars, 0, 0, 10, 0, 42, FALSE),
"Expected 'dt' to be greater than 0")
expect_error(
dust2_cpu_walk_alloc(pars, 0, -1, 10, 0, 42, FALSE),
"Expected 'dt' to be greater than 0")
expect_error(
dust2_cpu_walk_alloc(pars, 0, 1.5, 10, 0, 42, FALSE),
"Expected 'dt' to be at most 1")
expect_error(
dust2_cpu_walk_alloc(pars, 0, sqrt(2) / 2, 10, 0, 42, FALSE),
"Expected 'dt' to be the inverse of an integer")
})


test_that("time starts as an integer", {
pars <- list(sd = 1, random_initial = TRUE)
expect_error(
dust2_cpu_walk_alloc(pars, 0, 0.5, 10, 0, 42, FALSE),
"Requiring dt = 1 for now",
fixed = TRUE)
dust2_cpu_walk_alloc(pars, 1.5, 1, 10, 0, 42, FALSE),
"Expected 'time' to be integer-like")
})


Expand Down Expand Up @@ -183,3 +211,17 @@ test_that("require that parameter length matches requested number of groups", {
dust2_cpu_walk_alloc(pars, 0, 1, 10, 3, 42, FALSE),
"Expected 'pars' to have length 3 to match 'n_groups'")
})


test_that("can set time", {
pars <- list(sd = 1, random_initial = TRUE)
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, FALSE)
ptr <- obj[[1]]
expect_equal(dust2_cpu_walk_time(ptr), 0)
expect_null(dust2_cpu_walk_set_time(ptr, 4))
expect_equal(dust2_cpu_walk_time(ptr), 4)
expect_null(dust2_cpu_walk_set_time(ptr, 0))
expect_equal(dust2_cpu_walk_time(ptr), 0)
expect_error(dust2_cpu_walk_set_time(ptr, 0.5),
"Expected 'time' to be integer-like")
})

0 comments on commit 8db8a6a

Please sign in to comment.