Skip to content

Commit

Permalink
NA handling for PipeOpEncodeQuantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Jan 23, 2025
1 parent c774e38 commit 31ee558
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions R/PipeOpEncodePL.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' in `private$.get_bins()`, and then creating new feature columns through a continuous alternative to one-hot encoding.
#' Here, one new feature per bin is constructed, with values being either
#' * `0`, if the original value was below the lower bin boundary,
#' * `1`, if the original value was above the upper bin boundary, or
#' * `1`, if the original value was above or equal to the upper bin boundary, or
#' * a scaled value between `0` and `1`, if the original value was inside the bin boundaries. Scaling is done by
#' offsetting the original value by the lower bin boundary and dividing by the bin width.
#'
Expand Down Expand Up @@ -46,7 +46,7 @@
#' The `$state` is a named `list` with the `$state` elements inherited from [`PipeOpTaskPreprocSimple`], as well as:
#' * `bins` :: named `list`\cr
#' Named list of numeric vectors. Each element corresponds to one of the affected feature columns and contains the
#' bin boundaries derived through the private method `.get_bins()`. The element vectors are named by the respective
#' bin boundaries derived through `private$.get_bins()`. The element vectors are named by the respective
#' feature column.
#'
#' @section Parameters:
Expand Down Expand Up @@ -104,6 +104,10 @@ PipeOpEncodePL = R6Class("PipeOpEncodePL",
return(task) # early exit
}

# if (private$.can_handle_nas && any(task$missings())) {
# stopf("%s does not support handling of tasks with missing data.", class(self))
# }

dt = task$data(cols = cols)
res = imap_dtc(dt, function(d, col) encode_piecewise_linear(d, bins[[col]]))

Expand All @@ -124,10 +128,13 @@ encode_piecewise_linear = function(column, bins) {
for (t in seq_len(n_bins)) {
lower = bins[[t]]
upper = bins[[t + 1]]
colname = colnames(dt)[[t]]

dt[column >= upper, colnames(dt)[[t]] := 1]
indices = column < upper & column >= lower
dt[indices, colnames(dt)[[t]] := (column[indices] - lower) / (upper - lower)]
dt[column >= upper, (colname) := 1]
indices = !is.na(column) & column < upper & column >= lower
dt[indices, (colname) := (column[indices] - lower) / (upper - lower)]
# Filling NAs back in
dt[is.na(column), (colname) := NA]
}

dt
Expand All @@ -148,6 +155,8 @@ encode_piecewise_linear = function(column, bins) {
#' AND HOW THE NUMBER IS TO BE INTERPRETEED
#' (i.e. numsplits=2 creates 3 values, meaning five bins??!)
#'
#' document that nas are removed for calculating quantiles
#'
#' @section Construction:
#' ```
#' PipeOpEncodePL$new(id = "encodeplquantiles", param_vals = list())
Expand All @@ -168,11 +177,13 @@ encode_piecewise_linear = function(column, bins) {
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpEncodePL`]/[`PipeOpTaskPreprocSimple`], as well as:
#' * `numsplits` :: `numeric(1)` \cr
#' Default is ``.
#' * `numsplits` :: `integer(1)` \cr
#' Number of bins to create. Default is `2`.
#' * `type` :: `integer(1)`\cr
#' Method used to calculate sample quantiles. See help of [`stats::quantile`]. Default is `7`.
#'
#' @section Internals:
#'
#' Uses the [`stats::quantile`] function.
#'
#' @section Fields:
#' Only fields inherited from [`PipeOpEncodePL`]/[`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
Expand All @@ -196,7 +207,8 @@ PipeOpEncodePLQuantiles = R6Class("PipeOpEncodePLQuantiles",
public = list(
initialize = function(id = "encodeplquantiles", param_vals = list()) {
ps = ps(
numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict"))
numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict")),
type = p_int(lower = 1, upper = 9, default = 7, tags = c("train", "predict"))
)
super$initialize(id, param_set = ps, param_vals = param_vals, packages = "stats")
}
Expand All @@ -205,8 +217,15 @@ PipeOpEncodePLQuantiles = R6Class("PipeOpEncodePLQuantiles",

.get_bins = function(task, cols) {
numsplits = self$param_set$values$numsplits %??% 2
# Defaulting to default value in stats::quantile, i.e. method 7
type = self$param_set$values$type %??% 7

lapply(task$data(cols = cols), function(d) {
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d)))
unique(c(
min(d, na.rm = TRUE),
stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE, names = FALSE, type = type),
max(d, na.rm = TRUE)
))
})
}
)
Expand Down

0 comments on commit 31ee558

Please sign in to comment.