Skip to content

Commit c6e01fe

Browse files
committed
Add profile execute duration observation
1 parent e2a0c19 commit c6e01fe

File tree

3 files changed

+182
-121
lines changed

3 files changed

+182
-121
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_MODEL_EXECUTE_TIME_OBSERVE":
40+
lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@
1818
#
1919

2020
import math
21+
import atexit
2122
from typing import TYPE_CHECKING
23+
from contextlib import contextmanager
24+
from typing import List, Tuple
25+
from threading import Lock
2226

2327
import torch
2428
from packaging.version import InvalidVersion, Version
2529
from vllm.logger import logger
30+
from torch_npu.npu.streams import Event
2631

2732
import vllm_ascend.envs as envs
2833

@@ -169,3 +174,49 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
169174
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
170175
vllm_config.model_config.architectures[0], num_hidden_layers,
171176
len(original_sizes))
177+
178+
class ProfileExecuteDuration:
179+
_instance = None
180+
_observations: List[Tuple[str, Event, Event]] = []
181+
_lock = Lock()
182+
183+
def __new__(cls):
184+
with cls._lock:
185+
if cls._instance is None:
186+
cls._instance = super().__new__(cls)
187+
atexit.register(cls._instance.destroy)
188+
return cls._instance
189+
190+
def destroy(self):
191+
with self._lock:
192+
self._observations.clear()
193+
194+
@contextmanager
195+
def capture_async(self, duration_tag: str):
196+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
197+
yield
198+
return
199+
200+
observe_start = Event(enable_timing=True)
201+
observe_start.record()
202+
try:
203+
yield
204+
finally:
205+
observe_end = Event(enable_timing=True)
206+
observe_end.record()
207+
with self._lock:
208+
self._observations.append((duration_tag, observe_start, observe_end))
209+
210+
def pop_captured_sync(self, captured_name: str):
211+
"""Pop and synchronize all events in the observation list, print all duration"""
212+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
213+
return
214+
215+
log = f"Profile execute duration [{captured_name}]:"
216+
while self._observations:
217+
with self._lock:
218+
tag, observe_start, observe_end = self._observations.pop()
219+
observe_end.synchronize()
220+
duration = observe_start.elapsed_time(observe_end)
221+
log += f" [{tag}]:{duration:.2f}ms"
222+
print(log)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 129 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
5858
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
5959

60+
from vllm_ascend.utils import ProfileExecuteDuration
6061
from vllm_ascend.attention.attention import AttentionMaskBuilder
6162
from vllm_ascend.attention.attention_v1 import AscendAttentionState
6263
from vllm_ascend.platform import NPUPlatform
@@ -640,36 +641,37 @@ def _process_reqs(
640641
with set_forward_context(attn_metadata,
641642
self.vllm_config,
642643
num_tokens=num_input_tokens):
643-
model_kwargs = {}
644-
if self.enable_torchair_graph_mode:
645-
model_kwargs["kv_caches"] = self.kv_caches
646-
model_kwargs["attn_metadata"] = attn_metadata
647-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
648-
torch._dynamo.mark_static(input_ids)
649-
torch._dynamo.mark_static(positions)
650-
torch._dynamo.mark_static(attn_metadata.decode.block_table)
651-
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
652-
torch._dynamo.mark_static(attn_metadata.slot_mapping)
653-
for kv in self.kv_caches:
654-
if isinstance(kv, tuple):
655-
torch._dynamo.mark_static(kv[0])
656-
torch._dynamo.mark_static(kv[1])
657-
hidden_states = self.compile_model(
658-
input_ids=input_ids,
659-
positions=positions,
660-
intermediate_tensors=intermediate_tensors,
661-
inputs_embeds=None,
662-
**model_kwargs,
663-
)
664-
else:
665-
assert self.model is not None
666-
hidden_states = self.model(
667-
input_ids=input_ids,
668-
positions=positions,
669-
intermediate_tensors=intermediate_tensors,
670-
inputs_embeds=None,
671-
**model_kwargs,
672-
)
644+
with ProfileExecuteDuration().capture_async("forward"):
645+
model_kwargs = {}
646+
if self.enable_torchair_graph_mode:
647+
model_kwargs["kv_caches"] = self.kv_caches
648+
model_kwargs["attn_metadata"] = attn_metadata
649+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
650+
torch._dynamo.mark_static(input_ids)
651+
torch._dynamo.mark_static(positions)
652+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
653+
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
654+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
655+
for kv in self.kv_caches:
656+
if isinstance(kv, tuple):
657+
torch._dynamo.mark_static(kv[0])
658+
torch._dynamo.mark_static(kv[1])
659+
hidden_states = self.compile_model(
660+
input_ids=input_ids,
661+
positions=positions,
662+
intermediate_tensors=intermediate_tensors,
663+
inputs_embeds=None,
664+
**model_kwargs,
665+
)
666+
else:
667+
assert self.model is not None
668+
hidden_states = self.model(
669+
input_ids=input_ids,
670+
positions=positions,
671+
intermediate_tensors=intermediate_tensors,
672+
inputs_embeds=None,
673+
**model_kwargs,
674+
)
673675

674676
use_spec_decode = len(
675677
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -856,103 +858,109 @@ def execute_model(
856858
scheduler_output: "SchedulerOutput",
857859
intermediate_tensors: Optional[IntermediateTensors] = None,
858860
) -> Union[ModelRunnerOutput, torch.Tensor]:
859-
self._update_states(scheduler_output)
860-
if not scheduler_output.total_num_scheduled_tokens:
861-
# Return empty ModelRunnerOuptut if there's no work to do.
862-
return EMPTY_MODEL_RUNNER_OUTPUT
863-
(attn_metadata, hidden_states, spec_decode_metadata, positions,
864-
num_scheduled_tokens,
865-
sample_indices) = (self._process_reqs(scheduler_output,
866-
intermediate_tensors))
867-
logits = self.model.compute_logits(hidden_states[sample_indices], None)
868-
869-
# Apply structured output bitmasks if present
870-
if scheduler_output.grammar_bitmask is not None:
871-
logits = self.apply_grammar_bitmask(scheduler_output, logits)
872-
873-
# Sample the next token and get logprobs if needed.
874-
sampling_metadata = self.input_batch.sampling_metadata
875-
if spec_decode_metadata is None:
876-
sampler_output = self.sampler(
877-
logits=logits,
878-
sampling_metadata=sampling_metadata,
879-
)
880-
else:
881-
# When indexing with a tensor (bonus_logits_indices), PyTorch
882-
# creates a new tensor with separate storage from the original
883-
# logits tensor. This means any in-place operations on bonus_logits
884-
# won't affect the original logits tensor.
885-
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
886-
sampler_output = self.sampler(
887-
logits=bonus_logits,
888-
sampling_metadata=sampling_metadata,
889-
)
890-
bonus_token_ids = sampler_output.sampled_token_ids
861+
with ProfileExecuteDuration().capture_async("prepare input and forward"):
862+
self._update_states(scheduler_output)
863+
if not scheduler_output.total_num_scheduled_tokens:
864+
# Return empty ModelRunnerOuptut if there's no work to do.
865+
return EMPTY_MODEL_RUNNER_OUTPUT
866+
(attn_metadata, hidden_states, spec_decode_metadata, positions,
867+
num_scheduled_tokens,
868+
sample_indices) = (self._process_reqs(scheduler_output,
869+
intermediate_tensors))
870+
871+
with ProfileExecuteDuration().capture_async("post process"):
872+
logits = self.model.compute_logits(hidden_states[sample_indices], None)
873+
874+
# Apply structured output bitmasks if present
875+
if scheduler_output.grammar_bitmask is not None:
876+
logits = self.apply_grammar_bitmask(scheduler_output, logits)
877+
878+
# Sample the next token and get logprobs if needed.
879+
sampling_metadata = self.input_batch.sampling_metadata
880+
if spec_decode_metadata is None:
881+
sampler_output = self.sampler(
882+
logits=logits,
883+
sampling_metadata=sampling_metadata,
884+
)
885+
else:
886+
# When indexing with a tensor (bonus_logits_indices), PyTorch
887+
# creates a new tensor with separate storage from the original
888+
# logits tensor. This means any in-place operations on bonus_logits
889+
# won't affect the original logits tensor.
890+
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
891+
sampler_output = self.sampler(
892+
logits=bonus_logits,
893+
sampling_metadata=sampling_metadata,
894+
)
895+
bonus_token_ids = sampler_output.sampled_token_ids
896+
897+
# Just like `bonus_logits`, `target_logits` is a new tensor with
898+
# separate storage from the original `logits` tensor. Therefore,
899+
# it is safe to update `target_logits` in place.
900+
target_logits = logits[spec_decode_metadata.target_logits_indices]
901+
output_token_ids = self.rejection_sampler(
902+
spec_decode_metadata,
903+
None, # draft_probs
904+
target_logits,
905+
bonus_token_ids,
906+
sampling_metadata,
907+
)
908+
sampler_output.sampled_token_ids = output_token_ids
909+
910+
# TODO(woosuk): The following loop can be slow since it iterates over
911+
# the requests one by one. Optimize.
912+
for i, req_id in enumerate(self.input_batch.req_ids):
913+
req_state = self.requests[req_id]
914+
seq_len = (req_state.num_computed_tokens +
915+
scheduler_output.num_scheduled_tokens[req_id])
916+
if seq_len < req_state.num_tokens:
917+
# Ignore the sampled token.
918+
# Rewind the generator state as if the token was not sampled.
919+
generator = self.input_batch.generators.get(i)
920+
if generator is not None:
921+
generator.set_offset(generator.get_offset() - 4)
922+
923+
# NOTE: NPU -> CPU Sync happens here.
924+
# Move as many CPU operations as possible before this sync point.
925+
logprobs_tensors = sampler_output.logprobs_tensors
926+
logprobs_lists = logprobs_tensors.tolists() \
927+
if logprobs_tensors is not None else None
928+
929+
# Get the valid generated tokens.
930+
sampled_token_ids = sampler_output.sampled_token_ids
931+
max_gen_len = sampled_token_ids.shape[-1]
932+
if max_gen_len == 1:
933+
# No spec decode tokens.
934+
valid_sampled_token_ids = sampled_token_ids.tolist()
935+
else:
936+
# Includes spec decode tokens.
937+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
938+
sampled_token_ids,
939+
self.input_batch.vocab_size,
940+
)
891941

892-
# Just like `bonus_logits`, `target_logits` is a new tensor with
893-
# separate storage from the original `logits` tensor. Therefore,
894-
# it is safe to update `target_logits` in place.
895-
target_logits = logits[spec_decode_metadata.target_logits_indices]
896-
output_token_ids = self.rejection_sampler(
897-
spec_decode_metadata,
898-
None, # draft_probs
899-
target_logits,
900-
bonus_token_ids,
942+
spec_token_ids = self._get_spec_token_ids(
943+
valid_sampled_token_ids,
901944
sampling_metadata,
945+
scheduler_output,
946+
spec_decode_metadata,
947+
positions,
948+
num_scheduled_tokens,
949+
hidden_states,
950+
attn_metadata,
902951
)
903-
sampler_output.sampled_token_ids = output_token_ids
904952

905-
# TODO(woosuk): The following loop can be slow since it iterates over
906-
# the requests one by one. Optimize.
907-
for i, req_id in enumerate(self.input_batch.req_ids):
908-
req_state = self.requests[req_id]
909-
seq_len = (req_state.num_computed_tokens +
910-
scheduler_output.num_scheduled_tokens[req_id])
911-
if seq_len < req_state.num_tokens:
912-
# Ignore the sampled token.
913-
# Rewind the generator state as if the token was not sampled.
914-
generator = self.input_batch.generators.get(i)
915-
if generator is not None:
916-
generator.set_offset(generator.get_offset() - 4)
917-
918-
# NOTE: NPU -> CPU Sync happens here.
919-
# Move as many CPU operations as possible before this sync point.
920-
logprobs_tensors = sampler_output.logprobs_tensors
921-
logprobs_lists = logprobs_tensors.tolists() \
922-
if logprobs_tensors is not None else None
923-
924-
# Get the valid generated tokens.
925-
sampled_token_ids = sampler_output.sampled_token_ids
926-
max_gen_len = sampled_token_ids.shape[-1]
927-
if max_gen_len == 1:
928-
# No spec decode tokens.
929-
valid_sampled_token_ids = sampled_token_ids.tolist()
930-
else:
931-
# Includes spec decode tokens.
932-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
933-
sampled_token_ids,
934-
self.input_batch.vocab_size,
953+
model_runner_output = ModelRunnerOutput(
954+
req_ids=self.input_batch.req_ids,
955+
req_id_to_index=self.input_batch.req_id_to_index,
956+
sampled_token_ids=valid_sampled_token_ids,
957+
spec_token_ids=spec_token_ids,
958+
logprobs=logprobs_lists,
959+
prompt_logprobs_dict={},
935960
)
936961

937-
spec_token_ids = self._get_spec_token_ids(
938-
valid_sampled_token_ids,
939-
sampling_metadata,
940-
scheduler_output,
941-
spec_decode_metadata,
942-
positions,
943-
num_scheduled_tokens,
944-
hidden_states,
945-
attn_metadata,
946-
)
947-
948-
model_runner_output = ModelRunnerOutput(
949-
req_ids=self.input_batch.req_ids,
950-
req_id_to_index=self.input_batch.req_id_to_index,
951-
sampled_token_ids=valid_sampled_token_ids,
952-
spec_token_ids=spec_token_ids,
953-
logprobs=logprobs_lists,
954-
prompt_logprobs_dict={},
955-
)
962+
capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
963+
ProfileExecuteDuration().pop_captured_sync(capture_name)
956964
return model_runner_output
957965

958966
def _profile_multimodal(self) -> None:

0 commit comments

Comments
 (0)