Skip to content

Commit

Permalink
JOSS submission (#59)
Browse files Browse the repository at this point in the history
* edits

* ⭐

* fixed typo

* added boilerplate links
  • Loading branch information
tcoroller authored Oct 23, 2024
1 parent 86de946 commit d460f9c
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 8 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
![Docs](https://github.com/Novartis/torchsurv/actions/workflows/docs.yml/badge.svg?branch=main)
[![PyPI - Version](https://img.shields.io/pypi/v/torchsurv?)](https://pypi.org/project/torchsurv/)
[![Conda](https://img.shields.io/conda/v/conda-forge/torchsurv?label=conda)](https://anaconda.org/conda-forge/torchsurv)
[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg?)](https://arxiv.org/abs/2404.10761)
[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg?color=green)](https://arxiv.org/abs/2404.10761)
[![status](https://joss.theoj.org/papers/02d7496da2b9cc34f9a6e04cabf2298d/status.svg)](https://joss.theoj.org/papers/02d7496da2b9cc34f9a6e04cabf2298d)
[![Documentation](https://img.shields.io/badge/GithubPage-Sphinx-blue)](https://opensource.nibr.com/torchsurv/)
[![PyPI Downloads](https://img.shields.io/pypi/dm/torchsurv.svg?label=PyPI%20downloads)](
https://pypi.org/project/torchsurv/)
Expand All @@ -13,6 +14,8 @@ https://anaconda.org/conda-forge/torchsurv)

`TorchSurv` is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment. Unlike existing libraries that impose specific parametric forms on users, `TorchSurv` enables the use of custom `PyTorch`-based deep survival models. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive survival model parameterizations, `TorchSurv` facilitates efficient survival model implementation, particularly beneficial for high-dimensional input data scenarios.

If you find this repository useful, please consider giving a star! ⭐

## TL;DR

Our idea is to **keep things simple**. You are free to use any model architecture you want! Our code has 100% PyTorch backend and behaves like any other functions (losses or metrics) you may be familiar with.
Expand Down Expand Up @@ -44,6 +47,7 @@ cindex.p_value(method="noether", alternative="two_sided")
cindex.compare(cindexB)
```


## Installation and dependencies


Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"from torchsurv.metrics.cindex import ConcordanceIndex\n",
"from torchsurv.metrics.auc import Auc\n",
"\n",
"# local helpers\n",
"# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py\n",
"from helpers_introduction import Custom_dataset, plot_losses"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/momentum.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"metadata": {},
"outputs": [],
"source": [
"# For simplicity (or laziness), we already implemented the datamodule for MNIST. See code for details\n",
"# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_momentum.py\n",
"from helpers_momentum import MNISTDataModule, LitMomentum, LitMNIST"
]
},
Expand Down
133 changes: 128 additions & 5 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ for data in dataloader:
from torchsurv.loss import weibull

# PyTorch model outputs TWO Weibull parameters per observation
my_model = MyPyTorcWeibullhModel()
my_model = MyPyTorchWeibullModel()

for data in dataloader:
x, event, time = data
Expand All @@ -95,7 +95,8 @@ for data in dataloader:

```python
from torchsurv.loss import Momentum
my_model = MyPyTorchXCoxModel()

my_model = MyPyTorchCoxModel()
my_loss = cox.neg_partial_log_likelihood # Works with any TorchSurv loss
momentum = Momentum(backbone=my_model, loss=my_loss)

Expand All @@ -115,7 +116,7 @@ The `TorchSurv` package offers a comprehensive set of metrics to evaluate the pr
**AUC.** The AUC measures the discriminatory capacity of the survival model at a given time $t$, i.e., the ability to provide a reliable ranking of times-to-event based on estimated subject-specific risk scores [@Heagerty2005;@Uno2007;@Blanche2013].

```python
from torchsurv.metrics import Auc
from torchsurv.metrics.auc import Auc
auc = Auc()
auc(log_hzs, event, time) # AUC at each time
auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10
Expand All @@ -124,15 +125,15 @@ auc(log_hzs, event, time, new_time=torch.tensor(10.)) # AUC at time 10
**C-index.** The C-index is a generalization of the AUC that represents the assessment of the discriminatory capacity of the survival model across the entire time period [@Harrell1996;@Uno_2011].

```python
from torchsurv.metrics import ConcordanceIndex
from torchsurv.metrics.cindex import ConcordanceIndex
cindex = ConcordanceIndex()
cindex(log_hzs, event, time)
```

**Brier Score.** The Brier score evaluates the accuracy of a model at a given time $t$ [@Graf_1999]. It represents the average squared distance between the observed survival status and the predicted survival probability. The Brier score cannot be obtained for the Cox proportional hazards model because the survival function is not available, but it can be obtained for the Weibull ATF model.

```python
from torchsurv.metrics import Brier
from torchsurv.metrics.brier_score import BrierScore
surv = survival_function(log_params, time)
brier = Brier()
brier(surv, event, time) # Brier score at each time
Expand All @@ -147,6 +148,128 @@ cindex.p_value(alternative='greater') # pvalue, H0:c=0.5, HA:c>0.5
cindex.compare(cindex_other) # pvalue, H0:c1=c2, HA:c1>c2
```

# Comprehensive Example: Fitting a Cox Proportional Hazards Model with TorchSurv

In this section, we provide a reproducible code example to demonstrate how to use `TorchSurv` for fitting a Cox proportional hazards model. We simulate data where each observations is associated with 10 features, a time-to-event that depends linearly on these features, and a time-to-censoring. The observable data include only the minimum between the time-to-event and time-to-censoring, representing the first event that occurs. Subsequently, we fit a Cox proportional hazards model using maximum likelihood estimation and assess the model's predictive performance through the AUC and the concordance index. To facilitate rapid execution, we use a simple linear backend model in PyTorch to define the log relative hazards. For more comprehensive examples using real data, we encourage readers to visit the Torchsurv website.



```python
import torch
from torch.utils.data import DataLoader, Dataset
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc

torch.manual_seed(42)


# 1. Simulate data.
n_features = 10 # int, number of features per observation
time_end = torch.tensor(
2000.0
) # float, end of observational period after which all observations are censored
weights = (
torch.randn(n_features) * 5
) # float, weights associated with the features ~ normal(0, 5^2)


# Define the generator function
def tte_generator(batch_size: int):
while True:
x = torch.randn(batch_size, n_features) # features

mean_event_time, mean_censoring_time = 1000.0 + x @ weights, 1000.0

event_time = (
mean_event_time + torch.randn(batch_size) * 50
) # event time ~ normal(mean_event_time, 50^2)
censoring_time = torch.distributions.Exponential(
1 / mean_censoring_time
).sample(
(batch_size,)
) # censoring time ~ Exponential(mean = mean_censoring_time)
censoring_time = torch.minimum(
censoring_time, time_end
) # truncate censoring time to time_end

event = (event_time <= censoring_time).bool() # event indicator
time = torch.minimum(event_time, censoring_time) # observed time

yield x, event, time


# 2. Define the PyTorch dataset class
class TTE_dataset(Dataset):
def __init__(self, generator: callable, batch_size: int):
self.batch_size = batch_size
self.generatated_data = generator(batch_size=batch_size)

def __len__(self):
return self.batch_size

def __getitem__(self, index):
return next(self.generatated_data)


# 3. Define the backbone model on the log hazards.
class MyPyTorchCoxModel(torch.nn.Module):
def __init__(self):
super(MyPyTorchCoxModel, self).__init__()
self.fc = torch.nn.Linear(n_features, 1, bias=False) # Simple linear model

def forward(self, x):
return self.fc(x)


# 4. Instantiate the model, optimizer, dataset and dataloader
cox_model = MyPyTorchCoxModel()
optimizer = torch.optim.Adam(cox_model.parameters(), lr=0.01)

batch_size = 64 # int, batch size
dataset = TTE_dataset(tte_generator, batch_size=batch_size)
dataloader = DataLoader(
dataset, batch_size=1, shuffle=True
) # Batch size of 1 because dataset yields batches


# 5. Training loop
for epoch in range(100):
for i, batch in enumerate(dataloader):
x, event, time = [t.squeeze() for t in batch] # Squeeze extra dimension
optimizer.zero_grad()
log_hzs = cox_model(x) # torch.Size([batch_size, 1])
loss = cox.neg_partial_log_likelihood(log_hzs, event, time)
loss.backward()
optimizer.step()


# 6. Evaluate the model
n_samples_test = 1000 # int, number of observations in test set

data_test = next(tte_generator(batch_size=n_samples_test))
x, event, time = [t.squeeze() for t in data_test] # test set
log_hzs = cox_model(x) # log hazards evaluated on test set

# AUC at time point 1000
auc = Auc()
print(
"AUC:", auc(log_hzs, event, time, new_time=torch.tensor(1000.0))
) # tensor([0.5902])
print("AUC Confidence Interval:", auc.confidence_interval()) # tensor([0.5623, 0.6180])
print("AUC p-value:", auc.p_value(alternative="greater")) # tensor([0.])

# C-index
cindex = ConcordanceIndex()
print("C-Index:", cindex(log_hzs, event, time)) # tensor(0.5774)
print(
"C-Index Confidence Interval:", cindex.confidence_interval()
) # tensor([0.5086, 0.6463])
print("C-Index p-value:", cindex.p_value(alternative="greater")) # tensor(0.0138)
```



# Conflicts of interest

MM, PK, DO and TC are employees and stockholders of Novartis, a global pharmaceutical company.
Expand Down
Binary file modified paper/table_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d460f9c

Please sign in to comment.