Skip to content

Commit

Permalink
feat: add internal tuning callback
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Oct 24, 2024
1 parent e917a02 commit 70cd886
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 21 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Suggests:
mlr3pipelines,
rpart,
testthat (>= 3.0.0)
Remotes:
mlr-org/bbotk@result_extra
Config/testthat/edition: 3
Config/testthat/parallel: true
Encoding: UTF-8
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3fselect (development version)

* feat: Add internal tuning callback `mlr3fselect.internal_tuning`.

# mlr3fselect 1.1.1

* compatibility: bbotk 1.1.1
Expand Down
12 changes: 11 additions & 1 deletion R/AutoFSelector.R
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,18 @@ AutoFSelector = R6Class("AutoFSelector",

private = list(
.train = function(task) {
# construct instance from args; then tune
# construct instance from args
ia = self$instance_args
ia$task = task$clone()
instance = invoke(FSelectInstanceBatchSingleCrit$new, .args = ia)

# optimize feature selection
self$fselector$optimize(instance)

# make auto fselector available to callbacks
instance$objective$context$auto_fselector = self
call_back("on_auto_fselector_before_final_model", instance$objective$callbacks, instance$objective$context)

learner = ia$learner$clone(deep = TRUE)
task = task$clone()

Expand All @@ -325,9 +332,12 @@ AutoFSelector = R6Class("AutoFSelector",
task$select(feat)
learner$train(task)

call_back("on_auto_fselector_after_final_model", instance$objective$callbacks, instance$objective$context)

# the return model is a list of "learner", "features" and "fselect_instance"
result_model = list(learner = learner, features = feat)
if (private$.store_fselect_instance) result_model$fselect_instance = instance

result_model
},

Expand Down
46 changes: 31 additions & 15 deletions R/CallbackBatchFSelect.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ CallbackBatchFSelect = R6Class("CallbackBatchFSelect",
#' @field on_eval_before_archive (`function()`)\cr
#' Stage called before performance values are written to the archive.
#' Called in `ObjectiveFSelectBatch$eval_many()`.
on_eval_before_archive = NULL
on_eval_before_archive = NULL,

#' @field on_auto_fselector_before_final_model (`function()`)\cr
#' Stage called before the final model is trained.
#' Called in `AutoFSelector$train()`.
#' This stage is called after the optimization has finished and the final model is trained with the best feature set found.
on_auto_fselector_before_final_model = NULL,

#' @field on_auto_fselector_after_final_model (`function()`)\cr
#' Stage called after the final model is trained.
#' Called in `AutoFSelector$train()`.
#' This stage is called after the final model is trained with the best feature set found.
on_auto_fselector_after_final_model = NULL
)
)

Expand Down Expand Up @@ -59,26 +71,18 @@ CallbackBatchFSelect = R6Class("CallbackBatchFSelect",
#' - on_result
#' - on_optimization_end
#' End Feature Selection
#' Fit Final Model
#' - on_final_model
#' End Fit Final Model
#' ```
#'
#' See also the section on parameters for more information on the stages.
#' A feature selection callback works with [bbotk::ContextBatch] and [ContextBatchFSelect].
#'
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
#'
#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself.
#' Avoid writing large data the state.
#' This can slow down the feature selection when the evaluation of configurations is parallelized.
#'
#' Feature selection callbacks access two different contexts depending on the stage.
#' The stages `on_eval_after_design`, `on_eval_after_benchmark`, `on_eval_before_archive` access [ContextBatchFSelect].
#' This context can be used to customize the evaluation of a batch of feature sets.
#' Changes to the state of callback are lost after the evaluation of a batch and changes to the fselect instance or the fselector are not possible.
#' Persistent data should be written to the archive via `$aggregated_performance` (see [ContextBatchFSelect]).
#' The other stages access [bbotk::ContextBatch].
#' This context can be used to modify the fselect instance, archive, fselector and final result.
#' There are two different contexts because the evaluation can be parallelized i.e. multiple instances of [ContextBatchFSelect] exists on different workers at the same time.
#'
#' @param id (`character(1)`)\cr
#' Identifier for the new instance.
Expand Down Expand Up @@ -111,6 +115,12 @@ CallbackBatchFSelect = R6Class("CallbackBatchFSelect",
#' @param on_optimization_end (`function()`)\cr
#' Stage called at the end of the optimization.
#' Called in `Optimizer$optimize()`.
#' @param on_auto_fselector_before_final_model (`function()`)\cr
#' Stage called before the final model is trained.
#' Called in `AutoFSelector$train()`.
#' @param on_auto_fselector_after_final_model (`function()`)\cr
#' Stage called after the final model is trained.
#' Called in `AutoFSelector$train()`.
#'
#' @export
#' @inherit CallbackBatchFSelect examples
Expand All @@ -125,7 +135,9 @@ callback_batch_fselect = function(
on_eval_before_archive = NULL,
on_optimizer_after_eval = NULL,
on_result = NULL,
on_optimization_end = NULL
on_optimization_end = NULL,
on_auto_fselector_before_final_model = NULL,
on_auto_fselector_after_final_model = NULL
) {
stages = discard(set_names(list(
on_optimization_begin,
Expand All @@ -135,7 +147,9 @@ callback_batch_fselect = function(
on_eval_before_archive,
on_optimizer_after_eval,
on_result,
on_optimization_end),
on_optimization_end,
on_auto_fselector_before_final_model,
on_auto_fselector_after_final_model),
c(
"on_optimization_begin",
"on_optimizer_before_eval",
Expand All @@ -144,7 +158,9 @@ callback_batch_fselect = function(
"on_eval_before_archive",
"on_optimizer_after_eval",
"on_result",
"on_optimization_end")), is.null)
"on_optimization_end",
"on_auto_fselector_before_final_model",
"on_auto_fselector_after_final_model")), is.null)
walk(stages, function(stage) assert_function(stage, args = c("callback", "context")))
callback = CallbackBatchFSelect$new(id, label, man)
iwalk(stages, function(stage, name) callback[[name]] = stage)
Expand Down
7 changes: 7 additions & 0 deletions R/ContextBatchFSelect.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
#' @export
ContextBatchFSelect = R6Class("ContextBatchFSelect",
inherit = ContextBatch,
public = list(

#' @field auto_fselector ([AutoFSelector])\cr
#' The [AutoFSelector] instance.
auto_fselector = NULL
),

active = list(
#' @field xss (list())\cr
#' The feature sets of the latest batch.
Expand Down
6 changes: 4 additions & 2 deletions R/FSelectInstanceBatchMultiCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,18 @@ FSelectInstanceBatchMultiCrit = R6Class("FSelectInstanceBatchMultiCrit",
#'
#' @param ydt (`data.table::data.table()`)\cr
#' Optimal outcomes, e.g. the Pareto front.
#' @param extra (`data.table::data.table()`)\cr
#' Additional information.
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, ydt, ...) {
assign_result = function(xdt, ydt, extra = NULL, ...) {
# Add feature names to result for easy task subsetting
features = map(transpose_list(xdt), function(x) {
self$objective$task$feature_names[as.logical(x)]
})
set(xdt, j = "features", value = list(features))
set(xdt, j = "n_features", value = length(features[[1L]]))
super$assign_result(xdt, ydt)
super$assign_result(xdt, ydt, extra = extra)
if (!is.null(private$.result$x_domain)) set(private$.result, j = "x_domain", value = NULL)
},

Expand Down
6 changes: 4 additions & 2 deletions R/FSelectInstanceBatchSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,17 @@ FSelectInstanceBatchSingleCrit = R6Class("FSelectInstanceBatchSingleCrit",
#'
#' @param y (`numeric(1)`)\cr
#' Optimal outcome.
#' @param extra (`data.table::data.table()`)\cr
#' Additional information.
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, y, ...) {
assign_result = function(xdt, y, extra = NULL, ...) {
# Add feature names to result for easy task subsetting
feature_names = self$objective$task$feature_names
features = list(feature_names[as.logical(xdt[, feature_names, with = FALSE])])
set(xdt, j = "features", value = list(features))
set(xdt, j = "n_features", value = length(features[[1L]]))
super$assign_result(xdt, y)
super$assign_result(xdt, y, extra = extra)
if (!is.null(private$.result$x_domain)) set(private$.result, j = "x_domain", value = NULL)
},

Expand Down
3 changes: 2 additions & 1 deletion R/FSelectorBatchRFE.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ FSelectorBatchRFE = R6Class("FSelectorBatchRFE",
res = inst$archive$best()

xdt = res[, c(inst$search_space$ids(), "importance"), with = FALSE]
extra = res[, !c(inst$search_space$ids(), "importance"), with = FALSE]

# unlist keeps name!
y = unlist(res[, inst$archive$cols_y, with = FALSE])
inst$assign_result(xdt, y)
inst$assign_result(xdt, y, extra = extra)

invisible(NULL)
}
Expand Down
52 changes: 52 additions & 0 deletions R/mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,55 @@ load_callback_one_se_rule = function() {
}
)
}

#' @title Internal Tuning Callback
#'
#' @include CallbackBatchFSelect.R
#' @name mlr3fselect.internal_tuning
#'
#' @description
#' This callback runs internal tuning alongside the feature selection.
#' The internal tuning values are aggregated and stored in the results.
#' The final model is trained with the best feature set and the tuned value.
#'
#' @examples
#' clbk("mlr3fselect.internal_tuning")
NULL

load_callback_internal_tuning = function() {
callback_batch_fselect("mlr3fselect.internal_tuning",
label = "Internal Tuning",
man = "mlr3fselect::mlr3fselect.internal_tuning",

on_eval_before_archive = function(callback, context) {
# extract internal tuned values and aggregate folds
internal_tuned_values = mlr3misc::map(context$benchmark_result$resample_results$resample_result, function(resample_result) {
internal_tuned_values = mlr3misc::transpose_list(mlr3misc::map(mlr3misc::get_private(resample_result)$.data$learner_states(mlr3misc::get_private(resample_result)$.view), "internal_tuned_values"))
callback$state$internal_search_space$aggr_internal_tuned_values(internal_tuned_values)
})

data.table::set(context$aggregated_performance, j = "internal_tuned_values", value = list(internal_tuned_values))
},

on_optimization_end = function(callback, context) {
# save internal tuned values to results
set(context$result, j = "internal_tuned_values", value = list(context$result_extra[["internal_tuned_values"]]))
},

on_auto_fselector_before_final_model = function(callback, context) {
# copy original learner
callback$state$learner = context$auto_fselector$instance_args$learner$clone(deep = TRUE)

# deactivate internal tuning and set tuned values
learner = context$auto_fselector$instance_args$learner
learner$param_set$disable_internal_tuning(callback$state$internal_search_space$ids())
learner$param_set$set_values(.values = context$result$internal_tuned_values[[1]])
set_validate(learner, validate = NULL)
},

on_auto_fselector_after_final_model = function(callback, context) {
# restore original learner
context$auto_fselector$instance_args$learner = callback$state$learner
}
)
}
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
x$add("mlr3fselect.backup", load_callback_backup)
x$add("mlr3fselect.svm_rfe", load_callback_svm_rfe)
x$add("mlr3fselect.one_se_rule", load_callback_one_se_rule)
x$add("mlr3fselect.internal_tuning", load_callback_internal_tuning)

assign("lg", lgr::get_logger("bbotk"), envir = parent.env(environment()))
if (Sys.getenv("IN_PKGDOWN") == "true") {
Expand Down

0 comments on commit 70cd886

Please sign in to comment.