Skip to content

Commit eaf0c01

Browse files
[bugfix] torch profiler bug for single gpu with GPUExecutor
1 parent 7de49aa commit eaf0c01

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

examples/offline_inference_with_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
1717

1818
# Create an LLM.
19-
llm = LLM(model="facebook/opt-125m")
19+
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
2020

2121
llm.start_profile()
2222

vllm/engine/async_llm_engine.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.engine.metrics_types import StatLoggerBase
1818
from vllm.executor.executor_base import ExecutorAsyncBase
1919
from vllm.executor.ray_utils import initialize_ray_cluster
20+
from vllm.executor.gpu_executor import GPUExecutorAsync
2021
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
2122
SingletonPromptInputs)
2223
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
@@ -1156,7 +1157,13 @@ def remove_logger(self, logger_name: str) -> None:
11561157
self.engine.remove_logger(logger_name=logger_name)
11571158

11581159
async def start_profile(self) -> None:
1159-
self.engine.model_executor._run_workers("start_profile")
1160+
if isinstance(self.engine.model_executor, GPUExecutorAsync):
1161+
self.engine.model_executor.start_profile()
1162+
else:
1163+
self.engine.model_executor._run_workers("start_profile")
11601164

11611165
async def stop_profile(self) -> None:
1162-
self.engine.model_executor._run_workers("stop_profile")
1166+
if isinstance(self.engine.model_executor, GPUExecutorAsync):
1167+
self.engine.model_executor.stop_profile()
1168+
else:
1169+
self.engine.model_executor._run_workers("stop_profile")

vllm/engine/llm_engine.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from vllm.engine.output_processor.stop_checker import StopChecker
2727
from vllm.engine.output_processor.util import create_output_by_sequence_group
2828
from vllm.executor.executor_base import ExecutorBase
29+
from vllm.executor.gpu_executor import GPUExecutorAsync
2930
from vllm.executor.ray_utils import initialize_ray_cluster
3031
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
3132
InputRegistry, LLMInputs, PromptInputs,
@@ -1964,10 +1965,16 @@ def check_health(self) -> None:
19641965
self.model_executor.check_health()
19651966

19661967
def start_profile(self) -> None:
1967-
self.model_executor.start_profile()
1968+
if isinstance(self.model_executor, GPUExecutorAsync):
1969+
self.model_executor.start_profile()
1970+
else:
1971+
self.model_executor._run_workers("start_profile")
19681972

19691973
def stop_profile(self) -> None:
1970-
self.model_executor.stop_profile()
1974+
if isinstance(self.model_executor, GPUExecutorAsync):
1975+
self.model_executor.stop_profile()
1976+
else:
1977+
self.model_executor._run_workers("stop_profile")
19711978

19721979
def is_tracing_enabled(self) -> bool:
19731980
return self.tracer is not None

0 commit comments

Comments
 (0)