Skip to content

Commit

Permalink
Compute costs from spline in batches of 1000 for lower memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
scottstanie committed Oct 29, 2024
1 parent 40defb3 commit dce837b
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/whirlwind/_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def load_carballo_pdf_splines():
return spline_pdf0, spline_pdf1


def compute_carballo_costs(igram, corr, nlooks, mask):
""" """
def compute_carballo_costs(igram, corr, nlooks, mask, batch_size: int = 1000):
"""Compute phase gradient costs for unwrapping grid."""
phase_dy_smooth, phase_dx_smooth = calc_smooth_phase_gradients(igram)

corr = np.asanyarray(corr)
Expand All @@ -51,11 +51,26 @@ def compute_carballo_costs(igram, corr, nlooks, mask):

spline_pdf0, spline_pdf1 = load_carballo_pdf_splines()

def compute_cost(phase_diff, min_corr):
p1 = spline_pdf1((phase_diff, min_corr, nlooks))
p0 = spline_pdf0((phase_diff, min_corr, nlooks))
return -np.log(p1 / p0)
def compute_cost(phase_diff, min_corr, batch_size=batch_size):
total_size = phase_diff.size
costs = np.empty_like(phase_diff)

for start_idx in range(0, total_size, batch_size):
end_idx = min(start_idx + batch_size, total_size)
# Flatten the input arrays for the batch
phase_batch = phase_diff.ravel()[start_idx:end_idx]
corr_batch = min_corr.ravel()[start_idx:end_idx]

# Compute probabilities for the batch
p1_batch = spline_pdf1((phase_batch, corr_batch, nlooks))
p0_batch = spline_pdf0((phase_batch, corr_batch, nlooks))

# Store results back in original shape
costs.ravel()[start_idx:end_idx] = -np.log(p1_batch / p0_batch)

return costs

# Calculate costs with batched processing
cost_up = compute_cost(-phase_dx_smooth, corr_dx)
cost_lt = compute_cost(phase_dy_smooth, corr_dy)
cost_dn = compute_cost(phase_dx_smooth, corr_dx)
Expand All @@ -71,6 +86,7 @@ def compute_cost(phase_diff, min_corr):
cost_rt[mask_dy] = np.nan
cost_lt[mask_dy] = np.nan

# Original concatenation logic since that wasn't the memory issue
cost = np.ascontiguousarray(
np.concatenate(
[
Expand All @@ -82,6 +98,4 @@ def compute_cost(phase_diff, min_corr):
)
)
cost[np.isnan(cost)] = 0.0
cost = (100.0 * cost).astype(np.int32)

return cost
return (100.0 * cost).astype(np.int32)

0 comments on commit dce837b

Please sign in to comment.