diff --git a/py4DSTEM/tomography/tomography.py b/py4DSTEM/tomography/tomography.py index 1f3689857..375d2982c 100644 --- a/py4DSTEM/tomography/tomography.py +++ b/py4DSTEM/tomography/tomography.py @@ -332,9 +332,7 @@ def preprocess( weights_diff_all_counted = xp.bincount( ind_diff_all, weights=weights_diff_all, - minlength=s[3] - * s[4] - * s[5], + minlength=s[3] * s[4] * s[5], ) self._weights_diff_all_counted = weights_diff_all_counted @@ -1034,10 +1032,10 @@ def _solve_for_indicies( ind_real = np.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip") # solve for diffraction space coordinates - length = s[-1] * np.cos(tilt) - line_y_diff = np.arange(-(s[-1] - 1) / 2, s[-1] / 2) * length / s[-1] - line_z_diff = line_y_diff * np.tan(tilt) + (s[-1] - 1) / 2 - line_y_diff += (s[-1] - 1) / 2 + line_y_diff = np.arange(s[-1]) * np.cos(tilt) + line_z_diff = np.arange(s[-1]) * np.sin(tilt) + line_y_diff -= np.mean(line_y_diff) - s[-1] / 2 + line_z_diff -= np.mean(line_z_diff) - s[-1] / 2 yF_diff = np.floor(line_y_diff).astype("int") zF_diff = np.floor(line_z_diff).astype("int") @@ -1085,125 +1083,6 @@ def _solve_for_indicies( "clip", ) - # solve for diffraction normalization - if np.abs(tilt) <= np.pi / 8: - line_y_diff_norm = np.arange(-(s[-1] - 1) / 2, s[-1] / 2) - line_z_diff_norm = line_y_diff_norm * np.tan(tilt) + (s[-1] - 1) / 2 - line_y_diff_norm += (s[-1] - 1) / 2 - else: - line_z_diff_norm = np.arange(-(s[-1] - 1) / 2, s[-1] / 2) - line_y_diff_norm = line_z_diff_norm / np.tan(tilt) + (s[-1] - 1) / 2 - line_z_diff_norm += (s[-1] - 1) / 2 - - yF_diff_norm = np.floor(line_y_diff_norm).astype("int") - zF_diff_norm = np.floor(line_z_diff_norm).astype("int") - dy_diff_norm = line_y_diff_norm - yF_diff_norm - dz_diff_norm = line_z_diff_norm - zF_diff_norm - - ind0_diff_norm = np.hstack( - ( - np.tile(yF_diff_norm, s[-1]), - np.tile(yF_diff_norm + 1, s[-1]), - np.tile(yF_diff_norm, s[-1]), - np.tile(yF_diff_norm + 1, s[-1]), - ) - ) - - ind1_diff_norm = np.hstack( - ( - np.tile(zF_diff_norm, s[-1]), - np.tile(zF_diff_norm, s[-1]), - np.tile(zF_diff_norm + 1, s[-1]), - np.tile(zF_diff_norm + 1, s[-1]), - ) - ) - - weights_diff_norm = np.hstack( - ( - np.tile(((1 - dy_diff_norm) * (1 - dz_diff_norm)), s[-1]), - np.tile(((dy_diff_norm) * (1 - dz_diff_norm)), s[-1]), - np.tile(((1 - dy_diff_norm) * (dz_diff_norm)), s[-1]), - np.tile(((dy_diff_norm) * (dz_diff_norm)), s[-1]), - ) - ) - - ind_diff_norm = np.ravel_multi_index( - ( - np.tile(qxx.ravel(), 4), - ind0_diff_norm.ravel(), - ind1_diff_norm.ravel(), - ), - (s[-1], s[-1], s[-1]), - "clip", - ) - - # normalization real space - bincount_real_max = s[0] * s[1] * s[2] - - ind_real_bincount_weight = np.bincount( - ind_real.ravel(), weights_real.ravel(), minlength=bincount_real_max - ) - ind_real_bincount = np.bincount(ind_real.ravel(), minlength=bincount_real_max) - - ind_real_bincount_weight = ind_real_bincount_weight[ind_real_bincount > 0] - ind_real_bincount = ind_real_bincount[ind_real_bincount > 0] - - ind_real_bincount_weight[ind_real_bincount_weight == 0] = 1 - - correction_factor_real = 1 / ind_real_bincount_weight - - correction_factor_real = np.repeat(correction_factor_real, ind_real_bincount) - sorted_indicies = np.argsort(np.argsort(ind_real.ravel())) - correction_factor_real = correction_factor_real[sorted_indicies].reshape( - ind_real.shape - ) - weights_real = weights_real * correction_factor_real - - # normalization reciprocal space - bincount_diff_max = s[3] * s[4] * s[5] - - ind_diff_bincount_weight = np.bincount( - ind_diff.ravel(), weights_diff.ravel(), minlength=bincount_diff_max - ) - ind_diff_bincount = np.bincount(ind_diff.ravel(), minlength=bincount_diff_max) - - ind_diff_bincount_weight_norm = np.bincount( - ind_diff_norm.ravel(), - weights_diff_norm.ravel(), - minlength=bincount_diff_max, - ) - - ind_diff_bincount_norm = np.bincount( - ind_diff_norm.ravel(), minlength=bincount_diff_max - ) - - ind_diff_bincount_weight_norm = ind_diff_bincount_weight_norm[ - ind_diff_bincount > 0 - ] - ind_diff_bincount_norm = ind_diff_bincount_norm[ind_diff_bincount > 0] - ind_diff_bincount_weight = ind_diff_bincount_weight[ind_diff_bincount > 0] - ind_diff_bincount = ind_diff_bincount[ind_diff_bincount > 0] - - ind_diff_bincount_weight_norm[ind_diff_bincount_weight == 0] = 0 - ind_diff_bincount_weight[ind_diff_bincount_weight == 0] = 1 - ind_diff_bincount_weight_norm[ind_diff_bincount == 0] = 1 - ind_diff_bincount[ind_diff_bincount == 0] = 1 - - correction_factor_diff = ( - ind_diff_bincount_weight_norm - / ind_diff_bincount_weight - * ind_diff_bincount_weight.sum() - / ind_diff_bincount_weight_norm.sum() - ) - - correction_factor_diff = np.repeat(correction_factor_diff, ind_diff_bincount) - sorted_indicies = np.argsort(np.argsort(ind_diff.ravel())) - correction_factor_diff = correction_factor_diff[sorted_indicies].reshape( - ind_diff.shape - ) - - weights_diff = weights_diff * correction_factor_diff - if datacube_number == 0: self._ind_real = [] self._weights_real = [] @@ -1213,10 +1092,6 @@ def _solve_for_indicies( self._ind1_diff = [] self._ind0 = [] self._ind1 = [] - self._ind0_diff_norm = [] - self._ind1_diff_norm = [] - self._ind_diff_norm = [] - self._weights_diff_norm = [] self._ind_real.append(xp.asarray(ind_real)) self._ind_diff.append(xp.asarray(ind_diff)) @@ -1226,10 +1101,6 @@ def _solve_for_indicies( self._ind1_diff.append(xp.asarray(ind1_diff)) self._ind0.append(xp.asarray(ind0)) self._ind1.append(xp.asarray(ind1)) - self._ind0_diff_norm.append(xp.asarray(ind0_diff_norm)) - self._ind1_diff_norm.append(xp.asarray(ind1_diff_norm)) - self._ind_diff_norm.append(xp.asarray(ind_diff_norm)) - self._weights_diff_norm.append(xp.asarray(weights_diff_norm)) def _reshape_4D_array_to_2D( self, @@ -1459,40 +1330,7 @@ def _forward( ind_diff = self._ind_diff[datacube_number] weights_real = self._weights_real[datacube_number].reshape((4, s[1], s[2])) weights_diff = self._weights_diff[datacube_number] - # ind0 = self._ind0[datacube_number].reshape((s[1],s[2])) - - # ind0[ind0 >= s[2]] = s[2] - 1 - # ind0[ind0 < 0] = 0 - - # project - # bincount_diff = ( - # xp.tile( - # (xp.tile(self._ind_diffraction_ravel, 4)), - # (s[1]), - # ) - # + xp.repeat(xp.arange(s[1]), ind_diff.shape[0]) * self._q_length - # ) - - # bincount_real = ( - # xp.tile(xp.arange(obj.shape[1]), ind_real.shape[0]) - # + xp.repeat(ind0, obj.shape[1]) * obj.shape[1] - # ) - # obj_projected = ( - # ( - # xp.bincount( - # bincount_diff, - # ( - # xp.bincount( - # bincount_real, - # (obj[ind_real] * weights_real[:, None]).ravel(), - # ).reshape((-1, obj.shape[1]))[:, ind_diff] - # ).ravel() - # * xp.tile(weights_diff, s[1]).ravel(), - # minlength=self._q_length * s[1], - # ).reshape(s[1], self._q_length)[:, self._circular_mask_bincount] - # ) - # ) bincount_diff = ( xp.tile( (xp.tile(self._ind_diffraction_ravel, 4)), @@ -1507,17 +1345,6 @@ def _forward( minlength=s[1] * s[2] * self._q_length, ).reshape((-1, self._q_length))[:, self._circular_mask_bincount] - # bincount_real = ( - # xp.tile(xp.arange(obj_q_summed.shape[1]), ind_real.shape[0]) - # + xp.repeat(ind0, obj_q_summed.shape[1]) * obj_q_summed.shape[1] - # ) - - # obj_projected = xp.bincount( - # bincount_real, - # (obj_q_summed[ind_real] * weights_real[:, None]).ravel(), - # minlength=s[2] * obj_q_summed.shape[1], - # ).reshape((s[2], obj_q_summed.shape[1])) - obj_projected = (obj_q_summed[ind_real] * weights_real[:, :, :, None]).sum( (0, 1) ) @@ -1675,17 +1502,14 @@ def _back( normalize[:, 0] = 1 update_reshaped = ( - ((xp.tile(xp.repeat(update, 2, axis=1)[:, 1:] / normalize, 4))[:, i]) - * (self._weights_diff[datacube_number][ind_update]) - / (4 * 2) - ) + (xp.tile(xp.repeat(update, 2, axis=1)[:, 1:] / normalize, 4))[:, i] + ) * (self._weights_diff[datacube_number][ind_update]) else: normalize = xp.ones((xp.repeat(update, 2, axis=1)).shape) * 2 update_reshaped = ( - ((xp.tile(xp.repeat(update, 2, axis=1) / normalize, (4)))[:, i]) - * (self._weights_diff[datacube_number][ind_update]) - ) + (xp.tile(xp.repeat(update, 2, axis=1) / normalize, (4)))[:, i] + ) * (self._weights_diff[datacube_number][ind_update]) ind_real = self._ind_real[datacube_number].ravel() ind_diff = self._ind_diff[datacube_number][ind_update] @@ -1730,7 +1554,9 @@ def _back( ) ).reshape((-1, diff_shape_bin))[ind_real_bincount > 0] - i_real, i_diff = xp.meshgrid(xp.unique(ind_real), xp.unique(ind_diff), indexing="ij") + i_real, i_diff = xp.meshgrid( + xp.unique(ind_real), xp.unique(ind_diff), indexing="ij" + ) i_real = copy_to_device(i_real, storage) i_diff = copy_to_device(i_diff, storage) @@ -1931,7 +1757,7 @@ def widget( cyliner_mask=False, mode="dark-field", virtual_image_mask_radius=4, - **kwargs + **kwargs, ): """ """ from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label, Layout @@ -1968,8 +1794,8 @@ def widget( _, vmin, vmax = return_scaled_histogram_ordering( ((obj_6D) * diffraction_kernel[None, None, None, :, :]).mean((3, 4, 5)), - vmin = vmin, - vmax = vmax, + vmin=vmin, + vmax=vmax, ) # %matplotlib ipympl