From bb55020ea66acb07cd58e53a39b8e99b1f11a91c Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 31 May 2024 12:51:11 +0200 Subject: [PATCH] feat: allow different callbacks --- R/ensemble_fselect.R | 14 ++++++++------ tests/testthat/test_ensemble_fselect.R | 12 ++++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/R/ensemble_fselect.R b/R/ensemble_fselect.R index 2f0fdf71..0feef2f9 100644 --- a/R/ensemble_fselect.R +++ b/R/ensemble_fselect.R @@ -32,12 +32,14 @@ #' The inner resampling strategy used by the [FSelector]. #' @param store_models (`logical(1)`)\cr #' Whether to store models in [auto_fselector] or not. +#' @param callbacks (list of lists of [CallbackBatchFSelect])\cr +#' Callbacks to be used for each learner. +#' The lists must have the same length as the number of learners. #' #' @template param_fselector #' @template param_task #' @template param_measure #' @template param_terminator -#' @template param_callbacks #' #' @source #' `r format_bib("saeys2008", "abeel2010", "pes2020")` @@ -62,17 +64,17 @@ ensemble_fselect = function( inner_resampling, measure, terminator, - callbacks = list(), + callbacks = NULL, store_models = TRUE ) { assert_task(task) assert_learners(as_learners(learners), task = task) assert_resampling(init_resampling) - assert_choice(class(init_resampling)[1], - choices = c("ResamplingBootstrap", "ResamplingSubsampling")) + assert_choice(class(init_resampling)[1], choices = c("ResamplingBootstrap", "ResamplingSubsampling")) + assert_list(callbacks, types = "list", len = length(learners), null.ok = TRUE) # create auto_fselector for each learner - afss = map(learners, function(learner) { + afss = imap(unname(learners), function(learner, i) { auto_fselector( fselector = fselector, learner = learner, @@ -80,7 +82,7 @@ ensemble_fselect = function( measure = measure, terminator = terminator, store_models = store_models, - callbacks = callbacks + callbacks = callbacks[[i]] ) }) diff --git a/tests/testthat/test_ensemble_fselect.R b/tests/testthat/test_ensemble_fselect.R index 74ff1678..cb3eed5d 100644 --- a/tests/testthat/test_ensemble_fselect.R +++ b/tests/testthat/test_ensemble_fselect.R @@ -21,3 +21,15 @@ test_that("ensemble feature selection works", { expect_data_table(feature_ranking, nrows = 60) expect_names(names(feature_ranking), identical.to = c("feature", "inclusion_probability")) }) + +test_that("different callbacks can be set", { + efsr = ensemble_fselect( + fselector = fs("rfe", n_features = 2, feature_fraction = 0.8), + task = tsk("sonar"), + learners = lrns(c("classif.rpart", "classif.featureless")), + init_resampling = rsmp("subsampling", repeats = 2), + inner_resampling = rsmp("cv", folds = 3), + measure = msr("classif.ce"), + terminator = trm("none") + ) +})