Skip to content

Commit b80b651

Browse files
authored
Merge pull request #417 from mlr-org/pipeline_PEM
PEM Pipeline
2 parents a7bbc2c + 8d94b8b commit b80b651

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1319
-72
lines changed

DESCRIPTION

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.7.4
3+
Version: 0.7.5
44
Authors@R: c(
55
person("Raphael", "Sonabend", , "raphaelsonabend@gmail.com", role = "aut",
66
comment = c(ORCID = "0000-0001-9225-4654")),
@@ -18,7 +18,8 @@ Authors@R: c(
1818
person("Maximilian", "Muecke", , "muecke.maximilian@gmail.com", role = "ctb",
1919
comment = c(ORCID = "0009-0000-9432-9795")),
2020
person("Lee Xingzhuo", "Li", , "xingzhuo_li@yahoo.com.au", role = "ctb",
21-
comment = c(ORCID = "0000-0001-5259-5198"))
21+
comment = c(ORCID = "0000-0001-5259-5198")),
22+
person("Markus", "Goeswein", , "markus.goeswein@outlook.de", role = "ctb")
2223
)
2324
Description: Provides extensions for probabilistic supervised learning for
2425
'mlr3'. This includes extending the regression task to probabilistic
@@ -28,7 +29,7 @@ License: LGPL-3
2829
URL: https://mlr3proba.mlr-org.com, https://github.com/mlr-org/mlr3proba
2930
BugReports: https://github.com/mlr-org/mlr3proba/issues
3031
Depends:
31-
mlr3 (>= 0.14.1),
32+
mlr3 (>= 0.23.0),
3233
R (>= 3.5.0)
3334
Imports:
3435
checkmate,
@@ -48,7 +49,7 @@ Suggests:
4849
knitr,
4950
lgr,
5051
lifecycle,
51-
mlr3learners,
52+
mlr3learners (>= 0.10.0),
5253
mlr3viz,
5354
pammtools,
5455
param6 (>= 0.2.4),
@@ -58,13 +59,14 @@ Suggests:
5859
set6 (>= 0.2.6),
5960
simsurv,
6061
survAUC,
61-
testthat (>= 3.0.0)
62+
testthat (>= 3.0.0),
63+
glmnet
6264
LinkingTo:
6365
Rcpp
6466
Remotes:
6567
xoopR/distr6,
6668
xoopR/param6,
67-
xoopR/set6
69+
xoopR/set6,
6870
ByteCompile: true
6971
Config/testthat/edition: 3
7072
Encoding: UTF-8
@@ -115,11 +117,13 @@ Collate:
115117
'PipeOpDistrCompositor.R'
116118
'PipeOpPredClassifSurvDiscTime.R'
117119
'PipeOpPredClassifSurvIPCW.R'
120+
'PipeOpPredRegrSurvPEM.R'
118121
'PipeOpProbregrCompositor.R'
119122
'PipeOpResponseCompositor.R'
120123
'PipeOpSurvAvg.R'
121124
'PipeOpTaskSurvClassifDiscTime.R'
122125
'PipeOpTaskSurvClassifIPCW.R'
126+
'PipeOpTaskSurvRegrPEM.R'
123127
'PredictionDataDens.R'
124128
'PredictionDataSurv.R'
125129
'PredictionDens.R'

NAMESPACE

+3
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ export(PipeOpCrankCompositor)
7272
export(PipeOpDistrCompositor)
7373
export(PipeOpPredClassifSurvDiscTime)
7474
export(PipeOpPredClassifSurvIPCW)
75+
export(PipeOpPredRegrSurvPEM)
7576
export(PipeOpProbregr)
7677
export(PipeOpResponseCompositor)
7778
export(PipeOpSurvAvg)
7879
export(PipeOpTaskSurvClassifDiscTime)
7980
export(PipeOpTaskSurvClassifIPCW)
81+
export(PipeOpTaskSurvRegrPEM)
8082
export(PredictionDens)
8183
export(PredictionSurv)
8284
export(TaskDens)
@@ -95,6 +97,7 @@ export(get_mortality)
9597
export(pecs)
9698
export(pipeline_survtoclassif_IPCW)
9799
export(pipeline_survtoclassif_disctime)
100+
export(pipeline_survtoregr_pem)
98101
export(plot_probregr)
99102
import(checkmate)
100103
import(data.table)

NEWS.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# mlr3proba dev
1+
# mlr3proba 0.7.5
22

33
* fix: allow cloning of measures objects
4+
* New `PipeOp`s: `PipeOpTaskSurvRegrPEM`, `PipeOpPredRegrPEM`
5+
* New pipeline (**reduction method**): `pipeline_survtoregr_pem`
46

57
# mlr3proba 0.7.4
68

R/PipeOpPredRegrSurvPEM.R

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#' @title PipeOpPredRegrSurvPEM
2+
#' @name mlr_pipeops_trafopred_regrsurv_pem
3+
#'
4+
#' @description
5+
#' Transform [PredictionRegr] to [PredictionSurv].
6+
#' The predicted piece-wise constant hazards contained in [PredictionRegr] are transformed into survival probabilities and wrapped in a
7+
#' [PredictionSurv] object.
8+
#'
9+
#' We compute the survival probability from the predicted hazards using the following relation:
10+
#' \deqn{S(t | \mathbf{x}) = \exp \left( - \int_{0}^{t} \lambda(s | \mathbf{x}) \, ds \right) = \exp \left( - \sum_{j = 1}^{J} \lambda(j | \mathbf{x}) d_j\, \right),}
11+
#' where \eqn{j = 1, \ldots, J} denotes the interval, \eqn{t} the time, and \eqn{d_j} the duration of interval \eqn{j}.
12+
#'
13+
#' For a more detailed description of PEM, refer to [pipeline_survtoregr_pem] or the referred article.
14+
#'
15+
#' @section Dictionary:
16+
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
17+
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
18+
#' or with the associated sugar function [mlr3pipelines::po()]:
19+
#' ```
20+
#' PipeOpPredRegrSurvPEM$new()
21+
#' mlr_pipeops$get("trafopred_regrsurv_pem")
22+
#' po("trafopred_regrsurv_pem")
23+
#' ```
24+
#'
25+
#' @section Input and Output Channels:
26+
#' The input consists of a [PredictionRegr] and a [data.table][data.table::data.table]
27+
#' containing the transformed data. The [PredictionRegr] is provided by the [mlr3::LearnerRegr],
28+
#' while the [data.table] is generated by [PipeOpTaskSurvRegrPEM].
29+
#' The output is the input [PredictionRegr] transformed to a [PredictionSurv].
30+
#' Only works during prediction phase.
31+
#'
32+
#' @references
33+
#' `r format_bib("bender_2018")`
34+
#'
35+
#' @seealso [pipeline_survtoregr_pem]
36+
#' @family PipeOps
37+
#' @family Transformation PipeOps
38+
#' @export
39+
PipeOpPredRegrSurvPEM = R6Class(
40+
"PipeOpPredRegrSurvPEM",
41+
inherit = mlr3pipelines::PipeOp,
42+
43+
public = list(
44+
#' @description
45+
#' Creates a new instance of this [R6][R6::R6Class] class.
46+
#' @param id (character(1))\cr
47+
#' Identifier of the resulting object.
48+
initialize = function(id = "trafopred_regrsurv_pem") {
49+
super$initialize(
50+
id = id,
51+
input = data.table(
52+
name = c("input", "transformed_data"),
53+
train = c("NULL", "data.table"),
54+
predict = c("PredictionRegr", "data.table")
55+
),
56+
output = data.table(
57+
name = "output",
58+
train = "NULL",
59+
predict = "PredictionSurv"
60+
)
61+
)
62+
}
63+
),
64+
65+
active = list(
66+
#' @field predict_type (`character(1)`)\cr
67+
#' Returns the active predict type of this PipeOp, which is `"crank"`
68+
predict_type = function(rhs) {
69+
assert_ro_binding(rhs)
70+
"crank"
71+
}
72+
),
73+
74+
private = list(
75+
.predict = function(input) {
76+
pred = input[[1]] # predicted hazards provided by the regression learner
77+
data = input[[2]] # transformed data
78+
assert_true(!is.null(pred$response))
79+
80+
data = cbind(data, dt_hazard = pred$response)
81+
82+
# From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset)))
83+
rows_per_id = nrow(data) / length(unique(data$id))
84+
85+
surv = t(vapply(unique(data$id), function(unique_id) {
86+
exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])))
87+
}, numeric(rows_per_id)))
88+
89+
unique_end_times = sort(unique(data$tend))
90+
# coerce to distribution and crank
91+
pred_list = .surv_return(times = unique_end_times, surv = surv)
92+
93+
# select the real tend values by only selecting the last row of each id
94+
# basically a slightly more complex unique()
95+
real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0]
96+
97+
ids = unique(data$id)
98+
# select last row for every id => observed times
99+
id = pem_status = NULL # to fix note
100+
data = data[, .SD[.N, list(pem_status)], by = id]
101+
102+
# create prediction object
103+
p = PredictionSurv$new(
104+
row_ids = ids,
105+
crank = pred_list$crank, distr = pred_list$distr,
106+
truth = Surv(real_tend, as.integer(as.character(data$pem_status))))
107+
108+
list(p)
109+
},
110+
111+
.train = function(input) {
112+
self$state = list()
113+
list(input)
114+
}
115+
)
116+
)
117+
register_pipeop("trafopred_regrsurv_pem", PipeOpPredRegrSurvPEM)

R/PipeOpTaskSurvClassifDiscTime.R

+6-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#' [TaskClassif][mlr3::TaskClassif].
2929
#' The target column is named `"disc_status"` and indicates whether an event occurred
3030
#' in each time interval.
31-
#' An additional feature named `"tend"` contains the end time point of each interval.
31+
#' An additional numeric feature named `"tend"` contains the end time point of each interval.
3232
#' Lastly, the "output" task has a column with the original observation ids,
3333
#' under the role `"original_ids"`.
3434
#' The "transformed_data" is an empty [data.table][data.table::data.table].
@@ -134,12 +134,13 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
134134

135135
if (!is.null(max_time)) {
136136
assert(max_time > data[get(event_var) == 1, min(get(time_var))],
137-
"max_time must be greater than the minimum event time.")
137+
.var.name = "max_time must be greater than the minimum event time.")
138138
}
139139

140140
form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")
141141

142-
long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time)
142+
long_data = pammtools::as_ped(data = data, formula = form,
143+
cut = cut, max_time = max_time)
143144
self$state$cut = attributes(long_data)$trafo_args$cut
144145
long_data = as.data.table(long_data)
145146
setnames(long_data, old = "ped_status", new = "disc_status")
@@ -172,6 +173,8 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
172173

173174
max_time = max(cut)
174175
time = data[[time_var]]
176+
# setting time variable to max_time ensures that the ped data spans
177+
# over all intervals for every subject irrespective of event time
175178
data[[time_var]] = max_time
176179

177180
status = data[[event_var]]

0 commit comments

Comments
 (0)