From eaf0c011a7627ca3276a89ef5a446e0e024f1081 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Tue, 10 Sep 2024 21:57:23 -0700 Subject: [PATCH 1/4] [bugfix] torch profiler bug for single gpu with GPUExecutor --- examples/offline_inference_with_profiler.py | 2 +- vllm/engine/async_llm_engine.py | 11 +++++++++-- vllm/engine/llm_engine.py | 11 +++++++++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py index 906c9502800..1f00d268087 100644 --- a/examples/offline_inference_with_profiler.py +++ b/examples/offline_inference_with_profiler.py @@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) llm.start_profile() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 362b0f3a44b..f355361f01a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,6 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt @@ -1156,7 +1157,13 @@ def remove_logger(self, logger_name: str) -> None: self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: - self.engine.model_executor._run_workers("start_profile") + if isinstance(self.engine.model_executor, GPUExecutorAsync): + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") + if isinstance(self.engine.model_executor, GPUExecutorAsync): + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92e46c7af51..0425aa3e2f9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -26,6 +26,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase +from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, InputRegistry, LLMInputs, PromptInputs, @@ -1964,10 +1965,16 @@ def check_health(self) -> None: self.model_executor.check_health() def start_profile(self) -> None: - self.model_executor.start_profile() + if isinstance(self.model_executor, GPUExecutorAsync): + self.model_executor.start_profile() + else: + self.model_executor._run_workers("start_profile") def stop_profile(self) -> None: - self.model_executor.stop_profile() + if isinstance(self.model_executor, GPUExecutorAsync): + self.model_executor.stop_profile() + else: + self.model_executor._run_workers("stop_profile") def is_tracing_enabled(self) -> bool: return self.tracer is not None From e132dde0384210ad4242f3335acd4936eb27e4d2 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Thu, 12 Sep 2024 01:15:20 -0700 Subject: [PATCH 2/4] change to type --- vllm/engine/llm_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0425aa3e2f9..518137a031d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -26,7 +26,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase -from vllm.executor.gpu_executor import GPUExecutorAsync +from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, InputRegistry, LLMInputs, PromptInputs, @@ -1965,13 +1965,13 @@ def check_health(self) -> None: self.model_executor.check_health() def start_profile(self) -> None: - if isinstance(self.model_executor, GPUExecutorAsync): + if type(self.model_executor) == GPUExecutor: self.model_executor.start_profile() else: self.model_executor._run_workers("start_profile") def stop_profile(self) -> None: - if isinstance(self.model_executor, GPUExecutorAsync): + if type(self.model_executor) == GPUExecutor: self.model_executor.stop_profile() else: self.model_executor._run_workers("stop_profile") From b6c9c0f8f67bbe360c1dc66856aff6404d060c38 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Thu, 12 Sep 2024 01:43:34 -0700 Subject: [PATCH 3/4] format --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f355361f01a..4f28b412ec9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,8 +16,8 @@ PromptComponents, SchedulerOutputState) from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.gpu_executor import GPUExecutorAsync +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt From 6801e4448d8a3dcc975d175e2dfc8537546ccfee Mon Sep 17 00:00:00 2001 From: Will Lin Date: Thu, 12 Sep 2024 12:06:28 -0700 Subject: [PATCH 4/4] comments --- vllm/engine/async_llm_engine.py | 8 ++++++-- vllm/engine/llm_engine.py | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4f28b412ec9..23883c1efec 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1157,13 +1157,17 @@ def remove_logger(self, logger_name: str) -> None: self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: - if isinstance(self.engine.model_executor, GPUExecutorAsync): + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - if isinstance(self.engine.model_executor, GPUExecutorAsync): + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 518137a031d..af5ba33fe5e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1965,12 +1965,16 @@ def check_health(self) -> None: self.model_executor.check_health() def start_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor: self.model_executor.start_profile() else: self.model_executor._run_workers("start_profile") def stop_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor: self.model_executor.stop_profile() else: