|
2 | 2 | # SPDX-License-Identifier: BUSL-1.1
|
3 | 3 |
|
4 | 4 | import logging
|
5 |
| -from typing import Tuple |
| 5 | +from typing import Optional, Tuple |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 |
|
| 9 | +from .rope_helpers import apply_rope, estimate_pos_id_best |
| 10 | + |
9 | 11 | LOG = logging.getLogger(__name__)
|
10 | 12 |
|
11 | 13 |
|
@@ -114,3 +116,175 @@ def batch_omp(
|
114 | 116 | LOG.debug(f"OMP final RMS: {residuals.norm(dim=1).mean()}")
|
115 | 117 |
|
116 | 118 | return selected_indices, final_coeff
|
| 119 | + |
| 120 | + |
| 121 | +def batch_mp_resets( |
| 122 | + targets: torch.Tensor, |
| 123 | + candidate_points: torch.Tensor, |
| 124 | + k: int, |
| 125 | + eps: float = 1e-8, |
| 126 | + total_iterations: Optional[int] = None, |
| 127 | +) -> Tuple[torch.LongTensor, torch.Tensor]: |
| 128 | + """ |
| 129 | + Matching Pursuit with Resets |
| 130 | +
|
| 131 | + Algorithm: |
| 132 | + 1. first perform k iterations of standard matching pursuit |
| 133 | + 2. then, for each excess iteration, select a random index from the |
| 134 | + and remove it from the set of selected indices |
| 135 | + 3. select a new one from the remaining candidates (may be the same, may be different) |
| 136 | + 4. repeat until total_iterations are completed |
| 137 | + """ |
| 138 | + if total_iterations is None: |
| 139 | + total_iterations = k * 3 |
| 140 | + if total_iterations < k: |
| 141 | + raise ValueError( |
| 142 | + f"total_iterations {total_iterations} must be greater than or equal to k {k}" |
| 143 | + ) |
| 144 | + B, D = targets.shape |
| 145 | + N, _ = candidate_points.shape |
| 146 | + device = targets.device |
| 147 | + if k > N: |
| 148 | + raise ValueError(f"Cannot select {k} points from {N} candidates") |
| 149 | + work_dtype = ( |
| 150 | + targets.dtype |
| 151 | + if targets.dtype in (torch.float32, torch.float64) |
| 152 | + else torch.float32 |
| 153 | + ) |
| 154 | + targets_work = targets.to(dtype=work_dtype) |
| 155 | + points_work = candidate_points.to(dtype=work_dtype) |
| 156 | + selected_indices = torch.zeros((B, k), dtype=torch.long, device=device) |
| 157 | + mask = torch.zeros((B, N), dtype=torch.bool, device=device) |
| 158 | + coeff = torch.zeros((B, k), dtype=work_dtype, device=device) |
| 159 | + residuals = targets_work.clone() |
| 160 | + |
| 161 | + iter_indices = list(range(k)) |
| 162 | + while len(iter_indices) < total_iterations: |
| 163 | + honk = torch.randperm(k).tolist() |
| 164 | + iter_indices.extend(honk) |
| 165 | + iter_indices = iter_indices[:total_iterations] |
| 166 | + |
| 167 | + for step, t in enumerate(iter_indices): |
| 168 | + if step < k: |
| 169 | + # Initial phase: select a new candidate for position t |
| 170 | + inner_products = torch.matmul(residuals, points_work.T) # B x N |
| 171 | + # Mask already selected points |
| 172 | + inner_products = inner_products.masked_fill(mask, -float("inf")) |
| 173 | + max_values, max_indices = torch.max(inner_products, dim=1) # B, B |
| 174 | + selected_points = points_work[max_indices] # B x D |
| 175 | + norms_sq = torch.sum(selected_points**2, dim=1) + eps # B |
| 176 | + coeffs = max_values / norms_sq # B |
| 177 | + residuals -= coeffs.unsqueeze(-1) * selected_points |
| 178 | + selected_indices[:, t] = max_indices |
| 179 | + coeff[:, t] = coeffs |
| 180 | + mask.scatter_(1, max_indices.unsqueeze(1), True) |
| 181 | + else: |
| 182 | + # Replacement phase: replace the candidate at position t |
| 183 | + old_indices = selected_indices[:, t] |
| 184 | + old_coeffs = coeff[:, t] |
| 185 | + old_points = points_work[old_indices] |
| 186 | + # Add back the old contribution |
| 187 | + residuals += old_coeffs.unsqueeze(-1) * old_points |
| 188 | + # Remove old index from mask |
| 189 | + mask.scatter_(1, old_indices.unsqueeze(1), False) |
| 190 | + # Compute new inner products |
| 191 | + inner_products = torch.matmul(residuals, points_work.T) |
| 192 | + inner_products = inner_products.masked_fill(mask, -float("inf")) |
| 193 | + new_max_values, new_max_indices = torch.max(inner_products, dim=1) |
| 194 | + new_points = points_work[new_max_indices] |
| 195 | + norms_sq = torch.sum(new_points**2, dim=1) + eps |
| 196 | + new_coeffs = new_max_values / norms_sq |
| 197 | + residuals -= new_coeffs.unsqueeze(-1) * new_points |
| 198 | + selected_indices[:, t] = new_max_indices |
| 199 | + coeff[:, t] = new_coeffs |
| 200 | + # Update mask with new index |
| 201 | + mask.scatter_(1, new_max_indices.unsqueeze(1), True) |
| 202 | + |
| 203 | + return selected_indices, coeff |
| 204 | + |
| 205 | + |
| 206 | +def batch_mp_rope( |
| 207 | + targets: torch.Tensor, |
| 208 | + points_a: torch.Tensor, |
| 209 | + points_b: torch.Tensor, |
| 210 | + k: int, |
| 211 | + num_heads_a: int, |
| 212 | + num_heads_b: int, |
| 213 | + eps: float = 1e-8, |
| 214 | + a_rope_base: float = 10000.0, |
| 215 | + b_rope_base: float = 10000.0, |
| 216 | +) -> torch.Tensor: |
| 217 | + B, D_a = targets.shape |
| 218 | + N, _ = points_a.shape |
| 219 | + _, D_b = points_b.shape |
| 220 | + assert ( |
| 221 | + points_a.shape[0] == points_b.shape[0] |
| 222 | + ), "Number of points in A and B must match" |
| 223 | + device = targets.device |
| 224 | + if k > N: |
| 225 | + raise ValueError(f"Cannot select {k} points from {N} candidates") |
| 226 | + work_dtype = ( |
| 227 | + targets.dtype |
| 228 | + if targets.dtype in (torch.float32, torch.float64) |
| 229 | + else torch.float32 |
| 230 | + ) |
| 231 | + out_dtype = targets.dtype |
| 232 | + points_a = points_a.to(dtype=work_dtype) |
| 233 | + points_b = points_b.to(dtype=work_dtype) |
| 234 | + targets = targets.to(dtype=work_dtype) |
| 235 | + selected_indices = torch.zeros((B, k), dtype=torch.long, device=device) |
| 236 | + coeffs = torch.zeros((B, k), dtype=work_dtype, device=device) |
| 237 | + pos_ids = torch.zeros((B, k), dtype=work_dtype, device=device) |
| 238 | + mask = torch.zeros((B, N), dtype=torch.bool, device=device) |
| 239 | + residuals = targets.clone() |
| 240 | + |
| 241 | + for t in range(k): |
| 242 | + abs_inner = (residuals @ points_a.T).abs() # (B, N) |
| 243 | + abs_inner.masked_fill_(mask, -float("inf")) |
| 244 | + |
| 245 | + # Select new index with maximum correlation |
| 246 | + _, new_idx = torch.max(abs_inner, dim=1) # (B,) |
| 247 | + |
| 248 | + # update state |
| 249 | + selected_indices[:, t] = new_idx |
| 250 | + mask[torch.arange(B, device=device), new_idx] = True |
| 251 | + new_atom = points_a[new_idx] |
| 252 | + |
| 253 | + # compute position id for new atom |
| 254 | + pos_id = estimate_pos_id_best( |
| 255 | + new_atom, |
| 256 | + residuals, |
| 257 | + num_heads=num_heads_a, |
| 258 | + head_dim=D_a // num_heads_a, |
| 259 | + base=a_rope_base, |
| 260 | + ) |
| 261 | + pos_ids[:, t] = pos_id |
| 262 | + new_atom = apply_rope( |
| 263 | + new_atom, |
| 264 | + pos_id, |
| 265 | + num_heads=num_heads_a, |
| 266 | + head_dim=D_a // num_heads_a, |
| 267 | + base=a_rope_base, |
| 268 | + ) |
| 269 | + |
| 270 | + # compute coefficients |
| 271 | + current_coeff = (residuals * new_atom).sum(dim=1) / ( |
| 272 | + new_atom.pow(2).sum(dim=1).clamp(min=eps) |
| 273 | + ) |
| 274 | + coeffs[:, t] = current_coeff |
| 275 | + |
| 276 | + # update residuals |
| 277 | + residuals = residuals - current_coeff.unsqueeze(1) * new_atom |
| 278 | + |
| 279 | + # return result in b space |
| 280 | + selected_points_b = points_b[selected_indices] |
| 281 | + atoms_b = apply_rope( |
| 282 | + selected_points_b, |
| 283 | + pos_ids, |
| 284 | + num_heads=num_heads_b, |
| 285 | + head_dim=D_b // num_heads_b, |
| 286 | + base=b_rope_base, |
| 287 | + ) |
| 288 | + approx_b = (atoms_b * coeffs.unsqueeze(-1)).sum(dim=1) |
| 289 | + final_tensor = approx_b.to(out_dtype) |
| 290 | + return selected_indices, coeffs, final_tensor |
0 commit comments