diff --git a/R/DataDescriptor.R b/R/DataDescriptor.R index 157f2ede..ca2b9b32 100644 --- a/R/DataDescriptor.R +++ b/R/DataDescriptor.R @@ -16,8 +16,8 @@ #' indicate the batch dimension. #' @param graph ([`Graph`])\cr #' The preprocessing graph. -#' If left `NULL`, no preprocessing is applied to the data and `input_map`, `pointer` and `pointer_shape` -#' are inferred in case the dataset returns only one element. +#' If left `NULL`, no preprocessing is applied to the data and `input_map`, `pointer`, `pointer_shape`, and +#' `pointer_shape_predict` are inferred in case the dataset returns only one element. #' @param input_map (`character()`)\cr #' Character vector that must have the same length as the input of the graph. #' Specifies how the data from the `dataset` is fed into the preprocessing graph. @@ -81,29 +81,50 @@ DataDescriptor = R6Class("DataDescriptor", } }) } - if (is.null(graph)) { - if ((length(dataset_shapes) == 1L) && is.null(input_map)) { + # avoid name conflcts + if (is.null(input_map)) { + assert_true(length(dataset_shapes) == 1L) input_map = names(dataset_shapes) - } else { - assert_true(length(input_map) == 1L) - assert_subset(input_map, names(dataset_shapes)) } - - graph = as_graph(po("nop", id = paste0(class(dataset)[[1L]], "_", input_map))) - pointer = c(graph$output$op.id, graph$output$channel.name) - pointer_shape = dataset_shapes[[input_map]] + # get unique ID for input PipeOp + graph = as_graph(po("nop", id = + paste0("nop.", calculate_hash(address(dataset)), ".", input_map) + )) } else { graph = as_graph(graph, clone = clone_graph) assert_true(length(graph$pipeops) >= 1L) - assert_true(!is.null(input_map)) + } + # no preprocessing, dataset returns only a single element (there we can infer a lot) + simple_case = length(graph$pipeops) == 1L && inherits(graph$pipeops[[1L]], "PipeOpNOP") && + length(dataset_shapes) == 1L + + if (is.null(input_map) && simple_case) { + input_map = names(dataset_shapes) + } else { + assert_subset(input_map, names(dataset_shapes)) + } + if (is.null(pointer) && simple_case) { + pointer = c(graph$output$op.id, graph$output$channel.name) + } else { assert_choice(pointer[[1]], names(graph$pipeops)) assert_choice(pointer[[2]], graph$pipeops[[pointer[[1]]]]$output$name) - assert_subset(paste0(pointer, collapse = "."), graph$output$name) + } + if (is.null(pointer_shape) && simple_case) { + pointer_shape = dataset_shapes[[1L]] + } else { assert_shape(pointer_shape, null_ok = TRUE) - assert_subset(input_map, names(dataset_shapes)) - assert_true(length(input_map) == length(graph$input$name)) } + if (is.null(pointer_shape_predict) && simple_case) { + pointer_shape_predict = pointer_shape + } else if (simple_case) { + assert_true(isTRUE(all.equal(pointer_shape, pointer_shape_predict))) + } else { + assert_shape(pointer_shape_predict, null_ok = TRUE) + } + + assert_subset(paste0(pointer, collapse = "."), graph$output$name) + assert_true(length(input_map) == length(graph$input$name)) # We hash the address of the environment, because the hashes of an environment are not stable, # even with a .dataset (that should usually not really have a state), hashes might change due to byte-code diff --git a/man/DataDescriptor.Rd b/man/DataDescriptor.Rd index 5d015f5f..e64dad92 100644 --- a/man/DataDescriptor.Rd +++ b/man/DataDescriptor.Rd @@ -104,8 +104,8 @@ indicate the batch dimension.} \item{\code{graph}}{(\code{\link{Graph}})\cr The preprocessing graph. -If left \code{NULL}, no preprocessing is applied to the data and \code{input_map}, \code{pointer} and \code{pointer_shape} -are inferred in case the dataset returns only one element.} +If left \code{NULL}, no preprocessing is applied to the data and \code{input_map}, \code{pointer}, \code{pointer_shape}, and +\code{pointer_shape_predict} are inferred in case the dataset returns only one element.} \item{\code{input_map}}{(\code{character()})\cr Character vector that must have the same length as the input of the graph. diff --git a/man/mlr_tasks_lazy_iris.Rd b/man/mlr_tasks_lazy_iris.Rd index 3dd7ee28..fa9fbc0a 100644 --- a/man/mlr_tasks_lazy_iris.Rd +++ b/man/mlr_tasks_lazy_iris.Rd @@ -22,15 +22,7 @@ Just like the iris task, but the features are represented as one lazy tensor col \section{Meta Information}{ -\itemize{ -\item Task type: \dQuote{classif} -\item Properties: \dQuote{multiclass} -\item Has Missings: no -\item Target: \dQuote{Species} -\item Features: \dQuote{x} -\item Backend Dimension: 150x3 -\item Default Roles (use / test / holdout): 150, 0, 0 -} +\verb{r rd_info_task_torch("lazy_iris", missings = FALSE)} } \examples{ diff --git a/tests/testthat/test_lazy_tensor.R b/tests/testthat/test_lazy_tensor.R index 891936fc..9018fccc 100644 --- a/tests/testthat/test_lazy_tensor.R +++ b/tests/testthat/test_lazy_tensor.R @@ -74,7 +74,7 @@ test_that("transform_lazy_tensor works", { dd1 = dd(lt1) expect_equal(dd1$graph$edges, - data.table(src_id = "dataset_x", src_channel = "output", dst_id = "mod", dst_channel = "input") + data.table(src_id = dd1$graph$input$op.id, src_channel = "output", dst_id = "mod", dst_channel = "input") ) dd = dd(lt)