Skip to content

Commit

Permalink
feat: cache results
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed May 31, 2024
1 parent 62011f3 commit df1fd15
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
30 changes: 24 additions & 6 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
EnsembleFSResult = R6Class("EnsembleFSResult",
public = list(

#' @field benchmark_result (`BenchmarkResult`)\cr
#' @field benchmark_result ([mlr3::BenchmarkResult])\cr
#' The benchmark result.
benchmark_result = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param benchmark_result (`BenchmarkResult`)\cr
#' @param benchmark_result ([mlr3::BenchmarkResult])\cr
#' The benchmark result object.
#' @param result ([data.table::data.table])\cr
#' The result of the ensemble feature selection results.
Expand All @@ -53,6 +53,11 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
feature_ranking = function(method = "inclusion_probability") {
assert_choice(method, choices = "inclusion_probability")

# cached results
if (!is.null(private$.feature_ranking[[method]])) {
return(private$.feature_ranking[[method]])
}

features = self$benchmark_result$tasks$task[[1]]$feature_names

count = map_int(features, function(feature) {
Expand All @@ -64,25 +69,36 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
res = data.table(feature = features, inclusion_probability = count / nrow(self$result))
setorderv(res, "inclusion_probability", order = -1L)

res
private$.feature_ranking[[method]] = res
private$.feature_ranking[[method]]
},

#' @description
#' Calculates the stability of the selected features with the `stabm` package.
#' The results are cached.
#' When the same stability measure is requested again with different arguments, the cache must be reset.
#'
#' @param stability_measure (`character(1)`)\cr
#' The stability measure to be used.
#' One of the measures returned by [stabm::listStabilityMeasures()] in lower case.
#' Default is `"jaccard"`.
#' @param ... (`any`)\cr
#' Additional arguments passed to the stability measure function.
stability = function(stability_measure = "jaccard", ...) {
#' @param reset_cache (`logical(1)`)\cr
#' If `TRUE`, the cached results are ignored.
stability = function(stability_measure = "jaccard", ..., reset_cache = FALSE) {
funs = stabm::listStabilityMeasures()$Name
keys = tolower(gsub("stability", "", funs))
assert_choice(stability_measure, choices = keys)

# cached results
if (!is.null(private$.stability[[stability_measure]]) && !reset_cache) {
return(private$.stability[[stability_measure]])
}

fun = get(funs[which(stability_measure == keys)], envir = asNamespace("stabm"))
fun(self$result$features, ...)
private$.stability[[stability_measure]] = fun(self$result$features, ...)
private$.stability[[stability_measure]]
}
),

Expand All @@ -98,7 +114,9 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
),

private = list(
.result = NULL
.result = NULL,
.stability = NULL,
.feature_ranking = NULL
)
)

Expand Down
15 changes: 12 additions & 3 deletions man/EnsembleFSResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit df1fd15

Please sign in to comment.