Skip to content

Commit

Permalink
generalized guess_common_aberrations
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 17, 2025
1 parent d7ebd28 commit e6ebf70
Showing 1 changed file with 57 additions and 46 deletions.
103 changes: 57 additions & 46 deletions py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,12 +878,7 @@ def guess_common_aberrations(
kde_sigma_px=0.125,
kde_lowpass_filter=False,
lanczos_interpolation_order=None,
defocus=0,
astigmatism=0,
astigmatism_angle_deg=0,
coma=0,
coma_angle_deg=0,
spherical_aberration=0,
polar_parameters=None,
max_batch_size=None,
interpolation_max_batch_size=None,
plot_shifts_and_aligned_bf=True,
Expand All @@ -910,18 +905,9 @@ def guess_common_aberrations(
If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function
lanczos_interpolation_order: int, optional
If not None, Lanczos interpolation with the specified order is used instead of bilinear
defocus: float, optional
Defocus value to use in computing analytical BF shifts
astigmatism: float, optional
Astigmatism value to use in computing analytical BF shifts
astigmatism_angle_deg: float, optional
Astigmatism angle to use in computing analytical BF shifts
coma: float, optional
Coma value to use in computing analytical BF shifts
coma_angle_deg: float, optional
Coma angle to use in computing analytical BF shifts
spherical_aberration: float, optional
Spherical aberration value to use in computing analytical BF shifts
polar_parameters: dict, optional
Mapping from aberration symbols to their corresponding values. All aberration
magnitudes should be given in Å and angles should be given in radians.
max_batch_size: int, optional
Max number of virtual BF images to use at once in computing cross-correlation
plot_shifts_and_aligned_bf: bool, optional
Expand All @@ -934,7 +920,9 @@ def guess_common_aberrations(
Scale to multiply shifts by
interpolation_max_batch_size: int, optional
Max number of pixels to use at once in upsampling (max is #BF * upsampled_pixels)
kwargs:
Provide the aberration coefficients as keyword arguments.
kwargs which are not recognized as aberrations will be passed to the visualizations
"""
xp = self._xp
asnumpy = self._asnumpy
Expand All @@ -947,29 +935,52 @@ def guess_common_aberrations(
)
)

# aberrations_coefs
aberrations_mn = [
[1, 0, 0],
[1, 2, 0],
[1, 2, 1],
[2, 1, 0],
[2, 1, 1],
[3, 0, 0],
]
astigmatism_x = astigmatism * np.cos(np.deg2rad(astigmatism_angle_deg) * 2)
astigmatism_y = astigmatism * np.sin(np.deg2rad(astigmatism_angle_deg) * 2)
coma_x = coma * np.cos(np.deg2rad(coma_angle_deg) * 1)
coma_y = coma * np.sin(np.deg2rad(coma_angle_deg) * 1)
aberrations_coefs = xp.array(
[
-defocus,
astigmatism_x,
astigmatism_y,
coma_x,
coma_y,
spherical_aberration,
]
)
aberrations_kwargs = kwargs.copy()
plotting_kwargs = {}

for key, val in kwargs.items():
if (key not in polar_symbols) and (key not in polar_aliases.keys()):
aberrations_kwargs.pop(key)
plotting_kwargs[key] = val

if polar_parameters is None:
polar_parameters = {}

polar_dict = self._polar_parameters.copy() # copy init dict
polar_dict |= polar_parameters # merge with passed dict

for symbol, value in aberrations_kwargs.items():
if symbol in polar_dict.keys():
polar_dict[symbol] = value
elif symbol == "defocus":
polar_dict[polar_aliases[symbol]] = -value
elif symbol in polar_aliases.keys():
polar_dict[polar_aliases[symbol]] = value
else:
raise ValueError("{} not a recognized parameter".format(symbol))

aberrations_mn = []
for symbol in polar_symbols:
if symbol[0] == "C":
if np.abs(polar_dict[symbol]) > 0:
_, m, n = symbol
aberrations_mn.append([int(m), int(n), 0])
if int(n) > 0:
aberrations_mn.append([int(m), int(n), 1])

coeffs = []
for m, n, a in aberrations_mn:
mag = polar_dict.get(f"C{m}{n}")
if n == 0:
coeffs.append(mag)
else:
angle = polar_dict.get(f"phi{m}{n}")
if a == 0:
coeffs.append(mag * np.cos(angle * n))
else:
coeffs.append(mag * np.sin(angle * n))

aberrations_coefs = xp.array(coeffs)

# transpose rotation matrix
if transpose:
Expand Down Expand Up @@ -1076,9 +1087,9 @@ def guess_common_aberrations(
)

if plot_shifts_and_aligned_bf:
figsize = kwargs.pop("figsize", (8, 4))
color = kwargs.pop("color", (1, 0, 0, 1))
cmap = kwargs.pop("cmap", "magma")
figsize = plotting_kwargs.pop("figsize", (8, 4))
color = plotting_kwargs.pop("color", (1, 0, 0, 1))
cmap = plotting_kwargs.pop("cmap", "magma")

fig, axs = plt.subplots(1, 2, figsize=figsize)

Expand All @@ -1100,7 +1111,7 @@ def guess_common_aberrations(
0,
]

axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs)
axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **plotting_kwargs)
axs[1].set_ylabel("x [A]")
axs[1].set_xlabel("y [A]")
axs[1].set_title("Predicted Aligned BF Image")
Expand Down

0 comments on commit e6ebf70

Please sign in to comment.