From 7b90a04ebccc27e53fd5c8aa98bf66a8f77dec65 Mon Sep 17 00:00:00 2001 From: demboso1 Date: Thu, 28 Nov 2024 11:33:20 +0000 Subject: [PATCH] added comments to functions in cox.py file --- src/torchsurv/loss/cox.py | 48 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index ce49f46..703caa6 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -170,8 +170,28 @@ def _partial_likelihood_cox( ) -> torch.Tensor: """Calculate the partial log likelihood for the Cox proportional hazards model in the absence of ties in event time. + + Args: + log_hz_sorted (torch.Tensor, float): + Log relative hazard of length n_samples, ordered by time-to-event or censoring. + event_sorted (torch.Tensor, bool): + Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring. + + Returns: + (torch.tensor, float): + Vector of the partial log likelihoods. + + Note: + Let :math:`\tau_1 < \tau_2 < \cdots < \tau_N` + be the ordered times and let :math:`R(\tau_i) = \{ j: \tau_j \geq \tau_i\}` + be the risk set at :math:`\tau_i`. The partial log likelihood is defined as: + + .. math:: + + pll = \sum_{i: \: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right) """ log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) + return (log_hz_sorted - log_denominator)[event_sorted] @@ -183,6 +203,19 @@ def _partial_likelihood_efron( ) -> torch.Tensor: """Calculate the partial log likelihood for the Cox proportional hazards model using Efron's method to handle ties in event time. + + Args: + log_hz_sorted (torch.Tensor, float): + Log relative hazard of length n_samples, ordered by time-to-event or censoring. + event_sorted (torch.Tensor, bool): + Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring. + time_sorted (torch.Tensor): + Time-to-event values sorted in order. + time_unique (torch.Tensor): + Set of unique time-to-event values. + Returns: + (torch.tensor, float): + Vector of partial log likelihood estimated using Efron's method. """ J = len(time_unique) @@ -206,6 +239,7 @@ def _partial_likelihood_efron( log_denominator_efron[j] += torch.log( denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j] ) + return (log_nominator - log_denominator_efron)[include] @@ -216,6 +250,18 @@ def _partial_likelihood_breslow( ): """Calculate the partial log likelihood for the Cox proportional hazards model using Breslow's method to handle ties in event time. + + Args: + log_hz_sorted (torch.Tensor, float): + Log relative hazard of length n_samples, ordered by time-to-event or censoring. + event_sorted (torch.Tensor, bool): + Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring. + time_sorted (torch.Tensor): + Time-to-event values sorted in order. + + Returns: + (torch.tensor, float): + Vector containing partial log likelihood estimated using Breslow's method. """ N = len(time_sorted) @@ -223,7 +269,7 @@ def _partial_likelihood_breslow( log_denominator = torch.tensor( [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] ) - + return (log_hz_sorted - log_denominator)[event_sorted]