Skip to content

Commit 294933e

Browse files
committed
Add profile execute duration observation
1 parent a93bed4 commit 294933e

File tree

3 files changed

+189
-122
lines changed

3 files changed

+189
-122
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_MODEL_EXECUTE_TIME_OBSERVE":
40+
lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/worker.py
1818
#
1919

20+
import atexit
2021
import math
21-
from typing import TYPE_CHECKING
22+
from contextlib import contextmanager
23+
from threading import Lock
24+
from typing import TYPE_CHECKING, List, Tuple
2225

2326
import torch
2427
from packaging.version import InvalidVersion, Version
28+
from torch_npu.npu.streams import Event
2529
from vllm.logger import logger
2630

2731
import vllm_ascend.envs as envs
@@ -173,3 +177,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
173177

174178
def dispose_tensor(x: torch.Tensor):
175179
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
180+
181+
182+
class ProfileExecuteDuration:
183+
_instance = None
184+
_observations: List[Tuple[str, Event, Event]] = []
185+
_lock = Lock()
186+
187+
def __new__(cls):
188+
with cls._lock:
189+
if cls._instance is None:
190+
cls._instance = super().__new__(cls)
191+
atexit.register(cls._instance.destroy)
192+
return cls._instance
193+
194+
def destroy(self):
195+
with self._lock:
196+
self._observations.clear()
197+
198+
@contextmanager
199+
def capture_async(self, duration_tag: str):
200+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
201+
yield
202+
return
203+
204+
observe_start = Event(enable_timing=True)
205+
observe_start.record()
206+
try:
207+
yield
208+
finally:
209+
observe_end = Event(enable_timing=True)
210+
observe_end.record()
211+
with self._lock:
212+
self._observations.append(
213+
(duration_tag, observe_start, observe_end))
214+
215+
def pop_captured_sync(self, captured_name: str):
216+
"""Pop and synchronize all events in the observation list, print all duration"""
217+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
218+
return
219+
220+
log = f"Profile execute duration [{captured_name}]:"
221+
while self._observations:
222+
with self._lock:
223+
tag, observe_start, observe_end = self._observations.pop()
224+
observe_end.synchronize()
225+
duration = observe_start.elapsed_time(observe_end)
226+
log += f" [{tag}]:{duration:.2f}ms"
227+
print(log)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 134 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from vllm_ascend.attention.attention_v1 import AscendAttentionState
6262
from vllm_ascend.platform import NPUPlatform
6363
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
64+
from vllm_ascend.utils import ProfileExecuteDuration
6465

6566
if TYPE_CHECKING:
6667
import xgrammar as xgr # type: ignore[import-untyped]
@@ -628,36 +629,38 @@ def _process_reqs(
628629
with set_forward_context(attn_metadata,
629630
self.vllm_config,
630631
num_tokens=num_input_tokens):
631-
model_kwargs = {}
632-
if self.enable_torchair_graph_mode:
633-
model_kwargs["kv_caches"] = self.kv_caches
634-
model_kwargs["attn_metadata"] = attn_metadata
635-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
636-
torch._dynamo.mark_static(input_ids)
637-
torch._dynamo.mark_static(positions)
638-
torch._dynamo.mark_static(attn_metadata.decode.block_table)
639-
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
640-
torch._dynamo.mark_static(attn_metadata.slot_mapping)
641-
for kv in self.kv_caches:
642-
if isinstance(kv, tuple):
643-
torch._dynamo.mark_static(kv[0])
644-
torch._dynamo.mark_static(kv[1])
645-
hidden_states = self.compile_model(
646-
input_ids=input_ids,
647-
positions=positions,
648-
intermediate_tensors=intermediate_tensors,
649-
inputs_embeds=None,
650-
**model_kwargs,
651-
)
652-
else:
653-
assert self.model is not None
654-
hidden_states = self.model(
655-
input_ids=input_ids,
656-
positions=positions,
657-
intermediate_tensors=intermediate_tensors,
658-
inputs_embeds=None,
659-
**model_kwargs,
660-
)
632+
with ProfileExecuteDuration().capture_async("forward"):
633+
model_kwargs = {}
634+
if self.enable_torchair_graph_mode:
635+
model_kwargs["kv_caches"] = self.kv_caches
636+
model_kwargs["attn_metadata"] = attn_metadata
637+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
638+
torch._dynamo.mark_static(input_ids)
639+
torch._dynamo.mark_static(positions)
640+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
641+
torch._dynamo.mark_static(
642+
attn_metadata.decode.input_positions)
643+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
644+
for kv in self.kv_caches:
645+
if isinstance(kv, tuple):
646+
torch._dynamo.mark_static(kv[0])
647+
torch._dynamo.mark_static(kv[1])
648+
hidden_states = self.compile_model(
649+
input_ids=input_ids,
650+
positions=positions,
651+
intermediate_tensors=intermediate_tensors,
652+
inputs_embeds=None,
653+
**model_kwargs,
654+
)
655+
else:
656+
assert self.model is not None
657+
hidden_states = self.model(
658+
input_ids=input_ids,
659+
positions=positions,
660+
intermediate_tensors=intermediate_tensors,
661+
inputs_embeds=None,
662+
**model_kwargs,
663+
)
661664

662665
use_spec_decode = len(
663666
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -844,103 +847,113 @@ def execute_model(
844847
scheduler_output: "SchedulerOutput",
845848
intermediate_tensors: Optional[IntermediateTensors] = None,
846849
) -> Union[ModelRunnerOutput, torch.Tensor]:
847-
self._update_states(scheduler_output)
848-
if not scheduler_output.total_num_scheduled_tokens:
849-
# Return empty ModelRunnerOuptut if there's no work to do.
850-
return EMPTY_MODEL_RUNNER_OUTPUT
851-
(attn_metadata, hidden_states, spec_decode_metadata, positions,
852-
num_scheduled_tokens,
853-
sample_indices) = (self._process_reqs(scheduler_output,
854-
intermediate_tensors))
855-
logits = self.model.compute_logits(hidden_states[sample_indices], None)
856-
857-
# Apply structured output bitmasks if present
858-
if scheduler_output.grammar_bitmask is not None:
859-
logits = self.apply_grammar_bitmask(scheduler_output, logits)
860-
861-
# Sample the next token and get logprobs if needed.
862-
sampling_metadata = self.input_batch.sampling_metadata
863-
if spec_decode_metadata is None:
864-
sampler_output = self.sampler(
865-
logits=logits,
866-
sampling_metadata=sampling_metadata,
867-
)
868-
else:
869-
# When indexing with a tensor (bonus_logits_indices), PyTorch
870-
# creates a new tensor with separate storage from the original
871-
# logits tensor. This means any in-place operations on bonus_logits
872-
# won't affect the original logits tensor.
873-
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
874-
sampler_output = self.sampler(
875-
logits=bonus_logits,
876-
sampling_metadata=sampling_metadata,
877-
)
878-
bonus_token_ids = sampler_output.sampled_token_ids
850+
with ProfileExecuteDuration().capture_async(
851+
"prepare input and forward"):
852+
self._update_states(scheduler_output)
853+
if not scheduler_output.total_num_scheduled_tokens:
854+
# Return empty ModelRunnerOuptut if there's no work to do.
855+
return EMPTY_MODEL_RUNNER_OUTPUT
856+
(attn_metadata, hidden_states, spec_decode_metadata, positions,
857+
num_scheduled_tokens,
858+
sample_indices) = (self._process_reqs(scheduler_output,
859+
intermediate_tensors))
860+
861+
with ProfileExecuteDuration().capture_async("post process"):
862+
logits = self.model.compute_logits(hidden_states[sample_indices],
863+
None)
864+
865+
# Apply structured output bitmasks if present
866+
if scheduler_output.grammar_bitmask is not None:
867+
logits = self.apply_grammar_bitmask(scheduler_output, logits)
868+
869+
# Sample the next token and get logprobs if needed.
870+
sampling_metadata = self.input_batch.sampling_metadata
871+
if spec_decode_metadata is None:
872+
sampler_output = self.sampler(
873+
logits=logits,
874+
sampling_metadata=sampling_metadata,
875+
)
876+
else:
877+
# When indexing with a tensor (bonus_logits_indices), PyTorch
878+
# creates a new tensor with separate storage from the original
879+
# logits tensor. This means any in-place operations on bonus_logits
880+
# won't affect the original logits tensor.
881+
bonus_logits = logits[
882+
spec_decode_metadata.bonus_logits_indices]
883+
sampler_output = self.sampler(
884+
logits=bonus_logits,
885+
sampling_metadata=sampling_metadata,
886+
)
887+
bonus_token_ids = sampler_output.sampled_token_ids
888+
889+
# Just like `bonus_logits`, `target_logits` is a new tensor with
890+
# separate storage from the original `logits` tensor. Therefore,
891+
# it is safe to update `target_logits` in place.
892+
target_logits = logits[
893+
spec_decode_metadata.target_logits_indices]
894+
output_token_ids = self.rejection_sampler(
895+
spec_decode_metadata,
896+
None, # draft_probs
897+
target_logits,
898+
bonus_token_ids,
899+
sampling_metadata,
900+
)
901+
sampler_output.sampled_token_ids = output_token_ids
902+
903+
# TODO(woosuk): The following loop can be slow since it iterates over
904+
# the requests one by one. Optimize.
905+
for i, req_id in enumerate(self.input_batch.req_ids):
906+
req_state = self.requests[req_id]
907+
seq_len = (req_state.num_computed_tokens +
908+
scheduler_output.num_scheduled_tokens[req_id])
909+
if seq_len < req_state.num_tokens:
910+
# Ignore the sampled token.
911+
# Rewind the generator state as if the token was not sampled.
912+
generator = self.input_batch.generators.get(i)
913+
if generator is not None:
914+
generator.set_offset(generator.get_offset() - 4)
915+
916+
# NOTE: NPU -> CPU Sync happens here.
917+
# Move as many CPU operations as possible before this sync point.
918+
logprobs_tensors = sampler_output.logprobs_tensors
919+
logprobs_lists = logprobs_tensors.tolists() \
920+
if logprobs_tensors is not None else None
921+
922+
# Get the valid generated tokens.
923+
sampled_token_ids = sampler_output.sampled_token_ids
924+
max_gen_len = sampled_token_ids.shape[-1]
925+
if max_gen_len == 1:
926+
# No spec decode tokens.
927+
valid_sampled_token_ids = sampled_token_ids.tolist()
928+
else:
929+
# Includes spec decode tokens.
930+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
931+
sampled_token_ids,
932+
self.input_batch.vocab_size,
933+
)
879934

880-
# Just like `bonus_logits`, `target_logits` is a new tensor with
881-
# separate storage from the original `logits` tensor. Therefore,
882-
# it is safe to update `target_logits` in place.
883-
target_logits = logits[spec_decode_metadata.target_logits_indices]
884-
output_token_ids = self.rejection_sampler(
885-
spec_decode_metadata,
886-
None, # draft_probs
887-
target_logits,
888-
bonus_token_ids,
935+
spec_token_ids = self._get_spec_token_ids(
936+
valid_sampled_token_ids,
889937
sampling_metadata,
938+
scheduler_output,
939+
spec_decode_metadata,
940+
positions,
941+
num_scheduled_tokens,
942+
hidden_states,
943+
attn_metadata,
890944
)
891-
sampler_output.sampled_token_ids = output_token_ids
892945

893-
# TODO(woosuk): The following loop can be slow since it iterates over
894-
# the requests one by one. Optimize.
895-
for i, req_id in enumerate(self.input_batch.req_ids):
896-
req_state = self.requests[req_id]
897-
seq_len = (req_state.num_computed_tokens +
898-
scheduler_output.num_scheduled_tokens[req_id])
899-
if seq_len < req_state.num_tokens:
900-
# Ignore the sampled token.
901-
# Rewind the generator state as if the token was not sampled.
902-
generator = self.input_batch.generators.get(i)
903-
if generator is not None:
904-
generator.set_offset(generator.get_offset() - 4)
905-
906-
# NOTE: NPU -> CPU Sync happens here.
907-
# Move as many CPU operations as possible before this sync point.
908-
logprobs_tensors = sampler_output.logprobs_tensors
909-
logprobs_lists = logprobs_tensors.tolists() \
910-
if logprobs_tensors is not None else None
911-
912-
# Get the valid generated tokens.
913-
sampled_token_ids = sampler_output.sampled_token_ids
914-
max_gen_len = sampled_token_ids.shape[-1]
915-
if max_gen_len == 1:
916-
# No spec decode tokens.
917-
valid_sampled_token_ids = sampled_token_ids.tolist()
918-
else:
919-
# Includes spec decode tokens.
920-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
921-
sampled_token_ids,
922-
self.input_batch.vocab_size,
946+
model_runner_output = ModelRunnerOutput(
947+
req_ids=self.input_batch.req_ids,
948+
req_id_to_index=self.input_batch.req_id_to_index,
949+
sampled_token_ids=valid_sampled_token_ids,
950+
spec_token_ids=spec_token_ids,
951+
logprobs=logprobs_lists,
952+
prompt_logprobs_dict={},
923953
)
924954

925-
spec_token_ids = self._get_spec_token_ids(
926-
valid_sampled_token_ids,
927-
sampling_metadata,
928-
scheduler_output,
929-
spec_decode_metadata,
930-
positions,
931-
num_scheduled_tokens,
932-
hidden_states,
933-
attn_metadata,
934-
)
935-
936-
model_runner_output = ModelRunnerOutput(
937-
req_ids=self.input_batch.req_ids,
938-
req_id_to_index=self.input_batch.req_id_to_index,
939-
sampled_token_ids=valid_sampled_token_ids,
940-
spec_token_ids=spec_token_ids,
941-
logprobs=logprobs_lists,
942-
prompt_logprobs_dict={},
943-
)
955+
capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
956+
ProfileExecuteDuration().pop_captured_sync(capture_name)
944957
return model_runner_output
945958

946959
def _profile_multimodal(self) -> None:

0 commit comments

Comments
 (0)