diff --git a/R/DataDescriptor.R b/R/DataDescriptor.R index ca2b9b32..ece97899 100644 --- a/R/DataDescriptor.R +++ b/R/DataDescriptor.R @@ -56,12 +56,12 @@ DataDescriptor = R6Class("DataDescriptor", initialize = function(dataset, dataset_shapes, graph = NULL, input_map = NULL, pointer = NULL, pointer_shape = NULL, pointer_shape_predict = NULL, clone_graph = TRUE) { assert_class(dataset, "dataset") + assert_flag(clone_graph) # If the dataset implements a .getbatch() method the shape must be specified, as it should be the same for # all batches # For simplicity we here require the first dimension of the shape to be NA so we don't have to deal with it, # e.g. during subsetting assert_shapes(dataset_shapes, null_ok = is.null(dataset$.getbatch), unknown_batch = TRUE, named = TRUE) - assert_shape(pointer_shape_predict, null_ok = TRUE, unknown_batch = TRUE) # prevent user from e.g. forgetting to wrap the return in a list example = if (is.null(dataset$.getbatch)) { @@ -82,7 +82,7 @@ DataDescriptor = R6Class("DataDescriptor", }) } if (is.null(graph)) { - # avoid name conflcts + # avoid name conflicts if (is.null(input_map)) { assert_true(length(dataset_shapes) == 1L) input_map = names(dataset_shapes) @@ -170,10 +170,10 @@ DataDescriptor = R6Class("DataDescriptor", #' @field input_map (`character()`)\cr #' The input map from the dataset to the preprocessing graph. input_map = NULL, - #' @field pointer (`character(2)` | `NULL`)\cr + #' @field pointer (`character(2)`)\cr #' The output pointer. pointer = NULL, - #' @field pointer_shape (`integer` | `NULL`)\cr + #' @field pointer_shape (`integer()` | `NULL`)\cr #' The shape of the output indicated by `pointer`. pointer_shape = NULL, #' @field dataset_hash (`character(1)`)\cr diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index c3577c5d..5114f00c 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -80,6 +80,9 @@ #' or `"cb."`, as these are preserved for the dynamically constructed parameters of the optimizer, the loss function, #' and the callbacks. #' +#' To perform additional input checks on the task, the private `.verify_train_task(task, param_vals)` and +#' `.verify_predict_task(task, param_vals)` can be overwritten. +#' #' @family Learner #' @export LearnerTorch = R6Class("LearnerTorch", @@ -237,8 +240,32 @@ LearnerTorch = R6Class("LearnerTorch", ), private = list( .train = function(task) { - private$.verify_train_task(task) param_vals = self$param_set$get_values(tags = "train") + first_row = task$head(1) + measures = c(normalize_to_list(param_vals$measures_train), normalize_to_list(param_vals$measures_valid)) + available_predict_types = mlr_reflections$learner_predict_types[[self$task_type]][[self$predict_type]] + walk(measures, function(m) { + if (m$predict_type %nin% available_predict_types) { + stopf(paste0("Measure '%s' requires predict type '%s' but learner has '%s'.\n", + "Change the predict type or select other measures."), + m$id, m$predict_type, self$predict_type) + } + }) + + iwalk(first_row, function(x, nm) { + if (!is_lazy_tensor(x)) return(NULL) + predict_shape = dd(x)$pointer_shape_predict + train_shape = dd(x)$pointer_shape + if (is.null(train_shape) || is.null(predict_shape)) { + return(NULL) + } + if (!isTRUE(all.equal(train_shape, predict_shape))) { + stopf("Lazy tensor column '%s' has a different shape during training (%s) and prediction (%s).", + nm, paste0(train_shape, collapse = "x"), paste0(predict_shape, collapse = "x")) + } + }) + private$.verify_train_task(task, param_vals) + param_vals$device = auto_device(param_vals$device) if (param_vals$seed == "random") param_vals$seed = sample.int(10000000L, 1L) @@ -249,13 +276,26 @@ LearnerTorch = R6Class("LearnerTorch", return(model) }, .predict = function(task) { - private$.verify_predict_task(task) + cols = c(task$feature_names, task$target_names) + ci_predict = task$col_info[get("id") %in% cols, c("id", "type", "levels")] + ci_train = self$model$task_col_info[get("id") %in% cols, c("id", "type", "levels")] + # permuted factor levels cause issues, because we are converting fct -> int + if (!test_equal_col_info(ci_train, ci_predict)) { # nolint + stopf(paste0( + "Predict task's column info does not match the train task's column info.\n", + "This migth be handled more gracefully in the future.\n", + "Training column info:\n'%s'\n", + "Prediction column info:\n'%s'"), + paste0(capture.output(ci_train), collapse = "\n"), + paste0(capture.output(ci_predict), collapse = "\n")) + } + param_vals = self$param_set$get_values(tags = "predict") + private$.verify_predict_task(task, param_vals) # FIXME: https://github.com/mlr-org/mlr3/issues/946 # This addresses the issues with the facto lrvels and is only a temporary fix # Should be handled outside of mlr3torch # Ideally we could rely on state$train_task, but there is this bug # https://github.com/mlr-org/mlr3/issues/947 - param_vals = self$param_set$get_values(tags = "predict") param_vals$device = auto_device(param_vals$device) with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, { @@ -289,47 +329,8 @@ LearnerTorch = R6Class("LearnerTorch", .loss = NULL, .param_set_base = NULL, .callbacks = NULL, - .verify_train_task = function(task, row_ids) { - first_row = task$head(1) - pv = self$param_set$values - measures = c(normalize_to_list(pv$measures_train), normalize_to_list(pv$measures_valid)) - available_predict_types = mlr_reflections$learner_predict_types[[self$task_type]][[self$predict_type]] - walk(measures, function(m) { - if (m$predict_type %nin% available_predict_types) { - stopf(paste0("Measure '%s' requires predict type '%s' but learner has '%s'.\n", - "Change the predict type or select other measures."), - m$id, m$predict_type, self$predict_type) - } - }) - - iwalk(first_row, function(x, nm) { - if (!is_lazy_tensor(x)) return(NULL) - predict_shape = dd(x)$pointer_shape_predict - train_shape = dd(x)$pointer_shape - if (is.null(train_shape) || is.null(predict_shape)) { - return(NULL) - } - if (!isTRUE(all.equal(train_shape, predict_shape))) { - stopf("Lazy tensor column '%s' has a different shape during training (%s) and prediction (%s).", - nm, paste0(train_shape, collapse = "x"), paste0(predict_shape, collapse = "x")) - } - }) - }, - .verify_predict_task = function(task, row_ids) { - cols = c(task$feature_names, task$target_names) - ci_predict = task$col_info[get("id") %in% cols, c("id", "type", "levels")] - ci_train = self$model$task_col_info[get("id") %in% cols, c("id", "type", "levels")] - # permuted factor levels cause issues, because we are converting fct -> int - if (!test_equal_col_info(ci_train, ci_predict)) { # nolint - stopf(paste0( - "Predict task's `$col_info` does not match the train task's column info.\n", - "This migth be handled more gracefully in the future.\n", - "Training column info:\n'%s'\n", - "Prediction column info:\n'%s'"), - paste0(capture.output(ci_train), collapse = "\n"), - paste0(capture.output(ci_predict), collapse = "\n")) - } - }, + .verify_train_task = function(task, param_vals) NULL, + .verify_predict_task = function(task, param_vals) NULL, deep_clone = function(name, value) deep_clone(self, private, super, name, value) ) ) diff --git a/man/DataDescriptor.Rd b/man/DataDescriptor.Rd index e64dad92..3eab4374 100644 --- a/man/DataDescriptor.Rd +++ b/man/DataDescriptor.Rd @@ -44,10 +44,10 @@ The shapes of the output.} \item{\code{input_map}}{(\code{character()})\cr The input map from the dataset to the preprocessing graph.} -\item{\code{pointer}}{(\code{character(2)} | \code{NULL})\cr +\item{\code{pointer}}{(\code{character(2)})\cr The output pointer.} -\item{\code{pointer_shape}}{(\code{integer} | \code{NULL})\cr +\item{\code{pointer_shape}}{(\code{integer()} | \code{NULL})\cr The shape of the output indicated by \code{pointer}.} \item{\code{dataset_hash}}{(\code{character(1)})\cr diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 55d72e08..e447cd5d 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -90,6 +90,9 @@ While it is possible to add parameters by specifying the \code{param_set} constr not possible to remove existing parameters, i.e. those listed in section \emph{Parameters}. None of the parameters provided in \code{param_set} can have an id that starts with \code{"loss."}, \verb{"opt.", or }"cb."`, as these are preserved for the dynamically constructed parameters of the optimizer, the loss function, and the callbacks. + +To perform additional input checks on the task, the private \code{.verify_train_task(task, param_vals)} and +\code{.verify_predict_task(task, param_vals)} can be overwritten. } \seealso{ diff --git a/man/mlr_tasks_lazy_iris.Rd b/man/mlr_tasks_lazy_iris.Rd index fa9fbc0a..3dd7ee28 100644 --- a/man/mlr_tasks_lazy_iris.Rd +++ b/man/mlr_tasks_lazy_iris.Rd @@ -22,7 +22,15 @@ Just like the iris task, but the features are represented as one lazy tensor col \section{Meta Information}{ -\verb{r rd_info_task_torch("lazy_iris", missings = FALSE)} +\itemize{ +\item Task type: \dQuote{classif} +\item Properties: \dQuote{multiclass} +\item Has Missings: no +\item Target: \dQuote{Species} +\item Features: \dQuote{x} +\item Backend Dimension: 150x3 +\item Default Roles (use / test / holdout): 150, 0, 0 +} } \examples{