diff --git a/README.md b/README.md index 656e576..3ca12f9 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ ![Docs](https://github.com/Novartis/torchsurv/actions/workflows/docs.yml/badge.svg?branch=main) [![PyPI - Version](https://img.shields.io/pypi/v/torchsurv?)](https://pypi.org/project/torchsurv/) [![Conda](https://img.shields.io/conda/v/conda-forge/torchsurv?label=conda)](https://anaconda.org/conda-forge/torchsurv) -[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg?)](https://arxiv.org/abs/2404.10761) +[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg?color=green)](https://arxiv.org/abs/2404.10761) +[![status](https://joss.theoj.org/papers/02d7496da2b9cc34f9a6e04cabf2298d/status.svg)](https://joss.theoj.org/papers/02d7496da2b9cc34f9a6e04cabf2298d) [![Documentation](https://img.shields.io/badge/GithubPage-Sphinx-blue)](https://opensource.nibr.com/torchsurv/) [![PyPI Downloads](https://img.shields.io/pypi/dm/torchsurv.svg?label=PyPI%20downloads)]( https://pypi.org/project/torchsurv/) @@ -13,6 +14,8 @@ https://anaconda.org/conda-forge/torchsurv) `TorchSurv` is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment. Unlike existing libraries that impose specific parametric forms on users, `TorchSurv` enables the use of custom `PyTorch`-based deep survival models. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive survival model parameterizations, `TorchSurv` facilitates efficient survival model implementation, particularly beneficial for high-dimensional input data scenarios. +If you find this repository useful, please consider giving a star! ⭐ + ## TL;DR Our idea is to **keep things simple**. You are free to use any model architecture you want! Our code has 100% PyTorch backend and behaves like any other functions (losses or metrics) you may be familiar with. @@ -44,6 +47,7 @@ cindex.p_value(method="noether", alternative="two_sided") cindex.compare(cindexB) ``` + ## Installation and dependencies diff --git a/docs/notebooks/introduction.ipynb b/docs/notebooks/introduction.ipynb index 0b3a2f0..56e3496 100644 --- a/docs/notebooks/introduction.ipynb +++ b/docs/notebooks/introduction.ipynb @@ -65,7 +65,7 @@ "from torchsurv.metrics.cindex import ConcordanceIndex\n", "from torchsurv.metrics.auc import Auc\n", "\n", - "# local helpers\n", + "# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py\n", "from helpers_introduction import Custom_dataset, plot_losses" ] }, diff --git a/docs/notebooks/momentum.ipynb b/docs/notebooks/momentum.ipynb index 017f023..cfd0fe6 100644 --- a/docs/notebooks/momentum.ipynb +++ b/docs/notebooks/momentum.ipynb @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "# For simplicity (or laziness), we already implemented the datamodule for MNIST. See code for details\n", + "# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_momentum.py\n", "from helpers_momentum import MNISTDataModule, LitMomentum, LitMNIST" ] }, diff --git a/paper/paper.md b/paper/paper.md index cc21d58..77e8872 100644 --- a/paper/paper.md +++ b/paper/paper.md @@ -82,7 +82,7 @@ for data in dataloader: from torchsurv.loss import weibull # PyTorch model outputs TWO Weibull parameters per observation -my_model = MyPyTorcWeibullhModel() +my_model = MyPyTorchWeibullModel() for data in dataloader: x, event, time = data @@ -95,7 +95,8 @@ for data in dataloader: ```python from torchsurv.loss import Momentum -my_model = MyPyTorchXCoxModel() + +my_model = MyPyTorchCoxModel() my_loss = cox.neg_partial_log_likelihood # Works with any TorchSurv loss momentum = Momentum(backbone=my_model, loss=my_loss) @@ -115,7 +116,7 @@ The `TorchSurv` package offers a comprehensive set of metrics to evaluate the pr **AUC.** The AUC measures the discriminatory capacity of the survival model at a given time $t$, i.e., the ability to provide a reliable ranking of times-to-event based on estimated subject-specific risk scores [@Heagerty2005;@Uno2007;@Blanche2013]. ```python -from torchsurv.metrics import Auc +from torchsurv.metrics.auc import Auc auc = Auc() auc(log_hzs, event, time) # AUC at each time auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10 @@ -124,7 +125,7 @@ auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10 **C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the survival model across the entire time period [@Harrell1996;@Uno_2011]. ```python -from torchsurv.metrics import ConcordanceIndex +from torchsurv.metrics.cindex import ConcordanceIndex cindex = ConcordanceIndex() cindex(log_hzs, event, time) ``` @@ -132,7 +133,7 @@ cindex(log_hzs, event, time) **Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$ [@Graf_1999]. It represents the average squared distance between the observed survival status and the predicted survival probability. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model. ```python -from torchsurv.metrics import Brier +from torchsurv.metrics.brier_score import BrierScore surv = survival_function(log_params, time) brier = Brier() brier(surv, event, time) # Brier score at each time @@ -147,6 +148,128 @@ cindex.p_value(alternative='greater') # pvalue, H0:c=0.5, HA:c>0.5 cindex.compare(cindex_other) # pvalue, H0:c1=c2, HA:c1>c2 ``` +# Comprehensive Example: Fitting a Cox Proportional Hazards Model with TorchSurv + +In this section, we provide a reproducible code example to demonstrate how to use `TorchSurv` for fitting a Cox proportional hazards model. We simulate data where each observations is associated with 10 features, a time-to-event that depends linearly on these features, and a time-to-censoring. The observable data include only the minimum between the time-to-event and time-to-censoring, representing the first event that occurs. Subsequently, we fit a Cox proportional hazards model using maximum likelihood estimation and assess the model's predictive performance through the AUC and the concordance index. To facilitate rapid execution, we use a simple linear backend model in PyTorch to define the log relative hazards. For more comprehensive examples using real data, we encourage readers to visit the Torchsurv website. + + + +```python +import torch +from torch.utils.data import DataLoader, Dataset +from torchsurv.loss import cox +from torchsurv.metrics.cindex import ConcordanceIndex +from torchsurv.metrics.auc import Auc + +torch.manual_seed(42) + + +# 1. Simulate data. +n_features = 10 # int, number of features per observation +time_end = torch.tensor( + 2000.0 +) # float, end of observational period after which all observations are censored +weights = ( + torch.randn(n_features) * 5 +) # float, weights associated with the features ~ normal(0, 5^2) + + +# Define the generator function +def tte_generator(batch_size: int): + while True: + x = torch.randn(batch_size, n_features) # features + + mean_event_time, mean_censoring_time = 1000.0 + x @ weights, 1000.0 + + event_time = ( + mean_event_time + torch.randn(batch_size) * 50 + ) # event time ~ normal(mean_event_time, 50^2) + censoring_time = torch.distributions.Exponential( + 1 / mean_censoring_time + ).sample( + (batch_size,) + ) # censoring time ~ Exponential(mean = mean_censoring_time) + censoring_time = torch.minimum( + censoring_time, time_end + ) # truncate censoring time to time_end + + event = (event_time <= censoring_time).bool() # event indicator + time = torch.minimum(event_time, censoring_time) # observed time + + yield x, event, time + + +# 2. Define the PyTorch dataset class +class TTE_dataset(Dataset): + def __init__(self, generator: callable, batch_size: int): + self.batch_size = batch_size + self.generatated_data = generator(batch_size=batch_size) + + def __len__(self): + return self.batch_size + + def __getitem__(self, index): + return next(self.generatated_data) + + +# 3. Define the backbone model on the log hazards. +class MyPyTorchCoxModel(torch.nn.Module): + def __init__(self): + super(MyPyTorchCoxModel, self).__init__() + self.fc = torch.nn.Linear(n_features, 1, bias=False) # Simple linear model + + def forward(self, x): + return self.fc(x) + + +# 4. Instantiate the model, optimizer, dataset and dataloader +cox_model = MyPyTorchCoxModel() +optimizer = torch.optim.Adam(cox_model.parameters(), lr=0.01) + +batch_size = 64 # int, batch size +dataset = TTE_dataset(tte_generator, batch_size=batch_size) +dataloader = DataLoader( + dataset, batch_size=1, shuffle=True +) # Batch size of 1 because dataset yields batches + + +# 5. Training loop +for epoch in range(100): + for i, batch in enumerate(dataloader): + x, event, time = [t.squeeze() for t in batch] # Squeeze extra dimension + optimizer.zero_grad() + log_hzs = cox_model(x) # torch.Size([batch_size, 1]) + loss = cox.neg_partial_log_likelihood(log_hzs, event, time) + loss.backward() + optimizer.step() + + +# 6. Evaluate the model +n_samples_test = 1000 # int, number of observations in test set + +data_test = next(tte_generator(batch_size=n_samples_test)) +x, event, time = [t.squeeze() for t in data_test] # test set +log_hzs = cox_model(x) # log hazards evaluated on test set + +# AUC at time point 1000 +auc = Auc() +print( + "AUC:", auc(log_hzs, event, time, new_time=torch.tensor(1000.0)) +) # tensor([0.5902]) +print("AUC Confidence Interval:", auc.confidence_interval()) # tensor([0.5623, 0.6180]) +print("AUC p-value:", auc.p_value(alternative="greater")) # tensor([0.]) + +# C-index +cindex = ConcordanceIndex() +print("C-Index:", cindex(log_hzs, event, time)) # tensor(0.5774) +print( + "C-Index Confidence Interval:", cindex.confidence_interval() +) # tensor([0.5086, 0.6463]) +print("C-Index p-value:", cindex.p_value(alternative="greater")) # tensor(0.0138) +``` + + + # Conflicts of interest MM, PK, DO and TC are employees and stockholders of Novartis, a global pharmaceutical company. diff --git a/paper/table_1.png b/paper/table_1.png index efac5e0..f06d474 100644 Binary files a/paper/table_1.png and b/paper/table_1.png differ