Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 11, 2024
1 parent 0581cdc commit 14acd73
Showing 1 changed file with 20 additions and 30 deletions.
50 changes: 20 additions & 30 deletions R/embedded_ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,15 @@ embedded_ensemble_fselect = function(
assert_choice(class(init_resampling)[1], choices = c("ResamplingBootstrap", "ResamplingSubsampling"))
assert_flag(store_benchmark_result)

# activate predict sets
walk(learners, function(learner) {
learner$predict_sets = c("test", "internal_valid")
})

# set up the task and resampling for each iteration
init_resampling$instantiate(task)
grid = map_dtr(seq(init_resampling$iters), function(i) {
task_train = task$clone()$filter(init_resampling$train_set(i))
task_valid = task$clone()$filter(init_resampling$test_set(i))
task_train$internal_valid_task = task_valid
resampling = rsmp("insample")$instantiate(task_train)

data.table(
resampling_iteration = i,
learner_id = map_chr(learners, "id"),
learner = learners,
task = list(task_train),
resampling = list(resampling)
)
})
grid = benchmark_grid(
task = task,
learners = learners,
resamplings = init_resampling
)

design = grid[, c("learner", "task", "resampling"), with = FALSE]
bmr = benchmark(design, store_models = TRUE)
bmr = benchmark(grid, store_models = TRUE)

trained_learners = bmr$score()$learner

Expand All @@ -100,19 +85,24 @@ embedded_ensemble_fselect = function(
# extract n_features
n_features = map_int(features, length)

measure$predict_sets = "internal_valid"
scores_internal = bmr$score(measure)
# performance scores
scores = bmr$score(measure)

set(scores, j = "features", value = features)
set(scores, j = "n_features", value = n_features)
setnames(scores, "iteration", "resampling_iteration")

set(grid, j = "features", value = features)
set(grid, j = "n_features", value = n_features)
set(grid, j = measure$id, value = scores_internal[[measure$id]])
set(grid, j = "learner", value = NULL)
# remove R6 objects
set(grid, j = "task", value = NULL)
set(grid, j = "resampling", value = NULL)
set(scores, j = "learner", value = NULL)
set(scores, j = "task", value = NULL)
set(scores, j = "resampling", value = NULL)
set(scores, j = "prediction_test", value = NULL)
set(scores, j = "task_id", value = NULL)
set(scores, j = "nr", value = NULL)
set(scores, j = "resampling_id", value = NULL)

EnsembleFSResult$new(
result = grid,
result = scores,
features = task$feature_names,
benchmark_result = if (store_benchmark_result) bmr,
measure_id = measure$id,
Expand Down

0 comments on commit 14acd73

Please sign in to comment.