diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index cd06c27175c..7046751203d 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,11 +4,11 @@ import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] +NUM_HEADS = [(12, 12)] +HEAD_SIZES = [64] +BLOCK_SIZES = [16] DTYPES = [torch.float16, torch.bfloat16] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +NUM_BLOCKS = 99682 # Large enough to test overflow in index calculation. def ref_paged_attn( @@ -123,23 +123,8 @@ def test_flash_attn_with_paged_kv( f"{torch.max(torch.abs(output - ref_output))}" -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - sliding_window: Optional[int], - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) +def prepare_varlen_with_paged_kv_input(seq_lens, num_heads, head_size, + sliding_window, dtype, block_size): num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -174,12 +159,43 @@ def test_varlen_with_paged_kv( dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + max_num_blocks_per_seq = 128 block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + output = (query_lens, kv_lens, query, key_cache, value_cache, + cu_query_lens, cu_kv_lens, max_query_len, max_kv_len, scale, + window_size, block_tables) + return output + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + + query_lens, kv_lens, query, key_cache, value_cache, \ + cu_query_lens, cu_kv_lens, max_query_len, \ + max_kv_len, scale, window_size, block_tables \ + = prepare_varlen_with_paged_kv_input(seq_lens, num_heads, + head_size, sliding_window, + dtype, block_size) + output = flash_attn_varlen_func( q=query, k=key_cache, @@ -206,3 +222,90 @@ def test_varlen_with_paged_kv( ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 912)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv_cudagraph(seq_lens, num_heads, head_size, + sliding_window, dtype, block_size): + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + graph_seq_lens = [(1, 2)] + query_lens, kv_lens, g_query, g_key_cache, g_value_cache, \ + g_cu_query_lens, g_cu_kv_lens, max_query_len, \ + max_kv_len, scale, window_size, g_block_tables \ + = prepare_varlen_with_paged_kv_input(graph_seq_lens, num_heads, + head_size, sliding_window, + dtype, block_size) + + # Warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + flash_attn_varlen_func( + q=g_query, + k=g_key_cache, + v=g_value_cache, + cu_seqlens_q=g_cu_query_lens, + cu_seqlens_k=g_cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=g_block_tables, + ) + torch.cuda.current_stream().wait_stream(s) + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + output = flash_attn_varlen_func( + q=g_query, + k=g_key_cache, + v=g_value_cache, + cu_seqlens_q=g_cu_query_lens, + cu_seqlens_k=g_cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=g_block_tables, + ) + torch.cuda.synchronize() + + # Replay + query_lens, kv_lens, query, key_cache, value_cache, \ + cu_query_lens, cu_kv_lens, max_query_len, \ + max_kv_len, scale, window_size, block_tables \ + = prepare_varlen_with_paged_kv_input(seq_lens, num_heads, + head_size, sliding_window, + dtype, block_size) + g_query.copy_(query) + g_key_cache.copy_(key_cache) + g_value_cache.copy_(value_cache) + g_cu_query_lens.copy_(cu_query_lens) + g_cu_kv_lens.copy_(cu_kv_lens) + g_block_tables.copy_(block_tables) + + graph.replay() + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4a0e2b41849..53fbe2d5da3 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -200,6 +200,11 @@ def test_prepare_decode_cuda_graph(batch_size): # decode has only 1 token for query. start_idx += 1 start_loc.append(start_idx) + # start_loc are padded to expected_bs + 1 + last_loc = start_loc[-1] + 1 + for _ in range(expected_bs - (len(start_loc) - 1)): + start_loc.append(last_loc) + last_loc += 1 assert torch.allclose( attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) @@ -209,6 +214,10 @@ def test_prepare_decode_cuda_graph(batch_size): for seq_len in seq_lens: start_idx += seq_len seq_start_loc.append(start_idx) + last_loc = seq_start_loc[-1] + 1 + for _ in range(expected_bs - (len(start_loc) - 1)): + start_loc.append(last_loc) + last_loc += 1 assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) @@ -375,9 +384,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): attn_metadata = model_runner._prepare_model_input_tensors( seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), - vars(prefill_meta_actual)): - assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), - vars(decode_meta_actual)): - assert attr_expected[1] == attr_actual[1] + if attn_metadata.prefill_metadata: + for attr_expected, attr_actual in zip( + vars(attn_metadata.prefill_metadata), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip( + vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7d7aff9dc3c..0b4e06fec29 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -14,6 +14,7 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.utils import make_tensor_with_pad +import debug_print if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -86,10 +87,8 @@ class FlashAttentionMetadata(AttentionMetadata): updated from `CUDAGraphRunner.forward` API. """ # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] + # the computed tokens + new tokens. + seq_lens: List[int] # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -101,12 +100,6 @@ class FlashAttentionMetadata(AttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -115,9 +108,6 @@ class FlashAttentionMetadata(AttentionMetadata): # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) @@ -132,69 +122,20 @@ class FlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + # Fields that are not used in flash attention backend, + # but used in other backends + context_lens_tensor: Optional[torch.Tensor] = None + seq_lens_tensor: Optional[torch.Tensor] = None + max_prefill_seq_len: Optional[int] = None + max_decode_seq_len: Optional[int] = None @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None - - self._cached_prefill_metadata = FlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - ) - return self._cached_prefill_metadata + return None @property def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = FlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata + return None class FlashAttentionMetadataBuilder( @@ -314,6 +255,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if block_table: input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device=device) + # print("block_tables", block_tables.shape, block_tables.data_ptr()) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -323,18 +265,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) + context_lens_tensor = None + query_start_loc = torch.tensor([0] + query_lens, + dtype=torch.int32, + device=device).cumsum(dim=0, + dtype=torch.int32) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + # print("seq_lens_tensor", seq_lens_tensor) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -342,10 +281,12 @@ def build(self, seq_lens: List[int], query_lens: List[int], dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + # print("seq_lens_tensor", seq_lens_tensor) + # print("query_start_loc", query_start_loc) + # print("seq_start_loc", seq_start_loc) + # print("slot_mapping", self.slot_mapping) + # print("max_seq_lens", max(seq_lens)) + # print("max_query_len", max_query_len) slot_mapping_tensor = torch.tensor(self.slot_mapping, dtype=torch.long, @@ -475,7 +416,6 @@ def forward( if kv_cache is not None: key_cache = kv_cache[0] value_cache = kv_cache[1] - # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. @@ -490,74 +430,65 @@ def forward( v_scale, ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - out = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + # This is used during the profiling or prefill phase. + if kv_cache is None or (attn_metadata.block_tables is not None + and attn_metadata.block_tables.numel()) == 0: + # print("1-----------------") + k = key + v = value + block_tables = None + else: + # print("2-----------------") + k = kv_cache[0] + v = kv_cache[1] + block_tables = attn_metadata.block_tables + + max_seq_len = max(attn_metadata.seq_lens) + max_k = torch.max(k).reshape(1) + max_v = torch.max(v).reshape(1) + if attn_metadata.use_cuda_graph: + pass + # # block_tables.zero_() + # debug_print.print_tensor(query) + # debug_print.print_tensor(max_k) + # debug_print.print_tensor(max_v) + # debug_print.print_tensor(attn_metadata.query_start_loc) + # debug_print.print_tensor(attn_metadata.seq_start_loc) + # debug_print.print_tensor(attn_metadata.block_tables) + # debug_print.print_tensor(attn_metadata.seq_lens_tensor) + else: + pass + # print("query", query.shape, query[0, 0]) + # print("max_k", k.shape, max_k) + # print("max_v", v.shape, max_v) + # print("query_start_loc", attn_metadata.query_start_loc) + # print("seq_start_loc", attn_metadata.seq_start_loc) + # print("block_tables", block_tables) + # print("seq_lens", attn_metadata.seq_lens) + # print("max_query_len", attn_metadata.max_query_len) + # print("max_seqlen_k", max_seq_len) + # print("scale", self.scale) + # print("sliding_window", self.sliding_window) + # print("alibi_slopes", self.alibi_slopes) + + output = flash_attn_varlen_func( + q=query, + k=k, + v=v, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=max_seq_len, + softmax_scale=0.125, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + block_table=block_tables) + # if attn_metadata.use_cuda_graph: + # pass + # # debug_print.print_tensor(output[0,0]) + # else: + # print(output[0,0]) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c..2eaed209b00 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -557,6 +557,7 @@ def build(self) -> ModelInputForGPU: # Sequence and query lengths. seq_lens.extend([1] * cuda_graph_pad_size) + query_lens.extend([1] * cuda_graph_pad_size) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( @@ -1033,8 +1034,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( @@ -1071,6 +1072,19 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: last_page_len_buffer = torch.empty(max_batch_size, dtype=torch.int32, device=self.device) + else: + query_start_loc = torch.arange(0, + max_batch_size + 2, + dtype=torch.int32, + device=self.device) + seq_start_loc = torch.arange(0, + max_batch_size + 2, + dtype=torch.int32, + device=self.device) + seq_lens = [1] * max_batch_size + seq_lens_tensor = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the @@ -1139,13 +1153,13 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, + seq_lens=seq_lens[:batch_size], + seq_lens_tensor=seq_lens_tensor[:batch_size], + max_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=query_start_loc[:batch_size + 1], + seq_start_loc=seq_start_loc[:batch_size + 1], context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, @@ -1333,16 +1347,11 @@ def execute_model( # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - # TODO(andoorve): We can remove this once all - # virtual engines share the same kv cache. - virtual_engine = model_input.virtual_engine - if prefill_meta is None and decode_meta.use_cuda_graph: + if model_input.attn_metadata.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] else: model_executable = self.model @@ -1383,7 +1392,7 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) - elif decode_meta.use_cuda_graph: + elif model_input.attn_metadata.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] else: hidden_states = hidden_or_intermediate_states @@ -1489,9 +1498,10 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": - attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, + "block_tables": attn_metadata.block_tables, + "seq_start_loc": attn_metadata.seq_start_loc, + "query_start_loc": attn_metadata.query_start_loc, + "seq_lens_tensor": attn_metadata.seq_lens_tensor, **kwargs, } if intermediate_inputs is not None: @@ -1522,11 +1532,16 @@ def forward( self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) if self.backend_name != "flashinfer": - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, - non_blocking=True) self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + attn_metadata.block_tables, non_blocking=True) + self.input_buffers["query_start_loc"].copy_( + attn_metadata.query_start_loc, non_blocking=True) + self.input_buffers["seq_start_loc"].copy_( + attn_metadata.seq_start_loc, non_blocking=True) + + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.seq_lens_tensor, non_blocking=True) + if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs)