Skip to content

Commit

Permalink
added comments to functions in cox.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
SoniaDem committed Nov 28, 2024
1 parent 63c2080 commit 7b90a04
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion src/torchsurv/loss/cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,28 @@ def _partial_likelihood_cox(
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
in the absence of ties in event time.
Args:
log_hz_sorted (torch.Tensor, float):
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
event_sorted (torch.Tensor, bool):
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
Returns:
(torch.tensor, float):
Vector of the partial log likelihoods.
Note:
Let :math:`\tau_1 < \tau_2 < \cdots < \tau_N`
be the ordered times and let :math:`R(\tau_i) = \{ j: \tau_j \geq \tau_i\}`
be the risk set at :math:`\tau_i`. The partial log likelihood is defined as:
.. math::
pll = \sum_{i: \: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right)
"""
log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)

return (log_hz_sorted - log_denominator)[event_sorted]


Expand All @@ -183,6 +203,19 @@ def _partial_likelihood_efron(
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Efron's method to handle ties in event time.
Args:
log_hz_sorted (torch.Tensor, float):
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
event_sorted (torch.Tensor, bool):
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
time_sorted (torch.Tensor):
Time-to-event values sorted in order.
time_unique (torch.Tensor):
Set of unique time-to-event values.
Returns:
(torch.tensor, float):
Vector of partial log likelihood estimated using Efron's method.
"""
J = len(time_unique)

Expand All @@ -206,6 +239,7 @@ def _partial_likelihood_efron(
log_denominator_efron[j] += torch.log(
denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j]
)

return (log_nominator - log_denominator_efron)[include]


Expand All @@ -216,14 +250,26 @@ def _partial_likelihood_breslow(
):
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Breslow's method to handle ties in event time.
Args:
log_hz_sorted (torch.Tensor, float):
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
event_sorted (torch.Tensor, bool):
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
time_sorted (torch.Tensor):
Time-to-event values sorted in order.
Returns:
(torch.tensor, float):
Vector containing partial log likelihood estimated using Breslow's method.
"""
N = len(time_sorted)

R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)]
log_denominator = torch.tensor(
[torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)]
)

return (log_hz_sorted - log_denominator)[event_sorted]


Expand Down

0 comments on commit 7b90a04

Please sign in to comment.