Skip to content

[v1][KVCacheManager] Change prefix caching metric from counting blocks to counting tokens #18003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,22 @@ def get_computed_blocks(self,

computed_blocks = (
self.single_type_manager.find_longest_cache_hit(block_hashes))
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size

if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_computed_tokens

if last_block_hash is not None:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)

# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return KVCacheBlocks(computed_blocks), num_computed_tokens

def allocate_slots(
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,13 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.counter_gpu_prefix_cache_queries = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_queries",
documentation=
"GPU prefix cache queries, in terms of number of queried blocks.",
"GPU prefix cache queries, in terms of number of queried tokens.",
labelnames=labelnames).labels(*labelvalues)

self.counter_gpu_prefix_cache_hits = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_hits",
documentation=
"GPU prefix cache hits, in terms of number of cached blocks.",
"GPU prefix cache hits, in terms of number of cached tokens.",
labelnames=labelnames).labels(*labelvalues)

#
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PrefixCacheStats:
# The number of requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of blocks that were queried from the cache.
# means the number of tokens that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0
Expand Down