Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 23, 2024
1 parent 6990a2b commit a60abb6
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 51 deletions.
8 changes: 4 additions & 4 deletions R/DataDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ DataDescriptor = R6Class("DataDescriptor",
initialize = function(dataset, dataset_shapes, graph = NULL, input_map = NULL, pointer = NULL,
pointer_shape = NULL, pointer_shape_predict = NULL, clone_graph = TRUE) {
assert_class(dataset, "dataset")
assert_flag(clone_graph)
# If the dataset implements a .getbatch() method the shape must be specified, as it should be the same for
# all batches
# For simplicity we here require the first dimension of the shape to be NA so we don't have to deal with it,
# e.g. during subsetting
assert_shapes(dataset_shapes, null_ok = is.null(dataset$.getbatch), unknown_batch = TRUE, named = TRUE)
assert_shape(pointer_shape_predict, null_ok = TRUE, unknown_batch = TRUE)

# prevent user from e.g. forgetting to wrap the return in a list
example = if (is.null(dataset$.getbatch)) {
Expand All @@ -82,7 +82,7 @@ DataDescriptor = R6Class("DataDescriptor",
})
}
if (is.null(graph)) {
# avoid name conflcts
# avoid name conflicts
if (is.null(input_map)) {
assert_true(length(dataset_shapes) == 1L)
input_map = names(dataset_shapes)
Expand Down Expand Up @@ -170,10 +170,10 @@ DataDescriptor = R6Class("DataDescriptor",
#' @field input_map (`character()`)\cr
#' The input map from the dataset to the preprocessing graph.
input_map = NULL,
#' @field pointer (`character(2)` | `NULL`)\cr
#' @field pointer (`character(2)`)\cr
#' The output pointer.
pointer = NULL,
#' @field pointer_shape (`integer` | `NULL`)\cr
#' @field pointer_shape (`integer()` | `NULL`)\cr
#' The shape of the output indicated by `pointer`.
pointer_shape = NULL,
#' @field dataset_hash (`character(1)`)\cr
Expand Down
89 changes: 45 additions & 44 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
#' or `"cb."`, as these are preserved for the dynamically constructed parameters of the optimizer, the loss function,
#' and the callbacks.
#'
#' To perform additional input checks on the task, the private `.verify_train_task(task, param_vals)` and
#' `.verify_predict_task(task, param_vals)` can be overwritten.
#'
#' @family Learner
#' @export
LearnerTorch = R6Class("LearnerTorch",
Expand Down Expand Up @@ -237,8 +240,32 @@ LearnerTorch = R6Class("LearnerTorch",
),
private = list(
.train = function(task) {
private$.verify_train_task(task)
param_vals = self$param_set$get_values(tags = "train")
first_row = task$head(1)
measures = c(normalize_to_list(param_vals$measures_train), normalize_to_list(param_vals$measures_valid))
available_predict_types = mlr_reflections$learner_predict_types[[self$task_type]][[self$predict_type]]
walk(measures, function(m) {
if (m$predict_type %nin% available_predict_types) {
stopf(paste0("Measure '%s' requires predict type '%s' but learner has '%s'.\n",
"Change the predict type or select other measures."),
m$id, m$predict_type, self$predict_type)
}
})

iwalk(first_row, function(x, nm) {
if (!is_lazy_tensor(x)) return(NULL)
predict_shape = dd(x)$pointer_shape_predict
train_shape = dd(x)$pointer_shape
if (is.null(train_shape) || is.null(predict_shape)) {
return(NULL)
}
if (!isTRUE(all.equal(train_shape, predict_shape))) {
stopf("Lazy tensor column '%s' has a different shape during training (%s) and prediction (%s).",
nm, paste0(train_shape, collapse = "x"), paste0(predict_shape, collapse = "x"))
}
})
private$.verify_train_task(task, param_vals)

param_vals$device = auto_device(param_vals$device)
if (param_vals$seed == "random") param_vals$seed = sample.int(10000000L, 1L)

Expand All @@ -249,13 +276,26 @@ LearnerTorch = R6Class("LearnerTorch",
return(model)
},
.predict = function(task) {
private$.verify_predict_task(task)
cols = c(task$feature_names, task$target_names)
ci_predict = task$col_info[get("id") %in% cols, c("id", "type", "levels")]
ci_train = self$model$task_col_info[get("id") %in% cols, c("id", "type", "levels")]
# permuted factor levels cause issues, because we are converting fct -> int
if (!test_equal_col_info(ci_train, ci_predict)) { # nolint
stopf(paste0(
"Predict task's column info does not match the train task's column info.\n",
"This migth be handled more gracefully in the future.\n",
"Training column info:\n'%s'\n",
"Prediction column info:\n'%s'"),
paste0(capture.output(ci_train), collapse = "\n"),
paste0(capture.output(ci_predict), collapse = "\n"))
}
param_vals = self$param_set$get_values(tags = "predict")
private$.verify_predict_task(task, param_vals)
# FIXME: https://github.com/mlr-org/mlr3/issues/946
# This addresses the issues with the facto lrvels and is only a temporary fix
# Should be handled outside of mlr3torch
# Ideally we could rely on state$train_task, but there is this bug
# https://github.com/mlr-org/mlr3/issues/947
param_vals = self$param_set$get_values(tags = "predict")
param_vals$device = auto_device(param_vals$device)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, {
Expand Down Expand Up @@ -289,47 +329,8 @@ LearnerTorch = R6Class("LearnerTorch",
.loss = NULL,
.param_set_base = NULL,
.callbacks = NULL,
.verify_train_task = function(task, row_ids) {
first_row = task$head(1)
pv = self$param_set$values
measures = c(normalize_to_list(pv$measures_train), normalize_to_list(pv$measures_valid))
available_predict_types = mlr_reflections$learner_predict_types[[self$task_type]][[self$predict_type]]
walk(measures, function(m) {
if (m$predict_type %nin% available_predict_types) {
stopf(paste0("Measure '%s' requires predict type '%s' but learner has '%s'.\n",
"Change the predict type or select other measures."),
m$id, m$predict_type, self$predict_type)
}
})

iwalk(first_row, function(x, nm) {
if (!is_lazy_tensor(x)) return(NULL)
predict_shape = dd(x)$pointer_shape_predict
train_shape = dd(x)$pointer_shape
if (is.null(train_shape) || is.null(predict_shape)) {
return(NULL)
}
if (!isTRUE(all.equal(train_shape, predict_shape))) {
stopf("Lazy tensor column '%s' has a different shape during training (%s) and prediction (%s).",
nm, paste0(train_shape, collapse = "x"), paste0(predict_shape, collapse = "x"))
}
})
},
.verify_predict_task = function(task, row_ids) {
cols = c(task$feature_names, task$target_names)
ci_predict = task$col_info[get("id") %in% cols, c("id", "type", "levels")]
ci_train = self$model$task_col_info[get("id") %in% cols, c("id", "type", "levels")]
# permuted factor levels cause issues, because we are converting fct -> int
if (!test_equal_col_info(ci_train, ci_predict)) { # nolint
stopf(paste0(
"Predict task's `$col_info` does not match the train task's column info.\n",
"This migth be handled more gracefully in the future.\n",
"Training column info:\n'%s'\n",
"Prediction column info:\n'%s'"),
paste0(capture.output(ci_train), collapse = "\n"),
paste0(capture.output(ci_predict), collapse = "\n"))
}
},
.verify_train_task = function(task, param_vals) NULL,
.verify_predict_task = function(task, param_vals) NULL,
deep_clone = function(name, value) deep_clone(self, private, super, name, value)
)
)
Expand Down
4 changes: 2 additions & 2 deletions man/DataDescriptor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/mlr_tasks_lazy_iris.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a60abb6

Please sign in to comment.