From dafdb404ee848e366bb22922d5d1a700e4b71293 Mon Sep 17 00:00:00 2001 From: Dembowska Date: Tue, 17 Dec 2024 13:50:11 +0100 Subject: [PATCH] small changes to both the notebook and loss files --- docs/notebooks/introduction.ipynb | 159 ++++++++++++++++++++----- docs/notebooks/loss_time_covariates.py | 12 +- docs/notebooks/time_varying.ipynb | 76 +++++++----- src/torchsurv/loss/time_covariates.py | 45 ------- 4 files changed, 179 insertions(+), 113 deletions(-) delete mode 100644 src/torchsurv/loss/time_covariates.py diff --git a/docs/notebooks/introduction.ipynb b/docs/notebooks/introduction.ipynb index bdfb74d..b42a79d 100644 --- a/docs/notebooks/introduction.ipynb +++ b/docs/notebooks/introduction.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "013dbcb4", "metadata": {}, "outputs": [], @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "2601dd00-7bd2-49d5-9bdf-a84205872890", "metadata": {}, "outputs": [], @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "d7a98ea2-100f-43ef-8c45-c786ddcd313e", "metadata": {}, "outputs": [ @@ -79,7 +79,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "CUDA-enabled GPU/TPU is available.\n" + "No CUDA-enabled GPU found, using CPU.\n" ] } ], @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "1df49737-dc02-4d6b-acd7-d03b79f18a29", "metadata": { "scrolled": true @@ -225,7 +225,7 @@ "4 no 73 Post 35 II 1 26 65 772 1" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -270,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "7a5fd9ef-2643-46b7-9c98-05ff919026ea", "metadata": {}, "outputs": [ @@ -399,7 +399,7 @@ "4 0.0 1.0 0.0 " ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -416,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "0f8b7f3b-fb2a-4d74-ac99-8f6390b2f5eb", "metadata": {}, "outputs": [ @@ -446,7 +446,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "326c03fc-91f1-493b-a9ba-820de17fb2f8", "metadata": {}, "outputs": [], @@ -465,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "570386fb-f0ea-4061-bae2-11b274e7f851", "metadata": {}, "outputs": [ @@ -473,10 +473,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "x (shape) = torch.Size([128, 9])\n", + "x (shape) = torch.Size([32, 9])\n", "num_features = 9\n", - "event = torch.Size([128])\n", - "time = torch.Size([128])\n" + "event = torch.Size([32])\n", + "time = torch.Size([32])\n" ] } ], @@ -517,7 +517,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "9c2bd89a-c90a-4795-aab5-b5c21906a0de", "metadata": {}, "outputs": [], @@ -534,6 +534,109 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7d97e65d", + "metadata": {}, + "outputs": [], + "source": [ + "# This is for testing the loss function\n", + "x_test, (test_event, test_time) = next(iter(dataloader_train))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5d102dad", + "metadata": {}, + "outputs": [], + "source": [ + "log_hz = cox_model(x_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "210e6755", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_test torch.Size([32, 9])\n", + "events torch.Size([32])\n", + "times torch.Size([32])\n", + "\n", + "time_sorted torch.Size([32])\n", + "log_hz_sorted torch.Size([32, 1])\n", + "event_sorted torch.Size([32])\n", + "time_unique torch.Size([30])\n", + "------------------------------\n", + "covariates torch.Size([32, 9])\n", + "cov_inner torch.Size([32, 32])\n", + "log_nom_left torch.Size([1, 32])\n", + "bracket torch.Size([32, 9])\n", + "log_nom_right torch.Size([32, 32])\n", + "sum_nom torch.Size([1, 32])\n", + "log_denom torch.Size([1, 32])\n", + "last_bit torch.Size([1, 32])\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([[-1.4683e+04, -2.9827e+03, -2.9461e+04, -4.0582e+04, -1.7949e+04,\n", + " -1.4714e+05, -1.4940e+03, -7.7085e+04, -5.3855e+04, -9.3090e+03,\n", + " -9.8543e+03, -1.8929e+05, -5.1617e+03, -4.4286e+03, -9.6604e+04,\n", + " -1.5469e+04, -2.7680e+04, -6.3136e+04, -1.2045e+05, -9.3347e+04,\n", + " -1.7911e+05, -1.3205e+05, -1.6203e+05, -3.0884e+04, -2.3050e+03,\n", + " -2.1324e+05, -1.7852e+06, -1.7429e+04, -2.9495e+05, -8.4400e+03,\n", + " -5.5583e+04, 1.2975e+05]], grad_fn=)" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print('x_test', x_test.shape)\n", + "print('events', test_event.shape)\n", + "print('times', test_time.shape)\n", + "\n", + "time_sorted, idx = torch.sort(time)\n", + "log_hz_sorted = log_hz[idx]\n", + "event_sorted = event[idx]\n", + "time_unique = torch.unique(time_sorted)\n", + "print('')\n", + "print(\"time_sorted\", time_sorted.shape)\n", + "print('log_hz_sorted', log_hz_sorted.shape)\n", + "print('event_sorted', event_sorted.shape)\n", + "print(\"time_unique\", time_unique.shape)\n", + "\n", + "print('-'*30)\n", + "cov_fake = torch.clone(x_test)\n", + "print('covariates', cov_fake.shape)\n", + "covariates_sorted = cov_fake[idx, :]\n", + "covariate_inner_product = torch.matmul(covariates_sorted, covariates_sorted.T)\n", + "print('cov_inner', covariate_inner_product.shape)\n", + "log_nominator_left = torch.matmul(log_hz_sorted.T, covariate_inner_product)\n", + "print('log_nom_left', log_nominator_left.shape)\n", + "bracket = torch.mul(log_hz_sorted, covariates_sorted)\n", + "print('bracket', bracket.shape)\n", + "log_nominator_right = torch.matmul(bracket, bracket.T)\n", + "print('log_nom_right', log_nominator_right.shape)\n", + "sum_nominator_right = torch.sum(log_nominator_right, dim=0).unsqueeze(0)\n", + "print('sum_nom', sum_nominator_right.shape)\n", + "log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0).T\n", + "print('log_denom', log_denominator.shape)\n", + "last_bit = torch.div(log_nominator_left - sum_nominator_right, log_denominator)\n", + "print('last_bit', last_bit.shape)\n", + "last_bit\n" + ] + }, { "cell_type": "markdown", "id": "97c90244", @@ -544,7 +647,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "d7889dc1-1cfa-424e-a586-481cbc789581", "metadata": {}, "outputs": [ @@ -552,16 +655,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 000, Training loss: 12.75\n", - "Epoch: 010, Training loss: 12.02\n", - "Epoch: 020, Training loss: 11.79\n", - "Epoch: 030, Training loss: 11.84\n", - "Epoch: 040, Training loss: 11.61\n", - "Epoch: 050, Training loss: 11.61\n", - "Epoch: 060, Training loss: 11.46\n", - "Epoch: 070, Training loss: 11.57\n", - "Epoch: 080, Training loss: 11.56\n", - "Epoch: 090, Training loss: 11.20\n" + "Epoch: 000, Training loss: 31.85\n", + "Epoch: 010, Training loss: 30.18\n", + "Epoch: 020, Training loss: 29.73\n", + "Epoch: 030, Training loss: 29.84\n", + "Epoch: 040, Training loss: 29.04\n", + "Epoch: 050, Training loss: 29.61\n", + "Epoch: 060, Training loss: 29.46\n", + "Epoch: 070, Training loss: 28.94\n", + "Epoch: 080, Training loss: 29.31\n", + "Epoch: 090, Training loss: 28.00\n" ] } ], @@ -1183,7 +1286,7 @@ ], "metadata": { "kernelspec": { - "display_name": "torchsurv_env", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1197,7 +1300,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/docs/notebooks/loss_time_covariates.py b/docs/notebooks/loss_time_covariates.py index 9e57ccc..e54bdef 100644 --- a/docs/notebooks/loss_time_covariates.py +++ b/docs/notebooks/loss_time_covariates.py @@ -9,19 +9,11 @@ def neg_partial_time_log_likelihood( event: torch.Tensor, time: torch.Tensor, ties_method: str = "efron", - reduction: str = "mean", - checks: bool = True, + reduction: str = "mean" ) -> torch.Tensor: ''' - THIS FUNCTION IS NOT DONE, i HAVENT TESTED THE NEGATIVE PART YET + needs further work ''' - if checks: - _check_inputs(log_hz, event, time) - - if any([event.sum() == 0, len(log_hz.size()) == 0]): - warnings.warn("No events OR single sample. Returning zero loss for the batch") - return torch.tensor(0.0, requires_grad=True) - # sort data by time-to-event or censoring time_sorted, idx = torch.sort(time) log_hz_sorted = log_hz[idx] diff --git a/docs/notebooks/time_varying.ipynb b/docs/notebooks/time_varying.ipynb index dc05839..4bd6be9 100644 --- a/docs/notebooks/time_varying.ipynb +++ b/docs/notebooks/time_varying.ipynb @@ -21,8 +21,10 @@ "\n", "### Context and statistical set-up\n", "\n", - "Let $T^*_i$ be the be the failure time of interest for subject $i$ and $C$ be the censoring time. Let $T_i = min(T^*, C)$. We use $\\delta_i$ to denote whether $T*_i$ was observed. We will use $Z(t)$ to denote the value of of covariate $Z$ and time $t$. \n", - "We use $Z(t) to denote the value of Z at time $t$ and $\\overline{Z}(t)$ to denote the set of covariates from the beggining up to time $t$: $ \\overline{Z}(t) = \\{ Z(s): 0 \\leq s \\leq t\\}$.\n", + "Let $i$ e the index for some subject $i$ with a failute time denoted as $\\tau^*_i$ and $C$ be the censoring time. For the moment $C$ remains constant but there are extensions that allow for $C$ to vary over $i$. Let $\\tau_i = min(\\tau^*_i, C)$. We use $\\delta_i$ to denote whether $\\tau^*_i$ was observed. \n", + "\n", + "We will use $Z(t)$ to denote the value of of covariate $Z$ and time $t$. \n", + "We use $Z(t)$ to denote the value of Z at time $t$ and $\\overline{Z}(t)$ to denote the set of covariates from the beggining up to time $t$: $ \\overline{Z}(t) = \\{ Z(s): 0 \\leq s \\leq t\\}$.\n", "Let $t_k$ for $k \\in \\{1, \\dots, K\\} denote the time points at which the covariates are observed. For the moment, we assume that all subjects have been observed on the same time grid. $R_k$ is the set of individuals who are at risk at $t_k$. \n", "\n", "The conditional hazard function of $T$ given $\\overline{Z}(t)$ is defined as\n", @@ -64,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -88,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -113,12 +115,18 @@ "source": [ "## Simulating a dataset\n", "\n", - "We will simulate a dataset of 100 subjects with 6 follow up times where a covariate is observed. The covariates will change over time slightly but will be generated from one random variable per subject so that " + "We will simulate a dataset of 100 subjects with 10 follow up times where a covariate is observed. The covariates will follow a trigonometric function over time and will be dependant on a random variable to differentiate between subjects.\n", + "\n", + "For each $i$ the covariate follows the function:\n", + "\n", + "$$ Z_i(t) = a_i \\cos(2 \\pi t) $$\n", + "\n", + "where $a_i \\sim N(5, 2.5)$." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -130,8 +138,8 @@ "\n", "# create random variables following a normal distribution N(1,1) for each subject \n", "mean = 5\n", - "standard_dev = 5\n", - "random_vars = torch.randn(sample_size)#*standard_dev + mean\n", + "standard_dev = 2.5\n", + "random_vars = torch.randn(sample_size)*standard_dev + mean\n", "\n", "# using the random variables from above, we create a set of covariates for each subject \n", "t = torch.linspace(0, 2*math.pi, obs_time) # Generating 6 equidistant time points from 0 to 2*pi\n", @@ -168,21 +176,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# make random positive time to event\n", - "time = torch.floor(random_vars * 10)+4\n", - "time[time>9]=9\n", - "time[time<1]=0\n", - "#print(time) \n", + "time = torch.floor(random_vars)\n", + "# print(time) \n", "# tensor([1.2792e+01, -7.7415e+00, 9.2325e+00, 1.0845e+01, 7.6460e+00, ...\n", "\n", - "# decide who has an event, here we cosnider those whose time is greater than one (this means some a small subroup has not experienced an event)\n", - "events = time > 1\n", + "# decide who has an event, here we cosnider those whose time is greater than one and smaller than 9\n", + "events = (time > 1) & (time < 8)\n", "# tensor([ True, True, False, False, True, ...\n", - "#print(events)\n", + "# print(events)\n", "\n", "# remove the covariates for those who have observed an event\n", "\n", @@ -191,33 +197,43 @@ " time_cap = int(time[i])\n", " covars[i, time_cap:] = torch.zeros(obs_time-time_cap)\n", "\n", - "# covars should be tensor([[ 3.3737e-01, 2.5844e-01, 5.8584e-02, -1.6869e-01, -3.1702e-01, ... and zeros after an event occured\n", - "#print(covars)" + "# covars should be tensor([[ 3.3737e-01, 2.5844e-01, 5.8584e-02, -1.6869e-01, -3.1702e-01, ... \n", + "# and zeros after an event occured\n", + "\n", + "# print(covars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training the RNN \n", + "\n", + "Below we will give an example set up of how to use the partial log likelihood in a loss function. We import the python file containg the loss and set up an RNN to work with our simulated data." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "from loss_time_covariates import _partial_likelihood_time_cox" + "from loss_time_covariates import neg_partial_time_log_likelihood" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [ { - "ename": "ImportError", - "evalue": "cannot import name 'time_covariates' from 'torchsurv.loss' (/home/demboso1/conda-env2/lib/python3.10/site-packages/torchsurv/loss/__init__.py)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[160], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchsurv\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloss\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m time_covariates\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#from torchsurv.metrics.cindex import ConcordanceIndex\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Parameters\u001b[39;00m\n\u001b[1;32m 5\u001b[0m input_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", - "\u001b[0;31mImportError\u001b[0m: cannot import name 'time_covariates' from 'torchsurv.loss' (/home/demboso1/conda-env2/lib/python3.10/site-packages/torchsurv/loss/__init__.py)" + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([10, 100, 1])\n", + "torch.Size([10, 100, 1])\n", + "torch.Size([2, 100, 1])\n", + "torch.Size([10, 100, 1])\n" ] } ], diff --git a/src/torchsurv/loss/time_covariates.py b/src/torchsurv/loss/time_covariates.py deleted file mode 100644 index 5f6cb59..0000000 --- a/src/torchsurv/loss/time_covariates.py +++ /dev/null @@ -1,45 +0,0 @@ -import sys -import warnings - -import torch - -def time_partial_log_likelihood( - log_hz: torch.Tensor, #nx1 vector - event: torch.Tensor, #n vector (i think) - time: torch.Tensor, #n vector (i think) - covariates: torch.Tensor, #nxp vector, p number of params -) -> torch.Tensor: - - # sort data by time-to-event or censoring - time_sorted, idx = torch.sort(time) - log_hz_sorted = log_hz[idx] - event_sorted = event[idx] - - #keep log if we can - exp_log_hz = torch.exp(log_hz_sorted) - #remove mean over time from covariates - #sort covariates so that the rows match the ordering - covariates_sorted = covariates[idx, :] - covariates.mean(dim=0) - - #the left hand side (HS) of the equation - #below is Z_k Z_k^T - i think it should be a vector matrix dim nxn - covariate_inner_product = torch.matmul(covariates_sorted, covariates_sorted.T) - - #pointwise multiplication of vectors to get the nominator of left HS - #outcome in a vector of length n - # Ends up being (1, n) - log_nominator_left = torch.matmul(exp_log_hz.T, covariate_inner_product) - - #right hand size of the equation - #formulate the brackets \sum exp(theta)Z_k - bracket = torch.mul(exp_log_hz, covariates_sorted) - nominator_right = torch.matmul(bracket, bracket.T) #nxn matrix - ###not sure if the next line is this - #log_nominator_right = torch.sum(nominator_right, dim=0).unsqueeze(0) - ### or this - log_nominator_right = nominator_right[0,].unsqueeze(0) - #the denominator is the same on both sides - log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) #dim=0 sums over the oth dimension - partial_log_likelihood = torch.div(log_nominator_left - log_nominator_right, log_denominator) # (n, n) - - return (partial_log_likelihood)[event_sorted] \ No newline at end of file