Skip to content

Commit 9aac95c

Browse files
authored
Merge pull request #51 from StochasticTree/multivariate_example_fix
Updated multivariate treatment python demo to be an observational study
2 parents 99b33b6 + 8a9fbbb commit 9aac95c

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

demo/notebooks/multivariate_treatment_causal_inference.ipynb

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@
4848
"n = 5000\n",
4949
"p_X = 5\n",
5050
"X = rng.uniform(0, 1, (n, p_X))\n",
51-
"pi_X = 0.25 + 0.5*X[:,0]\n",
51+
"pi_X = np.c_[0.25 + 0.5*X[:,0], 0.75 - 0.5*X[:,1]]\n",
5252
"# Z = rng.uniform(0, 1, (n, 2))\n",
53-
"Z = rng.binomial(1, 0.5, (n, 2))\n",
53+
"Z = rng.binomial(1, pi_X, (n, 2))\n",
5454
"\n",
5555
"# Define the outcome mean functions (prognostic and treatment effects)\n",
56-
"mu_X = pi_X*5 + 2*X[:,2]\n",
56+
"mu_X = pi_X[:,0]*5 + pi_X[:,1]*2 + 2*X[:,2]\n",
5757
"tau_X = np.stack((X[:,1], X[:,2]), axis=-1)\n",
5858
"\n",
5959
"# Generate outcome\n",
@@ -129,6 +129,15 @@
129129
"plt.show()"
130130
]
131131
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": null,
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"np.sqrt(np.mean(np.power(y_avg_mcmc - y_test, 2)))"
139+
]
140+
},
132141
{
133142
"cell_type": "code",
134143
"execution_count": null,
@@ -144,6 +153,21 @@
144153
"plt.show()"
145154
]
146155
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
161+
"source": [
162+
"treatment_idx = 1\n",
163+
"forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:,:,treatment_idx])\n",
164+
"tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)\n",
165+
"tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test[:,treatment_idx],1), tau_avg_mcmc), axis = 1), columns=[\"True tau\", \"Average estimated tau\"])\n",
166+
"sns.scatterplot(data=tau_df_mcmc, x=\"True tau\", y=\"Average estimated tau\")\n",
167+
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
168+
"plt.show()"
169+
]
170+
},
147171
{
148172
"cell_type": "code",
149173
"execution_count": null,

0 commit comments

Comments
 (0)