Skip to content

Commit

Permalink
removing extra text and simplifying
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Jan 25, 2025
1 parent 970dc4f commit a02ed87
Showing 1 changed file with 15 additions and 189 deletions.
204 changes: 15 additions & 189 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,7 @@ def preprocess(
weights_diff_all_counted = xp.bincount(
ind_diff_all,
weights=weights_diff_all,
minlength=s[3]
* s[4]
* s[5],
minlength=s[3] * s[4] * s[5],
)
self._weights_diff_all_counted = weights_diff_all_counted

Expand Down Expand Up @@ -1034,10 +1032,10 @@ def _solve_for_indicies(
ind_real = np.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip")

# solve for diffraction space coordinates
length = s[-1] * np.cos(tilt)
line_y_diff = np.arange(-(s[-1] - 1) / 2, s[-1] / 2) * length / s[-1]
line_z_diff = line_y_diff * np.tan(tilt) + (s[-1] - 1) / 2
line_y_diff += (s[-1] - 1) / 2
line_y_diff = np.arange(s[-1]) * np.cos(tilt)
line_z_diff = np.arange(s[-1]) * np.sin(tilt)
line_y_diff -= np.mean(line_y_diff) - s[-1] / 2
line_z_diff -= np.mean(line_z_diff) - s[-1] / 2

yF_diff = np.floor(line_y_diff).astype("int")
zF_diff = np.floor(line_z_diff).astype("int")
Expand Down Expand Up @@ -1085,125 +1083,6 @@ def _solve_for_indicies(
"clip",
)

# solve for diffraction normalization
if np.abs(tilt) <= np.pi / 8:
line_y_diff_norm = np.arange(-(s[-1] - 1) / 2, s[-1] / 2)
line_z_diff_norm = line_y_diff_norm * np.tan(tilt) + (s[-1] - 1) / 2
line_y_diff_norm += (s[-1] - 1) / 2
else:
line_z_diff_norm = np.arange(-(s[-1] - 1) / 2, s[-1] / 2)
line_y_diff_norm = line_z_diff_norm / np.tan(tilt) + (s[-1] - 1) / 2
line_z_diff_norm += (s[-1] - 1) / 2

yF_diff_norm = np.floor(line_y_diff_norm).astype("int")
zF_diff_norm = np.floor(line_z_diff_norm).astype("int")
dy_diff_norm = line_y_diff_norm - yF_diff_norm
dz_diff_norm = line_z_diff_norm - zF_diff_norm

ind0_diff_norm = np.hstack(
(
np.tile(yF_diff_norm, s[-1]),
np.tile(yF_diff_norm + 1, s[-1]),
np.tile(yF_diff_norm, s[-1]),
np.tile(yF_diff_norm + 1, s[-1]),
)
)

ind1_diff_norm = np.hstack(
(
np.tile(zF_diff_norm, s[-1]),
np.tile(zF_diff_norm, s[-1]),
np.tile(zF_diff_norm + 1, s[-1]),
np.tile(zF_diff_norm + 1, s[-1]),
)
)

weights_diff_norm = np.hstack(
(
np.tile(((1 - dy_diff_norm) * (1 - dz_diff_norm)), s[-1]),
np.tile(((dy_diff_norm) * (1 - dz_diff_norm)), s[-1]),
np.tile(((1 - dy_diff_norm) * (dz_diff_norm)), s[-1]),
np.tile(((dy_diff_norm) * (dz_diff_norm)), s[-1]),
)
)

ind_diff_norm = np.ravel_multi_index(
(
np.tile(qxx.ravel(), 4),
ind0_diff_norm.ravel(),
ind1_diff_norm.ravel(),
),
(s[-1], s[-1], s[-1]),
"clip",
)

# normalization real space
bincount_real_max = s[0] * s[1] * s[2]

ind_real_bincount_weight = np.bincount(
ind_real.ravel(), weights_real.ravel(), minlength=bincount_real_max
)
ind_real_bincount = np.bincount(ind_real.ravel(), minlength=bincount_real_max)

ind_real_bincount_weight = ind_real_bincount_weight[ind_real_bincount > 0]
ind_real_bincount = ind_real_bincount[ind_real_bincount > 0]

ind_real_bincount_weight[ind_real_bincount_weight == 0] = 1

correction_factor_real = 1 / ind_real_bincount_weight

correction_factor_real = np.repeat(correction_factor_real, ind_real_bincount)
sorted_indicies = np.argsort(np.argsort(ind_real.ravel()))
correction_factor_real = correction_factor_real[sorted_indicies].reshape(
ind_real.shape
)
weights_real = weights_real * correction_factor_real

# normalization reciprocal space
bincount_diff_max = s[3] * s[4] * s[5]

ind_diff_bincount_weight = np.bincount(
ind_diff.ravel(), weights_diff.ravel(), minlength=bincount_diff_max
)
ind_diff_bincount = np.bincount(ind_diff.ravel(), minlength=bincount_diff_max)

ind_diff_bincount_weight_norm = np.bincount(
ind_diff_norm.ravel(),
weights_diff_norm.ravel(),
minlength=bincount_diff_max,
)

ind_diff_bincount_norm = np.bincount(
ind_diff_norm.ravel(), minlength=bincount_diff_max
)

ind_diff_bincount_weight_norm = ind_diff_bincount_weight_norm[
ind_diff_bincount > 0
]
ind_diff_bincount_norm = ind_diff_bincount_norm[ind_diff_bincount > 0]
ind_diff_bincount_weight = ind_diff_bincount_weight[ind_diff_bincount > 0]
ind_diff_bincount = ind_diff_bincount[ind_diff_bincount > 0]

ind_diff_bincount_weight_norm[ind_diff_bincount_weight == 0] = 0
ind_diff_bincount_weight[ind_diff_bincount_weight == 0] = 1
ind_diff_bincount_weight_norm[ind_diff_bincount == 0] = 1
ind_diff_bincount[ind_diff_bincount == 0] = 1

correction_factor_diff = (
ind_diff_bincount_weight_norm
/ ind_diff_bincount_weight
* ind_diff_bincount_weight.sum()
/ ind_diff_bincount_weight_norm.sum()
)

correction_factor_diff = np.repeat(correction_factor_diff, ind_diff_bincount)
sorted_indicies = np.argsort(np.argsort(ind_diff.ravel()))
correction_factor_diff = correction_factor_diff[sorted_indicies].reshape(
ind_diff.shape
)

weights_diff = weights_diff * correction_factor_diff

if datacube_number == 0:
self._ind_real = []
self._weights_real = []
Expand All @@ -1213,10 +1092,6 @@ def _solve_for_indicies(
self._ind1_diff = []
self._ind0 = []
self._ind1 = []
self._ind0_diff_norm = []
self._ind1_diff_norm = []
self._ind_diff_norm = []
self._weights_diff_norm = []

self._ind_real.append(xp.asarray(ind_real))
self._ind_diff.append(xp.asarray(ind_diff))
Expand All @@ -1226,10 +1101,6 @@ def _solve_for_indicies(
self._ind1_diff.append(xp.asarray(ind1_diff))
self._ind0.append(xp.asarray(ind0))
self._ind1.append(xp.asarray(ind1))
self._ind0_diff_norm.append(xp.asarray(ind0_diff_norm))
self._ind1_diff_norm.append(xp.asarray(ind1_diff_norm))
self._ind_diff_norm.append(xp.asarray(ind_diff_norm))
self._weights_diff_norm.append(xp.asarray(weights_diff_norm))

def _reshape_4D_array_to_2D(
self,
Expand Down Expand Up @@ -1459,40 +1330,7 @@ def _forward(
ind_diff = self._ind_diff[datacube_number]
weights_real = self._weights_real[datacube_number].reshape((4, s[1], s[2]))
weights_diff = self._weights_diff[datacube_number]
# ind0 = self._ind0[datacube_number].reshape((s[1],s[2]))

# ind0[ind0 >= s[2]] = s[2] - 1
# ind0[ind0 < 0] = 0

# project
# bincount_diff = (
# xp.tile(
# (xp.tile(self._ind_diffraction_ravel, 4)),
# (s[1]),
# )
# + xp.repeat(xp.arange(s[1]), ind_diff.shape[0]) * self._q_length
# )

# bincount_real = (
# xp.tile(xp.arange(obj.shape[1]), ind_real.shape[0])
# + xp.repeat(ind0, obj.shape[1]) * obj.shape[1]
# )

# obj_projected = (
# (
# xp.bincount(
# bincount_diff,
# (
# xp.bincount(
# bincount_real,
# (obj[ind_real] * weights_real[:, None]).ravel(),
# ).reshape((-1, obj.shape[1]))[:, ind_diff]
# ).ravel()
# * xp.tile(weights_diff, s[1]).ravel(),
# minlength=self._q_length * s[1],
# ).reshape(s[1], self._q_length)[:, self._circular_mask_bincount]
# )
# )
bincount_diff = (
xp.tile(
(xp.tile(self._ind_diffraction_ravel, 4)),
Expand All @@ -1507,17 +1345,6 @@ def _forward(
minlength=s[1] * s[2] * self._q_length,
).reshape((-1, self._q_length))[:, self._circular_mask_bincount]

# bincount_real = (
# xp.tile(xp.arange(obj_q_summed.shape[1]), ind_real.shape[0])
# + xp.repeat(ind0, obj_q_summed.shape[1]) * obj_q_summed.shape[1]
# )

# obj_projected = xp.bincount(
# bincount_real,
# (obj_q_summed[ind_real] * weights_real[:, None]).ravel(),
# minlength=s[2] * obj_q_summed.shape[1],
# ).reshape((s[2], obj_q_summed.shape[1]))

obj_projected = (obj_q_summed[ind_real] * weights_real[:, :, :, None]).sum(
(0, 1)
)
Expand Down Expand Up @@ -1675,17 +1502,14 @@ def _back(
normalize[:, 0] = 1

update_reshaped = (
((xp.tile(xp.repeat(update, 2, axis=1)[:, 1:] / normalize, 4))[:, i])
* (self._weights_diff[datacube_number][ind_update])
/ (4 * 2)
)
(xp.tile(xp.repeat(update, 2, axis=1)[:, 1:] / normalize, 4))[:, i]
) * (self._weights_diff[datacube_number][ind_update])
else:
normalize = xp.ones((xp.repeat(update, 2, axis=1)).shape) * 2

update_reshaped = (
((xp.tile(xp.repeat(update, 2, axis=1) / normalize, (4)))[:, i])
* (self._weights_diff[datacube_number][ind_update])
)
(xp.tile(xp.repeat(update, 2, axis=1) / normalize, (4)))[:, i]
) * (self._weights_diff[datacube_number][ind_update])

ind_real = self._ind_real[datacube_number].ravel()
ind_diff = self._ind_diff[datacube_number][ind_update]
Expand Down Expand Up @@ -1730,7 +1554,9 @@ def _back(
)
).reshape((-1, diff_shape_bin))[ind_real_bincount > 0]

i_real, i_diff = xp.meshgrid(xp.unique(ind_real), xp.unique(ind_diff), indexing="ij")
i_real, i_diff = xp.meshgrid(
xp.unique(ind_real), xp.unique(ind_diff), indexing="ij"
)

i_real = copy_to_device(i_real, storage)
i_diff = copy_to_device(i_diff, storage)
Expand Down Expand Up @@ -1931,7 +1757,7 @@ def widget(
cyliner_mask=False,
mode="dark-field",
virtual_image_mask_radius=4,
**kwargs
**kwargs,
):
""" """
from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label, Layout
Expand Down Expand Up @@ -1968,8 +1794,8 @@ def widget(

_, vmin, vmax = return_scaled_histogram_ordering(
((obj_6D) * diffraction_kernel[None, None, None, :, :]).mean((3, 4, 5)),
vmin = vmin,
vmax = vmax,
vmin=vmin,
vmax=vmax,
)

# %matplotlib ipympl
Expand Down

0 comments on commit a02ed87

Please sign in to comment.