Skip to content

Commit 51d6351

Browse files
SolitaryThinkerAlvant
authored andcommitted
[bugfix] torch profiler bug for single gpu with GPUExecutor (vllm-project#8354)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 6b30dd5 commit 51d6351

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

examples/offline_inference_with_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
1717

1818
# Create an LLM.
19-
llm = LLM(model="facebook/opt-125m")
19+
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
2020

2121
llm.start_profile()
2222

vllm/engine/async_llm_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1414
from vllm.engine.metrics_types import StatLoggerBase
1515
from vllm.executor.executor_base import ExecutorAsyncBase
16+
from vllm.executor.gpu_executor import GPUExecutorAsync
1617
from vllm.executor.ray_utils import initialize_ray_cluster
1718
from vllm.inputs import PromptInputs
1819
from vllm.logger import init_logger
@@ -1061,7 +1062,17 @@ def remove_logger(self, logger_name: str) -> None:
10611062
self.engine.remove_logger(logger_name=logger_name)
10621063

10631064
async def start_profile(self) -> None:
1064-
self.engine.model_executor._run_workers("start_profile")
1065+
# using type instead of isinstance to check to avoid capturing
1066+
# inherited classes
1067+
if type(self.engine.model_executor) == GPUExecutorAsync:
1068+
self.engine.model_executor.start_profile()
1069+
else:
1070+
self.engine.model_executor._run_workers("start_profile")
10651071

10661072
async def stop_profile(self) -> None:
1067-
self.engine.model_executor._run_workers("stop_profile")
1073+
# using type instead of isinstance to check to avoid capturing
1074+
# inherited classes
1075+
if type(self.engine.model_executor) == GPUExecutorAsync:
1076+
self.engine.model_executor.stop_profile()
1077+
else:
1078+
self.engine.model_executor._run_workers("stop_profile")

vllm/engine/llm_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from vllm.engine.output_processor.stop_checker import StopChecker
2727
from vllm.engine.output_processor.util import create_output_by_sequence_group
2828
from vllm.executor.executor_base import ExecutorBase
29+
from vllm.executor.gpu_executor import GPUExecutor
2930
from vllm.executor.ray_utils import initialize_ray_cluster
3031
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
3132
InputRegistry, LLMInputs, PromptInputs)
@@ -1644,10 +1645,20 @@ def check_health(self) -> None:
16441645
self.model_executor.check_health()
16451646

16461647
def start_profile(self) -> None:
1647-
self.model_executor.start_profile()
1648+
# using type instead of isinstance to check to avoid capturing
1649+
# inherited classes (MultiprocessingGPUExecutor)
1650+
if type(self.model_executor) == GPUExecutor:
1651+
self.model_executor.start_profile()
1652+
else:
1653+
self.model_executor._run_workers("start_profile")
16481654

16491655
def stop_profile(self) -> None:
1650-
self.model_executor.stop_profile()
1656+
# using type instead of isinstance to check to avoid capturing
1657+
# inherited classes (MultiprocessingGPUExecutor)
1658+
if type(self.model_executor) == GPUExecutor:
1659+
self.model_executor.stop_profile()
1660+
else:
1661+
self.model_executor._run_workers("stop_profile")
16511662

16521663
def is_tracing_enabled(self) -> bool:
16531664
return self.tracer is not None

0 commit comments

Comments
 (0)