From 1f80a0f6c0675c6dbefc3765fd5e73d0c6009009 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 11:16:27 +0200 Subject: [PATCH 1/2] Remove checks for `None` for fields which should never be `None` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4a503665503..3385591bfac 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4176,25 +4176,22 @@ def __post_init__(self): self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) - if self.cache_config is not None: - self.cache_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: + if self.lora_config is not None: self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_lora_support() - if self.prompt_adapter_config: + if self.prompt_adapter_config is not None: self.prompt_adapter_config.verify_with_model_config( self.model_config) - if self.quant_config is None and \ - self.model_config is not None and self.load_config is not None: + if self.quant_config is None and self.model_config is not None: self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) from vllm.platforms import current_platform - if self.scheduler_config is not None and \ - self.model_config is not None and \ + if self.model_config is not None and \ self.scheduler_config.chunked_prefill_enabled and \ self.model_config.dtype == torch.float32 and \ current_platform.get_device_capability() == (7, 5): @@ -4203,8 +4200,6 @@ def __post_init__(self): "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") - if self.compilation_config is None: - self.compilation_config = CompilationConfig() if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ @@ -4224,11 +4219,8 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() - if self.parallel_config is not None and \ - self.parallel_config.tensor_parallel_size > 1 and \ + if self.parallel_config.tensor_parallel_size > 1 and \ self.parallel_config.pipeline_parallel_size > 1 and \ - self.compilation_config is not None and \ - self.compilation_config.pass_config is not None and \ self.compilation_config.pass_config.enable_sequence_parallelism: logger.warning_once( "Sequence parallelism is not supported with pipeline " @@ -4238,8 +4230,7 @@ def __post_init__(self): self._set_cudagraph_sizes() - if self.cache_config is not None and \ - self.cache_config.cpu_offload_gb > 0 and \ + if self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION \ and not envs.VLLM_USE_V1: logger.warning( @@ -4262,7 +4253,7 @@ def __post_init__(self): "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True - if self.model_config and self.model_config.use_mla and \ + if self.model_config is not None and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " @@ -4273,16 +4264,16 @@ def __post_init__(self): self.scheduler_config.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False + self.cache_config.enable_prefix_caching = False - if (self.kv_events_config + if (self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events and not self.cache_config.enable_prefix_caching): logger.warning( "KV cache events are on, but prefix caching is not enabled." "Use --enable-prefix-caching to enable.") - if (self.kv_events_config and self.kv_events_config.publisher != "null" + if (self.kv_events_config is not None + and self.kv_events_config.publisher != "null" and not self.kv_events_config.enable_kv_cache_events): logger.warning("KV cache events are disabled," "but the scheduler is configured to publish them." From a8794daed88b70893d922339c82525b845a28611 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 15:20:30 +0200 Subject: [PATCH 2/2] Fix tests Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/engine/arg_utils.py | 25 +++++++++++++------------ vllm/entrypoints/llm.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 388e04323ac..00d4ac7a068 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,7 +8,7 @@ import sys import threading import warnings -from dataclasses import MISSING, dataclass, fields, is_dataclass +from dataclasses import MISSING, dataclass, field, fields, is_dataclass from itertools import permutations from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) @@ -152,29 +152,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool: def get_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} - for field in fields(cls): + for fld in fields(cls): # Get the set of possible types for the field type_hints: set[TypeHint] = set() - if get_origin(field.type) in {Union, Annotated}: - type_hints.update(get_args(field.type)) + if get_origin(fld.type) in {Union, Annotated}: + type_hints.update(get_args(fld.type)) else: - type_hints.add(field.type) + type_hints.add(fld.type) # If the field is a dataclass, we can use the model_validate_json generator = (th for th in type_hints if is_dataclass(th)) dataclass_cls = next(generator, None) # Get the default value of the field - if field.default is not MISSING: - default = field.default - elif field.default_factory is not MISSING: - if is_dataclass(field.default_factory) and is_in_doc_build(): + if fld.default is not MISSING: + default = fld.default + elif fld.default_factory is not MISSING: + if is_dataclass(fld.default_factory) and is_in_doc_build(): default = {} else: - default = field.default_factory() + default = fld.default_factory() # Get the help text for the field - name = field.name + name = fld.name help = cls_docs[name].strip() # Escape % for argparse help = help.replace("%", "%%") @@ -391,7 +391,8 @@ class EngineArgs: get_field(ModelConfig, "override_neuron_config") override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config - compilation_config: Optional[CompilationConfig] = None + compilation_config: CompilationConfig = \ + field(default_factory=CompilationConfig) worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cebddcc8e6a..8aa5da58a0f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -215,7 +215,7 @@ def __init__( else: compilation_config_instance = compilation_config else: - compilation_config_instance = None + compilation_config_instance = CompilationConfig() engine_args = EngineArgs( model=model,