Skip to content

Commit

Permalink
added rough test for problem
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Feb 15, 2025
1 parent 11cdde6 commit 76d6150
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions tests/testthat/test_pipeop_targetmutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,44 @@ test_that("PipeOpTargetMutate - basic properties", {

test_that("PipeOpTargetMutate - log base 2 trafo", {
skip_if_not_installed("rpart")
g = Graph$new()
g$add_pipeop(PipeOpTargetMutate$new("logtrafo",
param_vals = list(
trafo = function(x) log(x, base = 2),
inverter = function(x) list(response = 2 ^ x$response))
)
)
g$add_pipeop(LearnerRegrRpart$new())
g$add_pipeop(PipeOpTargetInvert$new())
g$add_edge(src_id = "logtrafo", dst_id = "targetinvert", src_channel = 1L, dst_channel = 1L)
g$add_edge(src_id = "logtrafo", dst_id = "regr.rpart", src_channel = 2L, dst_channel = 1L)
g$add_edge(src_id = "regr.rpart", dst_id = "targetinvert", src_channel = 1L, dst_channel = 2L)

task = mlr_tasks$get("boston_housing_classic")
train_out = g$train(task)
predict_out = g$predict(task)

dat = task$data()
dat$medv = log(dat$medv, base = 2)
task_log = TaskRegr$new("boston_housing_classic_log", backend = dat, target = "medv")

learner = LearnerRegrRpart$new()
learner$train(task_log)

learner_predict_out = learner$predict(task_log)
expect_equal(2 ^ learner_predict_out$truth, predict_out[[1L]]$truth)
expect_equal(2 ^ learner_predict_out$response, predict_out[[1L]]$response)
g = Graph$new()
g$add_pipeop(PipeOpTargetMutate$new("logtrafo",
param_vals = list(
trafo = function(x) log(x, base = 2),
inverter = function(x) list(response = 2 ^ x$response))
)
)
g$add_pipeop(LearnerRegrRpart$new())
g$add_pipeop(PipeOpTargetInvert$new())
g$add_edge(src_id = "logtrafo", dst_id = "targetinvert", src_channel = 1L, dst_channel = 1L)
g$add_edge(src_id = "logtrafo", dst_id = "regr.rpart", src_channel = 2L, dst_channel = 1L)
g$add_edge(src_id = "regr.rpart", dst_id = "targetinvert", src_channel = 1L, dst_channel = 2L)

task = mlr_tasks$get("boston_housing_classic")
train_out = g$train(task)
predict_out = g$predict(task)

dat = task$data()
dat$medv = log(dat$medv, base = 2)
task_log = TaskRegr$new("boston_housing_classic_log", backend = dat, target = "medv")

learner = LearnerRegrRpart$new()
learner$train(task_log)

learner_predict_out = learner$predict(task_log)
expect_equal(2 ^ learner_predict_out$truth, predict_out[[1L]]$truth)
expect_equal(2 ^ learner_predict_out$response, predict_out[[1L]]$response)
})

test_that("PipeOpTargetMutate - does not drop missing levels, #631", {
task = tsk("boston_housing")$filter(1:100)
train_out = op$train(list(task))[["output"]]
# train_out should also know all levels
expect_equal(task$levels(), train_out$levels())

# Check whether we need to fix and test this in other TargetTrafo POs
# Check why this does not occur in predict
# Think about how to solve this: in pipelines or mlr3::convert_task?
})

#'test_that("PipeOpTargetMutate - Regr -> Classif", {
Expand Down

0 comments on commit 76d6150

Please sign in to comment.