Skip to content

[ModelRunner]Add profile execute duration observation #996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
"VLLM_ENABLE_MC2":
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
"VLLM_MODEL_EXECUTE_TIME_OBSERVE":
lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))),
"USING_LCCL_COM":
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
"SOC_VERSION":
Expand Down
56 changes: 54 additions & 2 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#

import atexit
import math
from typing import TYPE_CHECKING
from contextlib import contextmanager
from threading import Lock
from typing import TYPE_CHECKING, List, Tuple

import torch
from packaging.version import InvalidVersion, Version
from vllm.logger import logger
from torch_npu.npu.streams import Event

import vllm_ascend.envs as envs
from vllm.logger import logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -173,3 +177,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:

def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))


class ProfileExecuteDuration:
_instance = None
_observations: List[Tuple[str, Event, Event]] = []
_lock = Lock()

def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
atexit.register(cls._instance.destroy)
return cls._instance

def destroy(self):
with self._lock:
self._observations.clear()

@contextmanager
def capture_async(self, duration_tag: str):
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
yield
return

observe_start = Event(enable_timing=True)
observe_start.record()
try:
yield
finally:
observe_end = Event(enable_timing=True)
observe_end.record()
with self._lock:
self._observations.append(
(duration_tag, observe_start, observe_end))

def pop_captured_sync(self, captured_name: str):
"""Pop and synchronize all events in the observation list, print all duration"""
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
return

log = f"Profile execute duration [{captured_name}]:"
while self._observations:
with self._lock:
tag, observe_start, observe_end = self._observations.pop()
observe_end.synchronize()
duration = observe_start.elapsed_time(observe_end)
log += f" [{tag}]:{duration:.2f}ms"
print(log)
258 changes: 136 additions & 122 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy.typing as npt
import torch
import torch.nn as nn

from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
Expand Down Expand Up @@ -56,14 +57,15 @@
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.utils import ProfileExecuteDuration

if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]

from vllm.v1.core.sched.output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
Expand Down Expand Up @@ -628,36 +630,38 @@ def _process_reqs(
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
model_kwargs = {}
if self.enable_torchair_graph_mode:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
with ProfileExecuteDuration().capture_async("forward"):
model_kwargs = {}
if self.enable_torchair_graph_mode:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)

use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
Expand Down Expand Up @@ -844,103 +848,113 @@ def execute_model(
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens,
sample_indices) = (self._process_reqs(scheduler_output,
intermediate_tensors))
logits = self.model.compute_logits(hidden_states[sample_indices], None)

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
logits = self.apply_grammar_bitmask(scheduler_output, logits)

# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
with ProfileExecuteDuration().capture_async(
"prepare input and forward"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens,
sample_indices) = (self._process_reqs(scheduler_output,
intermediate_tensors))

with ProfileExecuteDuration().capture_async("post process"):
logits = self.model.compute_logits(hidden_states[sample_indices],
None)

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
logits = self.apply_grammar_bitmask(scheduler_output, logits)

# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[
spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids

# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[
spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids

# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)

# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None

# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)

# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
spec_token_ids = self._get_spec_token_ids(
valid_sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
num_scheduled_tokens,
hidden_states,
attn_metadata,
)
sampler_output.sampled_token_ids = output_token_ids

# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)

# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None

# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict={},
)

spec_token_ids = self._get_spec_token_ids(
valid_sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
num_scheduled_tokens,
hidden_states,
attn_metadata,
)

model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict={},
)
capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
ProfileExecuteDuration().pop_captured_sync(capture_name)
return model_runner_output

def _profile_multimodal(self) -> None:
Expand Down
Loading