Skip to content

Enable Pydantic mypy checks and convert configs to Pydantic dataclasses #17599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
10d543b
Enable Pydantic mypy plugin
hmellor May 2, 2025
29fc238
Fix Pydantic errors in `protocol.py`
hmellor May 2, 2025
0ecc5d2
Fix other mypy errors
hmellor May 2, 2025
ad9535a
Convert dataclasses to pydantic dataclasses
hmellor May 2, 2025
2ec9ba6
Merge branch 'main' into enable-pydantic-mypy
hmellor May 4, 2025
0426733
Merge branch 'main' into enable-pydantic-mypy
hmellor May 12, 2025
f6b1be7
Fix missing imports
hmellor May 12, 2025
b1f0fdc
Make mypy pass
hmellor May 12, 2025
d2ff0da
Assert no longer needed
hmellor May 12, 2025
2ed29c8
Pydantic base models correctnyl handle mutable defaults
hmellor May 12, 2025
bb47b84
remove parenthesis
hmellor May 12, 2025
30fcc16
remove comment
hmellor May 12, 2025
73dba34
Remove model validator only used for tokenizer
hmellor May 12, 2025
eca37a2
Fix Pydantic runtime error
hmellor May 12, 2025
a72251f
Add Pydantic validation to dataclass instantiation from CLI
hmellor May 12, 2025
712b312
Skip validation for defaults which are deferred
hmellor May 12, 2025
4397de9
Fix docs build
hmellor May 12, 2025
caa1dc5
Fix docs build 2
hmellor May 15, 2025
b267203
Merge branch 'main' into enable-pydantic-mypy
hmellor May 15, 2025
96876a6
Merge branch 'main' into enable-pydantic-mypy
hmellor May 16, 2025
18ea0ac
`VllmConfig.compilation_config` should never be `None`
hmellor May 16, 2025
905ab3e
Type adapter works for non-Pydantic dataclasses
hmellor May 16, 2025
80d03a6
Using stdlib dataclass when not type checking breaks pydantic validation
hmellor May 16, 2025
92e2b75
Fix `compilation_config_instance` being `None`
hmellor May 16, 2025
e995cc0
Use stdlib dataclasses when not in docs build
hmellor May 16, 2025
fdea28a
Undo whitespace change
hmellor May 16, 2025
ead89d7
Make docs build pass
hmellor May 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ repos:
entry: tools/mypy.sh 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
stages: [pre-commit] # Don't run in CI
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ ignore = [
]

[tool.mypy]
plugins = ['pydantic.mypy']
ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"
Expand Down
101 changes: 58 additions & 43 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
import warnings
from collections import Counter
from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
is_dataclass, replace)
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)

import torch
from pydantic import ConfigDict, SkipValidation, TypeAdapter, model_validator
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated
from typing_extensions import deprecated, runtime_checkable

import vllm.envs as envs
from vllm import version
Expand All @@ -43,11 +44,15 @@
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
get_cpu_memory, get_open_port, is_in_doc_build,
is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname)

IS_IN_DOC_BUILD = is_in_doc_build()

if TYPE_CHECKING:
from _typeshed import DataclassInstance
from pydantic.dataclasses import dataclass
from ray.util.placement_group import PlacementGroup

from vllm.executor.executor_base import ExecutorBase
Expand All @@ -57,7 +62,17 @@

ConfigType = type[DataclassInstance]
else:

def dataclass(*args, **kwargs):
"""A non-Pydantic dataclass for docs builds."""
kwargs.pop("config", None)
from dataclasses import dataclass as _dataclass
return _dataclass(*args, **kwargs)

PlacementGroup = Any
ExecutorBase = Any
QuantizationConfig = Any
BaseModelLoader = Any
ConfigType = type

logger = init_logger(__name__)
Expand Down Expand Up @@ -95,6 +110,7 @@
PretrainedConfig]]


@runtime_checkable
class SupportsHash(Protocol):

def compute_hash(self) -> str:
Expand Down Expand Up @@ -226,7 +242,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
"""Configuration for the model."""

Expand All @@ -239,7 +255,7 @@ class ModelConfig:
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
tokenizer: str = None # type: ignore
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
Expand Down Expand Up @@ -287,7 +303,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: int = None # type: ignore
max_model_len: SkipValidation[int] = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.

Expand Down Expand Up @@ -601,6 +617,13 @@ def __post_init__(self) -> None:
self._verify_cuda_graph()
self._verify_bnb_config()

@model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
if not isinstance(self.max_model_len, int):
raise ValueError(
"max_model_len must be an integer after __post_init__.")
return self

@property
def registry(self):
return ModelRegistry
Expand Down Expand Up @@ -1391,7 +1414,7 @@ def matryoshka_dimensions(self):
class CacheConfig:
"""Configuration for the KV cache."""

block_size: BlockSize = None # type: ignore
block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
Expand Down Expand Up @@ -1924,19 +1947,19 @@ class SchedulerConfig:
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""

max_num_batched_tokens: int = None # type: ignore
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""

max_num_seqs: int = None # type: ignore
max_num_seqs: SkipValidation[int] = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.

This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""

max_model_len: int = None # type: ignore
max_model_len: SkipValidation[int] = None # type: ignore
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
Expand Down Expand Up @@ -1975,7 +1998,7 @@ class SchedulerConfig:
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""

enable_chunked_prefill: bool = None # type: ignore
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""

Expand Down Expand Up @@ -2197,7 +2220,7 @@ def is_multi_step(self) -> bool:


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""

Expand Down Expand Up @@ -2228,7 +2251,9 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
if self.device == "auto":
if IS_IN_DOC_BUILD:
self.device_type = "cpu"
elif self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
Expand Down Expand Up @@ -2263,8 +2288,7 @@ class SpeculativeConfig:
"""Configuration for speculative decoding."""

# General speculative decoding control
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
Expand Down Expand Up @@ -2340,26 +2364,23 @@ class SpeculativeConfig:
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
target_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: bool = field(default=None, init=True) # type: ignore
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""

# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
draft_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -2749,7 +2770,7 @@ def __repr__(self) -> str:


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""

Expand Down Expand Up @@ -2846,7 +2867,7 @@ def verify_lora_support(self):


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""

Expand Down Expand Up @@ -3861,17 +3882,11 @@ def __repr__(self) -> str:
"pass_config",
"traced_files",
}
include = dict()
for k, v in asdict(self).items():
if k in exclude:
continue
f = get_field(CompilationConfig, k)
if (d := f.default) is not MISSING and d == v:
continue
if (df := f.default_factory) is not MISSING and df() == v:
continue
include[k] = v
return json.dumps(include)
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self, exclude=exclude, exclude_unset=True).decode())

__str__ = __repr__

Expand All @@ -3880,7 +3895,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return cls(**json.loads(cli_value))
return TypeAdapter(CompilationConfig).validate_json(cli_value)

def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
Expand Down Expand Up @@ -4006,7 +4021,7 @@ def set_splitting_ops_for_v1(self):


@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
Expand Down
35 changes: 21 additions & 14 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import sys
import threading
import warnings
from dataclasses import MISSING, dataclass, fields, is_dataclass
from dataclasses import MISSING, dataclass, fields
from itertools import permutations
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)

import torch
from pydantic import TypeAdapter, ValidationError
from pydantic.dataclasses import is_pydantic_dataclass
from typing_extensions import TypeIs, deprecated

import vllm.envs as envs
Expand Down Expand Up @@ -161,14 +163,15 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
type_hints.add(field.type)

# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))
generator = (th for th in type_hints if is_pydantic_dataclass(th))
dataclass_cls = next(generator, None)

# Get the default value of the field
if field.default is not MISSING:
default = field.default
elif field.default_factory is not MISSING:
if is_dataclass(field.default_factory) and is_in_doc_build():
if (is_pydantic_dataclass(field.default_factory)
and is_in_doc_build()):
default = {}
else:
default = field.default_factory()
Expand All @@ -185,12 +188,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string."
if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method
if hasattr(dataclass_cls, "from_cli"):
from_cli = dataclass_cls.from_cli
dataclass_init = lambda x, f=from_cli: f(x)
kwargs[name]["type"] = dataclass_init

def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
try:
if hasattr(cls, "from_cli"):
return cls.from_cli(val)
return TypeAdapter(cls).validate_json(val)
except ValidationError as e:
raise argparse.ArgumentTypeError(repr(e)) from e

kwargs[name]["type"] = parse_dataclass
kwargs[name]["help"] += json_tip
elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
Expand Down Expand Up @@ -221,12 +228,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints,
dict) and (contains_type(type_hints, str) or any(
is_not_builtin(th) for th in type_hints)):
elif (contains_type(type_hints, dict)
and (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints))):
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str)
Expand Down Expand Up @@ -409,7 +415,8 @@ class EngineArgs:

calculate_kv_scales: bool = CacheConfig.calculate_kv_scales

additional_config: Optional[Dict[str, Any]] = None
additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
enable_reasoning: Optional[bool] = None # DEPRECATED
reasoning_parser: str = DecodingConfig.reasoning_backend

Expand Down
Loading