diff --git a/src/whirlwind/_cost.py b/src/whirlwind/_cost.py index e2a0d8f..c583087 100644 --- a/src/whirlwind/_cost.py +++ b/src/whirlwind/_cost.py @@ -41,8 +41,8 @@ def load_carballo_pdf_splines(): return spline_pdf0, spline_pdf1 -def compute_carballo_costs(igram, corr, nlooks, mask): - """ """ +def compute_carballo_costs(igram, corr, nlooks, mask, batch_size: int = 1000): + """Compute phase gradient costs for unwrapping grid.""" phase_dy_smooth, phase_dx_smooth = calc_smooth_phase_gradients(igram) corr = np.asanyarray(corr) @@ -51,11 +51,26 @@ def compute_carballo_costs(igram, corr, nlooks, mask): spline_pdf0, spline_pdf1 = load_carballo_pdf_splines() - def compute_cost(phase_diff, min_corr): - p1 = spline_pdf1((phase_diff, min_corr, nlooks)) - p0 = spline_pdf0((phase_diff, min_corr, nlooks)) - return -np.log(p1 / p0) + def compute_cost(phase_diff, min_corr, batch_size=batch_size): + total_size = phase_diff.size + costs = np.empty_like(phase_diff) + + for start_idx in range(0, total_size, batch_size): + end_idx = min(start_idx + batch_size, total_size) + # Flatten the input arrays for the batch + phase_batch = phase_diff.ravel()[start_idx:end_idx] + corr_batch = min_corr.ravel()[start_idx:end_idx] + + # Compute probabilities for the batch + p1_batch = spline_pdf1((phase_batch, corr_batch, nlooks)) + p0_batch = spline_pdf0((phase_batch, corr_batch, nlooks)) + # Store results back in original shape + costs.ravel()[start_idx:end_idx] = -np.log(p1_batch / p0_batch) + + return costs + + # Calculate costs with batched processing cost_up = compute_cost(-phase_dx_smooth, corr_dx) cost_lt = compute_cost(phase_dy_smooth, corr_dy) cost_dn = compute_cost(phase_dx_smooth, corr_dx) @@ -71,6 +86,7 @@ def compute_cost(phase_diff, min_corr): cost_rt[mask_dy] = np.nan cost_lt[mask_dy] = np.nan + # Original concatenation logic since that wasn't the memory issue cost = np.ascontiguousarray( np.concatenate( [ @@ -82,6 +98,4 @@ def compute_cost(phase_diff, min_corr): ) ) cost[np.isnan(cost)] = 0.0 - cost = (100.0 * cost).astype(np.int32) - - return cost + return (100.0 * cost).astype(np.int32)