diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e38ae6c94..f33a9b803 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -878,12 +878,7 @@ def guess_common_aberrations( kde_sigma_px=0.125, kde_lowpass_filter=False, lanczos_interpolation_order=None, - defocus=0, - astigmatism=0, - astigmatism_angle_deg=0, - coma=0, - coma_angle_deg=0, - spherical_aberration=0, + polar_parameters=None, max_batch_size=None, interpolation_max_batch_size=None, plot_shifts_and_aligned_bf=True, @@ -910,18 +905,9 @@ def guess_common_aberrations( If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function lanczos_interpolation_order: int, optional If not None, Lanczos interpolation with the specified order is used instead of bilinear - defocus: float, optional - Defocus value to use in computing analytical BF shifts - astigmatism: float, optional - Astigmatism value to use in computing analytical BF shifts - astigmatism_angle_deg: float, optional - Astigmatism angle to use in computing analytical BF shifts - coma: float, optional - Coma value to use in computing analytical BF shifts - coma_angle_deg: float, optional - Coma angle to use in computing analytical BF shifts - spherical_aberration: float, optional - Spherical aberration value to use in computing analytical BF shifts + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. max_batch_size: int, optional Max number of virtual BF images to use at once in computing cross-correlation plot_shifts_and_aligned_bf: bool, optional @@ -934,7 +920,9 @@ def guess_common_aberrations( Scale to multiply shifts by interpolation_max_batch_size: int, optional Max number of pixels to use at once in upsampling (max is #BF * upsampled_pixels) - + kwargs: + Provide the aberration coefficients as keyword arguments. + kwargs which are not recognized as aberrations will be passed to the visualizations """ xp = self._xp asnumpy = self._asnumpy @@ -947,29 +935,52 @@ def guess_common_aberrations( ) ) - # aberrations_coefs - aberrations_mn = [ - [1, 0, 0], - [1, 2, 0], - [1, 2, 1], - [2, 1, 0], - [2, 1, 1], - [3, 0, 0], - ] - astigmatism_x = astigmatism * np.cos(np.deg2rad(astigmatism_angle_deg) * 2) - astigmatism_y = astigmatism * np.sin(np.deg2rad(astigmatism_angle_deg) * 2) - coma_x = coma * np.cos(np.deg2rad(coma_angle_deg) * 1) - coma_y = coma * np.sin(np.deg2rad(coma_angle_deg) * 1) - aberrations_coefs = xp.array( - [ - -defocus, - astigmatism_x, - astigmatism_y, - coma_x, - coma_y, - spherical_aberration, - ] - ) + aberrations_kwargs = kwargs.copy() + plotting_kwargs = {} + + for key, val in kwargs.items(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + aberrations_kwargs.pop(key) + plotting_kwargs[key] = val + + if polar_parameters is None: + polar_parameters = {} + + polar_dict = self._polar_parameters.copy() # copy init dict + polar_dict |= polar_parameters # merge with passed dict + + for symbol, value in aberrations_kwargs.items(): + if symbol in polar_dict.keys(): + polar_dict[symbol] = value + elif symbol == "defocus": + polar_dict[polar_aliases[symbol]] = -value + elif symbol in polar_aliases.keys(): + polar_dict[polar_aliases[symbol]] = value + else: + raise ValueError("{} not a recognized parameter".format(symbol)) + + aberrations_mn = [] + for symbol in polar_symbols: + if symbol[0] == "C": + if np.abs(polar_dict[symbol]) > 0: + _, m, n = symbol + aberrations_mn.append([int(m), int(n), 0]) + if int(n) > 0: + aberrations_mn.append([int(m), int(n), 1]) + + coeffs = [] + for m, n, a in aberrations_mn: + mag = polar_dict.get(f"C{m}{n}") + if n == 0: + coeffs.append(mag) + else: + angle = polar_dict.get(f"phi{m}{n}") + if a == 0: + coeffs.append(mag * np.cos(angle * n)) + else: + coeffs.append(mag * np.sin(angle * n)) + + aberrations_coefs = xp.array(coeffs) # transpose rotation matrix if transpose: @@ -1076,9 +1087,9 @@ def guess_common_aberrations( ) if plot_shifts_and_aligned_bf: - figsize = kwargs.pop("figsize", (8, 4)) - color = kwargs.pop("color", (1, 0, 0, 1)) - cmap = kwargs.pop("cmap", "magma") + figsize = plotting_kwargs.pop("figsize", (8, 4)) + color = plotting_kwargs.pop("color", (1, 0, 0, 1)) + cmap = plotting_kwargs.pop("cmap", "magma") fig, axs = plt.subplots(1, 2, figsize=figsize) @@ -1100,7 +1111,7 @@ def guess_common_aberrations( 0, ] - axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs) + axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **plotting_kwargs) axs[1].set_ylabel("x [A]") axs[1].set_xlabel("y [A]") axs[1].set_title("Predicted Aligned BF Image")