Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed warning and code QC #41

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/torchsurv/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,6 @@ def _find_torch_unique_indices(
def _validate_auc_inputs(
estimate, time, auc_type, new_time, weight, weight_new_time
):

# check new_time and weight are provided, weight_new_time should be provided
if all([new_time is not None, weight is not None, weight_new_time is None]):
raise ValueError(
Expand Down Expand Up @@ -1222,12 +1221,10 @@ def _update_auc_new_time(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# update new time
if (
new_time is not None
): # if new_time are specified: ensure it has the correct format

# ensure that new_time are float
if isinstance(new_time, int):
new_time = torch.tensor([new_time]).float()
Expand All @@ -1237,7 +1234,6 @@ def _update_auc_new_time(
new_time = new_time.unsqueeze(0)

else: # else: find new_time

# if new_time are not specified, use unique event time
mask = event & (time < torch.max(time))
new_time, inverse_indices, counts = torch.unique(
Expand All @@ -1261,7 +1257,6 @@ def _update_auc_new_time(
def _update_auc_estimate(
estimate: torch.tensor, new_time: torch.tensor
) -> torch.tensor:

# squeeze estimate if shape = (n_samples, 1)
if estimate.ndim == 2 and estimate.shape[1] == 1:
estimate = estimate.squeeze(1)
Expand All @@ -1281,7 +1276,6 @@ def _update_auc_weight(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# if weight was not specified, weight of 1
if weight is None:
weight = torch.ones_like(time)
Expand Down
27 changes: 13 additions & 14 deletions src/torchsurv/metrics/brier_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,13 @@ def __call__(
)

# update inputs as required
estimate, new_time, weight, weight_new_time = (
BrierScore._update_brier_score_new_time(
estimate, time, new_time, weight, weight_new_time
)
(
estimate,
new_time,
weight,
weight_new_time,
) = BrierScore._update_brier_score_new_time(
estimate, time, new_time, weight, weight_new_time
)
weight, weight_new_time = BrierScore._update_brier_score_weight(
time, new_time, weight, weight_new_time
Expand All @@ -190,14 +193,14 @@ def __call__(

# Calculating the residuals for each subject and time point
residuals = torch.zeros_like(estimate)
for i, t in enumerate(new_time):
est = estimate[:, i]
is_case = ((time <= t) & (event)).int()
is_control = (time > t).int()
for index, new_time_i in enumerate(new_time):
est = estimate[:, index]
is_case = ((time <= new_time_i) & (event)).int()
is_control = (time > new_time_i).int()

residuals[:, i] = (
residuals[:, index] = (
torch.square(est) * is_case * weight
+ torch.square(1.0 - est) * is_control * weight_new_time[i]
+ torch.square(1.0 - est) * is_control * weight_new_time[index]
)

# Calculating the brier scores at each time point
Expand Down Expand Up @@ -827,7 +830,6 @@ def _validate_brier_score_inputs(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# check new_time and weight are provided, weight_new_time should be provided
if all([new_time is not None, weight is not None, weight_new_time is None]):
raise ValueError(
Expand Down Expand Up @@ -859,7 +861,6 @@ def _update_brier_score_new_time(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# check format of new_time
if (
new_time is not None
Expand All @@ -871,7 +872,6 @@ def _update_brier_score_new_time(
new_time = new_time.unsqueeze(0)

else: # else: find new_time

# if new_time are not specified, use unique time
new_time, inverse_indices, counts = torch.unique(
time, sorted=True, return_inverse=True, return_counts=True
Expand All @@ -896,7 +896,6 @@ def _update_brier_score_weight(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# if weight was not specified, weight of 1
if weight is None:
weight = torch.ones_like(time)
Expand Down
8 changes: 8 additions & 0 deletions src/torchsurv/metrics/cindex.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import sys
import warnings
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -639,6 +640,13 @@ def _compare_noether(self, other):
cindex1_se = self._concordance_index_se()
cindex2_se = other._concordance_index_se()

# Suppress the specific warning
warnings.filterwarnings(
"ignore",
message="Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.",
category=UserWarning,
)

# compute spearman correlation between risk prediction
corr = regression.SpearmanCorrCoef()(
self.estimate.reshape(-1), other.estimate.reshape(-1)
Expand Down
4 changes: 0 additions & 4 deletions tests/test_cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_boolean_y(self):
def test_log_likelihood_without_ties(self):
"""test cox partial log likelihood without ties on lung and gbsg datasets"""
for benchmark_cox_loglik in benchmark_cox_logliks:

if benchmark_cox_loglik["no_ties"][0] == True:

log_lik = -cox(
torch.tensor(
benchmark_cox_loglik["log_hazard"], dtype=torch.float32
Expand All @@ -82,9 +80,7 @@ def test_log_likelihood_without_ties(self):
def test_log_likelihood_with_ties(self):
"""test Efron and Breslow's approximation of cox partial log likelihood with ties on lung and gbsg data"""
for benchmark_cox_loglik in benchmark_cox_logliks:

if benchmark_cox_loglik["no_ties"][0] == False:

# efron approximation of partial log likelihood
log_lik_efron = -cox(
torch.tensor(
Expand Down
7 changes: 3 additions & 4 deletions tests/test_kaplan_meier.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_kaplan_meier_survival_distribution_real_data(self):
"""test Kaplan Meier survival distribution estimate on lung and gbsg datasets"""

for benchmark_kaplan_meier in benchmark_kaplan_meiers:

event = torch.tensor(benchmark_kaplan_meier["status"]).bool()
time = torch.tensor(benchmark_kaplan_meier["time"], dtype=torch.float32)
new_time = torch.tensor(
Expand Down Expand Up @@ -209,9 +208,9 @@ def test_kaplan_meier_prediction_error_raised(self):
for batch in batch_container.batches:
(train_time, train_event, test_time, *_) = batch

train_event[-1] = (
False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
)
train_event[
-1
] = False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
km = KaplanMeierEstimator()
km(train_event, train_time, censoring_dist=False)

Expand Down
1 change: 0 additions & 1 deletion tests/test_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class TestMometum(unittest.TestCase):

def test_momentum_weibull(self):
model = Momentum(
backbone=nn.Sequential(nn.Linear(8, 2)), # Weibull expect two ouputs
Expand Down
9 changes: 0 additions & 9 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def get_input_array(self) -> Tuple[np.array, np.array, np.array, np.array]:
)

def _generate_input(self):

# random maximum time in observational period
tmax = torch.randint(5, 500, (1,)).item()

Expand All @@ -297,7 +296,6 @@ def _generate_input(self):
self._generate_new_time()

def _generate_data(self, tmax: int, n_train: int, n_test: int):

# time-to-event or censoring in train
train_time = torch.randint(1, tmax + 1, (n_train,)).float()

Expand Down Expand Up @@ -340,7 +338,6 @@ def _generate_data(self, tmax: int, n_train: int, n_test: int):
def _enforce_conditions_data(
self, time: torch.tensor, event: torch.tensor, dataset_type: str
) -> Tuple[torch.tensor, torch.tensor]:

# if test max time should be greater than train max time
if dataset_type == "test":
if self.test_max_time_gt_train_max_time:
Expand Down Expand Up @@ -395,15 +392,13 @@ def _enforce_conditions_data(
return time, event

def _generate_estimate(self):

# random risk score for observations in test
estimate = torch.randn(len(self.test_event))

# enforce conditions risk score
self.estimate = self._enforce_conditions_estimate(estimate)

def _enforce_conditions_estimate(self, estimate: torch.tensor) -> torch.tensor:

# if there should be ties in risk score associated to patients with event
if self.ties_score_events:
estimate[torch.where(self.test_event == 1.0)[0][0]] = estimate[
Expand All @@ -425,7 +420,6 @@ def _enforce_conditions_estimate(self, estimate: torch.tensor) -> torch.tensor:
return estimate

def _generate_new_time(self):

if torch.all(self.test_event == False):
# if all patients are censored in test, no evaluation time
new_time = torch.tensor([])
Expand All @@ -447,7 +441,6 @@ def _generate_new_time(self):
self.new_time = self._enforce_conditions_time(new_time)

def _enforce_conditions_time(self, new_time: torch.tensor) -> torch.tensor:

# if the test max time should be included in evaluation time
if self.test_max_time_in_new_time:
new_time = torch.cat(
Expand All @@ -457,7 +450,6 @@ def _enforce_conditions_time(self, new_time: torch.tensor) -> torch.tensor:
return new_time

def _evaluate_conditions(self):

# are there ties in event times
self.has_train_ties_time_event = self._has_ties(
self.train_time[self.train_event == 1]
Expand Down Expand Up @@ -614,7 +606,6 @@ def generate_batches(self, n_batch: int, flags_to_set: list):
n_batch = len(flags_to_set)

for i in range(n_batch):

if i >= len(flags_to_set):
# simulate data without flag
self.generate_one_batch()
Expand Down
Loading