Skip to content

Commit

Permalink
ready for submission
Browse files Browse the repository at this point in the history
  • Loading branch information
melodiemonod committed Jul 24, 2024
1 parent 60b3f54 commit 32963b4
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ bibliography: paper.bib
# Summary

`TorchSurv` ([GitHub](https://github.com/Novartis/torchsurv) and [PyPI](https://pypi.org/project/torchsurv/)) is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment [@paszke2019pytorch]. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive parameterizations, `TorchSurv` facilitates efficient deep survival model implementation and is particularly beneficial for high-dimensional and complex data scenarios.
`TorchSurv` has been rigorously tested using both open-source and synthetically generated survival data. The package is thoroughly documented and includes illustrative examples. The latest documentation for TorchSurv can be found on the[`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/).
`TorchSurv` has been rigorously tested using both open-source and synthetically generated survival data. The package is thoroughly documented and includes illustrative examples. The latest documentation for TorchSurv can be found on the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/).

`TorchSurv` provides a user-friendly workflow for training and evaluating `PyTorch`-based deep survival models.
At its core, `TorchSurv` features `PyTorch`-based calculations of log-likelihoods for prominent survival models, including the Cox proportional hazards model [@Cox1972] and the Weibull Accelerated Time Failure (AFT) model [@Carroll2003].
In survival analysis, each observation is associated with survival reponse, denoted by $y$ (comprising the event indicator and the time-to-event or censoring), and covariates denoted by $x$. A survival model is parametrized by parameters, denoted by $\theta$. Within the `TorchSurv` framework, a `PyTorch`-based neural network is defined to act as a flexible function that takes the covariates $x$ as input and outputs the parameters $\theta$. Estimation of the parameters $\theta$ is achieved via maximum likelihood estimation facilitated by backpropagation.
Additionally, `TorchSurv` offers evaluation metrics, including the time-dependent Area Under the cure (AUC) under the Receiver operating characteristic (ROC) curve, the Concordance index (C-index) and the Brier Score, to characterize the predictive performance of survival models.
In survival analysis, each observation is associated with a survival reponse, denoted by $y$ (comprising the event indicator and the time-to-event or censoring), and covariates, denoted by $x$. A survival model is parametrized by parameters $\theta$. Within the `TorchSurv` framework, a `PyTorch`-based neural network is defined to act as a flexible function that takes the covariates $x$ as input and outputs the parameters $\theta$. Estimation of the parameters $\theta$ is achieved via maximum likelihood estimation.
Additionally, `TorchSurv` offers evaluation metrics, including the time-dependent Area Under under the Receiver operating characteristic (ROC) curve (AUC), the Concordance index (C-index) and the Brier Score, to characterize the predictive performance of survival models.
Below is an overview of the workflow for model inference and evaluation with `TorchSurv`:

1. Initialize a `PyTorch`-based neural network that defines the function from the covariates $x$ to the parameters $\theta$. In the context of the Cox proportional hazards model for example, the parameters are the log relative hazards.
Expand All @@ -56,7 +56,7 @@ Below is an overview of the workflow for model inference and evaluation with `To

# Statement of need

Survival analysis plays a crucial role in various domains, such as medicine and engineering. Deep learning presents promising opportunities for developing sophisticated survival models, where the parameters depend on covariates through complex functions. However, no existing library provides the flexibility to define survival model parameters using a custom `PyTorch`-based neural network.
Survival analysis plays a crucial role in various domains, such as medicine and engineering. Deep learning presents promising opportunities for developing sophisticated survival models, where the parameters depend on covariates through complex functions. However, no existing library provides the flexibility to define the survival model's parameters using a custom `PyTorch`-based neural network.

\autoref{tab:bibliography} compares the functionalities of `TorchSurv` with those of
`auton-survival` [@nagpal2022auton],
Expand All @@ -72,9 +72,9 @@ Our package, `TorchSurv`, is specifically designed for use in Python, but we als

`TorchSurv`'s log-likelihood and evaluation metrics functions have undergone thorough comparison with benchmarks generated with Python packages and R packages on open-source data and synthetic data. High agreement between the outputs is consistently observed, providing users with confidence in the accuracy and reliability of `TorchSurv`'s functionalities. The comparison is presented in the [`TorchSurv`'s website](https://opensource.nibr.com/torchsurv/benchmarks.html).

![**Survival analysis libraries in Python.** $^1$[@nagpal2022auton], $^{2}$[@Kvamme2019pycox], $^{3}$[@torchlifeAbeywardana], $^{4}$[@polsterl2020scikit], $^{5}$[@davidson2019lifelines], $^{6}$[@katzman2018deepsurv]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For computing the concordance index, `pycox` requires the use of the estimated survival function as the risk score and does not support other types of time-dependent risk scores. `scikit-survival` does not support time-dependent risk scores in both the concordance index and AUC computation. Additionally, both `pycox` and `scikit-survival `impose the use of inverse probability of censoring weighting (IPCW) for subject-specific weights. `scikit-survival` only offers the Breslow approximation of the Cox partial log-likelihood in case of ties in the event time, while it lacks the Efron approximation.\label{tab:bibliography}](table_1.png)
![**Survival analysis libraries in Python.** $^1$[@nagpal2022auton], $^{2}$[@Kvamme2019pycox], $^{3}$[@torchlifeAbeywardana], $^{4}$[@polsterl2020scikit], $^{5}$[@davidson2019lifelines], $^{6}$[@katzman2018deepsurv]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For computing the concordance index, `pycox` requires the use of the estimated survival function as the risk score and does not support other types of time-dependent risk scores. `scikit-survival` does not support time-dependent risk scores in both the concordance index and AUC computation. Additionally, both `pycox` and `scikit-survival` impose the use of inverse probability of censoring weighting (IPCW) for subject-specific weights. `scikit-survival` only offers the Breslow approximation of the Cox partial log-likelihood in case of ties in the event time.\label{tab:bibliography}](table_1.png)

![**Survival analysis libraries in R.** $^1$[@survivalpackage], $^{2}$[@survAUCpackage], $^{3}$[@timeROCpackage], $^{4}$[@risksetROCpackage], $^{5}$[@survcomppackage], $^{6}$[@survivalROCpackage], $^{7}$[@riskRegressionpackage], $^{8}$[@SurvMetricspackage], $^{9}$[@pecpackage]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For obtaining the evaluation metrics, packages `survival`, `riskRegression`, `SurvMetrics` and `pec` require the fitted model object as input (a specific object format) and `RisksetROC` imposes a smoothing method. Packages `timeROC`, `riskRegression` and `pec` force the user to choose a form for subject-specific weights (e.g., inverse probability of censoring weighting (IPCW)). Packages `survcomp` and `SurvivalROC` do not implement the general AUC but the censoring-adjusted AUC estimator proposed by @Heagerty2000.\label{tab:bibliography_R}](table_2.png)
![**Survival analysis libraries in R.** $^1$[@survivalpackage], $^{2}$[@survAUCpackage], $^{3}$[@timeROCpackage], $^{4}$[@risksetROCpackage], $^{5}$[@survcomppackage], $^{6}$[@survivalROCpackage], $^{7}$[@riskRegressionpackage], $^{8}$[@SurvMetricspackage], $^{9}$[@pecpackage]. A green tick indicates a fully supported feature, a red cross indicates an unsupported feature, a blue crossed tick indicates a partially supported feature. For obtaining the evaluation metrics, packages `survival`, `riskRegression`, `SurvMetrics` and `pec` require the fitted model object as input (a specific object format) and `RisksetROC` imposes to use a smoothing method. Packages `timeROC`, `riskRegression` and `pec` force the user to choose a form for subject-specific weights (e.g., inverse probability of censoring weighting (IPCW)). Packages `survcomp` and `SurvivalROC` do not implement the general AUC but the censoring-adjusted AUC estimator proposed by @Heagerty2000.\label{tab:bibliography_R}](table_2.png)


# Functionality
Expand All @@ -85,7 +85,9 @@ Our package, `TorchSurv`, is specifically designed for use in Python, but we als

```python
from torchsurv.loss import cox
my_model = MyPyTorchCoxModel() # PyTorch model outputs one (1) log hazards for Cox model

# PyTorch model outputs one log hazard by observation
my_model = MyPyTorchCoxModel()

for data in dataloader:
x, event, time = data # covariate, event indicator, time
Expand All @@ -98,7 +100,9 @@ for data in dataloader:

```python
from torchsurv.loss import weibull
my_model = MyPyTorcWeibullhModel() # PyTorch model outputs two (2) log parameters for Weibull model

# PyTorch model outputs two Weibull parameters by observation
my_model = MyPyTorcWeibullhModel()

for data in dataloader:
x, event, time = data # covariate, event indicator, time
Expand All @@ -113,7 +117,7 @@ snippet below.

```python
from torchsurv.loss import Momentum
my_model = MyPyTorchXCoxModel() # PyTorch model outputs one (1) log hazards for Cox model
my_model = MyPyTorchXCoxModel()
my_loss = cox.neg_partial_log_likelihood # Works with any TorchSurv loss
momentum = Momentum(backbone=my_model, loss=my_loss)

Expand All @@ -128,9 +132,9 @@ log_hzs = model_momentum.infer(x) # torch.Size([16, 1])

## Evaluation Metrics Functions

The `TorchSurv` package offers a comprehensive set of metrics to evaluate the predictive performance of survival models, including the AUC, C-index, and Brier score. The inputs of the evaluation metrics functions are the individual risk score estimated on the test set and the survival response on the test set. The risk score measures the risk (or a proxy thereof) that a subject has an event. We provide definitions for each metric and demonstrate their use through illustrative code snippets.
The `TorchSurv` package offers a comprehensive set of metrics to evaluate the predictive performance of survival models, including the AUC, C-index, and Brier score. The inputs of the evaluation metrics functions are the subject-specific risk score estimated on the test set and the survival response on the test set. The risk score measures the risk (or a proxy thereof) that a subject has an event. We provide definitions for each metric and demonstrate their use through illustrative code snippets.

**AUC.** The AUC measures the discriminatory capacity of a model at a given time $t$, i.e., the model’s ability to provide a reliable ranking of times-to-event based on estimated individual risk scores [@Heagerty2005;@Uno2007;@Blanche2013].
**AUC.** The AUC measures the discriminatory capacity of the survival model at a given time $t$, i.e., the ability to provide a reliable ranking of times-to-event based on estimated subject-specific risk scores [@Heagerty2005;@Uno2007;@Blanche2013].

```python
from torchsurv.metrics import Auc
Expand All @@ -139,15 +143,15 @@ auc(log_hzs, event, time) # AUC at each time
auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10
```

**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the model across the time period [@Harrell1996;@Uno_2011].
**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the survival model across the entire time period [@Harrell1996;@Uno_2011].

```python
from torchsurv.metrics import ConcordanceIndex
cindex = ConcordanceIndex()
cindex(log_hzs, event, time) # C-index
```

**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$. It represents the average squared distance between the observed survival status and the predicted survival probability [@Graf_1999]. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model.
**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$ [@Graf_1999]. It represents the average squared distance between the observed survival status and the predicted survival probability. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model.

```python
from torchsurv.metrics import Brier
Expand Down

0 comments on commit 32963b4

Please sign in to comment.