-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add embedded ensemble feature selection
- Loading branch information
Showing
4 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#' @title Embedded Ensemble Feature Selection | ||
#' | ||
#' @include CallbackBatchFSelect.R | ||
#' | ||
#' @description | ||
#' Ensemble feature selection using multiple learners. | ||
#' The ensemble feature selection method is designed to identify the most informative features from a given dataset by leveraging multiple machine learning models and resampling techniques. | ||
#' Returns an [EnsembleFSResult]. | ||
#' | ||
#' @details | ||
#' The method begins by applying an initial resampling technique specified by the user, to create **multiple subsamples** from the original dataset. | ||
#' This resampling process helps in generating diverse subsets of data for robust feature selection. | ||
#' | ||
#' For each subsample generated in the previous step, the method applies learners | ||
#' that support **embedded feature selection**. | ||
#' These learners are then scored on their ability to predict on the resampled | ||
#' test sets, storing the selected features during training, for each | ||
#' combination of subsample and learner. | ||
#' | ||
#' Results are stored in an [EnsembleFSResult]. | ||
#' | ||
#' @param learners (list of [mlr3::Learner])\cr | ||
#' The learners to be used for feature selection. | ||
#' All learners must have the `selected_features` property, i.e. implement | ||
#' embedded feature selection (e.g. regularized models). | ||
#' @param init_resampling ([mlr3::Resampling])\cr | ||
#' The initial resampling strategy of the data, from which each train set | ||
#' will be passed on to the learners and each test set will be used for | ||
#' prediction. | ||
#' Can only be [mlr3::ResamplingSubsampling] or [mlr3::ResamplingBootstrap]. | ||
#' @param measure ([mlr3::Measure])\cr | ||
#' The measure used to score each learner on the test sets generated by | ||
#' `init_resampling`. | ||
#' If `NULL`, default measure is used. | ||
#' @param store_benchmark_result (`logical(1)`)\cr | ||
#' Whether to store the benchmark result in [EnsembleFSResult] or not. | ||
#' | ||
#' @template param_task | ||
#' | ||
#' @returns an [EnsembleFSResult] object. | ||
#' | ||
#' @source | ||
#' `r format_bib("meinshausen2010", "hedou2024")` | ||
#' @export | ||
#' @examples | ||
#' \donttest{ | ||
#' eefsr = embedded_ensemble_fselect( | ||
#' task = tsk("sonar"), | ||
#' learners = lrns(c("classif.rpart", "classif.featureless")), | ||
#' init_resampling = rsmp("subsampling", repeats = 5), | ||
#' measure = msr("classif.ce") | ||
#' ) | ||
#' eefsr | ||
#' } | ||
embedded_ensemble_fselect = function( | ||
task, | ||
learners, | ||
init_resampling, | ||
measure, | ||
store_benchmark_result = TRUE | ||
) { | ||
# TOCHECK: DO WE NEED callbacks? | ||
assert_task(task) | ||
# check all learners support `selected_features()`, ie embedded feature selection | ||
assert_learners(as_learners(learners), task = task, properties = "selected_features") | ||
assert_resampling(init_resampling) | ||
assert_choice(class(init_resampling)[1], choices = c("ResamplingBootstrap", "ResamplingSubsampling")) | ||
assert_flag(store_benchmark_result) | ||
|
||
init_resampling$instantiate(task) | ||
grid = map_dtr(seq(init_resampling$iters), function(i) { | ||
# create task and resampling for each outer iteration | ||
task_train = task$clone()$filter(init_resampling$train_set(i)) | ||
pred_task = task$clone()$filter(init_resampling$test_set(i)) | ||
# TOCHECK: better way to keep the test sets? | ||
resampling = rsmp("insample")$instantiate(task_train) | ||
|
||
data.table( | ||
resampling_iteration = i, | ||
learner_id = map_chr(learners, "id"), | ||
learner = learners, | ||
task = list(task_train), | ||
pred_task = list(pred_task), | ||
resampling = list(resampling) | ||
) | ||
}) | ||
|
||
design = grid[, c("learner", "task", "resampling"), with = FALSE] | ||
print(design) | ||
|
||
bmr = benchmark(design, store_models = TRUE) | ||
|
||
trained_learners = bmr$score()$learner | ||
|
||
# extract selected features | ||
features = map(trained_learners, function(learner) { | ||
learner$selected_features() | ||
}) | ||
|
||
# extract n_features | ||
n_features = map_int(features, length) | ||
|
||
# TOCHECK: better way to predict on the kept test sets? if measure needs | ||
# extract scores from the init_resampling's test set | ||
scores = map_dbl(seq(length(trained_learners)), function(i) { | ||
learner = trained_learners[[i]] | ||
pred_task = grid$pred_task[[i]] # TOCHECK: 1-1 with the trained_learners? | ||
p = learner$predict(pred_task) | ||
p$score(measure) # TOCHECK: what if train_task is needed, can we do this with a resampling scheme? | ||
}) | ||
|
||
set(grid, j = "features", value = features) | ||
set(grid, j = "n_features", value = n_features) | ||
set(grid, j = measure$id, value = scores) | ||
set(grid, j = "learner", value = NULL) | ||
# remove R6 objects | ||
set(grid, j = "task", value = NULL) | ||
set(grid, j = "pred_task", value = NULL) | ||
set(grid, j = "resampling", value = NULL) | ||
|
||
EnsembleFSResult$new( | ||
result = grid, | ||
features = task$feature_names, | ||
benchmark_result = if (store_benchmark_result) bmr, | ||
measure_id = measure$id, | ||
minimize = measure$minimize | ||
) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.