From 32963b4fbe07efcafb661199f3da72fc9303c923 Mon Sep 17 00:00:00 2001 From: melodiemonod Date: Wed, 24 Jul 2024 15:35:42 +0100 Subject: [PATCH] ready for submission --- paper/paper.md | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/paper/paper.md b/paper/paper.md index 3c6fd15..fc0c339 100644 --- a/paper/paper.md +++ b/paper/paper.md @@ -35,12 +35,12 @@ bibliography: paper.bib # Summary `TorchSurv` ([GitHub](https://github.com/Novartis/torchsurv) and [PyPI](https://pypi.org/project/torchsurv/)) is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment [@paszke2019pytorch]. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive parameterizations, `TorchSurv` facilitates efficient deep survival model implementation and is particularly beneficial for high-dimensional and complex data scenarios. -`TorchSurv` has been rigorously tested using both open-source and synthetically generated survival data. The package is thoroughly documented and includes illustrative examples. The latest documentation for TorchSurv can be found on the[`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/). +`TorchSurv` has been rigorously tested using both open-source and synthetically generated survival data. The package is thoroughly documented and includes illustrative examples. The latest documentation for TorchSurv can be found on the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/). `TorchSurv` provides a user-friendly workflow for training and evaluating `PyTorch`-based deep survival models. At its core, `TorchSurv` features `PyTorch`-based calculations of log-likelihoods for prominent survival models, including the Cox proportional hazards model [@Cox1972] and the Weibull Accelerated Time Failure (AFT) model [@Carroll2003]. -In survival analysis, each observation is associated with survival reponse, denoted by $y$ (comprising the event indicator and the time-to-event or censoring), and covariates denoted by $x$. A survival model is parametrized by parameters, denoted by $\theta$. Within the `TorchSurv` framework, a `PyTorch`-based neural network is defined to act as a flexible function that takes the covariates $x$ as input and outputs the parameters $\theta$. Estimation of the parameters $\theta$ is achieved via maximum likelihood estimation facilitated by backpropagation. -Additionally, `TorchSurv` offers evaluation metrics, including the time-dependent Area Under the cure (AUC) under the Receiver operating characteristic (ROC) curve, the Concordance index (C-index) and the Brier Score, to characterize the predictive performance of survival models. +In survival analysis, each observation is associated with a survival reponse, denoted by $y$ (comprising the event indicator and the time-to-event or censoring), and covariates, denoted by $x$. A survival model is parametrized by parameters $\theta$. Within the `TorchSurv` framework, a `PyTorch`-based neural network is defined to act as a flexible function that takes the covariates $x$ as input and outputs the parameters $\theta$. Estimation of the parameters $\theta$ is achieved via maximum likelihood estimation. +Additionally, `TorchSurv` offers evaluation metrics, including the time-dependent Area Under under the Receiver operating characteristic (ROC) curve (AUC), the Concordance index (C-index) and the Brier Score, to characterize the predictive performance of survival models. Below is an overview of the workflow for model inference and evaluation with `TorchSurv`: 1. Initialize a `PyTorch`-based neural network that defines the function from the covariates $x$ to the parameters $\theta$. In the context of the Cox proportional hazards model for example, the parameters are the log relative hazards. @@ -56,7 +56,7 @@ Below is an overview of the workflow for model inference and evaluation with `To # Statement of need -Survival analysis plays a crucial role in various domains, such as medicine and engineering. Deep learning presents promising opportunities for developing sophisticated survival models, where the parameters depend on covariates through complex functions. However, no existing library provides the flexibility to define survival model parameters using a custom `PyTorch`-based neural network. +Survival analysis plays a crucial role in various domains, such as medicine and engineering. Deep learning presents promising opportunities for developing sophisticated survival models, where the parameters depend on covariates through complex functions. However, no existing library provides the flexibility to define the survival model's parameters using a custom `PyTorch`-based neural network. \autoref{tab:bibliography} compares the functionalities of `TorchSurv` with those of `auton-survival` [@nagpal2022auton], @@ -72,9 +72,9 @@ Our package, `TorchSurv`, is specifically designed for use in Python, but we als `TorchSurv`'s log-likelihood and evaluation metrics functions have undergone thorough comparison with benchmarks generated with Python packages and R packages on open-source data and synthetic data. High agreement between the outputs is consistently observed, providing users with confidence in the accuracy and reliability of `TorchSurv`'s functionalities. The comparison is presented in the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/benchmarks.html). -![**Survival analysis libraries in Python.** $^1$[@nagpal2022auton], $^{2}$[@Kvamme2019pycox], $^{3}$[@torchlifeAbeywardana], $^{4}$[@polsterl2020scikit], $^{5}$[@davidson2019lifelines], $^{6}$[@katzman2018deepsurv]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For computing the concordance index, `pycox` requires the use of the estimated survival function as the risk score and does not support other types of time-dependent risk scores. `scikit-survival` does not support time-dependent risk scores in both the concordance index and AUC computation. Additionally, both `pycox` and `scikit-survival `impose the use of inverse probability of censoring weighting (IPCW) for subject-specific weights. `scikit-survival` only offers the Breslow approximation of the Cox partial log-likelihood in case of ties in the event time, while it lacks the Efron approximation.\label{tab:bibliography}](table_1.png) +![**Survival analysis libraries in Python.** $^1$[@nagpal2022auton], $^{2}$[@Kvamme2019pycox], $^{3}$[@torchlifeAbeywardana], $^{4}$[@polsterl2020scikit], $^{5}$[@davidson2019lifelines], $^{6}$[@katzman2018deepsurv]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For computing the concordance index, `pycox` requires the use of the estimated survival function as the risk score and does not support other types of time-dependent risk scores. `scikit-survival` does not support time-dependent risk scores in both the concordance index and AUC computation. Additionally, both `pycox` and `scikit-survival` impose the use of inverse probability of censoring weighting (IPCW) for subject-specific weights. `scikit-survival` only offers the Breslow approximation of the Cox partial log-likelihood in case of ties in the event time.\label{tab:bibliography}](table_1.png) -![**Survival analysis libraries in R.** $^1$[@survivalpackage], $^{2}$[@survAUCpackage], $^{3}$[@timeROCpackage], $^{4}$[@risksetROCpackage], $^{5}$[@survcomppackage], $^{6}$[@survivalROCpackage], $^{7}$[@riskRegressionpackage], $^{8}$[@SurvMetricspackage], $^{9}$[@pecpackage]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For obtaining the evaluation metrics, packages `survival`, `riskRegression`, `SurvMetrics` and `pec` require the fitted model object as input (a specific object format) and `RisksetROC` imposes a smoothing method. Packages `timeROC`, `riskRegression` and `pec` force the user to choose a form for subject-specific weights (e.g., inverse probability of censoring weighting (IPCW)). Packages `survcomp` and `SurvivalROC` do not implement the general AUC but the censoring-adjusted AUC estimator proposed by @Heagerty2000.\label{tab:bibliography_R}](table_2.png) +![**Survival analysis libraries in R.** $^1$[@survivalpackage], $^{2}$[@survAUCpackage], $^{3}$[@timeROCpackage], $^{4}$[@risksetROCpackage], $^{5}$[@survcomppackage], $^{6}$[@survivalROCpackage], $^{7}$[@riskRegressionpackage], $^{8}$[@SurvMetricspackage], $^{9}$[@pecpackage]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For obtaining the evaluation metrics, packages `survival`, `riskRegression`, `SurvMetrics` and `pec` require the fitted model object as input (a specific object format) and `RisksetROC` imposes to use a smoothing method. Packages `timeROC`, `riskRegression` and `pec` force the user to choose a form for subject-specific weights (e.g., inverse probability of censoring weighting (IPCW)). Packages `survcomp` and `SurvivalROC` do not implement the general AUC but the censoring-adjusted AUC estimator proposed by @Heagerty2000.\label{tab:bibliography_R}](table_2.png) # Functionality @@ -85,7 +85,9 @@ Our package, `TorchSurv`, is specifically designed for use in Python, but we als ```python from torchsurv.loss import cox -my_model = MyPyTorchCoxModel() # PyTorch model outputs one (1) log hazards for Cox model + +# PyTorch model outputs one log hazard by observation +my_model = MyPyTorchCoxModel() for data in dataloader: x, event, time = data # covariate, event indicator, time @@ -98,7 +100,9 @@ for data in dataloader: ```python from torchsurv.loss import weibull -my_model = MyPyTorcWeibullhModel() # PyTorch model outputs two (2) log parameters for Weibull model + +# PyTorch model outputs two Weibull parameters by observation +my_model = MyPyTorcWeibullhModel() for data in dataloader: x, event, time = data # covariate, event indicator, time @@ -113,7 +117,7 @@ snippet below. ```python from torchsurv.loss import Momentum -my_model = MyPyTorchXCoxModel() # PyTorch model outputs one (1) log hazards for Cox model +my_model = MyPyTorchXCoxModel() my_loss = cox.neg_partial_log_likelihood # Works with any TorchSurv loss momentum = Momentum(backbone=my_model, loss=my_loss) @@ -128,9 +132,9 @@ log_hzs = model_momentum.infer(x) # torch.Size([16, 1]) ## Evaluation Metrics Functions -The `TorchSurv` package offers a comprehensive set of metrics to evaluate the predictive performance of survival models, including the AUC, C-index, and Brier score. The inputs of the evaluation metrics functions are the individual risk score estimated on the test set and the survival response on the test set. The risk score measures the risk (or a proxy thereof) that a subject has an event. We provide definitions for each metric and demonstrate their use through illustrative code snippets. +The `TorchSurv` package offers a comprehensive set of metrics to evaluate the predictive performance of survival models, including the AUC, C-index, and Brier score. The inputs of the evaluation metrics functions are the subject-specific risk score estimated on the test set and the survival response on the test set. The risk score measures the risk (or a proxy thereof) that a subject has an event. We provide definitions for each metric and demonstrate their use through illustrative code snippets. -**AUC.** The AUC measures the discriminatory capacity of a model at a given time $t$, i.e., the model’s ability to provide a reliable ranking of times-to-event based on estimated individual risk scores [@Heagerty2005;@Uno2007;@Blanche2013]. +**AUC.** The AUC measures the discriminatory capacity of the survival model at a given time $t$, i.e., the ability to provide a reliable ranking of times-to-event based on estimated subject-specific risk scores [@Heagerty2005;@Uno2007;@Blanche2013]. ```python from torchsurv.metrics import Auc @@ -139,7 +143,7 @@ auc(log_hzs, event, time) # AUC at each time auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10 ``` -**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the model across the time period [@Harrell1996;@Uno_2011]. +**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the survival model across the entire time period [@Harrell1996;@Uno_2011]. ```python from torchsurv.metrics import ConcordanceIndex @@ -147,7 +151,7 @@ cindex = ConcordanceIndex() cindex(log_hzs, event, time) # C-index ``` -**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$. It represents the average squared distance between the observed survival status and the predicted survival probability [@Graf_1999]. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model. +**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$ [@Graf_1999]. It represents the average squared distance between the observed survival status and the predicted survival probability. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model. ```python from torchsurv.metrics import Brier