Skip to content

Commit

Permalink
feat: allow different callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed May 31, 2024
1 parent df1fd15 commit bb55020
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
14 changes: 8 additions & 6 deletions R/ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
#' The inner resampling strategy used by the [FSelector].
#' @param store_models (`logical(1)`)\cr
#' Whether to store models in [auto_fselector] or not.
#' @param callbacks (list of lists of [CallbackBatchFSelect])\cr
#' Callbacks to be used for each learner.
#' The lists must have the same length as the number of learners.
#'
#' @template param_fselector
#' @template param_task
#' @template param_measure
#' @template param_terminator
#' @template param_callbacks
#'
#' @source
#' `r format_bib("saeys2008", "abeel2010", "pes2020")`
Expand All @@ -62,25 +64,25 @@ ensemble_fselect = function(
inner_resampling,
measure,
terminator,
callbacks = list(),
callbacks = NULL,
store_models = TRUE
) {
assert_task(task)
assert_learners(as_learners(learners), task = task)
assert_resampling(init_resampling)
assert_choice(class(init_resampling)[1],
choices = c("ResamplingBootstrap", "ResamplingSubsampling"))
assert_choice(class(init_resampling)[1], choices = c("ResamplingBootstrap", "ResamplingSubsampling"))
assert_list(callbacks, types = "list", len = length(learners), null.ok = TRUE)

# create auto_fselector for each learner
afss = map(learners, function(learner) {
afss = imap(unname(learners), function(learner, i) {
auto_fselector(
fselector = fselector,
learner = learner,
resampling = inner_resampling,
measure = measure,
terminator = terminator,
store_models = store_models,
callbacks = callbacks
callbacks = callbacks[[i]]
)
})

Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test_ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,15 @@ test_that("ensemble feature selection works", {
expect_data_table(feature_ranking, nrows = 60)
expect_names(names(feature_ranking), identical.to = c("feature", "inclusion_probability"))
})

test_that("different callbacks can be set", {
efsr = ensemble_fselect(
fselector = fs("rfe", n_features = 2, feature_fraction = 0.8),
task = tsk("sonar"),
learners = lrns(c("classif.rpart", "classif.featureless")),
init_resampling = rsmp("subsampling", repeats = 2),
inner_resampling = rsmp("cv", folds = 3),
measure = msr("classif.ce"),
terminator = trm("none")
)
})

0 comments on commit bb55020

Please sign in to comment.