Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 23, 2024
1 parent ed6b014 commit 9b2b1e4
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 68 deletions.
6 changes: 2 additions & 4 deletions R/PipeOpTaskPreprocTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ PipeOpTaskPreprocTorch = R6Class("PipeOpTaskPreprocTorch",
public = list(
#' @description
#' Creates a new instance of this [`R6`][R6::R6Class] class.
initialize = function(fn, id = "preproc_torch", param_vals = list(), param_set = ps(), packages = character(0), rowwise = FALSE) { # nolint
initialize = function(fn, id = "preproc_torch", param_vals = list(), param_set = ps(), packages = character(0), rowwise = FALSE,
stages_init = NULL) { # nolint
assert(check_function(fn), check_character(fn, len = 2L))
private$.fn = fn
private$.rowwise = assert_flag(rowwise)
Expand Down Expand Up @@ -440,7 +441,6 @@ create_ps = function(fn) {
#' Initial value for the `stages` parameter.
#' If `NULL` (default), will be set to `"both"` in case the `id` starts with `"trafo"` and to `"train"`
#' if it starts with `"augment"`. Otherwise it must specified.
#' @param
#' @template param_packages
#' @export
#' @returns An [`R6Class`][R6::R6Class] instance inheriting from [`PipeOpTaskPreprocTorch`]
Expand All @@ -460,8 +460,6 @@ pipeop_preproc_torch_class = function(id, fn, shapes_out, param_set = NULL, pack
"both"
} else if (startsWith(id, "augment_")) {
"train"
} else {
stopf("stages_init must be specified")
}
}

Expand Down
5 changes: 4 additions & 1 deletion R/shape.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ check_shape = function(shape, null_ok = FALSE, unknown_batch = NULL) {
sprintf("Invalid shape: %s.", paste0(format(shape), collapse = ", "))
}
assert_shapes = function(shapes, coerce = TRUE, named = FALSE, null_ok = FALSE, unknown_batch = NULL) { # nolint
ok = test_list(shapes, names = if (named && !identical(unique(names(shapes)), "...")) "unique", min.len = 1L)
ok = test_list(shapes, min.len = 1L)
if (named) {
assert_names(setdiff(names(shapes), "..."), type = "unique")
}
if (!ok) {
stopf("Invalid shape")
}
Expand Down
13 changes: 3 additions & 10 deletions R/task_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,8 @@ task_dataset = dataset(
# Here, we could have multiple `lazy_tensor` columns that share parts of the graph
# We try to merge those graphs if possible
if (length(lazy_tensor_features) > 1L) {
merge_result = try(merge_lazy_tensor_graphs(data), silent = TRUE)

if (inherits(merge_result, "try-error")) {
# This should basically never happen
lg$debug("Failed to merge data descriptor, this might lead to inefficient preprocessing.")
# TODO: test that is still works when this triggers
} else {
self$task$cbind(merge_result)
}
merge_result = merge_lazy_tensor_graphs(data)
self$task$cbind(merge_result)
}

# we can cache the output (hash) or the data (dataset_hash)
Expand Down Expand Up @@ -123,7 +116,7 @@ merge_compatible_lazy_tensor_graphs = function(lts) {
input_map = unname(unlist(input_map[graph$input$name]))

# some PipeOs that were previously terminal might not be anymore,
# for those we add nops and updaate the pointers for their data descriptors
# for those we add nops and update the pointers for their data descriptors
map_dtc(lts, function(lt) {
pointer_name = paste0(dd(lt)$pointer, collapse = ".")

Expand Down
25 changes: 0 additions & 25 deletions man/assert_shape.Rd

This file was deleted.

21 changes: 10 additions & 11 deletions man/mlr_pipeops_preproc_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions man/pipeop_preproc_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions man/pipeop_preproc_torch_class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 14 additions & 8 deletions tests/testthat/test_PipeOpTaskPreprocTorch.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("PipeOpTaskPreprocTorch: basic checks", {
po_test = po("preproc_torch", identity, packages = "R6")
po_test = po("preproc_torch", identity, packages = "R6", stages_init = "both")
expect_pipeop(po_test)
expect_equal(po_test$feature_types, "lazy_tensor")
expect_true(po_test$innum == 1L)
Expand All @@ -16,7 +16,8 @@ test_that("PipeOpTaskPreprocTorch: basic checks", {
super$initialize(
id = id,
param_vals = param_vals,
fn = function(x) -x
fn = function(x) -x,
stages_init = "both"
)
}
),
Expand Down Expand Up @@ -65,7 +66,8 @@ test_that("PipeOpTaskPreprocTorch: basic checks", {
param_set = param_set,
id = id,
param_vals = param_vals,
fn = function(x, a = 2) x * a
fn = function(x, a = 2) x * a,
stages_init = "both"
)
}
)
Expand Down Expand Up @@ -108,8 +110,12 @@ test_that("PipeOpTaskPreprocTorch: basic checks", {
# need to finish the augment -> stages transition, i.e. add tests and "both", then finish the preprocess implementation add the tests and test the preprocess autotest
# also test that the rowwise parameter works

po_test3 = pipeop_preproc_torch("test3", rowwise = FALSE, fn = function(x) x$reshape(-1), shapes_out = NULL)
po_test4 = pipeop_preproc_torch("test4", rowwise = TRUE, fn = function(x) x$reshape(-1), shapes_out = NULL)
po_test3 = pipeop_preproc_torch("test3", rowwise = FALSE, fn = function(x) x$reshape(-1), shapes_out = NULL,
stages_init = "both"
)
po_test4 = pipeop_preproc_torch("test4", rowwise = TRUE, fn = function(x) x$reshape(-1), shapes_out = NULL,
stages_init = "both"
)

expect_equal(
materialize(po_test3$train(list(task))[[1L]]$data(cols = "x1")$x1, rbind = TRUE)$shape,
Expand All @@ -120,8 +126,8 @@ test_that("PipeOpTaskPreprocTorch: basic checks", {
c(10, 1)
)

expect_true(pipeop_preproc_torch("test3", identity, rowwise = TRUE, shapes_out = NULL)$rowwise)
expect_false(pipeop_preproc_torch("test3", identity, rowwise = FALSE, shapes_out = NULL)$rowwise)
expect_true(pipeop_preproc_torch("test3", identity, rowwise = TRUE, shapes_out = NULL, stages_init = "both")$rowwise)
expect_false(pipeop_preproc_torch("test3", identity, rowwise = FALSE, shapes_out = NULL, stages_init = "both")$rowwise)

# stages_init works
})
Expand Down Expand Up @@ -155,7 +161,7 @@ test_that("PipeOpTaskPreprocTorch modifies the underlying lazy tensor columns co
taskin = as_task_regr(d, target = "y")

po_test = po("preproc_torch", fn = crate(function(x, a) x + a), param_set = ps(a = p_int(tags = c("train", "required"))),
a = -10, rowwise = FALSE)
a = -10, rowwise = FALSE, stages_init = "both")

taskout_train = po_test$train(list(taskin))[[1L]]
taskout_pred = po_test$predict(list(taskin))[[1L]]
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_materialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ test_that("PipeOpFeatureUnion can properly check whether two lazy tensors are id
task = tsk("lazy_iris")

graph = po("nop") %>>%
list(po("preproc_torch", function(x) x + 1), po("trafo_nop")) %>>%
list(po("preproc_torch", function(x) x + 1, stages_init = "both"), po("trafo_nop")) %>>%
po("featureunion")

expect_error(graph$train(task), "cannot aggregate different features sharing")
Expand Down

0 comments on commit 9b2b1e4

Please sign in to comment.