|
1 | 1 | """Utility functions for the cheetah tutorial notebook."""
|
2 | 2 |
|
| 3 | + |
3 | 4 | import matplotlib.pyplot as plt
|
4 | 5 | import numpy as np
|
| 6 | +import torch |
| 7 | +from torch import nn |
| 8 | +from xopt import VOCS, Evaluator |
5 | 9 |
|
6 | 10 |
|
7 | 11 | def plot_tuning_history(history: dict, fig=None, figsize=(16, 3)):
|
@@ -124,3 +128,66 @@ def plot_system_identification_training(
|
124 | 128 | fig.savefig(save_path)
|
125 | 129 |
|
126 | 130 | plt.show()
|
| 131 | + |
| 132 | + |
| 133 | +def plot_parameter_space_difference( |
| 134 | + vocs: VOCS, |
| 135 | + evaluator: Evaluator, |
| 136 | + prior_mean_module: nn.Module, |
| 137 | + num_points: int = 50, |
| 138 | + figsize: tuple = None, |
| 139 | +) -> plt.Figure: |
| 140 | + q1 = np.linspace( |
| 141 | + vocs.variables["q1"][0], vocs.variables["q1"][1], num_points, dtype=np.float32 |
| 142 | + ) |
| 143 | + q2 = np.linspace( |
| 144 | + vocs.variables["q2"][0], vocs.variables["q2"][1], num_points, dtype=np.float32 |
| 145 | + ) |
| 146 | + |
| 147 | + X, Y = np.meshgrid(q1, q2) |
| 148 | + |
| 149 | + Z_problem = np.zeros_like(X) |
| 150 | + for i in range(X.shape[0]): |
| 151 | + for j in range(X.shape[1]): |
| 152 | + Z_problem[i, j] = evaluator.evaluate({"q1": X[i, j], "q2": Y[i, j]})["mae"] |
| 153 | + Z_priormean = ( |
| 154 | + prior_mean_module(torch.tensor(np.stack([X, Y], axis=-1), dtype=torch.float32)) |
| 155 | + .detach() |
| 156 | + .numpy() |
| 157 | + ) |
| 158 | + |
| 159 | + figsize = (4, 1.8) if figsize is None else figsize |
| 160 | + fig, axes = plt.subplots(1, 2, figsize=figsize) |
| 161 | + |
| 162 | + v_min = min(Z_problem.min(), Z_priormean.min()) |
| 163 | + v_max = max(Z_problem.max(), Z_priormean.max()) |
| 164 | + |
| 165 | + axes[0].contourf(X, Y, Z_problem, levels=20, vmin=v_min, vmax=v_max) |
| 166 | + axes[1].contourf(X, Y, Z_priormean, levels=20, vmin=v_min, vmax=v_max) |
| 167 | + # Mark the minimum for both plots |
| 168 | + idx_min_problem = np.unravel_index(np.argmin(Z_problem, axis=None), Z_problem.shape) |
| 169 | + axes[0].scatter( |
| 170 | + q1[idx_min_problem[1]], q2[idx_min_problem[0]], color="red", marker="x" |
| 171 | + ) |
| 172 | + idx_min_priormean = np.unravel_index( |
| 173 | + np.argmin(Z_priormean, axis=None), Z_priormean.shape |
| 174 | + ) |
| 175 | + axes[1].scatter( |
| 176 | + q1[idx_min_priormean[1]], q2[idx_min_priormean[0]], color="red", marker="x" |
| 177 | + ) |
| 178 | + |
| 179 | + axes[0].set_title("Optimization Problem") |
| 180 | + axes[1].set_title("Cheetah Prior Mean Model") |
| 181 | + |
| 182 | + axes[1].set_yticks([]) |
| 183 | + axes[0].set_ylabel(r"$k_{Q2}$ (1/m)") |
| 184 | + for ax in axes: |
| 185 | + ax.set_xlabel(r"$k_{Q1}$ (1/m)") |
| 186 | + |
| 187 | + # Plot colorbar |
| 188 | + fig.subplots_adjust(right=0.9) |
| 189 | + cbar_ax = fig.add_axes([0.95, 0.1, 0.03, 0.8]) |
| 190 | + fig.colorbar( |
| 191 | + plt.cm.ScalarMappable(cmap="viridis"), cax=cbar_ax, aspect=30, shrink=0.5 |
| 192 | + ) |
| 193 | + cbar_ax.set_ylabel("Beam size (mm)") |
0 commit comments