Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse attribute-style comments from source code #20

Merged
merged 6 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ URL: https://github.com/mrc-ide/dust2, https://mrc-ide.github.io/dust2
BugReports: https://github.com/mrc-ide/dust2/issues
Imports:
cli,
decor,
mcstate2,
rlang
LinkingTo:
cpp11,
mcstate2
Suggests:
testthat (>= 3.0.0)
testthat (>= 3.0.0),
withr
Remotes: mrc-ide/mcstate2
117 changes: 117 additions & 0 deletions R/metadata.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## Each model will be marked up with decor-style comments, this allows
## us to describe a model without needing to try and parse anything,
## which is fraught at best with C++.
##
## The main things we need to know are:
##
## * what is the name of the class we are wrapping?
## * what should the name be of the model exported to R (if different)
## * what parameters does it accept (optional?)
## * does the model support comparison with data?
## * what data entries does it accept (optional?)
##
## Later, we'll want to do the same thing with GPU support if that
## looks different, and possibly with MPI support, though with those
## it might depend a bit on if we actually need much special there -
## it might be worth looking at the GPU stuff again fairly soon
## actually.
parse_metadata <- function(filename, call = NULL) {
data <- decor::cpp_decorations(files = filename)

class <- parse_metadata_class(data, call)
list(class = class,
name = parse_metadata_name(data, call) %||% class,
has_compare = parse_metadata_has_compare(data, call),
parameters = parse_metadata_parameters(data, call))
}


## All of the errors here could benefit from line number and context
## information, but it's not that important as users should never see
## these errors - models will mostly be written by odin
parse_metadata_class <- function(data, call = NULL) {
data <- find_attribute_value_single(data, "dust2::class", required = TRUE,
call = call)
if (length(data) != 1 || nzchar(names(data))) {
cli::cli_abort(
"Expected a single unnamed argument to '[[dust2::class()]]'",
call = call)
}
if (!is.symbol(data[[1]])) {
cli::cli_abort(
"Expected an unquoted string argument to '[[dust2::class()]]'",
call = call)
}
deparse(data[[1]])
}


parse_metadata_name <- function(data, call = NULL) {
data <- find_attribute_value_single(data, "dust2::name", required = FALSE,
call = call)
if (is.null(data)) {
return(NULL)
}
if (length(data) != 1 || nzchar(names(data))) {
cli::cli_abort(
"Expected a single unnamed argument to '[[dust2::name()]]'",
call = call)
}
if (!is.symbol(data[[1]])) {
cli::cli_abort(
"Expected an unquoted string argument to '[[dust2::name()]]'",
call = call)
}
deparse(data[[1]])
}


parse_metadata_has_compare <- function(data, call = NULL) {
data <- find_attribute_value_single(data, "dust2::has_compare",
required = FALSE, call = call)
if (is.null(data)) {
return(FALSE)
}
if (length(data) != 0) {
cli::cli_abort(
"Expected no arguments to '[[dust2::has_compare()]]'",
call = call)
}
TRUE
}


parse_metadata_parameters <- function(data, call = NULL) {
res <- data$params[data$decoration == "dust2::parameter"]
ok <- vlapply(res, function(x) {
length(x) == 1 && !nzchar(names(x)[[1]]) && is.symbol(x[[1]])
})
if (!all(ok)) {
cli::cli_abort(
paste("Expected an unnamed unquoted string argument to",
"'[[dust2::parameter()]]'"),
call = call)
}
data_frame(name = vcapply(res, function(x) deparse(x[[1]])))
}


find_attribute_value_single <- function(data, name, required, call = NULL) {
i <- data$decoration == name
if (!any(i)) {
if (required) {
cli::cli_abort(
"Attribute '[[{name}()]]' is required, but was not found",
call = call)
}
return(NULL)
}

if (sum(i) > 1) {
cli::cli_abort(
"More than one '[[{name}()]]' attribute found",
call = call)
}

data$params[[which(i)]]
}
20 changes: 20 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,23 @@ set_names <- function(x, nms) {
names(x) <- nms
x
}


vlapply <- function(...) {
vapply(..., FUN.VALUE = TRUE)
}


vcapply <- function(...) {
vapply(..., FUN.VALUE = "")
}


data_frame <- function(...) {
data.frame(..., stringsAsFactors = FALSE, check.names = FALSE)
}


dust2_file <- function(path) {
system.file(path, mustWork = TRUE, package = "dust2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That mustWork arg name always amuses me. I might write some args not_that_bothered_if_it_works in some of my functions.

}
37 changes: 18 additions & 19 deletions inst/examples/sir.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include <dust2/common.hpp>

namespace {
inline double with_default(double default_value, cpp11::sexp value) {
return value == R_NilValue ? default_value : cpp11::as_cpp<double>(value);
}
}

// [[dust2::class(sir)]]
// [[dust2::has_compare()]]
// [[dust2::parameter(I0)]]
// [[dust2::parameter(N)]]
// [[dust2::parameter(beta)]]
// [[dust2::parameter(gamma)]]
// [[dust2::parameter(exp_noise)]]
class sir {
public:
sir() = delete;
Expand Down Expand Up @@ -70,14 +71,11 @@ class sir {
}

static shared_state build_shared(cpp11::list pars) {
const real_type I0 = with_default(10, pars["I0"]);
const real_type N = with_default(1000, pars["N"]);

const real_type beta = with_default(0.2, pars["beta"]);
const real_type gamma = with_default(0.1, pars["gamma"]);

const real_type exp_noise = with_default(1e6, pars["exp_noise"]);

const real_type I0 = dust2::r::read_real(pars, "I0", 10);
const real_type N = dust2::r::read_real(pars, "N", 1000);
const real_type beta = dust2::r::read_real(pars, "beta", 0.2);
const real_type gamma = dust2::r::read_real(pars, "gamma", 0.1);
const real_type exp_noise = dust2::r::read_real(pars, "exp_noise", 1e6);
return shared_state{N, I0, beta, gamma, exp_noise};
}

Expand All @@ -88,9 +86,9 @@ class sir {
// This is the bit that we'll use to do fast parameter updating, and
// we'll guarantee somewhere that the size does not change.
static void update_shared(cpp11::list pars, shared_state& shared) {
shared.I0 = with_default(10, pars["I0"]);
shared.beta = with_default(0.2, pars["beta"]);
shared.gamma = with_default(0.1, pars["gamma"]);
shared.I0 = dust2::r::read_real(pars, "I0", shared.I0);
shared.beta = dust2::r::read_real(pars, "beta", shared.beta);
shared.gamma = dust2::r::read_real(pars, "gamma", shared.gamma);
}

// This is a reasonable default implementation in the no-internal
Expand All @@ -99,9 +97,10 @@ class sir {
internal_state& internal) {
}

static data_type build_data(cpp11::sexp r_data) {
static data_type build_data(cpp11::list r_data) {
auto data = static_cast<cpp11::list>(r_data);
return data_type{cpp11::as_cpp<real_type>(data["incidence"])};
auto incidence = dust2::r::read_real(data, "incidence");
return data_type{incidence};
}

static real_type compare_data(const real_type time,
Expand Down
21 changes: 9 additions & 12 deletions inst/examples/walk.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include <dust2/common.hpp>

// [[dust2::class(walk)]]
// [[dust2::parameter(sd)]]
// [[dust2::parameter(len)]]
// [[dust2::parameter(random_initial)]]
class walk {
public:
// No constructor - turning this off is optional
Expand Down Expand Up @@ -41,10 +45,7 @@ class walk {
// This is the bit that we'll use to do fast parameter updating, and
// we'll guarantee somewhere that the size does not change.
static void update_shared(cpp11::list pars, shared_state& shared) {
const cpp11::sexp r_sd = pars["sd"];
if (r_sd != R_NilValue) {
shared.sd = cpp11::as_cpp<walk::real_type>(pars["sd"]);
}
shared.sd = dust2::r::read_real(pars, "sd", shared.sd);
}

// This is a reasonable default implementation in the no-internal
Expand Down Expand Up @@ -87,14 +88,10 @@ class walk {

// Then, rather than a constructor we have some converters:
static shared_state build_shared(cpp11::list pars) {
size_t len = 1;
const cpp11::sexp r_len = pars["len"];
if (r_len != R_NilValue) {
len = cpp11::as_cpp<int>(r_len);
}
const walk::real_type sd = cpp11::as_cpp<walk::real_type>(pars["sd"]);
const bool random_initial = pars["random_initial"] == R_NilValue ? false :
cpp11::as_cpp<bool>(pars["random_initial"]);
const auto len = dust2::r::read_size(pars, "len", 1);
const auto sd = dust2::r::read_real(pars, "sd", 1);
const auto random_initial =
dust2::r::read_bool(pars, "random_initial", false);
return shared_state{len, sd, random_initial};
}

Expand Down
22 changes: 22 additions & 0 deletions inst/include/dust2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,26 @@ struct no_data {};
struct no_internal_state {};
struct no_shared_state {};

namespace r {

// The actual definitions of these are elsewhere, but these are
// functions that models may use so we declare them here.
inline double read_real(cpp11::list pars, const char * name);
inline double read_real(cpp11::list pars, const char * name,
double default_value);

inline int read_int(cpp11::list pars, const char * name);
inline int read_int(cpp11::list pars, const char * name,
int default_value);

inline size_t read_size(cpp11::list pars, const char * name);
inline size_t read_size(cpp11::list pars, const char * name,
size_t default_value);

inline bool read_bool(cpp11::list pars, const char * name);
inline bool read_bool(cpp11::list pars, const char * name,
bool default_value);

}

}
7 changes: 4 additions & 3 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,18 @@ SEXP dust2_cpu_compare_data(cpp11::sexp ptr,
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
const auto n_groups = obj->n_groups();
std::vector<data_type> data;
auto r_data_list = cpp11::as_cpp<cpp11::list>(r_data);
if (grouped) {
auto r_data_list = cpp11::as_cpp<cpp11::list>(r_data);
check_length(r_data_list, n_groups, "data");
for (size_t i = 0; i < n_groups; ++i) {
data.push_back(T::build_data(r_data_list[i]));
auto r_data_list_i = cpp11::as_cpp<cpp11::list>(r_data_list[i]);
data.push_back(T::build_data(r_data_list_i));
}
} else {
if (n_groups > 1) {
cpp11::stop("Can't compare with grouped = FALSE with more than one group");
}
data.push_back(T::build_data(r_data));
data.push_back(T::build_data(r_data_list));
}

cpp11::writable::doubles ret(obj->n_particles() * obj->n_groups());
Expand Down
62 changes: 60 additions & 2 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,14 @@ std::vector<typename T::data_type> check_data(cpp11::list r_data,
auto r_data_i = cpp11::as_cpp<cpp11::list>(r_data[i]);
check_length(r_data_i, n_groups, "data[i]"); // can do better with sstream
for (size_t j = 0; j < n_groups; ++j) {
data.push_back(T::build_data(r_data_i[j]));
auto r_data_ij = cpp11::as_cpp<cpp11::list>(r_data_i[j]);
data.push_back(T::build_data(r_data_ij));
}
}
} else {
for (size_t i = 0; i < n_time; ++i) {
data.push_back(T::build_data(r_data[i]));
auto r_data_i = cpp11::as_cpp<cpp11::list>(r_data[i]);
data.push_back(T::build_data(r_data_i));
}
}

Expand Down Expand Up @@ -286,5 +288,61 @@ SEXP rng_state_as_raw(const std::vector<T>& state) {
return ret;
}

inline double read_real(cpp11::list args, const char * name) {
cpp11::sexp value = args[name];
if (value == R_NilValue) {
cpp11::stop("A value is expected for '%s'", name);
}
return to_double(value, name);
}

inline double read_real(cpp11::list args, const char * name,
double default_value) {
cpp11::sexp value = args[name];
return value == R_NilValue ? default_value : to_double(value, name);
}

inline int read_int(cpp11::list args, const char * name) {
cpp11::sexp value = args[name];
if (value == R_NilValue) {
cpp11::stop("A value is expected for '%s'", name);
}
return to_int(value, name);
}

inline int read_int(cpp11::list args, const char * name,
int default_value) {
cpp11::sexp value = args[name];
return value == R_NilValue ? default_value : to_int(value, name);
}

inline size_t read_size(cpp11::list args, const char * name) {
cpp11::sexp value = args[name];
if (value == R_NilValue) {
cpp11::stop("A value is expected for '%s'", name);
}
return to_size(value, name);
}

inline size_t read_size(cpp11::list args, const char * name,
size_t default_value) {
cpp11::sexp value = args[name];
return value == R_NilValue ? default_value : to_size(value, name);
}

inline bool read_bool(cpp11::list args, const char * name) {
cpp11::sexp value = args[name];
if (value == R_NilValue) {
cpp11::stop("A value is expected for '%s'", name);
}
return to_bool(value, name);
}

inline bool read_bool(cpp11::list args, const char * name,
bool default_value) {
cpp11::sexp value = args[name];
return value == R_NilValue ? default_value : to_bool(value, name);
}

}
}
Loading
Loading