Skip to content

Commit

Permalink
feat: ties methods (#92)
Browse files Browse the repository at this point in the history
* feat: ties methods for archive$best()

* fix: n_features

* feat: add global ties option

* fix: rfe

* test: remove default

* test: remove duplicated test

* feat: add to fsi

* test: formal args

* docs: global default

* test: fsi

* chore: update news

* fix: ties method to auto_fselector

* refactor: rename to least features and remove first option

* chore: whitespace

* fix: default

* fix: nested

* tests: remove first

* fix: archive

* docs: update
  • Loading branch information
be-marc authored Dec 16, 2023
1 parent 09540cb commit 2824c69
Show file tree
Hide file tree
Showing 24 changed files with 473 additions and 37 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3fselect (development version)

* feat: Add `ties_method` options `"least_features"` and `"random"` to `ArchiveFSelect$best()`.
* refactor: Optimize runtime of `ArchiveFSelect$best()` method.
* feat: Add importance scores to result of `FSelectorRFE`.
* feat: Add number of features to `as.data.table.ArchiveFSelect()`.
* feat: Features can be always included with the `always_include` column role.
Expand Down
68 changes: 67 additions & 1 deletion R/ArchiveFSelect.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
#' * `measures` (list of [mlr3::Measure])\cr
#' Score feature sets on additional measures.
#'
#' @template param_ties_method
#'
#' @export
ArchiveFSelect = R6Class("ArchiveFSelect",
inherit = Archive,
Expand All @@ -72,8 +74,14 @@ ArchiveFSelect = R6Class("ArchiveFSelect",
#'
#' @param check_values (`logical(1)`)\cr
#' If `TRUE` (default), hyperparameter configurations are check for validity.
initialize = function(search_space, codomain, check_values = TRUE) {
initialize = function(
search_space,
codomain,
check_values = TRUE,
ties_method = "least_features"
) {
super$initialize(search_space, codomain, check_values)
self$ties_method = ties_method

# initialize empty benchmark result
self$benchmark_result = BenchmarkResult$new()
Expand Down Expand Up @@ -141,7 +149,65 @@ ArchiveFSelect = R6Class("ArchiveFSelect",
print = function() {
catf(format(self))
print(self$data[, setdiff(names(self$data), "uhash"), with = FALSE], digits=2)
},

#' @description
#' Returns the best scoring feature sets.
#'
#' @param batch (`integer()`)\cr
#' The batch number(s) to limit the best results to.
#' Default is all batches.
#' @param ties_method (`character(1)`)\cr
#' Method to handle ties.
#' If `NULL` (default), the global ties method set during initialization is used.
#' The default global ties method is `least_features` which selects the feature set with the least features.
#' If there are multiple best feature sets with the same number of features, one is selected randomly.
#' The `random` method returns a random feature set from the best feature sets.
#
#' @return [data.table::data.table()]
best = function(batch = NULL, ties_method = NULL) {
ties_method = assert_choice(ties_method, c("least_features", "random"), null.ok = TRUE) %??% self$ties_method
assert_subset(batch, seq_len(self$n_batch))
if (self$n_batch == 0L) return(data.table())

tab = if (is.null(batch)) self$data else self$data[list(batch), , on = "batch_nr"]

if (self$codomain$target_length == 1L) {
y = tab[[self$cols_y]] * -self$codomain$maximization_to_minimization

if (ties_method == "least_features") {
ii = which(y == max(y))
tab = tab[ii]
ii = which_min(rowSums(tab[, self$cols_x, with = FALSE]), ties_method = "random")
tab[ii]
} else {
ii = which_max(y, ties_method = "random")
tab[ii]
}
} else {
ymat = t(as.matrix(tab[, self$cols_y, with = FALSE]))
ymat = self$codomain$maximization_to_minimization * ymat
tab[!is_dominated(ymat)]
}
}
),

active = list(

#' @field ties_method (`character(1)`)\cr
#' Method to handle ties.
ties_method = function(rhs) {
if (!missing(rhs)) {
assert_choice(rhs, c("least_features", "random"))
private$.ties_method = rhs
} else {
private$.ties_method
}
}
),

private = list(
.ties_method = NULL
)
)

Expand Down
16 changes: 15 additions & 1 deletion R/AutoFSelector.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_ties_method
#'
#' @export
#' @examples
Expand Down Expand Up @@ -111,7 +112,19 @@ AutoFSelector = R6Class("AutoFSelector",
#'
#' @param fselector ([FSelector])\cr
#' Optimization algorithm.
initialize = function(fselector, learner, resampling, measure = NULL, terminator, store_fselect_instance = TRUE, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
initialize = function(
fselector,
learner,
resampling,
measure = NULL,
terminator,
store_fselect_instance = TRUE,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "least_features"
) {
ia = list()
self$fselector = assert_r6(fselector, "FSelector")$clone()
ia$learner = assert_learner(as_learner(learner, clone = TRUE))
Expand All @@ -125,6 +138,7 @@ AutoFSelector = R6Class("AutoFSelector",

ia$check_values = assert_flag(check_values)
ia$callbacks = assert_callbacks(as_callbacks(callbacks))
ia$ties_method = assert_choice(ties_method, c("least_features", "random"))
self$instance_args = ia

super$initialize(
Expand Down
17 changes: 15 additions & 2 deletions R/FSelectInstanceSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#' @template param_store_benchmark_result
#' @template param_callbacks
#' @template param_xdt
#' @template param_ties_method
#'
#' @export
#' @examples
Expand Down Expand Up @@ -93,12 +94,24 @@ FSelectInstanceSingleCrit = R6Class("FSelectInstanceSingleCrit",

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task, learner, resampling, measure, terminator, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
initialize = function(
task,
learner,
resampling,
measure,
terminator,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "least_features"
) {
# initialized specialized fselect archive and objective
archive = ArchiveFSelect$new(
search_space = task_to_domain(assert_task(task)),
codomain = measures_to_codomain(assert_measure(measure)),
check_values = check_values)
check_values = check_values,
ties_method = ties_method)

objective = ObjectiveFSelect$new(
task = task,
Expand Down
20 changes: 18 additions & 2 deletions R/auto_fselector.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,25 @@
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_ties_method
#'
#' @export
#' @inherit AutoFSelector examples
auto_fselector = function(fselector, learner, resampling, measure = NULL, term_evals = NULL, term_time = NULL, terminator = NULL, store_fselect_instance = TRUE, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
auto_fselector = function(
fselector,
learner,
resampling,
measure = NULL,
term_evals = NULL,
term_time = NULL,
terminator = NULL,
store_fselect_instance = TRUE,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "least_features"
) {
terminator = terminator %??% terminator_selection(term_evals, term_time)

AutoFSelector$new(
Expand All @@ -35,5 +50,6 @@ auto_fselector = function(fselector, learner, resampling, measure = NULL, term_e
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
callbacks = callbacks,
ties_method = ties_method)
}
52 changes: 40 additions & 12 deletions R/fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_ties_method
#'
#' @export
#' @examples
Expand All @@ -60,21 +61,48 @@
#'
#' # Inspect all evaluated configurations
#' as.data.table(instance$archive)
fselect = function(fselector, task, learner, resampling, measures = NULL, term_evals = NULL, term_time = NULL, terminator = NULL, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
fselect = function(
fselector,
task,
learner,
resampling,
measures = NULL,
term_evals = NULL,
term_time = NULL,
terminator = NULL,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "least_features"
) {
assert_fselector(fselector)
terminator = terminator %??% terminator_selection(term_evals, term_time)

FSelectInstance = if (!is.list(measures)) FSelectInstanceSingleCrit else FSelectInstanceMultiCrit
instance = FSelectInstance$new(
task = task,
learner = learner,
resampling = resampling,
measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
instance = if (!is.list(measures)) {
FSelectInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = resampling,
measure = measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks,
ties_method = ties_method)
} else {
FSelectInstanceMultiCrit$new(
task = task,
learner = learner,
resampling = resampling,
measures = measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
}

fselector$optimize(instance)
instance
Expand Down
22 changes: 20 additions & 2 deletions R/fselect_nested.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_ties_method
#'
#' @export
#' @examples
Expand All @@ -40,7 +41,23 @@
#'
#' # Unbiased performance of the final model trained on the full data set
#' rr$aggregate()
fselect_nested = function(fselector, task, learner, inner_resampling, outer_resampling, measure = NULL, term_evals = NULL, term_time = NULL, terminator = NULL, store_fselect_instance = TRUE, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
fselect_nested = function(
fselector,
task,
learner,
inner_resampling,
outer_resampling,
measure = NULL,
term_evals = NULL,
term_time = NULL,
terminator = NULL,
store_fselect_instance = TRUE,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "least_features"
) {
assert_task(task)
assert_resampling(inner_resampling)
assert_resampling(outer_resampling)
Expand All @@ -56,7 +73,8 @@ fselect_nested = function(fselector, task, learner, inner_resampling, outer_resa
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
callbacks = callbacks,
ties_method = ties_method)

resample(task, afs, outer_resampling, store_models = TRUE)
}
1 change: 1 addition & 0 deletions R/mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ load_callback_svm_rfe = function() {
#'
#' @description
#' Selects the smallest feature set within one standard error of the best as the result.
#' If there are multiple feature sets with the same performance and number of features, the first one is selected.
#'
#' @source
#' `r format_bib("kuhn2013")`
Expand Down
40 changes: 37 additions & 3 deletions R/sugar.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,47 @@ fss = function(.keys, ...) {
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_ties_method
#'
#' @inheritSection FSelectInstanceSingleCrit Resources
#' @inheritSection FSelectInstanceSingleCrit Default Measures
#'
#' @export
#' @inherit FSelectInstanceSingleCrit examples
fsi = function(task, learner, resampling, measures = NULL, terminator, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
FSelectInstance = if (!is.list(measures)) FSelectInstanceSingleCrit else FSelectInstanceMultiCrit
FSelectInstance$new(task, learner, resampling, measures, terminator, store_benchmark_result, store_models, check_values, callbacks)
fsi = function(
task,
learner,
resampling,
measures = NULL,
terminator,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "least_features"
) {
if (!is.list(measures)) {
FSelectInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = resampling,
measure = measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks,
ties_method = ties_method)
} else {
FSelectInstanceMultiCrit$new(
task = task,
learner = learner,
resampling = resampling,
measures = measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
}
}
4 changes: 4 additions & 0 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ expect_max_features = function(features, n) {
res = max(rowSums(features))
expect_set_equal(res, n)
}

expect_features = function(res, identical_to = NULL, must_include = NULL) {
expect_names(names(res)[as.logical(res)], must.include = must_include, identical.to = identical_to)
}
Loading

0 comments on commit 2824c69

Please sign in to comment.