From 10d543b8ab427869e26e42502250797f7cc0c4b2 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 2 May 2025 20:46:10 +0200 Subject: [PATCH 01/23] Enable Pydantic mypy plugin Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 90ed492d992..3bccc3bb5b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,7 @@ repos: entry: tools/mypy.sh 0 "local" language: python types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests] + additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] stages: [pre-commit] # Don't run in CI - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 diff --git a/pyproject.toml b/pyproject.toml index 069e295bfb9..1e64450e0e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ ignore = [ ] [tool.mypy] +plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" From 29fc238143b4eb1132a6c2ec5d78bd61761f62ca Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 2 May 2025 20:48:17 +0200 Subject: [PATCH 02/23] Fix Pydantic errors in `protocol.py` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/entrypoints/openai/protocol.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 389557dfb7c..4c66754029f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -190,7 +190,9 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" -class LogitsProcessorConstructor(BaseModel): +# extra="forbid" is a workaround, see +# https://github.com/pydantic/pydantic/issues/3125 +class LogitsProcessorConstructor(BaseModel, extra="forbid"): qualname: str args: Optional[list[Any]] = None kwargs: Optional[dict[str, Any]] = None @@ -249,7 +251,7 @@ class ChatCompletionRequest(OpenAIBaseModel): presence_penalty: Optional[float] = 0.0 response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = Field(default_factory=lambda: []) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None @@ -273,7 +275,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = Field(default_factory=lambda: []) include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 @@ -765,7 +767,7 @@ class CompletionRequest(OpenAIBaseModel): n: int = 1 presence_penalty: Optional[float] = 0.0 seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = Field(default_factory=lambda: []) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None @@ -779,7 +781,7 @@ class CompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = Field(default_factory=lambda: []) include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 From 0ecc5d2c5cf3f62922cd3a81cdbc1d7731dc49f0 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 2 May 2025 20:50:19 +0200 Subject: [PATCH 03/23] Fix other mypy errors Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 4 ++-- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/model_executor/guided_decoding/guided_fields.py | 6 ++---- vllm/platforms/tpu.py | 2 +- vllm/transformers_utils/tokenizer_group.py | 1 + 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9738d2fd0e0..106f237461b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -203,7 +203,7 @@ def get_field(cls: ConfigType, name: str) -> Field: cls_fields = {f.name: f for f in fields(cls)} if name not in cls_fields: raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields.get(name) + named_field: Field = cls_fields[name] if (default_factory := named_field.default_factory) is not MISSING: return field(default_factory=default_factory) if (default := named_field.default) is not MISSING: @@ -230,7 +230,7 @@ class ModelConfig: task, even if the same model can be used for multiple tasks. When the model only supports one task, "auto" can be used to select it; otherwise, you must specify explicitly which task to use.""" - tokenizer: str = None # type: ignore + tokenizer: Optional[str] = None """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode = "auto" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 6123811aabe..029d7bfbaa6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -246,7 +246,7 @@ def _validate_input( if isinstance(request, ChatCompletionRequest): # TODO(#9845): remove max_tokens when field dropped from OpenAI API max_tokens = request.max_completion_tokens or request.max_tokens - else: + elif isinstance(request, CompletionRequest): max_tokens = request.max_tokens if max_tokens is None: if token_num >= self.max_model_len: diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 1593868a164..14475993871 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -3,12 +3,10 @@ from dataclasses import dataclass from typing import Dict, List, Optional, TypedDict, Union -from pydantic import BaseModel - # These classes are deprecated, see SamplingParams class LLMGuidedOptions(TypedDict, total=False): - guided_json: Union[Dict, BaseModel, str] + guided_json: Union[Dict, str] guided_regex: str guided_choice: List[str] guided_grammar: str @@ -20,7 +18,7 @@ class LLMGuidedOptions(TypedDict, total=False): @dataclass class GuidedDecodingRequest: """One of the fields will be used to retrieve the logit processor.""" - guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_json: Optional[Union[Dict, str]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[str]] = None guided_grammar: Optional[str] = None diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d5923557a21..cee82c116c5 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -107,7 +107,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.cache_config.block_size, min_page_size, ) - vllm_config.cache_config.block_size = min_page_size + vllm_config.cache_config.block_size = min_page_size # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index aff2d2eb1c3..a28efb181f8 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -107,6 +107,7 @@ async def get_lora_tokenizer_async( def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig]): + assert isinstance(model_config.tokenizer, str) return TokenizerGroup( tokenizer_id=model_config.tokenizer, enable_lora=bool(lora_config), From ad9535a65e85b12c06d619971d60fb508de9e483 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 2 May 2025 21:25:44 +0200 Subject: [PATCH 04/23] Convert dataclasses to pydantic dataclasses Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 101 ++++++++++++++++++++------------------- vllm/engine/arg_utils.py | 35 ++++++++------ 2 files changed, 72 insertions(+), 64 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 106f237461b..b36c7bd95b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,8 +12,7 @@ import warnings from collections import Counter from contextlib import contextmanager -from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, - replace) +from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from functools import cached_property from importlib.util import find_spec from pathlib import Path @@ -21,7 +20,9 @@ Protocol, TypeVar, Union, cast, get_args, get_origin) import torch -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import (BaseModel, ConfigDict, SkipValidation, TypeAdapter, + model_validator) +from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated @@ -58,7 +59,10 @@ ConfigType = type[DataclassInstance] else: - QuantizationConfig = None + PlacementGroup = Any + ExecutorBase = Any + QuantizationConfig = Any + BaseModelLoader = Any ConfigType = type logger = init_logger(__name__) @@ -217,7 +221,7 @@ def get_field(cls: ConfigType, name: str) -> Field: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class ModelConfig: """Configuration for the model.""" @@ -277,7 +281,7 @@ class ModelConfig: """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - max_model_len: int = None # type: ignore + max_model_len: SkipValidation[int] = None # type: ignore """Model context length (prompt and output). If unspecified, will be automatically derived from the model config. @@ -569,6 +573,13 @@ def __post_init__(self) -> None: self._verify_cuda_graph() self._verify_bnb_config() + @model_validator(mode="after") + def validate(self: "ModelConfig") -> "ModelConfig": + if not isinstance(self.max_model_len, int): + raise ValueError( + "max_model_len must be an integer after __post_init__.") + return self + @property def registry(self): return ModelRegistry @@ -1329,7 +1340,7 @@ def matryoshka_dimensions(self): class CacheConfig: """Configuration for the KV cache.""" - block_size: BlockSize = None # type: ignore + block_size: SkipValidation[BlockSize] = None # type: ignore """Size of a contiguous cache block in number of tokens. This is ignored on neuron devices and set to `--max-model-len`. On CUDA devices, only block sizes up to 32 are supported. On HPU devices, block size defaults to 128. @@ -2130,7 +2141,7 @@ def is_multi_step(self) -> bool: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class DeviceConfig: """Configuration for the device to use for vLLM execution.""" @@ -2665,7 +2676,7 @@ def __repr__(self) -> str: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class LoRAConfig: """Configuration for LoRA.""" @@ -2761,7 +2772,7 @@ def verify_lora_support(self): @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class PromptAdapterConfig: """Configuration for PromptAdapters.""" @@ -2818,8 +2829,8 @@ def verify_with_model_config(self, model_config: ModelConfig): class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: dict[str, int] = get_field(ModelConfig, - "limit_mm_per_prompt") + limit_per_prompt: dict[str, int] = \ + cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) """ The maximum number of input items allowed per prompt for each modality. Defaults to 1 (V0) or 999 (V1) for each modality. @@ -3533,7 +3544,8 @@ class CompilationLevel: PIECEWISE = 3 -class CompilationConfig(BaseModel): +@dataclass +class CompilationConfig: """ Configuration for compilation. It has three parts: @@ -3618,13 +3630,13 @@ class CompilationConfig(BaseModel): debug_dump_path: str = "" cache_dir: str = "" backend: str = "" - custom_ops: list[str] = Field(default_factory=list) - splitting_ops: list[str] = Field(default=None) # type: ignore + custom_ops: list[str] = field(default_factory=list) + splitting_ops: list[str] = field(default_factory=list) use_inductor: bool = True - compile_sizes: Optional[list[Union[int, str]]] = Field(default=None) - inductor_compile_config: dict = Field(default_factory=dict) - inductor_passes: dict[str, str] = Field(default_factory=dict) + compile_sizes: Optional[list[Union[int, str]]] = None + inductor_compile_config: dict = field(default_factory=dict) + inductor_passes: dict[str, str] = field(default_factory=dict) use_cudagraph: bool = False cudagraph_num_of_warmups: int = 0 @@ -3645,8 +3657,8 @@ class PassConfig(BaseModel): TODO(luka) better pass enabling system. - enable_sequence_parallelism: whether to enable sequence parallelism. """ - dump_graph_stages: list[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) + dump_graph_stages: list[str] = field(default_factory=list) + dump_graph_dir: Path = Path(".") enable_fusion: bool = True enable_noop: bool = True enable_sequence_parallelism: bool = False @@ -3668,27 +3680,28 @@ def model_post_init(self, __context: Any) -> None: "Fusion enabled but reshape elimination disabled. " "RMSNorm + quant (fp8) fusion might not work") - pass_config: PassConfig = Field(default_factory=PassConfig) + pass_config: PassConfig = field(default_factory=PassConfig) # not configurable, computed after init - max_capture_size: int = PrivateAttr - local_cache_dir: str = PrivateAttr # local cache dir for each rank + max_capture_size: int = None # type: ignore + # local cache dir for each rank + local_cache_dir: str = None # type: ignore # optimization: # Intuitively, bs_to_padded_graph_size should be dict[int, int]. # since we know all keys are in a range [0, max_capture_size], # we can optimize it to list[int] for better lookup performance. - bs_to_padded_graph_size: list[int] = PrivateAttr + bs_to_padded_graph_size: list[int] = None # type: ignore # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = PrivateAttr - disabled_custom_ops: Counter[str] = PrivateAttr - traced_files: set[str] = PrivateAttr - compilation_time: float = PrivateAttr + enabled_custom_ops: Counter[str] = field(default_factory=Counter) + disabled_custom_ops: Counter[str] = field(default_factory=Counter) + traced_files: set[str] = field(default_factory=set) + compilation_time: float = 0.0 # Per-model forward context # Map from layer name to layer objects that need to be accessed outside # model code, e.g., Attention, FusedMOE when dp_size>1. - static_forward_context: dict[str, Any] = PrivateAttr + static_forward_context: dict[str, Any] = field(default_factory=dict) def compute_hash(self) -> str: """ @@ -3723,7 +3736,8 @@ def __repr__(self) -> str: "pass_config", "traced_files", } - return self.model_dump_json(exclude=exclude, exclude_unset=True) + return TypeAdapter(CompilationConfig).dump_json( + self, exclude=exclude, exclude_unset=True).decode() __str__ = __repr__ @@ -3732,11 +3746,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - # do not use `eval`, it is dangerous and can execute arbitrary code - dict_value = ast.literal_eval(cli_value) - return CompilationConfig.model_validate(dict_value) + return TypeAdapter(CompilationConfig).validate_json(cli_value) - def model_post_init(self, __context: Any) -> None: + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") @@ -3755,9 +3767,6 @@ def model_post_init(self, __context: Any) -> None: if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False - if self.splitting_ops is None: - self.splitting_ops = [] - for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( @@ -3774,12 +3783,6 @@ def model_post_init(self, __context: Any) -> None: self.inductor_compile_config[k] = func if isinstance( func, InductorPass) else CallableInductorPass(func) - self.enabled_custom_ops = Counter() - self.disabled_custom_ops = Counter() - self.traced_files = set() - self.static_forward_context = {} - self.compilation_time = 0.0 - def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -3877,22 +3880,19 @@ class VllmConfig: init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore lora_config: Optional[LoRAConfig] = None - speculative_config: SpeculativeConfig = field(default=None, - init=True) # type: ignore + speculative_config: Optional[SpeculativeConfig] = None decoding_config: Optional[DecodingConfig] = None observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore - kv_transfer_config: KVTransferConfig = field(default=None, - init=True) # type: ignore + kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. - additional_config: SupportsHash = field(default=None, - init=True) # type: ignore + additional_config: dict[str, Any] = field(default_factory=dict) instance_id: str = "" def compute_hash(self) -> str: @@ -3974,7 +3974,8 @@ def compute_hash(self) -> str: else: vllm_factors.append("None") if self.additional_config: - vllm_factors.append(self.additional_config.compute_hash()) + vllm_factors.append(hash(frozenset( + self.additional_config.items()))) else: vllm_factors.append("None") factors.append(vllm_factors) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0ba14c4dee0..2b85df7bd45 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -47,12 +47,9 @@ TypeHintT = Union[type[T], object] -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _optional_type(val: str) -> Optional[T]: - if val == "" or val == "None": - return None + def _parse_type(val: str) -> T: try: if return_type is json.loads and not re.match("^{.*}$", val): return cast(T, nullable_kvs(val)) @@ -61,14 +58,24 @@ def _optional_type(val: str) -> Optional[T]: raise argparse.ArgumentTypeError( f"Value {val} cannot be converted to {return_type}.") from e + return _parse_type + + +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + return _optional_type def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: if not re.match("^{.*}$", val): return str(val) - else: - return optional_type(json.loads)(val) + return optional_type(json.loads)(val) @deprecated( @@ -195,13 +202,12 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif contains_type(type_hints, - dict) and (contains_type(type_hints, str) or any( - is_not_builtin(th) for th in type_hints)): + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): - # Dict arguments will always be optional - kwargs[name]["type"] = optional_type(json.loads) + kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) or any(is_not_builtin(th) for th in type_hints)): @@ -379,7 +385,8 @@ class EngineArgs: calculate_kv_scales: bool = CacheConfig.calculate_kv_scales - additional_config: Optional[Dict[str, Any]] = None + additional_config: dict[str, Any] = \ + get_field(VllmConfig, "additional_config") enable_reasoning: Optional[bool] = None # DEPRECATED reasoning_parser: str = DecodingConfig.reasoning_backend @@ -808,7 +815,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: vllm_group.add_argument( "--additional-config", type=json.loads, - default=None, + default=dict(), help="Additional config for specified platform in JSON format. " "Different platforms may support different configs. Make sure the " "configs are valid for the platform you are using. The input format" From f6b1be7a0ee45b4c8077680a0f283919b78eb8f7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 11:43:46 +0200 Subject: [PATCH 05/23] Fix missing imports Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index 3519ffbd27e..c0c86c29fc2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -20,6 +20,7 @@ Protocol, TypeVar, Union, cast, get_args, get_origin) import torch +from pydantic import ConfigDict, SkipValidation, TypeAdapter, model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig From b1f0fdcf3d0b13b78cf1dff3c5320fdc522d13dc Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:08:04 +0200 Subject: [PATCH 06/23] Make mypy pass Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 11 ++++++++--- vllm/entrypoints/openai/serving_engine.py | 19 ++++++++++--------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c0c86c29fc2..a1170197cec 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -243,7 +243,7 @@ class ModelConfig: task, even if the same model can be used for multiple tasks. When the model only supports one task, "auto" can be used to select it; otherwise, you must specify explicitly which task to use.""" - tokenizer: Optional[str] = None + tokenizer: str = "" """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode = "auto" @@ -447,8 +447,6 @@ def compute_hash(self) -> str: def __post_init__(self) -> None: self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. - if self.tokenizer is None: - self.tokenizer = self.model if self.tokenizer_revision is None: self.tokenizer_revision = self.revision self.tokenizer = maybe_model_redirect(self.tokenizer) @@ -586,6 +584,13 @@ def __post_init__(self) -> None: self._verify_cuda_graph() self._verify_bnb_config() + @model_validator(mode="before") + @classmethod + def validate_model_config_before(cls, data): + if not data.get("tokenizer"): + data["tokenizer"] = data["model"] + return data + @model_validator(mode="after") def validate(self: "ModelConfig") -> "ModelConfig": if not isinstance(self.max_model_len, int): diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f1d907f519c..f8e7f29bc4e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -110,10 +110,9 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ - request_prompts: Optional[Sequence[RequestPrompt]] = \ - Field(default_factory=list) - engine_prompts: Optional[list[TokensPrompt]] = \ - Field(default_factory=list) + # Pydantic base models handle mutable defaults correctly + request_prompts: Optional[Sequence[RequestPrompt]] = [] + engine_prompts: Optional[list[TokensPrompt]] = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -497,12 +496,14 @@ def _validate_input( if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest, RerankRequest, ClassificationRequest)): - operation = { - ScoreRequest: "score", - ClassificationRequest: "classification" - }.get(type(request), "embedding generation") - if token_num > self.max_model_len: + if (token_num > self.max_model_len): + operations: dict[type[AnyRequest], str] = { + ScoreRequest: "score", + ClassificationRequest: "classification" + } + operation = operations.get(type(request), + "embedding generation") raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " From d2ff0da95d38534dc19824b04fdf84915294317f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:11:04 +0200 Subject: [PATCH 07/23] Assert no longer needed Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/transformers_utils/tokenizer_group.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index a28efb181f8..aff2d2eb1c3 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -107,7 +107,6 @@ async def get_lora_tokenizer_async( def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig]): - assert isinstance(model_config.tokenizer, str) return TokenizerGroup( tokenizer_id=model_config.tokenizer, enable_lora=bool(lora_config), From 2ed29c828c829f0882512490efd6b68f7ded1731 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:11:45 +0200 Subject: [PATCH 08/23] Pydantic base models correctnyl handle mutable defaults Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/entrypoints/openai/protocol.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0f58c769619..7e98983c1fd 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -174,8 +174,8 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" -# extra="forbid" is a workaround, see -# https://github.com/pydantic/pydantic/issues/3125 +# extra="forbid" is a workaround to have kwargs as a field, +# see https://github.com/pydantic/pydantic/issues/3125 class LogitsProcessorConstructor(BaseModel, extra="forbid"): qualname: str args: Optional[list[Any]] = None @@ -235,7 +235,7 @@ class ChatCompletionRequest(OpenAIBaseModel): presence_penalty: Optional[float] = 0.0 response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=lambda: []) + stop: Optional[Union[str, list[str]]] = [] stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None @@ -259,7 +259,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=lambda: []) + stop_token_ids: Optional[list[int]] = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 @@ -751,7 +751,7 @@ class CompletionRequest(OpenAIBaseModel): n: int = 1 presence_penalty: Optional[float] = 0.0 seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=lambda: []) + stop: Optional[Union[str, list[str]]] = [] stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None @@ -765,7 +765,7 @@ class CompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=lambda: []) + stop_token_ids: Optional[list[int]] = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 From bb47b84867c7029b4aa7090ad3ee52afb5341dc7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:14:37 +0200 Subject: [PATCH 09/23] remove parenthesis Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/entrypoints/openai/serving_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f8e7f29bc4e..ebe83f1e92e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -497,7 +497,7 @@ def _validate_input( (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest, RerankRequest, ClassificationRequest)): - if (token_num > self.max_model_len): + if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", ClassificationRequest: "classification" From 30fcc161ee63c2c0ab22b7c7738c109c1500bf51 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:15:54 +0200 Subject: [PATCH 10/23] remove comment Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/entrypoints/openai/serving_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ebe83f1e92e..3409e423ca5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -110,7 +110,6 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ - # Pydantic base models handle mutable defaults correctly request_prompts: Optional[Sequence[RequestPrompt]] = [] engine_prompts: Optional[list[TokensPrompt]] = [] From 73dba34439e1939a375458238312da0123189f00 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:27:04 +0200 Subject: [PATCH 11/23] Remove model validator only used for tokenizer Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a1170197cec..11c548705c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -243,7 +243,7 @@ class ModelConfig: task, even if the same model can be used for multiple tasks. When the model only supports one task, "auto" can be used to select it; otherwise, you must specify explicitly which task to use.""" - tokenizer: str = "" + tokenizer: SkipValidation[str] = None # type: ignore """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode = "auto" @@ -447,6 +447,8 @@ def compute_hash(self) -> str: def __post_init__(self) -> None: self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. + if self.tokenizer is None: + self.tokenizer = self.model if self.tokenizer_revision is None: self.tokenizer_revision = self.revision self.tokenizer = maybe_model_redirect(self.tokenizer) @@ -584,15 +586,8 @@ def __post_init__(self) -> None: self._verify_cuda_graph() self._verify_bnb_config() - @model_validator(mode="before") - @classmethod - def validate_model_config_before(cls, data): - if not data.get("tokenizer"): - data["tokenizer"] = data["model"] - return data - @model_validator(mode="after") - def validate(self: "ModelConfig") -> "ModelConfig": + def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": if not isinstance(self.max_model_len, int): raise ValueError( "max_model_len must be an integer after __post_init__.") From eca37a2634b2898ea3dfbe2d16bc6a3e00967f2d Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 12:38:53 +0200 Subject: [PATCH 12/23] Fix Pydantic runtime error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 11c548705c3..11954cbb609 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -24,7 +24,7 @@ from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig -from typing_extensions import deprecated +from typing_extensions import deprecated, runtime_checkable import vllm.envs as envs from vllm import version @@ -99,6 +99,7 @@ PretrainedConfig]] +@runtime_checkable class SupportsHash(Protocol): def compute_hash(self) -> str: @@ -3951,7 +3952,7 @@ def set_splitting_ops_for_v1(self): @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. From a72251fa675a2b5cab346dc1a5ef45a817ea2ab9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 13:21:45 +0200 Subject: [PATCH 13/23] Add Pydantic validation to dataclass instantiation from CLI Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/engine/arg_utils.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bfd54accf60..51a566bc063 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,12 +8,14 @@ import sys import threading import warnings -from dataclasses import MISSING, dataclass, fields, is_dataclass +from dataclasses import MISSING, dataclass, fields from itertools import permutations from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) import torch +from pydantic import TypeAdapter, ValidationError +from pydantic.dataclasses import is_pydantic_dataclass from typing_extensions import TypeIs, deprecated import vllm.envs as envs @@ -161,14 +163,15 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: type_hints.add(field.type) # If the field is a dataclass, we can use the model_validate_json - generator = (th for th in type_hints if is_dataclass(th)) + generator = (th for th in type_hints if is_pydantic_dataclass(th)) dataclass_cls = next(generator, None) # Get the default value of the field if field.default is not MISSING: default = field.default elif field.default_factory is not MISSING: - if is_dataclass(field.default_factory) and is_in_doc_build(): + if (is_pydantic_dataclass(field.default_factory) + and is_in_doc_build()): default = {} else: default = field.default_factory() @@ -185,12 +188,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Set other kwargs based on the type hints json_tip = "\n\nShould be a valid JSON string." if dataclass_cls is not None: - dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) - # Special case for configs with a from_cli method - if hasattr(dataclass_cls, "from_cli"): - from_cli = dataclass_cls.from_cli - dataclass_init = lambda x, f=from_cli: f(x) - kwargs[name]["type"] = dataclass_init + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass kwargs[name]["help"] += json_tip elif contains_type(type_hints, bool): # Creates --no- and -- flags From 712b312f798e979a5d5d1df38b750f61cddca966 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 16:35:18 +0200 Subject: [PATCH 14/23] Skip validation for defaults which are deferred Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 11954cbb609..9e197e678fa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1890,19 +1890,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = None # type: ignore + max_num_batched_tokens: SkipValidation[int] = None # type: ignore """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int = None # type: ignore + max_num_seqs: SkipValidation[int] = None # type: ignore """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int = None # type: ignore + max_model_len: SkipValidation[int] = None # type: ignore """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -1941,7 +1941,7 @@ class SchedulerConfig: """Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt.""" - enable_chunked_prefill: bool = None # type: ignore + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -2235,8 +2235,7 @@ class SpeculativeConfig: """Configuration for speculative decoding.""" # General speculative decoding control - num_speculative_tokens: int = field(default=None, - init=True) # type: ignore + num_speculative_tokens: SkipValidation[int] = None # type: ignore """The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required.""" model: Optional[str] = None @@ -2312,26 +2311,23 @@ class SpeculativeConfig: """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine - target_model_config: ModelConfig = field(default=None, - init=True) # type: ignore + target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" - target_parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore + target_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore """The parallel configuration for the target model.""" - enable_chunked_prefill: bool = field(default=None, - init=True) # type: ignore + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """Whether vLLM is configured to use chunked prefill or not. Used for raising an error since it's not yet compatible with speculative decode.""" - disable_log_stats: bool = field(default=None, init=True) # type: ignore + disable_log_stats: SkipValidation[bool] = None # type: ignore """Whether to disable the periodic printing of stage times in speculative decoding.""" # params generated in the post-init stage - draft_model_config: ModelConfig = field(default=None, - init=True) # type: ignore + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" - draft_parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore + draft_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: From 4397de9749a436f3d9e39c416f467b3f95f48a40 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 12 May 2025 20:38:24 +0100 Subject: [PATCH 15/23] Fix docs build Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 16 +++++++++++++--- vllm/entrypoints/openai/protocol.py | 4 +++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9e197e678fa..bf61cce9df2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -21,7 +21,6 @@ import torch from pydantic import ConfigDict, SkipValidation, TypeAdapter, model_validator -from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated, runtime_checkable @@ -49,6 +48,7 @@ if TYPE_CHECKING: from _typeshed import DataclassInstance + from pydantic.dataclasses import dataclass from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase @@ -58,6 +58,13 @@ ConfigType = type[DataclassInstance] else: + + def dataclass(*args, **kwargs): + """A non-Pydantic dataclass for docs builds.""" + kwargs.pop("config", None) + from dataclasses import dataclass as _dataclass + return _dataclass(*args, **kwargs) + PlacementGroup = Any ExecutorBase = Any QuantizationConfig = Any @@ -3813,8 +3820,11 @@ def __repr__(self) -> str: "pass_config", "traced_files", } - return TypeAdapter(CompilationConfig).dump_json( - self, exclude=exclude, exclude_unset=True).decode() + # The cast to string is necessary because Pydantic is mocked in docs + # builds and sphinx-argparse doesn't know the return type of decode() + return str( + TypeAdapter(CompilationConfig).dump_json( + self, exclude=exclude, exclude_unset=True).decode()) __str__ = __repr__ diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e98983c1fd..9497308040a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -176,11 +176,13 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): # extra="forbid" is a workaround to have kwargs as a field, # see https://github.com/pydantic/pydantic/issues/3125 -class LogitsProcessorConstructor(BaseModel, extra="forbid"): +class LogitsProcessorConstructor(BaseModel): qualname: str args: Optional[list[Any]] = None kwargs: Optional[dict[str, Any]] = None + model_config = ConfigDict(extra="forbid") + LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] From caa1dc560bf72010f8292dd41b95386858eb96d2 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 15 May 2025 19:09:33 +0100 Subject: [PATCH 16/23] Fix docs build 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index bf61cce9df2..7bb76f37f21 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -43,8 +43,11 @@ from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, is_torch_equal_or_newer, - random_uuid, resolve_obj_by_qualname) + get_cpu_memory, get_open_port, is_in_doc_build, + is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) + +IS_IN_DOC_BUILD = is_in_doc_build() if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -2207,7 +2210,9 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - if self.device == "auto": + if IS_IN_DOC_BUILD: + self.device_type = "cpu" + elif self.device == "auto": # Automated device type detection from vllm.platforms import current_platform self.device_type = current_platform.device_type From 18ea0acb2f47768abca4896aec0bf8aebb9be368 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 09:31:34 +0200 Subject: [PATCH 17/23] `VllmConfig.compilation_config` should never be `None` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 2 -- vllm/engine/arg_utils.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 44a19977141..29274f7300b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4277,8 +4277,6 @@ def __post_init__(self): "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") - if self.compilation_config is None: - self.compilation_config = CompilationConfig() if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a3e4562a124..f9f2a899a30 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -404,7 +404,8 @@ class EngineArgs: get_field(ModelConfig, "override_neuron_config") override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config - compilation_config: Optional[CompilationConfig] = None + compilation_config: CompilationConfig = \ + get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls From 905ab3e45f6ceaddf7866365a5562bccf7d2b40f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 09:40:08 +0200 Subject: [PATCH 18/23] Type adapter works for non-Pydantic dataclasses Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/engine/arg_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f9f2a899a30..f54cf413bdd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,14 +8,13 @@ import sys import threading import warnings -from dataclasses import MISSING, dataclass, fields +from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) import torch -from pydantic import TypeAdapter, ValidationError -from pydantic.dataclasses import is_pydantic_dataclass +from pydantic import SkipValidation, TypeAdapter, ValidationError from typing_extensions import TypeIs, deprecated import vllm.envs as envs @@ -44,6 +43,8 @@ # yapf: enable +IS_IN_DOC_BUILD = is_in_doc_build() + logger = init_logger(__name__) # object is used to allow for special typing forms @@ -158,20 +159,20 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the set of possible types for the field type_hints: set[TypeHint] = set() if get_origin(field.type) in {Union, Annotated}: - type_hints.update(get_args(field.type)) + predicate = lambda arg: not isinstance(arg, SkipValidation) + type_hints.update(filter(predicate, get_args(field.type))) else: type_hints.add(field.type) # If the field is a dataclass, we can use the model_validate_json - generator = (th for th in type_hints if is_pydantic_dataclass(th)) + generator = (th for th in type_hints if is_dataclass(th)) dataclass_cls = next(generator, None) # Get the default value of the field if field.default is not MISSING: default = field.default elif field.default_factory is not MISSING: - if (is_pydantic_dataclass(field.default_factory) - and is_in_doc_build()): + if IS_IN_DOC_BUILD and is_dataclass(field.default_factory): default = {} else: default = field.default_factory() From 80d03a6b8de2a4ee75bcb3373678fbfa87429499 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 09:42:10 +0200 Subject: [PATCH 19/23] Using stdlib dataclass when not type checking breaks pydantic validation Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 29274f7300b..969df6fb1a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,6 +22,7 @@ import torch from pydantic import ConfigDict, SkipValidation, TypeAdapter, model_validator +from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated, runtime_checkable @@ -52,7 +53,6 @@ if TYPE_CHECKING: from _typeshed import DataclassInstance - from pydantic.dataclasses import dataclass from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase @@ -62,13 +62,6 @@ ConfigType = type[DataclassInstance] else: - - def dataclass(*args, **kwargs): - """A non-Pydantic dataclass for docs builds.""" - kwargs.pop("config", None) - from dataclasses import dataclass as _dataclass - return _dataclass(*args, **kwargs) - PlacementGroup = Any ExecutorBase = Any QuantizationConfig = Any From 92e2b75e74201bb026b6a26c5ca505172cc0886e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 09:43:34 +0200 Subject: [PATCH 20/23] Fix `compilation_config_instance` being `None` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 053ee55bb6a..00adb751789 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -215,7 +215,7 @@ def __init__( else: compilation_config_instance = compilation_config else: - compilation_config_instance = None + compilation_config_instance = CompilationConfig() engine_args = EngineArgs( model=model, From e995cc0cf35f87358c84b5eda023a86bd997a164 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 10:16:44 +0200 Subject: [PATCH 21/23] Use stdlib dataclasses when not in docs build Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 969df6fb1a2..1fe55f10d27 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,7 +22,6 @@ import torch from pydantic import ConfigDict, SkipValidation, TypeAdapter, model_validator -from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated, runtime_checkable @@ -51,6 +50,16 @@ IS_IN_DOC_BUILD = is_in_doc_build() +if IS_IN_DOC_BUILD: + + def dataclass(*args, **kwargs): + """A non-Pydantic dataclass for docs builds.""" + kwargs.pop("config", None) + from dataclasses import dataclass as _dataclass + return _dataclass(*args, **kwargs) +else: + from pydantic.dataclasses import dataclass + if TYPE_CHECKING: from _typeshed import DataclassInstance from ray.util.placement_group import PlacementGroup From fdea28abfc5b156c391df21177982548de4998c3 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 10:22:12 +0200 Subject: [PATCH 22/23] Undo whitespace change Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index 1fe55f10d27..9c639398513 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1720,6 +1720,7 @@ class ParallelConfig: """Port of the data parallel master.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" + max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor From ead89d741c41e99f85c03a9fd3c4df0450c4cb76 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 12:14:05 +0200 Subject: [PATCH 23/23] Make docs build pass Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/engine/arg_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f54cf413bdd..85cf27fa273 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -159,7 +159,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the set of possible types for the field type_hints: set[TypeHint] = set() if get_origin(field.type) in {Union, Annotated}: - predicate = lambda arg: not isinstance(arg, SkipValidation) + if IS_IN_DOC_BUILD: + predicate = lambda _: True + else: + predicate = lambda arg: not isinstance(arg, SkipValidation) type_hints.update(filter(predicate, get_args(field.type))) else: type_hints.add(field.type)