From e05ac7a7e7b446e1096adfe77d1d16b10fef1889 Mon Sep 17 00:00:00 2001 From: smribet Date: Mon, 27 Jan 2025 05:54:31 -0800 Subject: [PATCH] whoops --- py4DSTEM/tomography/tomography.py | 45 ++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/tomography/tomography.py b/py4DSTEM/tomography/tomography.py index 65ac2e44d..840c72339 100644 --- a/py4DSTEM/tomography/tomography.py +++ b/py4DSTEM/tomography/tomography.py @@ -1000,7 +1000,7 @@ def _solve_for_indicies( # solve for real space coordinates y = np.arange(s[1]) z = np.arange(s[2]) - yy, zz = np.meshgrid(y, z) + yy, zz = np.meshgrid(y, z, indexing="ij") sin = np.sin(tilt) cos = np.cos(tilt) r = [[cos, sin], [-sin, cos]] @@ -1083,6 +1083,28 @@ def _solve_for_indicies( "clip", ) + # normalization real space + bincount_real_max = s[0] * s[1] * s[2] + + ind_real_bincount_weight = np.bincount( + ind_real.ravel(), weights_real.ravel(), minlength=bincount_real_max + ) + ind_real_bincount = np.bincount(ind_real.ravel(), minlength=bincount_real_max) + + ind_real_bincount_weight = ind_real_bincount_weight[ind_real_bincount > 0] + ind_real_bincount = ind_real_bincount[ind_real_bincount > 0] + + ind_real_bincount_weight[ind_real_bincount_weight == 0] = 1 + + correction_factor_real = 1 / ind_real_bincount_weight + + correction_factor_real = np.repeat(correction_factor_real, ind_real_bincount) + sorted_indicies = np.argsort(np.argsort(ind_real.ravel())) + correction_factor_real = correction_factor_real[sorted_indicies].reshape( + ind_real.shape + ) + weights_real = weights_real * correction_factor_real + if datacube_number == 0: self._ind_real = [] self._weights_real = [] @@ -1327,21 +1349,27 @@ def _forward( obj = copy_to_device(self._object[x_index], device) ind_real = self._ind_real[datacube_number].reshape((4, s[1], s[2])) - ind_diff = self._ind_diff[datacube_number] + ind_diff = self._ind_diff[datacube_number].reshape((4, s[-1], s[-1])) weights_real = self._weights_real[datacube_number].reshape((4, s[1], s[2])) - weights_diff = self._weights_diff[datacube_number] + weights_diff = self._weights_diff[datacube_number].reshape((4, s[-1], s[-1])) + + xp = np + obj_q_summed = (obj[:, ind_diff] * weights_diff).sum((1)) bincount_diff = ( xp.tile( - (xp.tile(self._ind_diffraction_ravel, 4)), + self._ind_diffraction_ravel, (s[1] * s[2]), ) - + xp.repeat(xp.arange(s[1] * s[2]), ind_diff.shape[0]) * self._q_length + + xp.repeat( + xp.arange(s[1] * s[2]), obj_q_summed.shape[1] * obj_q_summed.shape[2] + ) + * self._q_length ) obj_q_summed = xp.bincount( bincount_diff, - (obj[:, ind_diff] * weights_diff[None, :]).ravel(), + obj_q_summed.ravel(), minlength=s[1] * s[2] * self._q_length, ).reshape((-1, self._q_length))[:, self._circular_mask_bincount] @@ -1531,7 +1559,10 @@ def _back( minlength=((diff_max) * s[1]), ).reshape((s[1], -1))[:, ind_diff_bincount > 0] - update_q_summed = xp.tile(update_q_summed, (s[2] * 4, 1)) / (s[2]) + # update_q_summed = xp.tile(update_q_summed, (s[2] * 4, 1)) / (s[2]) + update_q_summed = xp.tile(xp.repeat(update_q_summed, s[2], axis=0), (4, 1)) / ( + s[2] + ) diff_shape_bin = update_q_summed.shape[-1]