diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f5c0c368d57..ac70df39eb4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,7 +57,7 @@ repos: entry: tools/mypy.sh 0 "local" language: python types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests] + additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] stages: [pre-commit] # Don't run in CI - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 diff --git a/pyproject.toml b/pyproject.toml index 0b803a26b65..731fa844449 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ ignore = [ ] [tool.mypy] +plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" diff --git a/vllm/config.py b/vllm/config.py index d07a1ff0523..9c639398513 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,8 +12,8 @@ import warnings from collections import Counter from contextlib import contextmanager -from dataclasses import (MISSING, Field, asdict, dataclass, field, fields, - is_dataclass, replace) +from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, + replace) from functools import cached_property from importlib.util import find_spec from pathlib import Path @@ -21,9 +21,10 @@ Protocol, TypeVar, Union, cast, get_args, get_origin) import torch +from pydantic import ConfigDict, SkipValidation, TypeAdapter, model_validator from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig -from typing_extensions import deprecated +from typing_extensions import deprecated, runtime_checkable import vllm.envs as envs from vllm import version @@ -43,8 +44,21 @@ from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, is_torch_equal_or_newer, - random_uuid, resolve_obj_by_qualname) + get_cpu_memory, get_open_port, is_in_doc_build, + is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) + +IS_IN_DOC_BUILD = is_in_doc_build() + +if IS_IN_DOC_BUILD: + + def dataclass(*args, **kwargs): + """A non-Pydantic dataclass for docs builds.""" + kwargs.pop("config", None) + from dataclasses import dataclass as _dataclass + return _dataclass(*args, **kwargs) +else: + from pydantic.dataclasses import dataclass if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -57,7 +71,10 @@ ConfigType = type[DataclassInstance] else: + PlacementGroup = Any + ExecutorBase = Any QuantizationConfig = Any + BaseModelLoader = Any ConfigType = type logger = init_logger(__name__) @@ -95,6 +112,7 @@ PretrainedConfig]] +@runtime_checkable class SupportsHash(Protocol): def compute_hash(self) -> str: @@ -226,7 +244,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class ModelConfig: """Configuration for the model.""" @@ -239,7 +257,7 @@ class ModelConfig: task, even if the same model can be used for multiple tasks. When the model only supports one task, "auto" can be used to select it; otherwise, you must specify explicitly which task to use.""" - tokenizer: str = None # type: ignore + tokenizer: SkipValidation[str] = None # type: ignore """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode = "auto" @@ -287,7 +305,7 @@ class ModelConfig: """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - max_model_len: int = None # type: ignore + max_model_len: SkipValidation[int] = None # type: ignore """Model context length (prompt and output). If unspecified, will be automatically derived from the model config. @@ -601,6 +619,13 @@ def __post_init__(self) -> None: self._verify_cuda_graph() self._verify_bnb_config() + @model_validator(mode="after") + def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + if not isinstance(self.max_model_len, int): + raise ValueError( + "max_model_len must be an integer after __post_init__.") + return self + @property def registry(self): return ModelRegistry @@ -1391,7 +1416,7 @@ def matryoshka_dimensions(self): class CacheConfig: """Configuration for the KV cache.""" - block_size: BlockSize = None # type: ignore + block_size: SkipValidation[BlockSize] = None # type: ignore """Size of a contiguous cache block in number of tokens. This is ignored on neuron devices and set to `--max-model-len`. On CUDA devices, only block sizes up to 32 are supported. On HPU devices, block size defaults to 128. @@ -1695,6 +1720,7 @@ class ParallelConfig: """Port of the data parallel master.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" + max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor @@ -1923,19 +1949,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = None # type: ignore + max_num_batched_tokens: SkipValidation[int] = None # type: ignore """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int = None # type: ignore + max_num_seqs: SkipValidation[int] = None # type: ignore """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int = None # type: ignore + max_model_len: SkipValidation[int] = None # type: ignore """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -1974,7 +2000,7 @@ class SchedulerConfig: """Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt.""" - enable_chunked_prefill: bool = None # type: ignore + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -2196,7 +2222,7 @@ def is_multi_step(self) -> bool: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class DeviceConfig: """Configuration for the device to use for vLLM execution.""" @@ -2227,7 +2253,9 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - if self.device == "auto": + if IS_IN_DOC_BUILD: + self.device_type = "cpu" + elif self.device == "auto": # Automated device type detection from vllm.platforms import current_platform self.device_type = current_platform.device_type @@ -2262,8 +2290,7 @@ class SpeculativeConfig: """Configuration for speculative decoding.""" # General speculative decoding control - num_speculative_tokens: int = field(default=None, - init=True) # type: ignore + num_speculative_tokens: SkipValidation[int] = None # type: ignore """The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required.""" model: Optional[str] = None @@ -2339,26 +2366,23 @@ class SpeculativeConfig: """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine - target_model_config: ModelConfig = field(default=None, - init=True) # type: ignore + target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" - target_parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore + target_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore """The parallel configuration for the target model.""" - enable_chunked_prefill: bool = field(default=None, - init=True) # type: ignore + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """Whether vLLM is configured to use chunked prefill or not. Used for raising an error since it's not yet compatible with speculative decode.""" - disable_log_stats: bool = field(default=None, init=True) # type: ignore + disable_log_stats: SkipValidation[bool] = None # type: ignore """Whether to disable the periodic printing of stage times in speculative decoding.""" # params generated in the post-init stage - draft_model_config: ModelConfig = field(default=None, - init=True) # type: ignore + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" - draft_parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore + draft_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: @@ -2748,7 +2772,7 @@ def __repr__(self) -> str: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class LoRAConfig: """Configuration for LoRA.""" @@ -2845,7 +2869,7 @@ def verify_lora_support(self): @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class PromptAdapterConfig: """Configuration for PromptAdapters.""" @@ -3860,17 +3884,11 @@ def __repr__(self) -> str: "pass_config", "traced_files", } - include = dict() - for k, v in asdict(self).items(): - if k in exclude: - continue - f = get_field(CompilationConfig, k) - if (d := f.default) is not MISSING and d == v: - continue - if (df := f.default_factory) is not MISSING and df() == v: - continue - include[k] = v - return json.dumps(include) + # The cast to string is necessary because Pydantic is mocked in docs + # builds and sphinx-argparse doesn't know the return type of decode() + return str( + TypeAdapter(CompilationConfig).dump_json( + self, exclude=exclude, exclude_unset=True).decode()) __str__ = __repr__ @@ -3879,7 +3897,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - return cls(**json.loads(cli_value)) + return TypeAdapter(CompilationConfig).validate_json(cli_value) def __post_init__(self) -> None: count_none = self.custom_ops.count("none") @@ -4005,7 +4023,7 @@ def set_splitting_ops_for_v1(self): @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. @@ -4262,8 +4280,6 @@ def __post_init__(self): "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") - if self.compilation_config is None: - self.compilation_config = CompilationConfig() if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dc2bb3a52ca..85cf27fa273 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -14,6 +14,7 @@ Type, TypeVar, Union, cast, get_args, get_origin) import torch +from pydantic import SkipValidation, TypeAdapter, ValidationError from typing_extensions import TypeIs, deprecated import vllm.envs as envs @@ -42,6 +43,8 @@ # yapf: enable +IS_IN_DOC_BUILD = is_in_doc_build() + logger = init_logger(__name__) # object is used to allow for special typing forms @@ -156,7 +159,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the set of possible types for the field type_hints: set[TypeHint] = set() if get_origin(field.type) in {Union, Annotated}: - type_hints.update(get_args(field.type)) + if IS_IN_DOC_BUILD: + predicate = lambda _: True + else: + predicate = lambda arg: not isinstance(arg, SkipValidation) + type_hints.update(filter(predicate, get_args(field.type))) else: type_hints.add(field.type) @@ -168,7 +175,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: if field.default is not MISSING: default = field.default elif field.default_factory is not MISSING: - if is_dataclass(field.default_factory) and is_in_doc_build(): + if IS_IN_DOC_BUILD and is_dataclass(field.default_factory): default = {} else: default = field.default_factory() @@ -189,12 +196,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" if dataclass_cls is not None: - dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) - # Special case for configs with a from_cli method - if hasattr(dataclass_cls, "from_cli"): - from_cli = dataclass_cls.from_cli - dataclass_init = lambda x, f=from_cli: f(x) - kwargs[name]["type"] = dataclass_init + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass kwargs[name]["help"] += json_tip elif contains_type(type_hints, bool): # Creates --no- and -- flags @@ -225,12 +236,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif contains_type(type_hints, - dict) and (contains_type(type_hints, str) or any( - is_not_builtin(th) for th in type_hints)): + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): - # Dict arguments will always be optional kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) @@ -398,7 +408,8 @@ class EngineArgs: get_field(ModelConfig, "override_neuron_config") override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config - compilation_config: Optional[CompilationConfig] = None + compilation_config: CompilationConfig = \ + get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -413,7 +424,8 @@ class EngineArgs: calculate_kv_scales: bool = CacheConfig.calculate_kv_scales - additional_config: Optional[Dict[str, Any]] = None + additional_config: dict[str, Any] = \ + get_field(VllmConfig, "additional_config") enable_reasoning: Optional[bool] = None # DEPRECATED reasoning_parser: str = DecodingConfig.reasoning_backend diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 053ee55bb6a..00adb751789 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -215,7 +215,7 @@ def __init__( else: compilation_config_instance = compilation_config else: - compilation_config_instance = None + compilation_config_instance = CompilationConfig() engine_args = EngineArgs( model=model, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index cd6ee367011..52b13d0ed1c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -175,11 +175,15 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +# extra="forbid" is a workaround to have kwargs as a field, +# see https://github.com/pydantic/pydantic/issues/3125 class LogitsProcessorConstructor(BaseModel): qualname: str args: Optional[list[Any]] = None kwargs: Optional[dict[str, Any]] = None + model_config = ConfigDict(extra="forbid") + LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] @@ -234,7 +238,7 @@ class ChatCompletionRequest(OpenAIBaseModel): presence_penalty: Optional[float] = 0.0 response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = [] stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None @@ -258,7 +262,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 @@ -755,7 +759,7 @@ class CompletionRequest(OpenAIBaseModel): n: int = 1 presence_penalty: Optional[float] = 0.0 seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = [] stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None @@ -769,7 +773,7 @@ class CompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f9eebde3718..ddcfbb9b138 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -110,10 +110,8 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ - request_prompts: Optional[Sequence[RequestPrompt]] = \ - Field(default_factory=list) - engine_prompts: Optional[list[TokensPrompt]] = \ - Field(default_factory=list) + request_prompts: Optional[Sequence[RequestPrompt]] = [] + engine_prompts: Optional[list[TokensPrompt]] = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -497,12 +495,14 @@ def _validate_input( if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest, RerankRequest, ClassificationRequest)): - operation = { - ScoreRequest: "score", - ClassificationRequest: "classification" - }.get(type(request), "embedding generation") if token_num > self.max_model_len: + operations: dict[type[AnyRequest], str] = { + ScoreRequest: "score", + ClassificationRequest: "classification" + } + operation = operations.get(type(request), + "embedding generation") raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 085f37a5d51..316860718b7 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -3,12 +3,10 @@ from dataclasses import dataclass from typing import Optional, TypedDict, Union -from pydantic import BaseModel - # These classes are deprecated, see SamplingParams class LLMGuidedOptions(TypedDict, total=False): - guided_json: Union[dict, BaseModel, str] + guided_json: Union[dict, str] guided_regex: str guided_choice: list[str] guided_grammar: str @@ -20,7 +18,7 @@ class LLMGuidedOptions(TypedDict, total=False): @dataclass class GuidedDecodingRequest: """One of the fields will be used to retrieve the logit processor.""" - guided_json: Optional[Union[dict, BaseModel, str]] = None + guided_json: Optional[Union[dict, str]] = None guided_regex: Optional[str] = None guided_choice: Optional[list[str]] = None guided_grammar: Optional[str] = None