Skip to content

Commit 3c4a725

Browse files
Yard1fialhocoelho
authored andcommitted
[Bugfix] Fix Ray Metrics API usage (vllm-project#6354)
1 parent b97793d commit 3c4a725

File tree

4 files changed

+195
-40
lines changed

4 files changed

+195
-40
lines changed

tests/metrics/test_metrics.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import List
22

33
import pytest
4+
import ray
45
from prometheus_client import REGISTRY
56

67
from vllm import EngineArgs, LLMEngine
78
from vllm.engine.arg_utils import AsyncEngineArgs
89
from vllm.engine.async_llm_engine import AsyncLLMEngine
10+
from vllm.engine.metrics import RayPrometheusStatLogger
911
from vllm.sampling_params import SamplingParams
1012

1113
MODELS = [
@@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
241243
labels)
242244
assert (
243245
metric_value == num_requests), "Metrics should be collected"
246+
247+
248+
@pytest.mark.parametrize("model", MODELS)
249+
@pytest.mark.parametrize("dtype", ["half"])
250+
@pytest.mark.parametrize("max_tokens", [16])
251+
def test_engine_log_metrics_ray(
252+
example_prompts,
253+
model: str,
254+
dtype: str,
255+
max_tokens: int,
256+
) -> None:
257+
# This test is quite weak - it only checks that we can use
258+
# RayPrometheusStatLogger without exceptions.
259+
# Checking whether the metrics are actually emitted is unfortunately
260+
# non-trivial.
261+
262+
# We have to run in a Ray task for Ray metrics to be emitted correctly
263+
@ray.remote(num_gpus=1)
264+
def _inner():
265+
266+
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
267+
268+
def __init__(self, *args, **kwargs):
269+
self._i = 0
270+
super().__init__(*args, **kwargs)
271+
272+
def log(self, *args, **kwargs):
273+
self._i += 1
274+
return super().log(*args, **kwargs)
275+
276+
engine_args = EngineArgs(
277+
model=model,
278+
dtype=dtype,
279+
disable_log_stats=False,
280+
)
281+
engine = LLMEngine.from_engine_args(engine_args)
282+
logger = _RayPrometheusStatLogger(
283+
local_interval=0.5,
284+
labels=dict(model_name=engine.model_config.served_model_name),
285+
max_model_len=engine.model_config.max_model_len)
286+
engine.add_logger("ray", logger)
287+
for i, prompt in enumerate(example_prompts):
288+
engine.add_request(
289+
f"request-id-{i}",
290+
prompt,
291+
SamplingParams(max_tokens=max_tokens),
292+
)
293+
while engine.has_unfinished_requests():
294+
engine.step()
295+
assert logger._i > 0, ".log must be called at least once"
296+
297+
ray.get(_inner.remote())

vllm/engine/async_llm_engine.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.engine.arg_utils import AsyncEngineArgs
1313
from vllm.engine.async_timeout import asyncio_timeout
1414
from vllm.engine.llm_engine import LLMEngine
15+
from vllm.engine.metrics import StatLoggerBase
1516
from vllm.executor.ray_utils import initialize_ray_cluster, ray
1617
from vllm.inputs import LLMInputs, PromptInputs
1718
from vllm.logger import init_logger
@@ -389,6 +390,7 @@ def from_engine_args(
389390
engine_args: AsyncEngineArgs,
390391
start_engine_loop: bool = True,
391392
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
393+
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
392394
) -> "AsyncLLMEngine":
393395
"""Creates an async LLM engine from the engine arguments."""
394396
# Create the engine configs.
@@ -451,6 +453,7 @@ def from_engine_args(
451453
max_log_len=engine_args.max_log_len,
452454
start_engine_loop=start_engine_loop,
453455
usage_context=usage_context,
456+
stat_loggers=stat_loggers,
454457
)
455458
return engine
456459

@@ -957,3 +960,19 @@ async def is_tracing_enabled(self) -> bool:
957960
)
958961
else:
959962
return self.engine.is_tracing_enabled()
963+
964+
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
965+
if self.engine_use_ray:
966+
ray.get(
967+
self.engine.add_logger.remote( # type: ignore
968+
logger_name=logger_name, logger=logger))
969+
else:
970+
self.engine.add_logger(logger_name=logger_name, logger=logger)
971+
972+
def remove_logger(self, logger_name: str) -> None:
973+
if self.engine_use_ray:
974+
ray.get(
975+
self.engine.remove_logger.remote( # type: ignore
976+
logger_name=logger_name))
977+
else:
978+
self.engine.remove_logger(logger_name=logger_name)

vllm/engine/llm_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def from_engine_args(
379379
cls,
380380
engine_args: EngineArgs,
381381
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
382+
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
382383
) -> "LLMEngine":
383384
"""Creates an LLM engine from the engine arguments."""
384385
# Create the engine configs.
@@ -423,6 +424,7 @@ def from_engine_args(
423424
executor_class=executor_class,
424425
log_stats=not engine_args.disable_log_stats,
425426
usage_context=usage_context,
427+
stat_loggers=stat_loggers,
426428
)
427429
return engine
428430

0 commit comments

Comments
 (0)