diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8e1cc1c16..13ce74fb6 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -36,6 +36,8 @@ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + "VLLM_MODEL_EXECUTE_TIME_OBSERVE": + lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))), "USING_LCCL_COM": lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), "SOC_VERSION": diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 67cc0b8f4..00a00e3d2 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -17,14 +17,18 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # +import atexit import math -from typing import TYPE_CHECKING +from contextlib import contextmanager +from threading import Lock +from typing import TYPE_CHECKING, List, Tuple import torch from packaging.version import InvalidVersion, Version -from vllm.logger import logger +from torch_npu.npu.streams import Event import vllm_ascend.envs as envs +from vllm.logger import logger if TYPE_CHECKING: from vllm.config import VllmConfig @@ -173,3 +177,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: def dispose_tensor(x: torch.Tensor): x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) + + +class ProfileExecuteDuration: + _instance = None + _observations: List[Tuple[str, Event, Event]] = [] + _lock = Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + atexit.register(cls._instance.destroy) + return cls._instance + + def destroy(self): + with self._lock: + self._observations.clear() + + @contextmanager + def capture_async(self, duration_tag: str): + if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE: + yield + return + + observe_start = Event(enable_timing=True) + observe_start.record() + try: + yield + finally: + observe_end = Event(enable_timing=True) + observe_end.record() + with self._lock: + self._observations.append( + (duration_tag, observe_start, observe_end)) + + def pop_captured_sync(self, captured_name: str): + """Pop and synchronize all events in the observation list, print all duration""" + if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE: + return + + log = f"Profile execute duration [{captured_name}]:" + while self._observations: + with self._lock: + tag, observe_start, observe_end = self._observations.pop() + observe_end.synchronize() + duration = observe_start.elapsed_time(observe_end) + log += f" [{tag}]:{duration:.2f}ms" + print(log) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 184f3529b..03d2f71c0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -29,6 +29,7 @@ import numpy.typing as npt import torch import torch.nn as nn + from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig @@ -56,14 +57,15 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin - from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler +from vllm_ascend.utils import ProfileExecuteDuration if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] + from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -628,36 +630,38 @@ def _process_reqs( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - model_kwargs = {} - if self.enable_torchair_graph_mode: - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static(attn_metadata.decode.input_positions) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - for kv in self.kv_caches: - if isinstance(kv, tuple): - torch._dynamo.mark_static(kv[0]) - torch._dynamo.mark_static(kv[1]) - hidden_states = self.compile_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - **model_kwargs, - ) - else: - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - **model_kwargs, - ) + with ProfileExecuteDuration().capture_async("forward"): + model_kwargs = {} + if self.enable_torchair_graph_mode: + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + for kv in self.kv_caches: + if isinstance(kv, tuple): + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + hidden_states = self.compile_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) + else: + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -844,103 +848,113 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, - sample_indices) = (self._process_reqs(scheduler_output, - intermediate_tensors)) - logits = self.model.compute_logits(hidden_states[sample_indices], None) - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - logits = self.apply_grammar_bitmask(scheduler_output, logits) - - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata - if spec_decode_metadata is None: - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids + with ProfileExecuteDuration().capture_async( + "prepare input and forward"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + (attn_metadata, hidden_states, spec_decode_metadata, positions, + num_scheduled_tokens, + sample_indices) = (self._process_reqs(scheduler_output, + intermediate_tensors)) + + with ProfileExecuteDuration().capture_async("post process"): + logits = self.model.compute_logits(hidden_states[sample_indices], + None) + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + bonus_logits = logits[ + spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[ + spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Get the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, + spec_token_ids = self._get_spec_token_ids( + valid_sampled_token_ids, sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + num_scheduled_tokens, + hidden_states, + attn_metadata, ) - sampler_output.sampled_token_ids = output_token_ids - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Get the valid generated tokens. - sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict={}, ) - spec_token_ids = self._get_spec_token_ids( - valid_sampled_token_ids, - sampling_metadata, - scheduler_output, - spec_decode_metadata, - positions, - num_scheduled_tokens, - hidden_states, - attn_metadata, - ) - - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict={}, - ) + capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + ProfileExecuteDuration().pop_captured_sync(capture_name) return model_runner_output def _profile_multimodal(self) -> None: