Skip to content

Commit 41ef5bf

Browse files
authored
adapt paged attention implementation from pytorch (#105)
1 parent 232f8c6 commit 41ef5bf

File tree

3 files changed

+315
-2
lines changed

3 files changed

+315
-2
lines changed
+313
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
from typing import Optional, Union
2+
3+
import torch
4+
from torch.nn.attention.flex_attention import (
5+
_identity,
6+
_mask_mod_signature,
7+
_score_mod_signature,
8+
BlockMask,
9+
noop_mask,
10+
)
11+
12+
13+
def _cdiv(x: Union[int, float, torch.Tensor], multiple: Union[int, float, torch.Tensor]):
14+
return (x + multiple - 1) // multiple
15+
16+
17+
class PagedAttention:
18+
"""
19+
PagedAttention supports flex attention inference with a large batch size.
20+
With PagedAttention, a batch of key/value tensors with varying kv length
21+
is splitted into tensor blocks of fixed length and cached in a compact way.
22+
Thus we can avoid redundant memory consumption due to varying kv length and
23+
support a larger batch size.
24+
"""
25+
26+
def __init__(
27+
self,
28+
n_pages: int,
29+
page_size: int,
30+
max_batch_size: int,
31+
device: str = "cuda",
32+
):
33+
# number of pages
34+
self.n_pages = n_pages
35+
36+
# number of tokens per page
37+
self.page_size = page_size
38+
39+
# page table: [batch, logical_block_idx] -> physical_page_idx
40+
self.page_table = -torch.ones(
41+
(max_batch_size, self.n_pages), dtype=torch.int64, device=device
42+
)
43+
44+
# capacity: batch_idx -> allocated sequence length
45+
self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
46+
47+
# index of empty pages that is available for allocation
48+
self.empty_pages = list(range(n_pages - 1, -1, -1))
49+
50+
# mapping from physical page index to logical page index
51+
self.physical_to_logical = -torch.ones(
52+
(max_batch_size, n_pages), dtype=torch.int64, device=device
53+
)
54+
55+
def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None:
56+
"""
57+
Requests the capacity of a given batch to be at least enough to
58+
hold `seq_len` elements.
59+
60+
Args:
61+
batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.
62+
seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.
63+
"""
64+
65+
if seq_len <= self.capacity[batch_idx]:
66+
return
67+
68+
num_pages_to_allocate = _cdiv(seq_len - self.capacity[batch_idx], self.page_size)
69+
70+
assert len(self.empty_pages) >= num_pages_to_allocate, (
71+
f"requested {num_pages_to_allocate.item()} pages "
72+
f"but there are only {len(self.empty_pages)} empty pages"
73+
)
74+
75+
start_page_idx = self.capacity[batch_idx] // self.page_size
76+
end_page_idx = start_page_idx + num_pages_to_allocate
77+
78+
# find empty physical pages
79+
allocated_pages = torch.tensor(
80+
self.empty_pages[-num_pages_to_allocate:],
81+
device=num_pages_to_allocate.device,
82+
)
83+
self.empty_pages = self.empty_pages[:-num_pages_to_allocate]
84+
85+
# update page table
86+
self.page_table[
87+
batch_idx,
88+
start_page_idx:end_page_idx,
89+
] = allocated_pages
90+
91+
# update metadata
92+
self.physical_to_logical[batch_idx, allocated_pages] = torch.arange(
93+
start_page_idx.item(),
94+
end_page_idx.item(),
95+
device=num_pages_to_allocate.device,
96+
)
97+
self.capacity[batch_idx] += num_pages_to_allocate * self.page_size
98+
99+
def erase(self, batch_idx: torch.Tensor) -> None:
100+
"""
101+
Removes a single batch from paged attention.
102+
103+
Args:
104+
batch_idx (Tensor): batch index to be removed; shape :math:`(1)`.
105+
"""
106+
107+
# find allocated pages
108+
allocated_page_idx = self.page_table[batch_idx] != -1
109+
allocated_pages = self.page_table[batch_idx][allocated_page_idx]
110+
111+
# clean metadata
112+
self.capacity[batch_idx] = 0
113+
self.empty_pages += allocated_pages.tolist()
114+
self.physical_to_logical[batch_idx][:, allocated_pages] = -1
115+
self.page_table[batch_idx] = -1
116+
117+
def assign(
118+
self,
119+
batch_idx: torch.Tensor,
120+
input_pos: torch.Tensor,
121+
k_val: torch.Tensor,
122+
v_val: torch.Tensor,
123+
k_cache: torch.Tensor,
124+
v_cache: torch.Tensor,
125+
) -> None:
126+
"""
127+
Assigns new contents `val` to the storage `cache` at the location
128+
`batch_idx` and `input_pos`.
129+
130+
Args:
131+
batch_idx (Tensor): batch index; shape :math:`(B)`.
132+
input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`.
133+
val (Tensor): value to be assigned; shape :math:`(B, H, S, D)`
134+
cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)`
135+
"""
136+
if k_val.requires_grad:
137+
raise RuntimeError("val must not require gradient")
138+
139+
B, H, S, K_D = k_val.shape
140+
V_D = v_val.shape[3]
141+
if B != batch_idx.shape[0]:
142+
raise RuntimeError(
143+
f"Expect val and batch_idx have the same batch size "
144+
f"but got B={B} and B={batch_idx.shape[0]}."
145+
)
146+
if H != k_cache.shape[1]:
147+
raise RuntimeError(
148+
f"Expect val and cache has the same number of heads "
149+
f"but got H={H} and H={k_cache.shape[1]}."
150+
)
151+
if S != input_pos.shape[1]:
152+
raise RuntimeError(
153+
f"Expect val and input_pos has the same length "
154+
f"but got S={S} and S={input_pos.shape[0]}."
155+
)
156+
if K_D != k_cache.shape[3]:
157+
raise RuntimeError(
158+
f"Expect k_val and k_cache has the same hidden dim "
159+
f"but got D={K_D} and D={k_cache.shape[3]}."
160+
)
161+
if V_D != v_cache.shape[3]:
162+
raise RuntimeError(
163+
f"Expect v_val and v_cache has the same hidden dim "
164+
f"but got D={V_D} and D={v_cache.shape[3]}."
165+
)
166+
167+
# find address
168+
logical_block_idx = input_pos // self.page_size # [B, S]
169+
logical_block_offset = input_pos % self.page_size # [B, S]
170+
physical_block_idx = torch.gather(
171+
self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)
172+
).to(torch.int32) # [B, S]
173+
174+
addr = (physical_block_idx * self.page_size + logical_block_offset).view(-1) # [B*S]
175+
176+
k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)
177+
v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)
178+
179+
k_cache[:, :, addr, :] = k_val
180+
v_cache[:, :, addr, :] = v_val
181+
182+
def convert_logical_block_mask(
183+
self,
184+
block_mask: BlockMask,
185+
batch_idx: Optional[torch.Tensor] = None,
186+
) -> BlockMask:
187+
"""
188+
Converts a logical block mask by mapping its logical kv indices to the corresponding
189+
physical kv indices.
190+
191+
Args:
192+
block_mask (BlockMask): logical block mask;
193+
kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`.
194+
batch_idx (Tensor): batch index corresponding to the block_mask
195+
batch dimension. This provides flexibility to convert a
196+
block mask with smaller batch size than the page table;
197+
shape :math:`(B)`.
198+
"""
199+
B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape
200+
201+
if block_mask.BLOCK_SIZE[1] != self.page_size:
202+
raise RuntimeError(
203+
f"Expect block_mask has the same column block size as page_size"
204+
f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}"
205+
)
206+
207+
# Increase the num columns of converted block mask from logical block mask's
208+
# num columns to n_pages, since a) the converted block mask
209+
# may have larger indices values; and b) `_ordered_to_dense` realizes
210+
# a dense tensor with these converted indices. There would be an IndexError
211+
# if using the logical block mask's num columns.
212+
213+
device = block_mask.kv_num_blocks.device
214+
215+
if batch_idx is None:
216+
batch_idx = torch.arange(B, device=device)
217+
page_table = self.page_table[batch_idx]
218+
219+
new_kv_num_blocks = block_mask.kv_num_blocks.clone()
220+
221+
new_kv_indices = torch.zeros((B, H, ROWS, self.n_pages), dtype=torch.int32, device=device)
222+
new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
223+
torch.gather(page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64))
224+
.view(block_mask.kv_indices.shape)
225+
.to(torch.int32)
226+
)
227+
228+
new_full_kv_indices, new_full_kv_num_blocks = None, None
229+
if block_mask.full_kv_num_blocks is not None:
230+
assert block_mask.full_kv_indices is not None
231+
new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone()
232+
new_full_kv_indices = torch.zeros(
233+
(B, H, ROWS, self.n_pages), dtype=torch.int32, device=device
234+
)
235+
new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
236+
torch.gather(
237+
page_table,
238+
1,
239+
block_mask.full_kv_indices.view(B, -1).to(torch.int64),
240+
)
241+
.view(block_mask.full_kv_indices.shape)
242+
.to(torch.int32)
243+
)
244+
245+
new_mask_mod = self.get_mask_mod(block_mask.mask_mod)
246+
247+
seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
248+
return BlockMask.from_kv_blocks(
249+
new_kv_num_blocks,
250+
new_kv_indices,
251+
new_full_kv_num_blocks,
252+
new_full_kv_indices,
253+
block_mask.BLOCK_SIZE,
254+
new_mask_mod,
255+
seq_lengths=seq_lengths,
256+
)
257+
258+
def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
259+
"""
260+
Converts a mask_mod based on mapping from the physical block index to the logical
261+
block index.
262+
263+
Args:
264+
mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
265+
"""
266+
if mask_mod is None:
267+
mask_mod = noop_mask
268+
269+
def new_mask_mod(
270+
b: torch.Tensor,
271+
h: torch.Tensor,
272+
q_idx: torch.Tensor,
273+
physical_kv_idx: torch.Tensor,
274+
):
275+
physical_kv_block = physical_kv_idx // self.page_size
276+
physical_kv_offset = physical_kv_idx % self.page_size
277+
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
278+
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
279+
return torch.where(
280+
logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
281+
)
282+
283+
return new_mask_mod
284+
285+
def get_score_mod(self, score_mod: Optional[_score_mod_signature]) -> _score_mod_signature:
286+
"""
287+
Converts a score_mod based on mapping from the physical block index to the logical
288+
block index.
289+
290+
Args:
291+
score_mod (_score_mod_signature): score_mod based on the logical block index.
292+
"""
293+
if score_mod is None:
294+
score_mod = _identity
295+
296+
def new_score_mod(
297+
score: torch.Tensor,
298+
b: torch.Tensor,
299+
h: torch.Tensor,
300+
q_idx: torch.Tensor,
301+
physical_kv_idx: torch.Tensor,
302+
):
303+
physical_kv_block = physical_kv_idx // self.page_size
304+
physical_kv_offset = physical_kv_idx % self.page_size
305+
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
306+
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
307+
return torch.where(
308+
logical_block_idx >= 0,
309+
score_mod(score, b, h, q_idx, logical_kv_idx),
310+
float("-inf"),
311+
)
312+
313+
return new_score_mod

attn_gym/paged_attention/throughput.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
"""
2424

2525
import torch
26-
from torch.nn.attention.experimental._paged_attention import PagedAttention
2726
from torch.nn.attention.flex_attention import (
2827
_identity,
2928
BlockMask,
@@ -35,6 +34,7 @@
3534
from typing import Tuple
3635
from utils import gen_offset, slice_block_mask
3736
from model import PagedAttentionLayer
37+
from paged_attention import PagedAttention
3838

3939
create_block_mask = torch.compile(create_block_mask)
4040

attn_gym/paged_attention/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
2-
from torch.nn.attention.experimental._paged_attention import PagedAttention
32
from torch.nn.attention.flex_attention import (
43
_identity,
54
BlockMask,
65
)
6+
from paged_attention import PagedAttention
77

88

99
def batch_reserve(paged_attention: PagedAttention, target_seq_len: torch.Tensor):

0 commit comments

Comments
 (0)