Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 23, 2024
1 parent 559ade7 commit 6990a2b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
51 changes: 36 additions & 15 deletions R/DataDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#' indicate the batch dimension.
#' @param graph ([`Graph`])\cr
#' The preprocessing graph.
#' If left `NULL`, no preprocessing is applied to the data and `input_map`, `pointer` and `pointer_shape`
#' are inferred in case the dataset returns only one element.
#' If left `NULL`, no preprocessing is applied to the data and `input_map`, `pointer`, `pointer_shape`, and
#' `pointer_shape_predict` are inferred in case the dataset returns only one element.
#' @param input_map (`character()`)\cr
#' Character vector that must have the same length as the input of the graph.
#' Specifies how the data from the `dataset` is fed into the preprocessing graph.
Expand Down Expand Up @@ -81,29 +81,50 @@ DataDescriptor = R6Class("DataDescriptor",
}
})
}

if (is.null(graph)) {
if ((length(dataset_shapes) == 1L) && is.null(input_map)) {
# avoid name conflcts
if (is.null(input_map)) {
assert_true(length(dataset_shapes) == 1L)
input_map = names(dataset_shapes)
} else {
assert_true(length(input_map) == 1L)
assert_subset(input_map, names(dataset_shapes))
}

graph = as_graph(po("nop", id = paste0(class(dataset)[[1L]], "_", input_map)))
pointer = c(graph$output$op.id, graph$output$channel.name)
pointer_shape = dataset_shapes[[input_map]]
# get unique ID for input PipeOp
graph = as_graph(po("nop", id =
paste0("nop.", calculate_hash(address(dataset)), ".", input_map)
))
} else {
graph = as_graph(graph, clone = clone_graph)
assert_true(length(graph$pipeops) >= 1L)
assert_true(!is.null(input_map))
}
# no preprocessing, dataset returns only a single element (there we can infer a lot)
simple_case = length(graph$pipeops) == 1L && inherits(graph$pipeops[[1L]], "PipeOpNOP") &&
length(dataset_shapes) == 1L

if (is.null(input_map) && simple_case) {
input_map = names(dataset_shapes)
} else {
assert_subset(input_map, names(dataset_shapes))
}
if (is.null(pointer) && simple_case) {
pointer = c(graph$output$op.id, graph$output$channel.name)
} else {
assert_choice(pointer[[1]], names(graph$pipeops))
assert_choice(pointer[[2]], graph$pipeops[[pointer[[1]]]]$output$name)
assert_subset(paste0(pointer, collapse = "."), graph$output$name)
}
if (is.null(pointer_shape) && simple_case) {
pointer_shape = dataset_shapes[[1L]]
} else {
assert_shape(pointer_shape, null_ok = TRUE)
assert_subset(input_map, names(dataset_shapes))
assert_true(length(input_map) == length(graph$input$name))
}
if (is.null(pointer_shape_predict) && simple_case) {
pointer_shape_predict = pointer_shape
} else if (simple_case) {
assert_true(isTRUE(all.equal(pointer_shape, pointer_shape_predict)))
} else {
assert_shape(pointer_shape_predict, null_ok = TRUE)
}

assert_subset(paste0(pointer, collapse = "."), graph$output$name)
assert_true(length(input_map) == length(graph$input$name))

# We hash the address of the environment, because the hashes of an environment are not stable,
# even with a .dataset (that should usually not really have a state), hashes might change due to byte-code
Expand Down
4 changes: 2 additions & 2 deletions man/DataDescriptor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 1 addition & 9 deletions man/mlr_tasks_lazy_iris.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test_lazy_tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ test_that("transform_lazy_tensor works", {
dd1 = dd(lt1)

expect_equal(dd1$graph$edges,
data.table(src_id = "dataset_x", src_channel = "output", dst_id = "mod", dst_channel = "input")
data.table(src_id = dd1$graph$input$op.id, src_channel = "output", dst_id = "mod", dst_channel = "input")
)

dd = dd(lt)
Expand Down

0 comments on commit 6990a2b

Please sign in to comment.