Skip to content

Commit 91da711

Browse files
committed
improve
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent 68565b6 commit 91da711

File tree

3 files changed

+60
-26
lines changed

3 files changed

+60
-26
lines changed

vllm/v1/core/kv_cache_manager.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,7 @@ def __init__(
7373
List[KVCacheBlock]] = defaultdict(list)
7474

7575
# Prefix cache metrics.
76-
self.prefix_caching_metrics: PrefixCachingMetrics = {
77-
"query_total": 0,
78-
"query_hit": 0,
79-
}
76+
self.prefix_caching_metrics = PrefixCachingMetrics()
8077

8178
@property
8279
def usage(self) -> float:
@@ -88,21 +85,14 @@ def usage(self) -> float:
8885
return 1.0 - (self.free_block_queue.num_free_blocks /
8986
self.num_gpu_blocks)
9087

91-
def get_and_reset_prefix_cache_hit_rate(self) -> float:
92-
"""Get the overall hit rate of prefix caching and reset
93-
the metrics.
88+
@property
89+
def prefix_cache_hit_rate(self) -> float:
90+
"""Get the prefix caching hit rate.
9491
9592
Returns:
96-
The hit rate of prefix caching (between 0.0 and 1.0).
93+
The prefix caching hit rate.
9794
"""
98-
hit_rate = 0.0
99-
if self.prefix_caching_metrics["query_total"] > 0:
100-
hit_rate = self.prefix_caching_metrics[
101-
"query_hit"] / self.prefix_caching_metrics["query_total"]
102-
103-
self.prefix_caching_metrics["query_hit"] = 0
104-
self.prefix_caching_metrics["query_total"] = 0
105-
return hit_rate
95+
return self.prefix_caching_metrics.hit_rate
10696

10797
def get_computed_blocks(
10898
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
@@ -139,8 +129,10 @@ def get_computed_blocks(
139129
else:
140130
break
141131

142-
self.prefix_caching_metrics["query_total"] += len(block_hashes)
143-
self.prefix_caching_metrics["query_hit"] += len(computed_blocks)
132+
self.prefix_caching_metrics.add_request_query(
133+
num_queries=len(block_hashes),
134+
num_hits=len(computed_blocks),
135+
)
144136

145137
# NOTE(woosuk): Since incomplete blocks are not eligible for
146138
# sharing, `num_computed_tokens` is always a multiple of
@@ -306,6 +298,9 @@ def reset_prefix_cache(self) -> bool:
306298
for block in self.block_pool:
307299
block.reset_hash()
308300

301+
# Reset the prefix caching metrics.
302+
self.prefix_caching_metrics.reset()
303+
309304
logger.info("Successfully reset prefix cache")
310305
return True
311306

vllm/v1/core/kv_cache_utils.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""KV-Cache Utilities."""
3+
from collections import deque
34
from collections.abc import Sequence
45
from dataclasses import dataclass
5-
from typing import Any, List, NamedTuple, Optional, Tuple, TypedDict
6+
from typing import Any, List, NamedTuple, Optional, Tuple
67

78
from vllm.config import VllmConfig
89
from vllm.logger import init_logger
@@ -28,14 +29,52 @@ class BlockHashType(NamedTuple):
2829
extra_keys: Optional[Any] = None
2930

3031

31-
class PrefixCachingMetrics(TypedDict):
32-
"""Metrics for prefix caching."""
32+
class PrefixCachingMetrics:
33+
"""Metrics for prefix caching with a hit rate of the most recent N requests.
3334
34-
query_total: int
35-
"""The total number of queries."""
35+
Args:
36+
interval: The number of the most recent requests to aggregate.
37+
Defaults to 1000.
38+
"""
39+
40+
def __init__(self, interval: int = 1000):
41+
self.interval = interval
42+
self.aggregated_query_total = 0
43+
self.aggregated_query_hit = 0
44+
self.request_queries: deque[Tuple[int, int]] = deque()
3645

37-
query_hit: int
38-
"""The number of queries that hit the prefix cache."""
46+
def add_request_query(self, num_queries: int, num_hits: int):
47+
"""Add a request to the metrics. This function is called when
48+
a new request is being scheduled and is looking for computed blocks.
49+
When there are more than `interval` requests, the oldest request
50+
is removed from the metrics.
51+
52+
Args:
53+
num_queries: The number of queries in the request.
54+
num_hits: The number of hits in the request.
55+
"""
56+
57+
self.request_queries.append((num_queries, num_hits))
58+
if len(self.request_queries) > self.interval:
59+
old_num_queries, old_num_hits = self.request_queries.popleft()
60+
self.aggregated_query_total -= old_num_queries
61+
self.aggregated_query_hit -= old_num_hits
62+
63+
self.aggregated_query_total += num_queries
64+
self.aggregated_query_hit += num_hits
65+
66+
def reset(self):
67+
"""Reset the metrics."""
68+
self.aggregated_query_total = 0
69+
self.aggregated_query_hit = 0
70+
self.request_queries.clear()
71+
72+
@property
73+
def hit_rate(self) -> float:
74+
"""Calculate the hit rate for the past N requests."""
75+
if self.aggregated_query_total == 0:
76+
return 0.0
77+
return self.aggregated_query_hit / self.aggregated_query_total
3978

4079

4180
@dataclass

vllm/v1/core/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def make_stats(self) -> SchedulerStats:
554554
num_waiting_reqs=len(self.waiting),
555555
gpu_cache_usage=self.kv_cache_manager.usage,
556556
gpu_prefix_cache_hit_rate=self.kv_cache_manager.
557-
get_and_reset_prefix_cache_hit_rate(),
557+
prefix_cache_hit_rate,
558558
)
559559

560560

0 commit comments

Comments
 (0)