diff --git a/tests/testthat/test_pipeop_targetmutate.R b/tests/testthat/test_pipeop_targetmutate.R index 1d06fd01d..b01c0ac7c 100644 --- a/tests/testthat/test_pipeop_targetmutate.R +++ b/tests/testthat/test_pipeop_targetmutate.R @@ -40,33 +40,44 @@ test_that("PipeOpTargetMutate - basic properties", { test_that("PipeOpTargetMutate - log base 2 trafo", { skip_if_not_installed("rpart") - g = Graph$new() - g$add_pipeop(PipeOpTargetMutate$new("logtrafo", - param_vals = list( - trafo = function(x) log(x, base = 2), - inverter = function(x) list(response = 2 ^ x$response)) - ) - ) - g$add_pipeop(LearnerRegrRpart$new()) - g$add_pipeop(PipeOpTargetInvert$new()) - g$add_edge(src_id = "logtrafo", dst_id = "targetinvert", src_channel = 1L, dst_channel = 1L) - g$add_edge(src_id = "logtrafo", dst_id = "regr.rpart", src_channel = 2L, dst_channel = 1L) - g$add_edge(src_id = "regr.rpart", dst_id = "targetinvert", src_channel = 1L, dst_channel = 2L) - - task = mlr_tasks$get("boston_housing_classic") - train_out = g$train(task) - predict_out = g$predict(task) - - dat = task$data() - dat$medv = log(dat$medv, base = 2) - task_log = TaskRegr$new("boston_housing_classic_log", backend = dat, target = "medv") - - learner = LearnerRegrRpart$new() - learner$train(task_log) - - learner_predict_out = learner$predict(task_log) - expect_equal(2 ^ learner_predict_out$truth, predict_out[[1L]]$truth) - expect_equal(2 ^ learner_predict_out$response, predict_out[[1L]]$response) + g = Graph$new() + g$add_pipeop(PipeOpTargetMutate$new("logtrafo", + param_vals = list( + trafo = function(x) log(x, base = 2), + inverter = function(x) list(response = 2 ^ x$response)) + ) + ) + g$add_pipeop(LearnerRegrRpart$new()) + g$add_pipeop(PipeOpTargetInvert$new()) + g$add_edge(src_id = "logtrafo", dst_id = "targetinvert", src_channel = 1L, dst_channel = 1L) + g$add_edge(src_id = "logtrafo", dst_id = "regr.rpart", src_channel = 2L, dst_channel = 1L) + g$add_edge(src_id = "regr.rpart", dst_id = "targetinvert", src_channel = 1L, dst_channel = 2L) + + task = mlr_tasks$get("boston_housing_classic") + train_out = g$train(task) + predict_out = g$predict(task) + + dat = task$data() + dat$medv = log(dat$medv, base = 2) + task_log = TaskRegr$new("boston_housing_classic_log", backend = dat, target = "medv") + + learner = LearnerRegrRpart$new() + learner$train(task_log) + + learner_predict_out = learner$predict(task_log) + expect_equal(2 ^ learner_predict_out$truth, predict_out[[1L]]$truth) + expect_equal(2 ^ learner_predict_out$response, predict_out[[1L]]$response) +}) + +test_that("PipeOpTargetMutate - does not drop missing levels, #631", { + task = tsk("boston_housing")$filter(1:100) + train_out = op$train(list(task))[["output"]] + # train_out should also know all levels + expect_equal(task$levels(), train_out$levels()) + + # Check whether we need to fix and test this in other TargetTrafo POs + # Check why this does not occur in predict + # Think about how to solve this: in pipelines or mlr3::convert_task? }) #'test_that("PipeOpTargetMutate - Regr -> Classif", {