Skip to content

Commit 5ca6902

Browse files
committed
Move the plotting function for prior mean bo to utils
1 parent 2286e70 commit 5ca6902

File tree

2 files changed

+100
-107
lines changed

2 files changed

+100
-107
lines changed

src/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Utility functions for the cheetah tutorial notebook."""
22

3+
34
import matplotlib.pyplot as plt
45
import numpy as np
6+
import torch
7+
from torch import nn
8+
from xopt import VOCS, Evaluator
59

610

711
def plot_tuning_history(history: dict, fig=None, figsize=(16, 3)):
@@ -124,3 +128,66 @@ def plot_system_identification_training(
124128
fig.savefig(save_path)
125129

126130
plt.show()
131+
132+
133+
def plot_parameter_space_difference(
134+
vocs: VOCS,
135+
evaluator: Evaluator,
136+
prior_mean_module: nn.Module,
137+
num_points: int = 50,
138+
figsize: tuple = None,
139+
) -> plt.Figure:
140+
q1 = np.linspace(
141+
vocs.variables["q1"][0], vocs.variables["q1"][1], num_points, dtype=np.float32
142+
)
143+
q2 = np.linspace(
144+
vocs.variables["q2"][0], vocs.variables["q2"][1], num_points, dtype=np.float32
145+
)
146+
147+
X, Y = np.meshgrid(q1, q2)
148+
149+
Z_problem = np.zeros_like(X)
150+
for i in range(X.shape[0]):
151+
for j in range(X.shape[1]):
152+
Z_problem[i, j] = evaluator.evaluate({"q1": X[i, j], "q2": Y[i, j]})["mae"]
153+
Z_priormean = (
154+
prior_mean_module(torch.tensor(np.stack([X, Y], axis=-1), dtype=torch.float32))
155+
.detach()
156+
.numpy()
157+
)
158+
159+
figsize = (4, 1.8) if figsize is None else figsize
160+
fig, axes = plt.subplots(1, 2, figsize=figsize)
161+
162+
v_min = min(Z_problem.min(), Z_priormean.min())
163+
v_max = max(Z_problem.max(), Z_priormean.max())
164+
165+
axes[0].contourf(X, Y, Z_problem, levels=20, vmin=v_min, vmax=v_max)
166+
axes[1].contourf(X, Y, Z_priormean, levels=20, vmin=v_min, vmax=v_max)
167+
# Mark the minimum for both plots
168+
idx_min_problem = np.unravel_index(np.argmin(Z_problem, axis=None), Z_problem.shape)
169+
axes[0].scatter(
170+
q1[idx_min_problem[1]], q2[idx_min_problem[0]], color="red", marker="x"
171+
)
172+
idx_min_priormean = np.unravel_index(
173+
np.argmin(Z_priormean, axis=None), Z_priormean.shape
174+
)
175+
axes[1].scatter(
176+
q1[idx_min_priormean[1]], q2[idx_min_priormean[0]], color="red", marker="x"
177+
)
178+
179+
axes[0].set_title("Optimization Problem")
180+
axes[1].set_title("Cheetah Prior Mean Model")
181+
182+
axes[1].set_yticks([])
183+
axes[0].set_ylabel(r"$k_{Q2}$ (1/m)")
184+
for ax in axes:
185+
ax.set_xlabel(r"$k_{Q1}$ (1/m)")
186+
187+
# Plot colorbar
188+
fig.subplots_adjust(right=0.9)
189+
cbar_ax = fig.add_axes([0.95, 0.1, 0.03, 0.8])
190+
fig.colorbar(
191+
plt.cm.ScalarMappable(cmap="viridis"), cax=cbar_ax, aspect=30, shrink=0.5
192+
)
193+
cbar_ax.set_ylabel("Beam size (mm)")

tutorial.ipynb

Lines changed: 33 additions & 107 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)