|
48 | 48 | "n = 5000\n",
|
49 | 49 | "p_X = 5\n",
|
50 | 50 | "X = rng.uniform(0, 1, (n, p_X))\n",
|
51 |
| - "pi_X = 0.25 + 0.5*X[:,0]\n", |
| 51 | + "pi_X = np.c_[0.25 + 0.5*X[:,0], 0.75 - 0.5*X[:,1]]\n", |
52 | 52 | "# Z = rng.uniform(0, 1, (n, 2))\n",
|
53 |
| - "Z = rng.binomial(1, 0.5, (n, 2))\n", |
| 53 | + "Z = rng.binomial(1, pi_X, (n, 2))\n", |
54 | 54 | "\n",
|
55 | 55 | "# Define the outcome mean functions (prognostic and treatment effects)\n",
|
56 |
| - "mu_X = pi_X*5 + 2*X[:,2]\n", |
| 56 | + "mu_X = pi_X[:,0]*5 + pi_X[:,1]*2 + 2*X[:,2]\n", |
57 | 57 | "tau_X = np.stack((X[:,1], X[:,2]), axis=-1)\n",
|
58 | 58 | "\n",
|
59 | 59 | "# Generate outcome\n",
|
|
129 | 129 | "plt.show()"
|
130 | 130 | ]
|
131 | 131 | },
|
| 132 | + { |
| 133 | + "cell_type": "code", |
| 134 | + "execution_count": null, |
| 135 | + "metadata": {}, |
| 136 | + "outputs": [], |
| 137 | + "source": [ |
| 138 | + "np.sqrt(np.mean(np.power(y_avg_mcmc - y_test, 2)))" |
| 139 | + ] |
| 140 | + }, |
132 | 141 | {
|
133 | 142 | "cell_type": "code",
|
134 | 143 | "execution_count": null,
|
|
144 | 153 | "plt.show()"
|
145 | 154 | ]
|
146 | 155 | },
|
| 156 | + { |
| 157 | + "cell_type": "code", |
| 158 | + "execution_count": null, |
| 159 | + "metadata": {}, |
| 160 | + "outputs": [], |
| 161 | + "source": [ |
| 162 | + "treatment_idx = 1\n", |
| 163 | + "forest_preds_tau_mcmc = np.squeeze(bcf_model.tau_hat_test[:,:,treatment_idx])\n", |
| 164 | + "tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)\n", |
| 165 | + "tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test[:,treatment_idx],1), tau_avg_mcmc), axis = 1), columns=[\"True tau\", \"Average estimated tau\"])\n", |
| 166 | + "sns.scatterplot(data=tau_df_mcmc, x=\"True tau\", y=\"Average estimated tau\")\n", |
| 167 | + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", |
| 168 | + "plt.show()" |
| 169 | + ] |
| 170 | + }, |
147 | 171 | {
|
148 | 172 | "cell_type": "code",
|
149 | 173 | "execution_count": null,
|
|
0 commit comments