Skip to content

Commit

Permalink
small changes to both the notebook and loss files
Browse files Browse the repository at this point in the history
  • Loading branch information
SoniaDem committed Dec 17, 2024
1 parent 0d3905b commit dafdb40
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 113 deletions.
159 changes: 131 additions & 28 deletions docs/notebooks/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "013dbcb4",
"metadata": {},
"outputs": [],
Expand All @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "2601dd00-7bd2-49d5-9bdf-a84205872890",
"metadata": {},
"outputs": [],
Expand All @@ -71,15 +71,15 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "d7a98ea2-100f-43ef-8c45-c786ddcd313e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA-enabled GPU/TPU is available.\n"
"No CUDA-enabled GPU found, using CPU.\n"
]
}
],
Expand Down Expand Up @@ -107,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "1df49737-dc02-4d6b-acd7-d03b79f18a29",
"metadata": {
"scrolled": true
Expand Down Expand Up @@ -225,7 +225,7 @@
"4 no 73 Post 35 II 1 26 65 772 1"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -270,7 +270,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "7a5fd9ef-2643-46b7-9c98-05ff919026ea",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -399,7 +399,7 @@
"4 0.0 1.0 0.0 "
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -416,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "0f8b7f3b-fb2a-4d74-ac99-8f6390b2f5eb",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -446,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"id": "326c03fc-91f1-493b-a9ba-820de17fb2f8",
"metadata": {},
"outputs": [],
Expand All @@ -465,18 +465,18 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "570386fb-f0ea-4061-bae2-11b274e7f851",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x (shape) = torch.Size([128, 9])\n",
"x (shape) = torch.Size([32, 9])\n",
"num_features = 9\n",
"event = torch.Size([128])\n",
"time = torch.Size([128])\n"
"event = torch.Size([32])\n",
"time = torch.Size([32])\n"
]
}
],
Expand Down Expand Up @@ -517,7 +517,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "9c2bd89a-c90a-4795-aab5-b5c21906a0de",
"metadata": {},
"outputs": [],
Expand All @@ -534,6 +534,109 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7d97e65d",
"metadata": {},
"outputs": [],
"source": [
"# This is for testing the loss function\n",
"x_test, (test_event, test_time) = next(iter(dataloader_train))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5d102dad",
"metadata": {},
"outputs": [],
"source": [
"log_hz = cox_model(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "210e6755",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_test torch.Size([32, 9])\n",
"events torch.Size([32])\n",
"times torch.Size([32])\n",
"\n",
"time_sorted torch.Size([32])\n",
"log_hz_sorted torch.Size([32, 1])\n",
"event_sorted torch.Size([32])\n",
"time_unique torch.Size([30])\n",
"------------------------------\n",
"covariates torch.Size([32, 9])\n",
"cov_inner torch.Size([32, 32])\n",
"log_nom_left torch.Size([1, 32])\n",
"bracket torch.Size([32, 9])\n",
"log_nom_right torch.Size([32, 32])\n",
"sum_nom torch.Size([1, 32])\n",
"log_denom torch.Size([1, 32])\n",
"last_bit torch.Size([1, 32])\n"
]
},
{
"data": {
"text/plain": [
"tensor([[-1.4683e+04, -2.9827e+03, -2.9461e+04, -4.0582e+04, -1.7949e+04,\n",
" -1.4714e+05, -1.4940e+03, -7.7085e+04, -5.3855e+04, -9.3090e+03,\n",
" -9.8543e+03, -1.8929e+05, -5.1617e+03, -4.4286e+03, -9.6604e+04,\n",
" -1.5469e+04, -2.7680e+04, -6.3136e+04, -1.2045e+05, -9.3347e+04,\n",
" -1.7911e+05, -1.3205e+05, -1.6203e+05, -3.0884e+04, -2.3050e+03,\n",
" -2.1324e+05, -1.7852e+06, -1.7429e+04, -2.9495e+05, -8.4400e+03,\n",
" -5.5583e+04, 1.2975e+05]], grad_fn=<DivBackward0>)"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print('x_test', x_test.shape)\n",
"print('events', test_event.shape)\n",
"print('times', test_time.shape)\n",
"\n",
"time_sorted, idx = torch.sort(time)\n",
"log_hz_sorted = log_hz[idx]\n",
"event_sorted = event[idx]\n",
"time_unique = torch.unique(time_sorted)\n",
"print('')\n",
"print(\"time_sorted\", time_sorted.shape)\n",
"print('log_hz_sorted', log_hz_sorted.shape)\n",
"print('event_sorted', event_sorted.shape)\n",
"print(\"time_unique\", time_unique.shape)\n",
"\n",
"print('-'*30)\n",
"cov_fake = torch.clone(x_test)\n",
"print('covariates', cov_fake.shape)\n",
"covariates_sorted = cov_fake[idx, :]\n",
"covariate_inner_product = torch.matmul(covariates_sorted, covariates_sorted.T)\n",
"print('cov_inner', covariate_inner_product.shape)\n",
"log_nominator_left = torch.matmul(log_hz_sorted.T, covariate_inner_product)\n",
"print('log_nom_left', log_nominator_left.shape)\n",
"bracket = torch.mul(log_hz_sorted, covariates_sorted)\n",
"print('bracket', bracket.shape)\n",
"log_nominator_right = torch.matmul(bracket, bracket.T)\n",
"print('log_nom_right', log_nominator_right.shape)\n",
"sum_nominator_right = torch.sum(log_nominator_right, dim=0).unsqueeze(0)\n",
"print('sum_nom', sum_nominator_right.shape)\n",
"log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0).T\n",
"print('log_denom', log_denominator.shape)\n",
"last_bit = torch.div(log_nominator_left - sum_nominator_right, log_denominator)\n",
"print('last_bit', last_bit.shape)\n",
"last_bit\n"
]
},
{
"cell_type": "markdown",
"id": "97c90244",
Expand All @@ -544,24 +647,24 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 15,
"id": "d7889dc1-1cfa-424e-a586-481cbc789581",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 000, Training loss: 12.75\n",
"Epoch: 010, Training loss: 12.02\n",
"Epoch: 020, Training loss: 11.79\n",
"Epoch: 030, Training loss: 11.84\n",
"Epoch: 040, Training loss: 11.61\n",
"Epoch: 050, Training loss: 11.61\n",
"Epoch: 060, Training loss: 11.46\n",
"Epoch: 070, Training loss: 11.57\n",
"Epoch: 080, Training loss: 11.56\n",
"Epoch: 090, Training loss: 11.20\n"
"Epoch: 000, Training loss: 31.85\n",
"Epoch: 010, Training loss: 30.18\n",
"Epoch: 020, Training loss: 29.73\n",
"Epoch: 030, Training loss: 29.84\n",
"Epoch: 040, Training loss: 29.04\n",
"Epoch: 050, Training loss: 29.61\n",
"Epoch: 060, Training loss: 29.46\n",
"Epoch: 070, Training loss: 28.94\n",
"Epoch: 080, Training loss: 29.31\n",
"Epoch: 090, Training loss: 28.00\n"
]
}
],
Expand Down Expand Up @@ -1183,7 +1286,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "torchsurv_env",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -1197,7 +1300,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
12 changes: 2 additions & 10 deletions docs/notebooks/loss_time_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,11 @@ def neg_partial_time_log_likelihood(
event: torch.Tensor,
time: torch.Tensor,
ties_method: str = "efron",
reduction: str = "mean",
checks: bool = True,
reduction: str = "mean"
) -> torch.Tensor:
'''
THIS FUNCTION IS NOT DONE, i HAVENT TESTED THE NEGATIVE PART YET
needs further work
'''
if checks:
_check_inputs(log_hz, event, time)

if any([event.sum() == 0, len(log_hz.size()) == 0]):
warnings.warn("No events OR single sample. Returning zero loss for the batch")
return torch.tensor(0.0, requires_grad=True)

# sort data by time-to-event or censoring
time_sorted, idx = torch.sort(time)
log_hz_sorted = log_hz[idx]
Expand Down
Loading

0 comments on commit dafdb40

Please sign in to comment.