Skip to content

Commit 80f5444

Browse files
committed
Fix Ray Metrics
1 parent a4feba9 commit 80f5444

File tree

2 files changed

+164
-28
lines changed

2 files changed

+164
-28
lines changed

tests/metrics/test_metrics.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import List
22

33
import pytest
4+
import ray
45
from prometheus_client import REGISTRY
56

67
from vllm import EngineArgs, LLMEngine
78
from vllm.engine.arg_utils import AsyncEngineArgs
89
from vllm.engine.async_llm_engine import AsyncLLMEngine
910
from vllm.sampling_params import SamplingParams
11+
from vllm.engine.metrics import RayPrometheusStatLogger
1012

1113
MODELS = [
1214
"facebook/opt-125m",
@@ -192,3 +194,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
192194
labels)
193195
assert (
194196
metric_value == num_requests), "Metrics should be collected"
197+
198+
199+
@pytest.mark.parametrize("model", MODELS)
200+
@pytest.mark.parametrize("dtype", ["half"])
201+
@pytest.mark.parametrize("max_tokens", [16])
202+
def test_engine_log_metrics_ray(
203+
example_prompts,
204+
model: str,
205+
dtype: str,
206+
max_tokens: int,
207+
) -> None:
208+
# This test is quite weak - it only checks that we can use
209+
# RayPrometheusStatLogger without exceptions.
210+
# Checking whether the metrics are actually emitted is unfortunately
211+
# non-trivial.
212+
213+
# We have to run in a Ray task for Ray metrics to be emitted correctly
214+
@ray.remote(num_gpus=1)
215+
def _inner():
216+
217+
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
218+
219+
def __init__(self, *args, **kwargs):
220+
self._i = 0
221+
super().__init__(*args, **kwargs)
222+
223+
def log(self, *args, **kwargs):
224+
self._i += 1
225+
return super().log(*args, **kwargs)
226+
227+
engine_args = EngineArgs(
228+
model=model,
229+
dtype=dtype,
230+
disable_log_stats=False,
231+
)
232+
engine = LLMEngine.from_engine_args(engine_args)
233+
logger = _RayPrometheusStatLogger(
234+
local_interval=0.5,
235+
labels=dict(model_name=engine.model_config.served_model_name),
236+
max_model_len=engine.model_config.max_model_len)
237+
engine.add_logger("ray", logger)
238+
for i, prompt in enumerate(example_prompts):
239+
engine.add_request(
240+
f"request-id-{i}",
241+
prompt,
242+
SamplingParams(max_tokens=max_tokens),
243+
)
244+
while engine.has_unfinished_requests():
245+
engine.step()
246+
assert logger._i > 0, ".log must be called at least once"
247+
248+
ray.get(_inner.remote())

vllm/engine/metrics.py

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,63 +30,63 @@
3030
# begin-metrics-definitions
3131
class Metrics:
3232
labelname_finish_reason = "finished_reason"
33-
_base_library = prometheus_client
33+
_gauge_cls = prometheus_client.Gauge
34+
_counter_cls = prometheus_client.Counter
35+
_histogram_cls = prometheus_client.Histogram
3436

3537
def __init__(self, labelnames: List[str], max_model_len: int):
3638
# Unregister any existing vLLM collectors
3739
self._unregister_vllm_metrics()
3840

3941
# Config Information
40-
self.info_cache_config = prometheus_client.Info(
41-
name='vllm:cache_config',
42-
documentation='information of cache_config')
42+
self._create_info_cache_config()
4343

4444
# System stats
4545
# Scheduler State
46-
self.gauge_scheduler_running = self._base_library.Gauge(
46+
self.gauge_scheduler_running = self._gauge_cls(
4747
name="vllm:num_requests_running",
4848
documentation="Number of requests currently running on GPU.",
4949
labelnames=labelnames)
50-
self.gauge_scheduler_waiting = self._base_library.Gauge(
50+
self.gauge_scheduler_waiting = self._gauge_cls(
5151
name="vllm:num_requests_waiting",
5252
documentation="Number of requests waiting to be processed.",
5353
labelnames=labelnames)
54-
self.gauge_scheduler_swapped = self._base_library.Gauge(
54+
self.gauge_scheduler_swapped = self._gauge_cls(
5555
name="vllm:num_requests_swapped",
5656
documentation="Number of requests swapped to CPU.",
5757
labelnames=labelnames)
5858
# KV Cache Usage in %
59-
self.gauge_gpu_cache_usage = self._base_library.Gauge(
59+
self.gauge_gpu_cache_usage = self._gauge_cls(
6060
name="vllm:gpu_cache_usage_perc",
6161
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
6262
labelnames=labelnames)
63-
self.gauge_cpu_cache_usage = self._base_library.Gauge(
63+
self.gauge_cpu_cache_usage = self._gauge_cls(
6464
name="vllm:cpu_cache_usage_perc",
6565
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
6666
labelnames=labelnames)
6767

6868
# Iteration stats
69-
self.counter_num_preemption = self._base_library.Counter(
69+
self.counter_num_preemption = self._counter_cls(
7070
name="vllm:num_preemptions_total",
7171
documentation="Cumulative number of preemption from the engine.",
7272
labelnames=labelnames)
73-
self.counter_prompt_tokens = self._base_library.Counter(
73+
self.counter_prompt_tokens = self._counter_cls(
7474
name="vllm:prompt_tokens_total",
7575
documentation="Number of prefill tokens processed.",
7676
labelnames=labelnames)
77-
self.counter_generation_tokens = self._base_library.Counter(
77+
self.counter_generation_tokens = self._counter_cls(
7878
name="vllm:generation_tokens_total",
7979
documentation="Number of generation tokens processed.",
8080
labelnames=labelnames)
81-
self.histogram_time_to_first_token = self._base_library.Histogram(
81+
self.histogram_time_to_first_token = self._histogram_cls(
8282
name="vllm:time_to_first_token_seconds",
8383
documentation="Histogram of time to first token in seconds.",
8484
labelnames=labelnames,
8585
buckets=[
8686
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
8787
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
8888
])
89-
self.histogram_time_per_output_token = self._base_library.Histogram(
89+
self.histogram_time_per_output_token = self._histogram_cls(
9090
name="vllm:time_per_output_token_seconds",
9191
documentation="Histogram of time per output token in seconds.",
9292
labelnames=labelnames,
@@ -97,67 +97,145 @@ def __init__(self, labelnames: List[str], max_model_len: int):
9797

9898
# Request stats
9999
# Latency
100-
self.histogram_e2e_time_request = self._base_library.Histogram(
100+
self.histogram_e2e_time_request = self._histogram_cls(
101101
name="vllm:e2e_request_latency_seconds",
102102
documentation="Histogram of end to end request latency in seconds.",
103103
labelnames=labelnames,
104104
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
105105
# Metadata
106-
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
106+
self.histogram_num_prompt_tokens_request = self._histogram_cls(
107107
name="vllm:request_prompt_tokens",
108108
documentation="Number of prefill tokens processed.",
109109
labelnames=labelnames,
110110
buckets=build_1_2_5_buckets(max_model_len),
111111
)
112112
self.histogram_num_generation_tokens_request = \
113-
self._base_library.Histogram(
113+
self._histogram_cls(
114114
name="vllm:request_generation_tokens",
115115
documentation="Number of generation tokens processed.",
116116
labelnames=labelnames,
117117
buckets=build_1_2_5_buckets(max_model_len),
118118
)
119-
self.histogram_best_of_request = self._base_library.Histogram(
119+
self.histogram_best_of_request = self._histogram_cls(
120120
name="vllm:request_params_best_of",
121121
documentation="Histogram of the best_of request parameter.",
122122
labelnames=labelnames,
123123
buckets=[1, 2, 5, 10, 20],
124124
)
125-
self.histogram_n_request = self._base_library.Histogram(
125+
self.histogram_n_request = self._histogram_cls(
126126
name="vllm:request_params_n",
127127
documentation="Histogram of the n request parameter.",
128128
labelnames=labelnames,
129129
buckets=[1, 2, 5, 10, 20],
130130
)
131-
self.counter_request_success = self._base_library.Counter(
131+
self.counter_request_success = self._counter_cls(
132132
name="vllm:request_success_total",
133133
documentation="Count of successfully processed requests.",
134134
labelnames=labelnames + [Metrics.labelname_finish_reason])
135135

136136
# Deprecated in favor of vllm:prompt_tokens_total
137-
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
137+
self.gauge_avg_prompt_throughput = self._gauge_cls(
138138
name="vllm:avg_prompt_throughput_toks_per_s",
139139
documentation="Average prefill throughput in tokens/s.",
140140
labelnames=labelnames,
141141
)
142142
# Deprecated in favor of vllm:generation_tokens_total
143-
self.gauge_avg_generation_throughput = self._base_library.Gauge(
143+
self.gauge_avg_generation_throughput = self._gauge_cls(
144144
name="vllm:avg_generation_throughput_toks_per_s",
145145
documentation="Average generation throughput in tokens/s.",
146146
labelnames=labelnames,
147147
)
148148

149+
def _create_info_cache_config(self) -> None:
150+
# Config Information
151+
self.info_cache_config = prometheus_client.Info(
152+
name='vllm:cache_config',
153+
documentation='information of cache_config')
154+
149155
def _unregister_vllm_metrics(self) -> None:
150-
for collector in list(self._base_library.REGISTRY._collector_to_names):
156+
for collector in list(prometheus_client.REGISTRY._collector_to_names):
151157
if hasattr(collector, "_name") and "vllm" in collector._name:
152-
self._base_library.REGISTRY.unregister(collector)
158+
prometheus_client.REGISTRY.unregister(collector)
159+
160+
161+
# end-metrics-definitions
162+
163+
164+
class _RayGaugeWrapper:
165+
"""Wraps around ray.util.metrics.Gauge to provide same API as
166+
prometheus_client.Gauge"""
167+
168+
def __init__(self,
169+
name: str,
170+
documentation: str = "",
171+
labelnames: Optional[List[str]] = None):
172+
labelnames_tuple = tuple(labelnames) if labelnames else None
173+
self._gauge = ray_metrics.Gauge(name=name,
174+
description=documentation,
175+
tag_keys=labelnames_tuple)
176+
177+
def labels(self, **labels):
178+
self._gauge.set_default_tags(labels)
179+
return self
180+
181+
def set(self, value: Union[int, float]):
182+
return self._gauge.set(value)
183+
184+
185+
class _RayCounterWrapper:
186+
"""Wraps around ray.util.metrics.Counter to provide same API as
187+
prometheus_client.Counter"""
188+
189+
def __init__(self,
190+
name: str,
191+
documentation: str = "",
192+
labelnames: Optional[List[str]] = None):
193+
labelnames_tuple = tuple(labelnames) if labelnames else None
194+
self._counter = ray_metrics.Counter(name=name,
195+
description=documentation,
196+
tag_keys=labelnames_tuple)
197+
198+
def labels(self, **labels):
199+
self._counter.set_default_tags(labels)
200+
return self
201+
202+
def inc(self, value: Union[int, float] = 1.0):
203+
if value == 0:
204+
return
205+
return self._counter.inc(value)
206+
207+
208+
class _RayHistogramWrapper:
209+
"""Wraps around ray.util.metrics.Histogram to provide same API as
210+
prometheus_client.Histogram"""
211+
212+
def __init__(self,
213+
name: str,
214+
documentation: str = "",
215+
labelnames: Optional[List[str]] = None,
216+
buckets: Optional[List[float]] = None):
217+
labelnames_tuple = tuple(labelnames) if labelnames else None
218+
self._histogram = ray_metrics.Histogram(name=name,
219+
description=documentation,
220+
tag_keys=labelnames_tuple,
221+
boundaries=buckets)
222+
223+
def labels(self, **labels):
224+
self._histogram.set_default_tags(labels)
225+
return self
226+
227+
def observe(self, value: Union[int, float]):
228+
return self._histogram.observe(value)
153229

154230

155231
class RayMetrics(Metrics):
156232
"""
157233
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
158234
Provides the same metrics as Metrics but uses Ray's util.metrics library.
159235
"""
160-
_base_library = ray_metrics
236+
_gauge_cls = _RayGaugeWrapper
237+
_counter_cls = _RayCounterWrapper
238+
_histogram_cls = _RayHistogramWrapper
161239

162240
def __init__(self, labelnames: List[str], max_model_len: int):
163241
if ray_metrics is None:
@@ -168,8 +246,9 @@ def _unregister_vllm_metrics(self) -> None:
168246
# No-op on purpose
169247
pass
170248

171-
172-
# end-metrics-definitions
249+
def _create_info_cache_config(self) -> None:
250+
# No-op on purpose
251+
pass
173252

174253

175254
def build_1_2_5_buckets(max_value: int) -> List[int]:
@@ -457,4 +536,7 @@ def log(self, stats: Stats):
457536

458537
class RayPrometheusStatLogger(PrometheusStatLogger):
459538
"""RayPrometheusStatLogger uses Ray metrics instead."""
460-
_metrics_cls = RayMetrics
539+
_metrics_cls = RayMetrics
540+
541+
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
542+
return None

0 commit comments

Comments
 (0)