|
1 |
| -import importlib |
2 |
| -import string |
3 |
| -import subprocess |
4 |
| -import sys |
5 |
| -import uuid |
6 |
| -from functools import lru_cache, partial |
7 |
| -from typing import Callable, Dict, List, Optional, Tuple, Type, Union |
8 |
| - |
9 |
| -import torch.nn as nn |
10 |
| - |
11 |
| -from vllm.logger import init_logger |
12 |
| -from vllm.utils import is_hip |
13 |
| - |
14 |
| -from .interfaces import supports_multimodal, supports_pp |
15 |
| - |
16 |
| -logger = init_logger(__name__) |
17 |
| - |
18 |
| -_GENERATION_MODELS = { |
19 |
| - "AquilaModel": ("llama", "LlamaForCausalLM"), |
20 |
| - "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 |
21 |
| - "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), |
22 |
| - "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b |
23 |
| - "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b |
24 |
| - "BloomForCausalLM": ("bloom", "BloomForCausalLM"), |
25 |
| - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), |
26 |
| - "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), |
27 |
| - "CohereForCausalLM": ("commandr", "CohereForCausalLM"), |
28 |
| - "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), |
29 |
| - "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), |
30 |
| - "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), |
31 |
| - "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), |
32 |
| - "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), |
33 |
| - "FalconForCausalLM": ("falcon", "FalconForCausalLM"), |
34 |
| - "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), |
35 |
| - "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), |
36 |
| - "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), |
37 |
| - "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
38 |
| - "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), |
39 |
| - "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), |
40 |
| - "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), |
41 |
| - "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), |
42 |
| - "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), |
43 |
| - "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), |
44 |
| - "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), |
45 |
| - "JambaForCausalLM": ("jamba", "JambaForCausalLM"), |
46 |
| - "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), |
47 |
| - # For decapoda-research/llama-* |
48 |
| - "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), |
49 |
| - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), |
50 |
| - "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), |
51 |
| - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), |
52 |
| - # transformers's mpt class has lower case |
53 |
| - "MptForCausalLM": ("mpt", "MPTForCausalLM"), |
54 |
| - "MPTForCausalLM": ("mpt", "MPTForCausalLM"), |
55 |
| - "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), |
56 |
| - "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), |
57 |
| - "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), |
58 |
| - "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), |
59 |
| - "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), |
60 |
| - "OPTForCausalLM": ("opt", "OPTForCausalLM"), |
61 |
| - "OrionForCausalLM": ("orion", "OrionForCausalLM"), |
62 |
| - "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), |
63 |
| - "PhiForCausalLM": ("phi", "PhiForCausalLM"), |
64 |
| - "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), |
65 |
| - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), |
66 |
| - "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), |
67 |
| - "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), |
68 |
| - "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), |
69 |
| - "Qwen2VLForConditionalGeneration": |
70 |
| - ("qwen2_vl", "Qwen2VLForConditionalGeneration"), |
71 |
| - "RWForCausalLM": ("falcon", "FalconForCausalLM"), |
72 |
| - "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), |
73 |
| - "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), |
74 |
| - "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), |
75 |
| - "SolarForCausalLM": ("solar", "SolarForCausalLM"), |
76 |
| - "XverseForCausalLM": ("xverse", "XverseForCausalLM"), |
77 |
| - # NOTE: The below models are for speculative decoding only |
78 |
| - "MedusaModel": ("medusa", "Medusa"), |
79 |
| - "EAGLEModel": ("eagle", "EAGLE"), |
80 |
| - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), |
81 |
| -} |
82 |
| - |
83 |
| -_EMBEDDING_MODELS = { |
84 |
| - "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), |
85 |
| - "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), |
86 |
| -} |
87 |
| - |
88 |
| -_MULTIMODAL_MODELS = { |
89 |
| - "Blip2ForConditionalGeneration": |
90 |
| - ("blip2", "Blip2ForConditionalGeneration"), |
91 |
| - "ChameleonForConditionalGeneration": |
92 |
| - ("chameleon", "ChameleonForConditionalGeneration"), |
93 |
| - "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), |
94 |
| - "InternVLChatModel": ("internvl", "InternVLChatModel"), |
95 |
| - "LlavaForConditionalGeneration": ("llava", |
96 |
| - "LlavaForConditionalGeneration"), |
97 |
| - "LlavaNextForConditionalGeneration": ("llava_next", |
98 |
| - "LlavaNextForConditionalGeneration"), |
99 |
| - "LlavaNextVideoForConditionalGeneration": |
100 |
| - ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), |
101 |
| - "LlavaOnevisionForConditionalGeneration": |
102 |
| - ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), |
103 |
| - "MiniCPMV": ("minicpmv", "MiniCPMV"), |
104 |
| - "PaliGemmaForConditionalGeneration": ("paligemma", |
105 |
| - "PaliGemmaForConditionalGeneration"), |
106 |
| - "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), |
107 |
| - "PixtralForConditionalGeneration": ("pixtral", |
108 |
| - "PixtralForConditionalGeneration"), |
109 |
| - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), |
110 |
| - "Qwen2VLForConditionalGeneration": ("qwen2_vl", |
111 |
| - "Qwen2VLForConditionalGeneration"), |
112 |
| - "UltravoxModel": ("ultravox", "UltravoxModel"), |
113 |
| - "MllamaForConditionalGeneration": ("mllama", |
114 |
| - "MllamaForConditionalGeneration"), |
115 |
| -} |
116 |
| -_CONDITIONAL_GENERATION_MODELS = { |
117 |
| - "BartModel": ("bart", "BartForConditionalGeneration"), |
118 |
| - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), |
119 |
| -} |
120 |
| - |
121 |
| -_MODELS = { |
122 |
| - **_GENERATION_MODELS, |
123 |
| - **_EMBEDDING_MODELS, |
124 |
| - **_MULTIMODAL_MODELS, |
125 |
| - **_CONDITIONAL_GENERATION_MODELS, |
126 |
| -} |
127 |
| - |
128 |
| -# Architecture -> type. |
129 |
| -# out of tree models |
130 |
| -_OOT_MODELS: Dict[str, Type[nn.Module]] = {} |
131 |
| - |
132 |
| -# Models not supported by ROCm. |
133 |
| -_ROCM_UNSUPPORTED_MODELS: List[str] = [] |
134 |
| - |
135 |
| -# Models partially supported by ROCm. |
136 |
| -# Architecture -> Reason. |
137 |
| -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " |
138 |
| - "Triton flash attention. For half-precision SWA support, " |
139 |
| - "please use CK flash attention by setting " |
140 |
| - "`VLLM_USE_TRITON_FLASH_ATTN=0`") |
141 |
| -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { |
142 |
| - "Qwen2ForCausalLM": |
143 |
| - _ROCM_SWA_REASON, |
144 |
| - "MistralForCausalLM": |
145 |
| - _ROCM_SWA_REASON, |
146 |
| - "MixtralForCausalLM": |
147 |
| - _ROCM_SWA_REASON, |
148 |
| - "PaliGemmaForConditionalGeneration": |
149 |
| - ("ROCm flash attention does not yet " |
150 |
| - "fully support 32-bit precision on PaliGemma"), |
151 |
| - "Phi3VForCausalLM": |
152 |
| - ("ROCm Triton flash attention may run into compilation errors due to " |
153 |
| - "excessive use of shared memory. If this happens, disable Triton FA " |
154 |
| - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") |
155 |
| -} |
156 |
| - |
157 |
| - |
158 |
| -class ModelRegistry: |
159 |
| - |
160 |
| - @staticmethod |
161 |
| - def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: |
162 |
| - module_relname, cls_name = _MODELS[model_arch] |
163 |
| - return f"vllm.model_executor.models.{module_relname}", cls_name |
164 |
| - |
165 |
| - @staticmethod |
166 |
| - @lru_cache(maxsize=128) |
167 |
| - def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: |
168 |
| - if model_arch not in _MODELS: |
169 |
| - return None |
170 |
| - |
171 |
| - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) |
172 |
| - module = importlib.import_module(module_name) |
173 |
| - return getattr(module, cls_name, None) |
174 |
| - |
175 |
| - @staticmethod |
176 |
| - def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: |
177 |
| - if model_arch in _OOT_MODELS: |
178 |
| - return _OOT_MODELS[model_arch] |
179 |
| - |
180 |
| - if is_hip(): |
181 |
| - if model_arch in _ROCM_UNSUPPORTED_MODELS: |
182 |
| - raise ValueError( |
183 |
| - f"Model architecture {model_arch} is not supported by " |
184 |
| - "ROCm for now.") |
185 |
| - if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: |
186 |
| - logger.warning( |
187 |
| - "Model architecture %s is partially supported by ROCm: %s", |
188 |
| - model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) |
189 |
| - |
190 |
| - return None |
191 |
| - |
192 |
| - @staticmethod |
193 |
| - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: |
194 |
| - model = ModelRegistry._try_get_model_stateless(model_arch) |
195 |
| - if model is not None: |
196 |
| - return model |
197 |
| - |
198 |
| - return ModelRegistry._try_get_model_stateful(model_arch) |
199 |
| - |
200 |
| - @staticmethod |
201 |
| - def resolve_model_cls( |
202 |
| - architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: |
203 |
| - if isinstance(architectures, str): |
204 |
| - architectures = [architectures] |
205 |
| - if not architectures: |
206 |
| - logger.warning("No model architectures are specified") |
207 |
| - |
208 |
| - for arch in architectures: |
209 |
| - model_cls = ModelRegistry._try_load_model_cls(arch) |
210 |
| - if model_cls is not None: |
211 |
| - return (model_cls, arch) |
212 |
| - |
213 |
| - raise ValueError( |
214 |
| - f"Model architectures {architectures} are not supported for now. " |
215 |
| - f"Supported architectures: {ModelRegistry.get_supported_archs()}") |
216 |
| - |
217 |
| - @staticmethod |
218 |
| - def get_supported_archs() -> List[str]: |
219 |
| - return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) |
220 |
| - |
221 |
| - @staticmethod |
222 |
| - def register_model(model_arch: str, model_cls: Type[nn.Module]): |
223 |
| - if model_arch in _MODELS: |
224 |
| - logger.warning( |
225 |
| - "Model architecture %s is already registered, and will be " |
226 |
| - "overwritten by the new model class %s.", model_arch, |
227 |
| - model_cls.__name__) |
228 |
| - |
229 |
| - _OOT_MODELS[model_arch] = model_cls |
230 |
| - |
231 |
| - @staticmethod |
232 |
| - @lru_cache(maxsize=128) |
233 |
| - def _check_stateless( |
234 |
| - func: Callable[[Type[nn.Module]], bool], |
235 |
| - model_arch: str, |
236 |
| - *, |
237 |
| - default: Optional[bool] = None, |
238 |
| - ) -> bool: |
239 |
| - """ |
240 |
| - Run a boolean function against a model and return the result. |
241 |
| -
|
242 |
| - If the model is not found, returns the provided default value. |
243 |
| -
|
244 |
| - If the model is not already imported, the function is run inside a |
245 |
| - subprocess to avoid initializing CUDA for the main program. |
246 |
| - """ |
247 |
| - model = ModelRegistry._try_get_model_stateless(model_arch) |
248 |
| - if model is not None: |
249 |
| - return func(model) |
250 |
| - |
251 |
| - if model_arch not in _MODELS and default is not None: |
252 |
| - return default |
253 |
| - |
254 |
| - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) |
255 |
| - |
256 |
| - valid_name_characters = string.ascii_letters + string.digits + "._" |
257 |
| - if any(s not in valid_name_characters for s in module_name): |
258 |
| - raise ValueError(f"Unsafe module name detected for {model_arch}") |
259 |
| - if any(s not in valid_name_characters for s in cls_name): |
260 |
| - raise ValueError(f"Unsafe class name detected for {model_arch}") |
261 |
| - if any(s not in valid_name_characters for s in func.__module__): |
262 |
| - raise ValueError(f"Unsafe module name detected for {func}") |
263 |
| - if any(s not in valid_name_characters for s in func.__name__): |
264 |
| - raise ValueError(f"Unsafe class name detected for {func}") |
265 |
| - |
266 |
| - err_id = uuid.uuid4() |
267 |
| - |
268 |
| - stmts = ";".join([ |
269 |
| - f"from {module_name} import {cls_name}", |
270 |
| - f"from {func.__module__} import {func.__name__}", |
271 |
| - f"assert {func.__name__}({cls_name}), '{err_id}'", |
272 |
| - ]) |
273 |
| - |
274 |
| - result = subprocess.run([sys.executable, "-c", stmts], |
275 |
| - capture_output=True) |
276 |
| - |
277 |
| - if result.returncode != 0: |
278 |
| - err_lines = [line.decode() for line in result.stderr.splitlines()] |
279 |
| - if err_lines and err_lines[-1] != f"AssertionError: {err_id}": |
280 |
| - err_str = "\n".join(err_lines) |
281 |
| - raise RuntimeError( |
282 |
| - "An unexpected error occurred while importing the model in " |
283 |
| - f"another process. Error log:\n{err_str}") |
284 |
| - |
285 |
| - return result.returncode == 0 |
286 |
| - |
287 |
| - @staticmethod |
288 |
| - def is_embedding_model(architectures: Union[str, List[str]]) -> bool: |
289 |
| - if isinstance(architectures, str): |
290 |
| - architectures = [architectures] |
291 |
| - if not architectures: |
292 |
| - logger.warning("No model architectures are specified") |
293 |
| - |
294 |
| - return any(arch in _EMBEDDING_MODELS for arch in architectures) |
295 |
| - |
296 |
| - @staticmethod |
297 |
| - def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: |
298 |
| - if isinstance(architectures, str): |
299 |
| - architectures = [architectures] |
300 |
| - if not architectures: |
301 |
| - logger.warning("No model architectures are specified") |
302 |
| - |
303 |
| - is_mm = partial(ModelRegistry._check_stateless, |
304 |
| - supports_multimodal, |
305 |
| - default=False) |
306 |
| - |
307 |
| - return any(is_mm(arch) for arch in architectures) |
308 |
| - |
309 |
| - @staticmethod |
310 |
| - def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: |
311 |
| - if isinstance(architectures, str): |
312 |
| - architectures = [architectures] |
313 |
| - if not architectures: |
314 |
| - logger.warning("No model architectures are specified") |
315 |
| - |
316 |
| - is_pp = partial(ModelRegistry._check_stateless, |
317 |
| - supports_pp, |
318 |
| - default=False) |
319 |
| - |
320 |
| - return any(is_pp(arch) for arch in architectures) |
321 |
| - |
| 1 | +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, |
| 2 | + SupportsPP, has_inner_state, supports_lora, |
| 3 | + supports_multimodal, supports_pp) |
| 4 | +from .registry import ModelRegistry |
322 | 5 |
|
323 | 6 | __all__ = [
|
324 | 7 | "ModelRegistry",
|
| 8 | + "HasInnerState", |
| 9 | + "has_inner_state", |
| 10 | + "SupportsLoRA", |
| 11 | + "supports_lora", |
| 12 | + "SupportsMultiModal", |
| 13 | + "supports_multimodal", |
| 14 | + "SupportsPP", |
| 15 | + "supports_pp", |
325 | 16 | ]
|
0 commit comments