Skip to content

Commit 4b07d36

Browse files
authored
Improve configs - CacheConfig (#16835)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 87aaade commit 4b07d36

File tree

3 files changed

+123
-157
lines changed

3 files changed

+123
-157
lines changed

vllm/config.py

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,22 +1245,70 @@ def is_matryoshka(self) -> bool:
12451245
or getattr(self.hf_config, "is_matryoshka", False))
12461246

12471247

1248+
BlockSize = Literal[8, 16, 32, 64, 128]
1249+
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
1250+
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
1251+
1252+
1253+
@config
1254+
@dataclass
12481255
class CacheConfig:
1249-
"""Configuration for the KV cache.
1256+
"""Configuration for the KV cache."""
12501257

1251-
Args:
1252-
block_size: Size of a cache block in number of tokens.
1253-
gpu_memory_utilization: Fraction of GPU memory to use for the
1254-
vLLM execution.
1255-
swap_space: Size of the CPU swap space per GPU (in GiB).
1256-
cache_dtype: Data type for kv cache storage.
1257-
is_attention_free: Whether the model is attention-free.
1258-
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
1259-
profiled num_gpu_blocks if specified. Does nothing if None.
1260-
sliding_window: Sliding window size for the KV cache.
1261-
enable_prefix_caching: Whether to enable prefix caching.
1262-
cpu_offload_gb: Size of the CPU offload buffer in GiB.
1258+
block_size: Optional[BlockSize] = None
1259+
"""Size of a contiguous cache block in number of tokens. This is ignored on
1260+
neuron devices and set to `--max-model-len`. On CUDA devices, only block
1261+
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
1262+
"""
1263+
gpu_memory_utilization: float = 0.9
1264+
"""The fraction of GPU memory to be used for the model executor, which can
1265+
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
1266+
utilization. If unspecified, will use the default value of 0.9. This is a
1267+
per-instance limit, and only applies to the current vLLM instance. It does
1268+
not matter if you have another vLLM instance running on the same GPU. For
1269+
example, if you have two vLLM instances running on the same GPU, you can
1270+
set the GPU memory utilization to 0.5 for each instance."""
1271+
swap_space: float = 4
1272+
"""Size of the CPU swap space per GPU (in GiB)."""
1273+
cache_dtype: CacheDType = "auto"
1274+
"""Data type for kv cache storage. If "auto", will use model data type.
1275+
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
1276+
fp8 (=fp8_e4m3)."""
1277+
is_attention_free: bool = False
1278+
"""Whether the model is attention-free. This is primarily set in
1279+
`ModelConfig` and that value should be manually duplicated here."""
1280+
num_gpu_blocks_override: Optional[int] = None
1281+
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
1282+
if specified. Does nothing if `None`. Used for testing preemption."""
1283+
sliding_window: Optional[int] = None
1284+
"""Sliding window size for the KV cache. This is primarily set in
1285+
`ModelConfig` and that value should be manually duplicated here."""
1286+
enable_prefix_caching: Optional[bool] = None
1287+
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
1288+
default for V1."""
1289+
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
1290+
"""Set the hash algorithm for prefix caching:\n
1291+
- "builtin" is Python's built-in hash.\n
1292+
- "sha256" is collision resistant but with certain overheads."""
1293+
cpu_offload_gb: float = 0
1294+
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
1295+
no offloading. Intuitively, this argument can be seen as a virtual way to
1296+
increase the GPU memory size. For example, if you have one 24 GB GPU and
1297+
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
1298+
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
1299+
Note that this requires fast CPU-GPU interconnect, as part of the model is
1300+
loaded from CPU memory to GPU memory on the fly in each model forward pass.
12631301
"""
1302+
calculate_kv_scales: bool = False
1303+
"""This enables dynamic calculation of `k_scale` and `v_scale` when
1304+
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
1305+
checkpoint if available. Otherwise, the scales will default to 1.0."""
1306+
1307+
# Will be set after profiling.
1308+
num_gpu_blocks: Optional[int] = field(default=None, init=False)
1309+
"""The number of blocks to allocate for GPU memory."""
1310+
num_cpu_blocks: Optional[int] = field(default=None, init=False)
1311+
"""The number of blocks to allocate for CPU memory."""
12641312

12651313
def compute_hash(self) -> str:
12661314
"""
@@ -1281,43 +1329,13 @@ def compute_hash(self) -> str:
12811329
usedforsecurity=False).hexdigest()
12821330
return hash_str
12831331

1284-
def __init__(
1285-
self,
1286-
block_size: int,
1287-
gpu_memory_utilization: float,
1288-
swap_space: float,
1289-
cache_dtype: str,
1290-
is_attention_free: bool = False,
1291-
num_gpu_blocks_override: Optional[int] = None,
1292-
sliding_window: Optional[int] = None,
1293-
enable_prefix_caching: bool = False,
1294-
prefix_caching_hash_algo: str = "builtin",
1295-
cpu_offload_gb: float = 0,
1296-
calculate_kv_scales: Optional[bool] = None,
1297-
) -> None:
1298-
self.block_size = block_size
1299-
self.gpu_memory_utilization = gpu_memory_utilization
1300-
self.swap_space_bytes = swap_space * GiB_bytes
1301-
self.num_gpu_blocks_override = num_gpu_blocks_override
1302-
self.cache_dtype = cache_dtype
1303-
self.is_attention_free = is_attention_free
1304-
self.sliding_window = sliding_window
1305-
self.enable_prefix_caching = enable_prefix_caching
1306-
self.prefix_caching_hash_algo = prefix_caching_hash_algo
1307-
self.cpu_offload_gb = cpu_offload_gb
1308-
self.calculate_kv_scales = calculate_kv_scales
1332+
def __post_init__(self) -> None:
1333+
self.swap_space_bytes = self.swap_space * GiB_bytes
1334+
13091335
self._verify_args()
13101336
self._verify_cache_dtype()
13111337
self._verify_prefix_caching()
13121338

1313-
# Will be set after profiling.
1314-
self.num_gpu_blocks: Optional[int] = None
1315-
self.num_cpu_blocks: Optional[int] = None
1316-
1317-
# Set calculate_kv_scales to False if the value is unset.
1318-
if self.calculate_kv_scales is None:
1319-
self.calculate_kv_scales = False
1320-
13211339
def metrics_info(self):
13221340
# convert cache_config to dict(key: str, value: str) for prometheus
13231341
# metrics info
@@ -1336,7 +1354,7 @@ def _verify_args(self) -> None:
13361354
def _verify_cache_dtype(self) -> None:
13371355
if self.cache_dtype == "auto":
13381356
pass
1339-
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
1357+
elif self.cache_dtype in get_args(CacheDType):
13401358
logger.info(
13411359
"Using fp8 data type to store kv cache. It reduces the GPU "
13421360
"memory footprint and boosts the performance. "
@@ -1354,12 +1372,12 @@ def _verify_prefix_caching(self) -> None:
13541372
"Prefix caching is not supported with sliding window. "
13551373
"Run with --disable-sliding-window to use prefix caching.")
13561374

1357-
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
1358-
"builtin", "sha256"):
1375+
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
1376+
not in get_args(PrefixCachingHashAlgo)):
13591377
raise ValueError(
13601378
"Unknown prefix caching hash algorithm: "
1361-
f"{self.prefix_caching_hash_algo}. Must be either "
1362-
"'builtin' or 'sha256'.")
1379+
f"{self.prefix_caching_hash_algo}. Must be one of "
1380+
f"{get_args(PrefixCachingHashAlgo)}.")
13631381

13641382
def verify_with_parallel_config(
13651383
self,

0 commit comments

Comments
 (0)