|
| 1 | +#' @title PipeOpPredRegrSurvPEM |
| 2 | +#' @name mlr_pipeops_trafopred_regrsurv_pem |
| 3 | +#' |
| 4 | +#' @description |
| 5 | +#' Transform [PredictionRegr] to [PredictionSurv]. |
| 6 | +#' The predicted piece-wise constant hazards contained in [PredictionRegr] are transformed into survival probabilities and wrapped in a |
| 7 | +#' [PredictionSurv] object. |
| 8 | +#' |
| 9 | +#' We compute the survival probability from the predicted hazards using the following relation: |
| 10 | +#' \deqn{S(t | \mathbf{x}) = \exp \left( - \int_{0}^{t} \lambda(s | \mathbf{x}) \, ds \right) = \exp \left( - \sum_{j = 1}^{J} \lambda(j | \mathbf{x}) d_j\, \right),} |
| 11 | +#' where \eqn{j = 1, \ldots, J} denotes the interval, \eqn{t} the time, and \eqn{d_j} the duration of interval \eqn{j}. |
| 12 | +#' |
| 13 | +#' For a more detailed description of PEM, refer to [pipeline_survtoregr_pem] or the referred article. |
| 14 | +#' |
| 15 | +#' @section Dictionary: |
| 16 | +#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the |
| 17 | +#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] |
| 18 | +#' or with the associated sugar function [mlr3pipelines::po()]: |
| 19 | +#' ``` |
| 20 | +#' PipeOpPredRegrSurvPEM$new() |
| 21 | +#' mlr_pipeops$get("trafopred_regrsurv_pem") |
| 22 | +#' po("trafopred_regrsurv_pem") |
| 23 | +#' ``` |
| 24 | +#' |
| 25 | +#' @section Input and Output Channels: |
| 26 | +#' The input consists of a [PredictionRegr] and a [data.table][data.table::data.table] |
| 27 | +#' containing the transformed data. The [PredictionRegr] is provided by the [mlr3::LearnerRegr], |
| 28 | +#' while the [data.table] is generated by [PipeOpTaskSurvRegrPEM]. |
| 29 | +#' The output is the input [PredictionRegr] transformed to a [PredictionSurv]. |
| 30 | +#' Only works during prediction phase. |
| 31 | +#' |
| 32 | +#' @references |
| 33 | +#' `r format_bib("bender_2018")` |
| 34 | +#' |
| 35 | +#' @seealso [pipeline_survtoregr_pem] |
| 36 | +#' @family PipeOps |
| 37 | +#' @family Transformation PipeOps |
| 38 | +#' @export |
| 39 | +PipeOpPredRegrSurvPEM = R6Class( |
| 40 | + "PipeOpPredRegrSurvPEM", |
| 41 | + inherit = mlr3pipelines::PipeOp, |
| 42 | + |
| 43 | + public = list( |
| 44 | + #' @description |
| 45 | + #' Creates a new instance of this [R6][R6::R6Class] class. |
| 46 | + #' @param id (character(1))\cr |
| 47 | + #' Identifier of the resulting object. |
| 48 | + initialize = function(id = "trafopred_regrsurv_pem") { |
| 49 | + super$initialize( |
| 50 | + id = id, |
| 51 | + input = data.table( |
| 52 | + name = c("input", "transformed_data"), |
| 53 | + train = c("NULL", "data.table"), |
| 54 | + predict = c("PredictionRegr", "data.table") |
| 55 | + ), |
| 56 | + output = data.table( |
| 57 | + name = "output", |
| 58 | + train = "NULL", |
| 59 | + predict = "PredictionSurv" |
| 60 | + ) |
| 61 | + ) |
| 62 | + } |
| 63 | + ), |
| 64 | + |
| 65 | + active = list( |
| 66 | + #' @field predict_type (`character(1)`)\cr |
| 67 | + #' Returns the active predict type of this PipeOp, which is `"crank"` |
| 68 | + predict_type = function(rhs) { |
| 69 | + assert_ro_binding(rhs) |
| 70 | + "crank" |
| 71 | + } |
| 72 | + ), |
| 73 | + |
| 74 | + private = list( |
| 75 | + .predict = function(input) { |
| 76 | + pred = input[[1]] # predicted hazards provided by the regression learner |
| 77 | + data = input[[2]] # transformed data |
| 78 | + assert_true(!is.null(pred$response)) |
| 79 | + |
| 80 | + data = cbind(data, dt_hazard = pred$response) |
| 81 | + |
| 82 | + # From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset))) |
| 83 | + rows_per_id = nrow(data) / length(unique(data$id)) |
| 84 | + |
| 85 | + surv = t(vapply(unique(data$id), function(unique_id) { |
| 86 | + exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]]))) |
| 87 | + }, numeric(rows_per_id))) |
| 88 | + |
| 89 | + unique_end_times = sort(unique(data$tend)) |
| 90 | + # coerce to distribution and crank |
| 91 | + pred_list = .surv_return(times = unique_end_times, surv = surv) |
| 92 | + |
| 93 | + # select the real tend values by only selecting the last row of each id |
| 94 | + # basically a slightly more complex unique() |
| 95 | + real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0] |
| 96 | + |
| 97 | + ids = unique(data$id) |
| 98 | + # select last row for every id => observed times |
| 99 | + id = pem_status = NULL # to fix note |
| 100 | + data = data[, .SD[.N, list(pem_status)], by = id] |
| 101 | + |
| 102 | + # create prediction object |
| 103 | + p = PredictionSurv$new( |
| 104 | + row_ids = ids, |
| 105 | + crank = pred_list$crank, distr = pred_list$distr, |
| 106 | + truth = Surv(real_tend, as.integer(as.character(data$pem_status)))) |
| 107 | + |
| 108 | + list(p) |
| 109 | + }, |
| 110 | + |
| 111 | + .train = function(input) { |
| 112 | + self$state = list() |
| 113 | + list(input) |
| 114 | + } |
| 115 | + ) |
| 116 | +) |
| 117 | +register_pipeop("trafopred_regrsurv_pem", PipeOpPredRegrSurvPEM) |
0 commit comments