Skip to content

Commit

Permalink
fix variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Jun 18, 2024
1 parent 9423cb9 commit 727f3a8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
28 changes: 14 additions & 14 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
#' selection.
#' @param benchmark_result ([mlr3::BenchmarkResult])\cr
#' The benchmark result object.
#' @param measure_var (`character(1)`)\cr
#' @param measure_id (`character(1)`)\cr
#' Column name of `"result"` that corresponds to the measure used.
#' @param minimize (`logical(1)`)\cr
#' If `TRUE` (default), lower values of the measure correspond to higher performance.
initialize = function(result, features, benchmark_result = NULL, measure_id,
minimize = TRUE) {
assert_data_table(result)
private$.measure_var = assert_string(measure_var, null.ok = FALSE)
private$.measure_id = assert_string(measure_id, null.ok = FALSE)
mandatory_columns = c("resampling_iteration", "learner_id", "features", "n_features")
assert_names(names(result), must.include = c(mandatory_columns, measure_var))
assert_names(names(result), must.include = c(mandatory_columns, measure_id))
private$.result = result
private$.features = assert_character(features, any.missing = FALSE, null.ok = FALSE)
private$.minimize = assert_logical(minimize, null.ok = FALSE)
Expand Down Expand Up @@ -219,15 +219,15 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
pareto_front = function(type = "empirical") {
assert_choice(type, choices = c("empirical", "estimated"))
result = private$.result
measure_var = private$.measure_var
measure_id = private$.measure_id

# Keep only n_features and performance scores
cols_to_keep = c("n_features", measure_var)
cols_to_keep = c("n_features", measure_id)
data = result[, ..cols_to_keep][order(n_features)]

# Initialize the Pareto front
pf = data.table(n_features = numeric(0))
pf[, (measure_var) := numeric(0)]
pf[, (measure_id) := numeric(0)]

# Initialize the best performance to a large number so
# that the Pareto front has at least one point
Expand All @@ -237,14 +237,14 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
for (i in seq_row(data)) {
# Determine the condition based on minimize
if (minimize) {
condition = data[[measure_var]][i] < best_score
condition = data[[measure_id]][i] < best_score
} else {
condition = data[[measure_var]][i] > best_score
condition = data[[measure_id]][i] > best_score
}

if (condition) {
pf = rbind(pf, data[i])
best_score = data[[measure_var]][i]
best_score = data[[measure_id]][i]
}
}

Expand All @@ -253,13 +253,13 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
pf[, n_features_inv := 1 / n_features]

# Fit the linear model
form = mlr3misc::formulate(lhs = measure_var, rhs = "n_features_inv")
form = mlr3misc::formulate(lhs = measure_id, rhs = "n_features_inv")
model = stats::lm(formula = form, data = pf)

# Predict values using the model to create a smooth curve
pf_pred = data.table(n_features = seq(1, max(data$n_features)))
pf_pred[, n_features_inv := 1 / n_features]
pf_pred[, (measure_var) := predict(model, newdata = pf_pred)]
pf_pred[, (measure_id) := predict(model, newdata = pf_pred)]
pf_pred$n_features_inv = NULL
pf = pf_pred
}
Expand All @@ -279,7 +279,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
cbind(private$.result, tab)
},

#' @field nlearners (`numeric(1)`)\cr
#' @field n_learners (`numeric(1)`)\cr
#' Returns the number of learners used in the ensemble feature selection.
n_learners = function(rhs) {
assert_ro_binding(rhs)
Expand All @@ -290,7 +290,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
#' Returns the measure id used in the ensemble feature selection.
measure = function(rhs) {
assert_ro_binding(rhs)
private$.measure_var
private$.measure_id
}
),

Expand All @@ -300,7 +300,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
.stability_learner = NULL,
.feature_ranking = NULL,
.features = NULL,
.measure_var = NULL,
.measure_id = NULL,
.minimize = NULL
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ ensemble_fselect = function(
result = grid,
features = task$feature_names,
benchmark_result = if (store_benchmark_result) bmr,
measure_var = measure$id,
measure_id = measure$id,
minimize = measure$minimize
)
}
6 changes: 3 additions & 3 deletions man/ensemble_fs_result.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions tests/testthat/test_ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test_that("ensemble feature selection works", {
expect_vector(efsr$result$classif.ce, size = 4)
expect_benchmark_result(efsr$benchmark_result)
expect_equal(efsr$measure, "classif.ce")
expect_equal(efsr$nlearners, 2)
expect_equal(efsr$n_learners, 2)

# stability
expect_number(efsr$stability(stability_measure = "jaccard"))
Expand Down Expand Up @@ -63,7 +63,7 @@ test_that("ensemble feature selection works without benchmark result", {
expect_vector(efsr$result$classif.ce, size = 4)
expect_null(efsr$benchmark_result)
expect_equal(efsr$measure, "classif.ce")
expect_equal(efsr$nlearners, 2)
expect_equal(efsr$n_learners, 2)

# stability
expect_number(efsr$stability(stability_measure = "jaccard"))
Expand Down Expand Up @@ -109,7 +109,7 @@ test_that("ensemble feature selection works with rfe", {
expect_list(efsr$result$importance, any.missing = FALSE, len = 4)
expect_benchmark_result(efsr$benchmark_result)
expect_equal(efsr$measure, "classif.ce")
expect_equal(efsr$nlearners, 2)
expect_equal(efsr$n_learners, 2)

# stability
expect_number(efsr$stability(stability_measure = "jaccard"))
Expand Down Expand Up @@ -137,7 +137,7 @@ test_that("ensemble feature selection works with rfe", {

test_that("EnsembleFSResult initialization", {
result = data.table(a = 1, b = 3)
expect_error(EnsembleFSResult$new(result = result, features = LETTERS, measure_var = "a"), "missing elements")
expect_error(EnsembleFSResult$new(result = result, features = LETTERS, measure_id = "a"), "missing elements")

result = data.table(
resampling_iteration = c(1, 1, 1, 2, 2, 2, 3, 3, 3),
Expand All @@ -157,9 +157,9 @@ test_that("EnsembleFSResult initialization", {
)

# works without benchmark result object
efsr = EnsembleFSResult$new(result = result, features = paste0("V", 1:20), measure_var = "classif.ce")
efsr = EnsembleFSResult$new(result = result, features = paste0("V", 1:20), measure_id = "classif.ce")
expect_class(efsr, "EnsembleFSResult")
expect_equal(efsr$nlearners, 3)
expect_equal(efsr$n_learners, 3)
tab = as.data.table(efsr)
expect_data_table(tab)
expect_names(names(tab), identical.to = c("resampling_iteration", "learner_id",
Expand Down

0 comments on commit 727f3a8

Please sign in to comment.