Skip to content

Commit 0a88b24

Browse files
alexm-redhatAlvant
authored andcommitted
[Bugfix] multi-step + flashinfer: ensure cuda graph compatible (vllm-project#8427)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent c6c8b91 commit 0a88b24

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,9 +597,19 @@ def build(self, seq_lens: List[int], query_lens: List[int],
597597
# The shape of graph_block_tables is
598598
# [max batch size, max context len // block size].
599599
input_block_tables = self.runner.graph_block_tables[:batch_size]
600+
max_blocks = input_block_tables.shape[1]
600601
for i, block_table in enumerate(self.block_tables):
601602
if block_table:
602-
input_block_tables[i, :len(block_table)] = block_table
603+
num_blocks = len(block_table)
604+
if num_blocks <= max_blocks:
605+
input_block_tables[i, :num_blocks] = block_table
606+
else:
607+
# It may be possible to have more blocks allocated due
608+
# to lookahead slots of multi-step, however, they are
609+
# not used anyway, so can be safely ignored.
610+
input_block_tables[
611+
i, :max_blocks] = block_table[:max_blocks]
612+
603613
block_tables = torch.from_numpy(input_block_tables).to(
604614
device, non_blocking=True)
605615

0 commit comments

Comments
 (0)