|
| 1 | +from typing import Optional, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.nn.attention.flex_attention import ( |
| 5 | + _identity, |
| 6 | + _mask_mod_signature, |
| 7 | + _score_mod_signature, |
| 8 | + BlockMask, |
| 9 | + noop_mask, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +def _cdiv(x: Union[int, float, torch.Tensor], multiple: Union[int, float, torch.Tensor]): |
| 14 | + return (x + multiple - 1) // multiple |
| 15 | + |
| 16 | + |
| 17 | +class PagedAttention: |
| 18 | + """ |
| 19 | + PagedAttention supports flex attention inference with a large batch size. |
| 20 | + With PagedAttention, a batch of key/value tensors with varying kv length |
| 21 | + is splitted into tensor blocks of fixed length and cached in a compact way. |
| 22 | + Thus we can avoid redundant memory consumption due to varying kv length and |
| 23 | + support a larger batch size. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + n_pages: int, |
| 29 | + page_size: int, |
| 30 | + max_batch_size: int, |
| 31 | + device: str = "cuda", |
| 32 | + ): |
| 33 | + # number of pages |
| 34 | + self.n_pages = n_pages |
| 35 | + |
| 36 | + # number of tokens per page |
| 37 | + self.page_size = page_size |
| 38 | + |
| 39 | + # page table: [batch, logical_block_idx] -> physical_page_idx |
| 40 | + self.page_table = -torch.ones( |
| 41 | + (max_batch_size, self.n_pages), dtype=torch.int64, device=device |
| 42 | + ) |
| 43 | + |
| 44 | + # capacity: batch_idx -> allocated sequence length |
| 45 | + self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device) |
| 46 | + |
| 47 | + # index of empty pages that is available for allocation |
| 48 | + self.empty_pages = list(range(n_pages - 1, -1, -1)) |
| 49 | + |
| 50 | + # mapping from physical page index to logical page index |
| 51 | + self.physical_to_logical = -torch.ones( |
| 52 | + (max_batch_size, n_pages), dtype=torch.int64, device=device |
| 53 | + ) |
| 54 | + |
| 55 | + def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None: |
| 56 | + """ |
| 57 | + Requests the capacity of a given batch to be at least enough to |
| 58 | + hold `seq_len` elements. |
| 59 | +
|
| 60 | + Args: |
| 61 | + batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`. |
| 62 | + seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`. |
| 63 | + """ |
| 64 | + |
| 65 | + if seq_len <= self.capacity[batch_idx]: |
| 66 | + return |
| 67 | + |
| 68 | + num_pages_to_allocate = _cdiv(seq_len - self.capacity[batch_idx], self.page_size) |
| 69 | + |
| 70 | + assert len(self.empty_pages) >= num_pages_to_allocate, ( |
| 71 | + f"requested {num_pages_to_allocate.item()} pages " |
| 72 | + f"but there are only {len(self.empty_pages)} empty pages" |
| 73 | + ) |
| 74 | + |
| 75 | + start_page_idx = self.capacity[batch_idx] // self.page_size |
| 76 | + end_page_idx = start_page_idx + num_pages_to_allocate |
| 77 | + |
| 78 | + # find empty physical pages |
| 79 | + allocated_pages = torch.tensor( |
| 80 | + self.empty_pages[-num_pages_to_allocate:], |
| 81 | + device=num_pages_to_allocate.device, |
| 82 | + ) |
| 83 | + self.empty_pages = self.empty_pages[:-num_pages_to_allocate] |
| 84 | + |
| 85 | + # update page table |
| 86 | + self.page_table[ |
| 87 | + batch_idx, |
| 88 | + start_page_idx:end_page_idx, |
| 89 | + ] = allocated_pages |
| 90 | + |
| 91 | + # update metadata |
| 92 | + self.physical_to_logical[batch_idx, allocated_pages] = torch.arange( |
| 93 | + start_page_idx.item(), |
| 94 | + end_page_idx.item(), |
| 95 | + device=num_pages_to_allocate.device, |
| 96 | + ) |
| 97 | + self.capacity[batch_idx] += num_pages_to_allocate * self.page_size |
| 98 | + |
| 99 | + def erase(self, batch_idx: torch.Tensor) -> None: |
| 100 | + """ |
| 101 | + Removes a single batch from paged attention. |
| 102 | +
|
| 103 | + Args: |
| 104 | + batch_idx (Tensor): batch index to be removed; shape :math:`(1)`. |
| 105 | + """ |
| 106 | + |
| 107 | + # find allocated pages |
| 108 | + allocated_page_idx = self.page_table[batch_idx] != -1 |
| 109 | + allocated_pages = self.page_table[batch_idx][allocated_page_idx] |
| 110 | + |
| 111 | + # clean metadata |
| 112 | + self.capacity[batch_idx] = 0 |
| 113 | + self.empty_pages += allocated_pages.tolist() |
| 114 | + self.physical_to_logical[batch_idx][:, allocated_pages] = -1 |
| 115 | + self.page_table[batch_idx] = -1 |
| 116 | + |
| 117 | + def assign( |
| 118 | + self, |
| 119 | + batch_idx: torch.Tensor, |
| 120 | + input_pos: torch.Tensor, |
| 121 | + k_val: torch.Tensor, |
| 122 | + v_val: torch.Tensor, |
| 123 | + k_cache: torch.Tensor, |
| 124 | + v_cache: torch.Tensor, |
| 125 | + ) -> None: |
| 126 | + """ |
| 127 | + Assigns new contents `val` to the storage `cache` at the location |
| 128 | + `batch_idx` and `input_pos`. |
| 129 | +
|
| 130 | + Args: |
| 131 | + batch_idx (Tensor): batch index; shape :math:`(B)`. |
| 132 | + input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`. |
| 133 | + val (Tensor): value to be assigned; shape :math:`(B, H, S, D)` |
| 134 | + cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)` |
| 135 | + """ |
| 136 | + if k_val.requires_grad: |
| 137 | + raise RuntimeError("val must not require gradient") |
| 138 | + |
| 139 | + B, H, S, K_D = k_val.shape |
| 140 | + V_D = v_val.shape[3] |
| 141 | + if B != batch_idx.shape[0]: |
| 142 | + raise RuntimeError( |
| 143 | + f"Expect val and batch_idx have the same batch size " |
| 144 | + f"but got B={B} and B={batch_idx.shape[0]}." |
| 145 | + ) |
| 146 | + if H != k_cache.shape[1]: |
| 147 | + raise RuntimeError( |
| 148 | + f"Expect val and cache has the same number of heads " |
| 149 | + f"but got H={H} and H={k_cache.shape[1]}." |
| 150 | + ) |
| 151 | + if S != input_pos.shape[1]: |
| 152 | + raise RuntimeError( |
| 153 | + f"Expect val and input_pos has the same length " |
| 154 | + f"but got S={S} and S={input_pos.shape[0]}." |
| 155 | + ) |
| 156 | + if K_D != k_cache.shape[3]: |
| 157 | + raise RuntimeError( |
| 158 | + f"Expect k_val and k_cache has the same hidden dim " |
| 159 | + f"but got D={K_D} and D={k_cache.shape[3]}." |
| 160 | + ) |
| 161 | + if V_D != v_cache.shape[3]: |
| 162 | + raise RuntimeError( |
| 163 | + f"Expect v_val and v_cache has the same hidden dim " |
| 164 | + f"but got D={V_D} and D={v_cache.shape[3]}." |
| 165 | + ) |
| 166 | + |
| 167 | + # find address |
| 168 | + logical_block_idx = input_pos // self.page_size # [B, S] |
| 169 | + logical_block_offset = input_pos % self.page_size # [B, S] |
| 170 | + physical_block_idx = torch.gather( |
| 171 | + self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64) |
| 172 | + ).to(torch.int32) # [B, S] |
| 173 | + |
| 174 | + addr = (physical_block_idx * self.page_size + logical_block_offset).view(-1) # [B*S] |
| 175 | + |
| 176 | + k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D) |
| 177 | + v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D) |
| 178 | + |
| 179 | + k_cache[:, :, addr, :] = k_val |
| 180 | + v_cache[:, :, addr, :] = v_val |
| 181 | + |
| 182 | + def convert_logical_block_mask( |
| 183 | + self, |
| 184 | + block_mask: BlockMask, |
| 185 | + batch_idx: Optional[torch.Tensor] = None, |
| 186 | + ) -> BlockMask: |
| 187 | + """ |
| 188 | + Converts a logical block mask by mapping its logical kv indices to the corresponding |
| 189 | + physical kv indices. |
| 190 | +
|
| 191 | + Args: |
| 192 | + block_mask (BlockMask): logical block mask; |
| 193 | + kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`. |
| 194 | + batch_idx (Tensor): batch index corresponding to the block_mask |
| 195 | + batch dimension. This provides flexibility to convert a |
| 196 | + block mask with smaller batch size than the page table; |
| 197 | + shape :math:`(B)`. |
| 198 | + """ |
| 199 | + B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape |
| 200 | + |
| 201 | + if block_mask.BLOCK_SIZE[1] != self.page_size: |
| 202 | + raise RuntimeError( |
| 203 | + f"Expect block_mask has the same column block size as page_size" |
| 204 | + f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}" |
| 205 | + ) |
| 206 | + |
| 207 | + # Increase the num columns of converted block mask from logical block mask's |
| 208 | + # num columns to n_pages, since a) the converted block mask |
| 209 | + # may have larger indices values; and b) `_ordered_to_dense` realizes |
| 210 | + # a dense tensor with these converted indices. There would be an IndexError |
| 211 | + # if using the logical block mask's num columns. |
| 212 | + |
| 213 | + device = block_mask.kv_num_blocks.device |
| 214 | + |
| 215 | + if batch_idx is None: |
| 216 | + batch_idx = torch.arange(B, device=device) |
| 217 | + page_table = self.page_table[batch_idx] |
| 218 | + |
| 219 | + new_kv_num_blocks = block_mask.kv_num_blocks.clone() |
| 220 | + |
| 221 | + new_kv_indices = torch.zeros((B, H, ROWS, self.n_pages), dtype=torch.int32, device=device) |
| 222 | + new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( |
| 223 | + torch.gather(page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64)) |
| 224 | + .view(block_mask.kv_indices.shape) |
| 225 | + .to(torch.int32) |
| 226 | + ) |
| 227 | + |
| 228 | + new_full_kv_indices, new_full_kv_num_blocks = None, None |
| 229 | + if block_mask.full_kv_num_blocks is not None: |
| 230 | + assert block_mask.full_kv_indices is not None |
| 231 | + new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone() |
| 232 | + new_full_kv_indices = torch.zeros( |
| 233 | + (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device |
| 234 | + ) |
| 235 | + new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( |
| 236 | + torch.gather( |
| 237 | + page_table, |
| 238 | + 1, |
| 239 | + block_mask.full_kv_indices.view(B, -1).to(torch.int64), |
| 240 | + ) |
| 241 | + .view(block_mask.full_kv_indices.shape) |
| 242 | + .to(torch.int32) |
| 243 | + ) |
| 244 | + |
| 245 | + new_mask_mod = self.get_mask_mod(block_mask.mask_mod) |
| 246 | + |
| 247 | + seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size) |
| 248 | + return BlockMask.from_kv_blocks( |
| 249 | + new_kv_num_blocks, |
| 250 | + new_kv_indices, |
| 251 | + new_full_kv_num_blocks, |
| 252 | + new_full_kv_indices, |
| 253 | + block_mask.BLOCK_SIZE, |
| 254 | + new_mask_mod, |
| 255 | + seq_lengths=seq_lengths, |
| 256 | + ) |
| 257 | + |
| 258 | + def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature: |
| 259 | + """ |
| 260 | + Converts a mask_mod based on mapping from the physical block index to the logical |
| 261 | + block index. |
| 262 | +
|
| 263 | + Args: |
| 264 | + mask_mod (_mask_mod_signature): mask_mod based on the logical block index. |
| 265 | + """ |
| 266 | + if mask_mod is None: |
| 267 | + mask_mod = noop_mask |
| 268 | + |
| 269 | + def new_mask_mod( |
| 270 | + b: torch.Tensor, |
| 271 | + h: torch.Tensor, |
| 272 | + q_idx: torch.Tensor, |
| 273 | + physical_kv_idx: torch.Tensor, |
| 274 | + ): |
| 275 | + physical_kv_block = physical_kv_idx // self.page_size |
| 276 | + physical_kv_offset = physical_kv_idx % self.page_size |
| 277 | + logical_block_idx = self.physical_to_logical[b, physical_kv_block] |
| 278 | + logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset |
| 279 | + return torch.where( |
| 280 | + logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False |
| 281 | + ) |
| 282 | + |
| 283 | + return new_mask_mod |
| 284 | + |
| 285 | + def get_score_mod(self, score_mod: Optional[_score_mod_signature]) -> _score_mod_signature: |
| 286 | + """ |
| 287 | + Converts a score_mod based on mapping from the physical block index to the logical |
| 288 | + block index. |
| 289 | +
|
| 290 | + Args: |
| 291 | + score_mod (_score_mod_signature): score_mod based on the logical block index. |
| 292 | + """ |
| 293 | + if score_mod is None: |
| 294 | + score_mod = _identity |
| 295 | + |
| 296 | + def new_score_mod( |
| 297 | + score: torch.Tensor, |
| 298 | + b: torch.Tensor, |
| 299 | + h: torch.Tensor, |
| 300 | + q_idx: torch.Tensor, |
| 301 | + physical_kv_idx: torch.Tensor, |
| 302 | + ): |
| 303 | + physical_kv_block = physical_kv_idx // self.page_size |
| 304 | + physical_kv_offset = physical_kv_idx % self.page_size |
| 305 | + logical_block_idx = self.physical_to_logical[b, physical_kv_block] |
| 306 | + logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset |
| 307 | + return torch.where( |
| 308 | + logical_block_idx >= 0, |
| 309 | + score_mod(score, b, h, q_idx, logical_kv_idx), |
| 310 | + float("-inf"), |
| 311 | + ) |
| 312 | + |
| 313 | + return new_score_mod |
0 commit comments