Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 11, 2024
1 parent 14acd73 commit 81b475d
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 30 deletions.
14 changes: 12 additions & 2 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,21 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
#' The benchmark result object.
#' @param measure_id (`character(1)`)\cr
#' Column name of `"result"` that corresponds to the measure used.
#' @param inner_measure_id (`character(1)`)\cr
#' Column name of `"result"` that corresponds to the inner measure used.
#' @param minimize (`logical(1)`)\cr
#' If `TRUE` (default), lower values of the measure correspond to higher performance.
initialize = function(result, features, benchmark_result = NULL, measure_id,
minimize = TRUE) {
initialize = function(
result,
features,
benchmark_result = NULL,
measure_id,
inner_measure_id = NULL,
minimize = TRUE
) {
assert_data_table(result)
private$.measure_id = assert_string(measure_id, null.ok = FALSE)
private$.inner_measure_id = assert_string(inner_measure_id, null.ok = TRUE)
mandatory_columns = c("resampling_iteration", "learner_id", "features", "n_features")
assert_names(names(result), must.include = c(mandatory_columns, measure_id))
private$.result = result
Expand Down Expand Up @@ -423,6 +432,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
.feature_ranking = NULL,
.features = NULL,
.measure_id = NULL,
.inner_measure_id = NULL,
.minimize = NULL
)
)
Expand Down
4 changes: 2 additions & 2 deletions R/embedded_ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ embedded_ensemble_fselect = function(

init_resampling$instantiate(task)

grid = benchmark_grid(
design = benchmark_grid(
task = task,
learners = learners,
resamplings = init_resampling
)

bmr = benchmark(grid, store_models = TRUE)
bmr = benchmark(design, store_models = TRUE)

trained_learners = bmr$score()$learner

Expand Down
66 changes: 40 additions & 26 deletions R/ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#' init_resampling = rsmp("subsampling", repeats = 2),
#' inner_resampling = rsmp("cv", folds = 3),
#' measure = msr("classif.ce"),
#' inner_measure = msr("classif.ce"),
#' terminator = trm("evals", n_evals = 10)
#' )
#' efsr
Expand All @@ -61,6 +62,7 @@ ensemble_fselect = function(
init_resampling,
inner_resampling,
measure,
inner_measure,
terminator,
callbacks = NULL,
store_benchmark_result = TRUE,
Expand All @@ -79,30 +81,34 @@ ensemble_fselect = function(
fselector = fselector,
learner = learner,
resampling = inner_resampling,
measure = measure,
measure = inner_measure,
terminator = terminator,
store_models = store_models,
callbacks = callbacks[[i]]
)
})

init_resampling$instantiate(task)
grid = map_dtr(seq(init_resampling$iters), function(i) {
# init_resampling$instantiate(task)
# # grid = map_dtr(seq(init_resampling$iters), function(i) {

# create task and resampling for each outer iteration
task_subset = task$clone()$filter(init_resampling$train_set(i))
resampling = rsmp("insample")$instantiate(task_subset)
# # # create task and resampling for each outer iteration
# # task_subset = task$clone()$filter(init_resampling$train_set(i))
# # resampling = rsmp("insample")$instantiate(task_subset)

data.table(
resampling_iteration = i,
learner_id = map_chr(learners, "id"),
learner = afss,
task = list(task_subset),
resampling = list(resampling)
)
})
# # data.table(
# # resampling_iteration = i,
# # learner_id = map_chr(learners, "id"),
# # learner = afss,
# # task = list(task_subset),
# # resampling = list(resampling)
# # )
# # })

design = grid[, c("learner", "task", "resampling"), with = FALSE]
design = benchmark_grid(
task = task,
learners = afss,
resamplings = init_resampling
)

bmr = benchmark(design, store_models = TRUE)

Expand All @@ -119,13 +125,25 @@ ensemble_fselect = function(
})

# extract scores
scores = map_dbl(afss, function(afs) {
afs$fselect_instance$archive$best()[, measure$id, with = FALSE][[1]]
inner_scores = map_dbl(afss, function(afs) {
afs$fselect_instance$archive$best()[, inner_measure$id, with = FALSE][[1]]
})

set(grid, j = "features", value = features)
set(grid, j = "n_features", value = n_features)
set(grid, j = measure$id, value = scores)
scores = bmr$score(measure)

set(scores, j = "features", value = features)
set(scores, j = "n_features", value = n_features)
set(scores, j = sprintf("%s_inner", inner_measure$id), value = inner_scores)
setnames(scores, "iteration", "resampling_iteration")

# remove R6 objects
set(scores, j = "learner", value = NULL)
set(scores, j = "task", value = NULL)
set(scores, j = "resampling", value = NULL)
set(scores, j = "prediction_test", value = NULL)
set(scores, j = "task_id", value = NULL)
set(scores, j = "nr", value = NULL)
set(scores, j = "resampling_id", value = NULL)

# extract importance scores if RFE optimization was used
if (class(fselector)[1] == "FSelectorBatchRFE") {
Expand All @@ -135,16 +153,12 @@ ensemble_fselect = function(
set(grid, j = "importance", value = imp_scores)
}

# remove R6 objects
set(grid, j = "learner", value = NULL)
set(grid, j = "task", value = NULL)
set(grid, j = "resampling", value = NULL)

EnsembleFSResult$new(
result = grid,
result = scores,
features = task$feature_names,
benchmark_result = if (store_benchmark_result) bmr,
measure_id = measure$id,
inner_measure_id = sprintf("%s_inner", inner_measure$id),
minimize = measure$minimize
)
}
4 changes: 4 additions & 0 deletions man/ensemble_fs_result.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/ensemble_fselect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 81b475d

Please sign in to comment.