Skip to content

Commit

Permalink
use internal valid task
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 11, 2024
1 parent 123624e commit 0581cdc
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions R/embedded_ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,43 +53,41 @@
#' eefsr
#' }
embedded_ensemble_fselect = function(
task,
learners,
init_resampling,
measure,
store_benchmark_result = TRUE
task,
learners,
init_resampling,
measure,
store_benchmark_result = TRUE
) {
# TOCHECK: DO WE NEED callbacks?
assert_task(task)
# check all learners support `selected_features()`, ie embedded feature selection
assert_learners(as_learners(learners), task = task, properties = "selected_features")
assert_resampling(init_resampling)
assert_choice(class(init_resampling)[1], choices = c("ResamplingBootstrap", "ResamplingSubsampling"))
assert_flag(store_benchmark_result)

# activate predict sets
walk(learners, function(learner) {
learner$predict_sets = c("test", "internal_valid")
})

# set up the task and resampling for each iteration
init_resampling$instantiate(task)
grid = map_dtr(seq(init_resampling$iters), function(i) {
# create task and resampling for each outer iteration
task_train = task$clone()$filter(init_resampling$train_set(i))
pred_task = task$clone()$filter(init_resampling$test_set(i))
# TOCHECK: better way to keep the test sets?
task_valid = task$clone()$filter(init_resampling$test_set(i))
task_train$internal_valid_task = task_valid
resampling = rsmp("insample")$instantiate(task_train)

data.table(
resampling_iteration = i,
learner_id = map_chr(learners, "id"),
learner = learners,
task = list(task_train),
pred_task = list(pred_task),
resampling = list(resampling)
)
})

# TOCHECK: can we have `task` as the original task here,
# `resampling` the subsamples (train/test partitions) so that
# bmr$score(measure) later just works out of the box?
design = grid[, c("learner", "task", "resampling"), with = FALSE]

bmr = benchmark(design, store_models = TRUE)

trained_learners = bmr$score()$learner
Expand All @@ -102,22 +100,15 @@ embedded_ensemble_fselect = function(
# extract n_features
n_features = map_int(features, length)

# TOCHECK: better way to predict on the kept test sets? if measure needs
# extract scores from the init_resampling's test set
scores = map_dbl(seq(length(trained_learners)), function(i) {
learner = trained_learners[[i]]
pred_task = grid$pred_task[[i]] # TOCHECK: 1-1 with the trained_learners?
p = learner$predict(pred_task)
p$score(measure) # TOCHECK: what if train_task is needed, can we do this with a resampling scheme?
})
measure$predict_sets = "internal_valid"
scores_internal = bmr$score(measure)

set(grid, j = "features", value = features)
set(grid, j = "n_features", value = n_features)
set(grid, j = measure$id, value = scores)
set(grid, j = measure$id, value = scores_internal[[measure$id]])
set(grid, j = "learner", value = NULL)
# remove R6 objects
set(grid, j = "task", value = NULL)
set(grid, j = "pred_task", value = NULL)
set(grid, j = "resampling", value = NULL)

EnsembleFSResult$new(
Expand Down

0 comments on commit 0581cdc

Please sign in to comment.