Skip to content

Commit

Permalink
outcome simulation parameters improved
Browse files Browse the repository at this point in the history
  • Loading branch information
SoniaDem committed Jan 3, 2025
1 parent 8c754bc commit 8a81ae3
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 83 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "conda-env2",
"language": "python",
"name": "python3"
},
Expand Down
196 changes: 114 additions & 82 deletions docs/notebooks/time_varying.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -84,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -95,7 +95,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -118,17 +118,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulating a dataset: first approach to test dimensions but doesn't guarantee meaningful results\n",
"\n",
"We will simulate a dataset of 100 subjects with 10 follow up times where a covariate is observed. The covariates will follow a trigonometric function over time and will be dependant on a random variable to differentiate between subjects.\n",
"\n",
"For each $i$ the covariate follows the function:\n",
"\n",
"$$ Z_i(t) = a_i \\cos(2 \\pi t) $$\n",
"\n",
"where $a_i \\sim N(5, 2.5)$.\n",
"\n",
"## Proper simulation guidance: data that can be interpreted\n",
"## Simulating realistic data\n",
"\n",
"A good approach for simulating data is described in detail by [Ngwa et al 2020](https://pmc.ncbi.nlm.nih.gov/articles/PMC7731987/). If this is not yet implemented, it would be a good way of starting to ensure that both methods work as expected. There are tow parts in simulating such a dataset. First, simulating the longitudina lobservational data and then the survival data. Below we describe methodologies for both.\n",
"\n",
Expand All @@ -151,12 +141,71 @@
"\n",
"$$ V = Z_i GZ_i ^T + R_i, \\text{ where }Z_i = [[1,1,1,1,1,1]^T, [0,5,10,15,20,25]^T]$$\n",
"\n",
"and $R_i = diag(\\sigma^2)$ and $\\sigma^2$ is set to $0.1161$."
"and $R_i = diag(\\sigma^2)$ and $\\sigma^2$ is set to $0.1161$.\n",
"\n",
"Note: Compared to the paper, we slightly adjust steps 3 and 4 from the simulation algorithm section (6.1) to avoid fitting a random effects model which adds more complexity in terms of data formatting. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[34.2016, 34.2866, 34.3716, 34.4566, 34.5416, 34.6266],\n",
" [33.4380, 33.4018, 33.3657, 33.3295, 33.2933, 33.2572],\n",
" [31.5581, 31.5498, 31.5415, 31.5332, 31.5248, 31.5165],\n",
" [35.7813, 35.8513, 35.9212, 35.9912, 36.0611, 36.1310]])\n"
]
}
],
"source": [
"import torch.distributions as dist\n",
"\n",
"# Set random seed for reproducibility\n",
"torch.manual_seed(123)\n",
"\n",
"n = 100 # Number of subjects\n",
"T = 6 # Number of time points\n",
"time_vec = torch.tensor([0, 5, 10, 15, 20, 25])\n",
"\n",
"# Simulation parameters\n",
"age_mean = 35\n",
"age_std = 5\n",
"sex_prob = 0.54\n",
"G = torch.tensor([[0.29, -0.00465],[-0.00465, 0.000320]])\n",
"Z = torch.tensor([[1, 1, 1, 1, 1, 1], time_vec], dtype=torch.float32).T\n",
"sigma = torch.tensor([0.1161])\n",
"alpha = 1\n",
"\n",
"# Simulate age at baseline\n",
"age_dist = dist.Normal(age_mean, age_std)\n",
"age = age_dist.sample((n,))\n",
"\n",
"# Simulate sex\n",
"sex_dist = dist.Bernoulli(probs=sex_prob)\n",
"sex = sex_dist.sample((n,))\n",
"\n",
"# Simulate random effects\n",
"random_effects_dist = dist.MultivariateNormal(torch.zeros(2), G)\n",
"random_effects = random_effects_dist.sample((n,))\n",
"\n",
"# sample random error\n",
"error_sample = dist.Normal(0, sigma).sample((n,))\n",
"\n",
"# Generate expected longitudinal trajectories\n",
"# quite frakly this is useless now - it was based on my bad understanding of the algorithm\n",
"trajectories = random_effects[:, 0].unsqueeze(1) + random_effects[:, 1].unsqueeze(1) * Z[:,1] + alpha * age.unsqueeze(1) + error_sample\n",
"\n",
"print(trajectories[1:5, :])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -178,13 +227,14 @@
"\n",
"n = 100 # Number of subjects\n",
"T = 6 # Number of time points\n",
"time_vec = torch.tensor([0, 5, 10, 15, 20, 25])\n",
"\n",
"# Simulation parameters\n",
"age_mean = 35\n",
"age_std = 5\n",
"sex_prob = 0.54\n",
"G = torch.tensor([[0.29, -0.00465],[-0.00465, 0.000320]])\n",
"Z = torch.tensor([[1, 1, 1, 1, 1, 1], [0, 5, 10, 15, 20, 25]], dtype=torch.float32).T\n",
"Z = torch.tensor([[1, 1, 1, 1, 1, 1], time_vec], dtype=torch.float32).T\n",
"sigma = torch.tensor([0.1161])\n",
"alpha = 1\n",
"\n",
Expand Down Expand Up @@ -244,14 +294,14 @@
"Generate the time-to-event $\\tau_j$ using the following equations for the Cox Exponential model:\n",
"$$ t = \\frac{1}{\\gamma \\cdot b_{i2}} W \\Big( \\frac{-\\gamma(b_{i2}) \\log(Q)}{\\lambda \\exp (X^T \\alpha + \\gamma(b_{i1}))} \\Big). $$\n",
"\n",
"Where $W$ is the Lambert W function (LWF) first proposed by [Corless et al. 1996](https://link.springer.com/article/10.1007/BF02124750) provide a history, theory and applications of the LWF. The LWF also known as Omega function is the inverse of the function $f(p) = p \\cdot \\exp(p) $.\n",
"Where $W$ is the Lambert W function (LWF) first proposed by [Corless et al. 1996](https://link.springer.com/article/10.1007/BF02124750) provide a history, theory and applications of the LWF. The LWF is the inverse of the function $f(p) = p \\cdot \\exp(p) $.\n",
"\n",
"Generate the censoring variable $C \\sim Unif⁡(25, 30)$ for censoring to occur later in study. From the survival and censoring times, we obtain the censoring indicator $\\delta_i$ which is defined as 1 if $\\tau_j < C_i$ and 0 otherwise.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -260,6 +310,17 @@
"from scipy.special import lambertw"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: pre-determined parameters such as $\\alpha, \\gamma, \\lambda_0$ have a large effect on the event time outcomes, the values used here are:\n",
"- $\\alpha_{age} = 0.05$,\n",
"- $\\alpha_{sex} = -0.5$,\n",
"- $\\gamma = 0.1$,\n",
"- $\\lambda_0 = 0.05$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -268,50 +329,44 @@
{
"data": {
"text/plain": [
"tensor([2.6643e-08+0.j, 7.9076e-08+0.j, 1.0829e-07+0.j, 2.2052e-07+0.j, 2.5905e-08+0.j,\n",
" 9.7290e-09+0.j, 8.7800e-08+0.j, 1.5979e-07+0.j, 7.4499e-09+0.j, 1.2207e-06+0.j,\n",
" 8.8470e-09+0.j, 1.7203e-06+0.j, 3.2379e-08+0.j, 4.4401e-10+0.j, 1.7946e-08+0.j,\n",
" 2.5867e-08+0.j, 7.5077e-08+0.j, 7.9405e-07+0.j, 9.4477e-08+0.j, 6.2436e-09+0.j,\n",
" 9.7389e-10+0.j, 1.9608e-09+0.j, 2.0911e-09+0.j, 1.0234e-07+0.j, 1.7900e-08+0.j,\n",
" 3.1549e-06+0.j, 4.5246e-09+0.j, 8.9455e-07+0.j, 9.1886e-07+0.j, 2.7448e-08+0.j,\n",
" 3.0028e-07+0.j, 7.5052e-05+0.j, 1.7000e-06+0.j, 3.9642e-08+0.j, 1.2840e-05+0.j,\n",
" 4.4597e-09+0.j, 1.6356e-07+0.j, 5.2418e-07+0.j, 1.7659e-07+0.j, 7.6708e-07+0.j,\n",
" 6.8524e-09+0.j, 1.0414e-09+0.j, 1.2928e-08+0.j, 7.8624e-08+0.j, 3.9656e-08+0.j,\n",
" 6.4054e-08+0.j, 3.1831e-08+0.j, 2.1889e-07+0.j, 5.6625e-09+0.j, 8.7863e-10+0.j,\n",
" 2.0402e-06+0.j, 2.9838e-09+0.j, 1.0714e-06+0.j, 1.8990e-09+0.j, 4.4492e-09+0.j,\n",
" 4.3770e-08+0.j, 2.2484e-10+0.j, 3.7806e-07+0.j, 1.2497e-07+0.j, 9.0410e-09+0.j,\n",
" 3.2210e-07+0.j, 4.5522e-07+0.j, 3.2021e-08+0.j, 2.8577e-08+0.j, 1.7105e-08+0.j,\n",
" 5.1632e-09+0.j, 9.6815e-08+0.j, 9.2736e-07+0.j, 6.9194e-08+0.j, 1.0694e-09+0.j,\n",
" 2.0602e-05+0.j, 3.3414e-09+0.j, 9.4600e-09+0.j, 1.0943e-08+0.j, 6.6366e-07+0.j,\n",
" 7.4981e-08+0.j, 7.4259e-08+0.j, 3.7868e-08+0.j, 2.0600e-08+0.j, 2.3038e-06+0.j,\n",
" 2.6782e-08+0.j, 2.4470e-07+0.j, 4.1181e-09+0.j, 6.4102e-09+0.j, 6.7062e-08+0.j,\n",
" 2.0507e-07+0.j, 7.4554e-09+0.j, 5.7818e-08+0.j, 5.4998e-08+0.j, 6.2367e-07+0.j,\n",
" 8.7814e-07+0.j, 2.0683e-05+0.j, 1.0754e-05+0.j, 4.9900e-08+0.j, 7.3498e-07+0.j,\n",
" 1.0718e-08+0.j, 5.9085e-07+0.j, 3.0981e-08+0.j, 1.0097e-08+0.j, 5.8124e-07+0.j],\n",
" dtype=torch.complex128)"
"tensor([ 9.6428, 6.8019, 7.1837, 8.1690, 10.6510, 5.2226, 7.0858, 11.3846,\n",
" 5.4684, 14.3864, 5.1831, 9.3314, 6.3816, 3.8954, 10.0959, 6.0119,\n",
" 11.7046, 12.7777, 10.7462, 8.4370, 4.6285, 4.8617, 4.3450, 10.8670,\n",
" 10.1935, 16.2546, 5.4758, 8.5248, 9.2135, 9.4407, 11.7310, 21.0234,\n",
" 14.0767, 9.1752, 18.8326, 8.9085, 11.2594, 8.9873, 7.5456, 8.4984,\n",
" 9.0333, 4.8472, 9.4688, 7.7191, 6.2192, 6.4989, 9.8902, 7.8185,\n",
" 5.2405, 4.2516, 9.3067, 5.0147, 8.3767, 4.9315, 8.5749, 11.3669,\n",
" 6.0864, 7.5788, 11.8391, 8.8440, 12.2118, 13.6110, 6.2863, 5.8571,\n",
" 9.5126, 8.6607, 6.8886, 15.5586, 10.6941, 7.2345, 18.2753, 5.4170,\n",
" 5.2679, 9.0509, 12.9154, 11.2252, 7.4939, 6.5494, 10.3731, 14.2850,\n",
" 5.7533, 12.2423, 5.6055, 5.2892, 11.0855, 11.8667, 5.5114, 11.4350,\n",
" 10.3182, 12.8253, 14.6775, 19.0688, 17.0049, 6.3822, 14.5267, 8.7058,\n",
" 8.2680, 10.7909, 5.2648, 12.7710], dtype=torch.float64)"
]
},
"execution_count": 38,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Specify the values for parameters, generate the random variables and call on relevant variables defined previously\n",
"\n",
"alpha = torch.tensor([0.5, -0.2]) # regression coefficient for time-invariant covariates\n",
"gamma = torch.tensor(0.3) # association strength between longitudinal measures and time-to-event\n",
"lambda_0 = torch.tensor(0.1) # baseline hazard rate\n",
"alpha = torch.tensor([0.05, -0.5]) # regression coefficient for time-invariant covariates\n",
"gamma = torch.tensor(0.1) # association strength between longitudinal measures and time-to-event\n",
"lambda_0 = torch.tensor(0.05) # baseline hazard rate\n",
"\n",
"# Generate the random variables for hazard of a subject and censoring\n",
"Q = dist.Uniform(0, 1).sample() # Random variable for hazard (Q)\n",
"C = dist.Uniform(25, 30).sample() # Random variable for censoring\n",
"C = dist.Uniform(20, 30).sample() # Random variable for censoring\n",
"\n",
"# age and sex are the names of variables corresponding to those covariates\n",
"# create the X matrix of covariates\n",
"XX = torch.stack((age, sex), dim=1)\n",
"\n",
"# b1 = torch.tensor([4.250]), b2 = torch.tensor([0.250])\n",
"# get b1 and b2 from the random sample we made before\n",
"b1 = random_effects[:, 0]\n",
"b2 = random_effects[:, 1]\n",
"\n",
"# Generate time to event T using the equation above\n",
"log_Q = torch.log(Q)\n",
Expand All @@ -321,9 +376,20 @@
"lambert_W = lambertw(-lambert_W_nominator/(lambda_0*lambert_W_denominator))\n",
"time_to_event = lambert_W/(gamma*b2)\n",
"\n",
"#take the real part of the LBF, the complex part is =0\n",
"outcome_LWF = time_to_event.real\n",
"\n",
"# implement censoring with some level of intensity\n",
"time_to_event\n",
"#needs to be scaled and floored to be a reasonable time to event I think "
"outcome_LWF\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A simpler method for generating the time-to-event where the covariate is assumed to have a more straightforward relation in time $Z(t) = kt$ for some $k>0$. This approach is sugested by [Peter C. Austin 2012](https://pmc.ncbi.nlm.nih.gov/articles/PMC3546387/pdf/sim0031-3946.pdf) and here \n",
"$$ t = \\frac{1}{\\gamma k} \\log \\Big ( 1 + \\frac{\\gamma k (-log(u))}{\\lambda \\exp(\\alpha X)}\\Big). $$\n",
"The above equation has been adapted to remain consistent with the parameters defined before. In our case, $k$ could be replaced with $b_{i2}$ if $b_{i2}$ would be sampled such that it is strictly positive. In the above configuration that is not the case."
]
},
{
Expand All @@ -343,40 +409,6 @@
"- impute based on some model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# CODE BELOW IS OLD, WILL REMOVE SOON\n",
"# # make random positive time to event\n",
"# time = torch.floor(random_vars)\n",
"# # this is a workaround the loss function. This is done so that when we find the right\n",
"# # indices in the log_hz we don't try to pick up things that are out of bounds.\n",
"# time[time<0] = 0\n",
"# time[time>9] = 9\n",
"# # print(time) \n",
"# # tensor([1.2792e+01, -7.7415e+00, 9.2325e+00, 1.0845e+01, 7.6460e+00, ...\n",
"\n",
"# # decide who has an event, here we cosnider those whose time is greater than one and smaller than 9\n",
"# events = (time > 1) & (time < 8)\n",
"# # tensor([ True, True, False, False, True, ...\n",
"# # print(events)\n",
"\n",
"# # remove the covariates for those who have observed an event\n",
"\n",
"# for i in range(sample_size):\n",
"# if events[i]==True:\n",
"# time_cap = int(time[i])\n",
"# covars[i, time_cap:] = torch.zeros(obs_time-time_cap)\n",
"\n",
"# # covars should be tensor([[ 3.3737e-01, 2.5844e-01, 5.8584e-02, -1.6869e-01, -3.1702e-01, ... \n",
"# # and zeros after an event occured\n",
"\n",
"# # print(covars)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 8a81ae3

Please sign in to comment.