Skip to content

Commit ddc729b

Browse files
committed
Move registry to its own file
1 parent 0f6d7a9 commit ddc729b

File tree

7 files changed

+339
-333
lines changed

7 files changed

+339
-333
lines changed

docs/source/models/adding_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
9999
5. Register your model
100100
----------------------
101101

102-
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_.
102+
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
103103

104104
6. Out-of-Tree Model Integration
105105
--------------------------------------------

vllm/lora/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from vllm.lora.punica import PunicaWrapper
2525
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
2626
parse_fine_tuned_lora_name, replace_submodule)
27-
from vllm.model_executor.models.interfaces import (SupportsLoRA,
28-
supports_multimodal)
27+
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
2928
from vllm.model_executor.models.module_mapping import MultiModelKeys
3029
from vllm.model_executor.models.utils import PPMissingLayer
3130
from vllm.utils import is_pin_memory_available

vllm/model_executor/model_loader/loader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@
4141
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
4242
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
4343
safetensors_weights_iterator)
44-
from vllm.model_executor.models.interfaces import (has_inner_state,
45-
supports_lora,
46-
supports_multimodal)
44+
from vllm.model_executor.models import (has_inner_state, supports_lora,
45+
supports_multimodal)
4746
from vllm.model_executor.utils import set_weight_attrs
4847
from vllm.platforms import current_platform
4948
from vllm.utils import is_pin_memory_available
Lines changed: 12 additions & 321 deletions
Original file line numberDiff line numberDiff line change
@@ -1,325 +1,16 @@
1-
import importlib
2-
import string
3-
import subprocess
4-
import sys
5-
import uuid
6-
from functools import lru_cache, partial
7-
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
8-
9-
import torch.nn as nn
10-
11-
from vllm.logger import init_logger
12-
from vllm.utils import is_hip
13-
14-
from .interfaces import supports_multimodal, supports_pp
15-
16-
logger = init_logger(__name__)
17-
18-
_GENERATION_MODELS = {
19-
"AquilaModel": ("llama", "LlamaForCausalLM"),
20-
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
21-
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
22-
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
23-
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
24-
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
25-
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
26-
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
27-
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
28-
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
29-
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
30-
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
31-
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
32-
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
33-
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
34-
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
35-
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
36-
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
37-
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
38-
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
39-
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
40-
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
41-
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
42-
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
43-
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
44-
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
45-
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
46-
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
47-
# For decapoda-research/llama-*
48-
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
49-
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
50-
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
51-
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
52-
# transformers's mpt class has lower case
53-
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
54-
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
55-
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
56-
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
57-
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
58-
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
59-
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
60-
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
61-
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
62-
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
63-
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
64-
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
65-
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
66-
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
67-
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
68-
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
69-
"Qwen2VLForConditionalGeneration":
70-
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
71-
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
72-
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
73-
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
74-
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
75-
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
76-
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
77-
# NOTE: The below models are for speculative decoding only
78-
"MedusaModel": ("medusa", "Medusa"),
79-
"EAGLEModel": ("eagle", "EAGLE"),
80-
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
81-
}
82-
83-
_EMBEDDING_MODELS = {
84-
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
85-
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
86-
}
87-
88-
_MULTIMODAL_MODELS = {
89-
"Blip2ForConditionalGeneration":
90-
("blip2", "Blip2ForConditionalGeneration"),
91-
"ChameleonForConditionalGeneration":
92-
("chameleon", "ChameleonForConditionalGeneration"),
93-
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
94-
"InternVLChatModel": ("internvl", "InternVLChatModel"),
95-
"LlavaForConditionalGeneration": ("llava",
96-
"LlavaForConditionalGeneration"),
97-
"LlavaNextForConditionalGeneration": ("llava_next",
98-
"LlavaNextForConditionalGeneration"),
99-
"LlavaNextVideoForConditionalGeneration":
100-
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
101-
"LlavaOnevisionForConditionalGeneration":
102-
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
103-
"MiniCPMV": ("minicpmv", "MiniCPMV"),
104-
"PaliGemmaForConditionalGeneration": ("paligemma",
105-
"PaliGemmaForConditionalGeneration"),
106-
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
107-
"PixtralForConditionalGeneration": ("pixtral",
108-
"PixtralForConditionalGeneration"),
109-
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
110-
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
111-
"Qwen2VLForConditionalGeneration"),
112-
"UltravoxModel": ("ultravox", "UltravoxModel"),
113-
"MllamaForConditionalGeneration": ("mllama",
114-
"MllamaForConditionalGeneration"),
115-
}
116-
_CONDITIONAL_GENERATION_MODELS = {
117-
"BartModel": ("bart", "BartForConditionalGeneration"),
118-
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
119-
}
120-
121-
_MODELS = {
122-
**_GENERATION_MODELS,
123-
**_EMBEDDING_MODELS,
124-
**_MULTIMODAL_MODELS,
125-
**_CONDITIONAL_GENERATION_MODELS,
126-
}
127-
128-
# Architecture -> type.
129-
# out of tree models
130-
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
131-
132-
# Models not supported by ROCm.
133-
_ROCM_UNSUPPORTED_MODELS: List[str] = []
134-
135-
# Models partially supported by ROCm.
136-
# Architecture -> Reason.
137-
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
138-
"Triton flash attention. For half-precision SWA support, "
139-
"please use CK flash attention by setting "
140-
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
141-
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
142-
"Qwen2ForCausalLM":
143-
_ROCM_SWA_REASON,
144-
"MistralForCausalLM":
145-
_ROCM_SWA_REASON,
146-
"MixtralForCausalLM":
147-
_ROCM_SWA_REASON,
148-
"PaliGemmaForConditionalGeneration":
149-
("ROCm flash attention does not yet "
150-
"fully support 32-bit precision on PaliGemma"),
151-
"Phi3VForCausalLM":
152-
("ROCm Triton flash attention may run into compilation errors due to "
153-
"excessive use of shared memory. If this happens, disable Triton FA "
154-
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
155-
}
156-
157-
158-
class ModelRegistry:
159-
160-
@staticmethod
161-
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
162-
module_relname, cls_name = _MODELS[model_arch]
163-
return f"vllm.model_executor.models.{module_relname}", cls_name
164-
165-
@staticmethod
166-
@lru_cache(maxsize=128)
167-
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
168-
if model_arch not in _MODELS:
169-
return None
170-
171-
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
172-
module = importlib.import_module(module_name)
173-
return getattr(module, cls_name, None)
174-
175-
@staticmethod
176-
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
177-
if model_arch in _OOT_MODELS:
178-
return _OOT_MODELS[model_arch]
179-
180-
if is_hip():
181-
if model_arch in _ROCM_UNSUPPORTED_MODELS:
182-
raise ValueError(
183-
f"Model architecture {model_arch} is not supported by "
184-
"ROCm for now.")
185-
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
186-
logger.warning(
187-
"Model architecture %s is partially supported by ROCm: %s",
188-
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
189-
190-
return None
191-
192-
@staticmethod
193-
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
194-
model = ModelRegistry._try_get_model_stateless(model_arch)
195-
if model is not None:
196-
return model
197-
198-
return ModelRegistry._try_get_model_stateful(model_arch)
199-
200-
@staticmethod
201-
def resolve_model_cls(
202-
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
203-
if isinstance(architectures, str):
204-
architectures = [architectures]
205-
if not architectures:
206-
logger.warning("No model architectures are specified")
207-
208-
for arch in architectures:
209-
model_cls = ModelRegistry._try_load_model_cls(arch)
210-
if model_cls is not None:
211-
return (model_cls, arch)
212-
213-
raise ValueError(
214-
f"Model architectures {architectures} are not supported for now. "
215-
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
216-
217-
@staticmethod
218-
def get_supported_archs() -> List[str]:
219-
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
220-
221-
@staticmethod
222-
def register_model(model_arch: str, model_cls: Type[nn.Module]):
223-
if model_arch in _MODELS:
224-
logger.warning(
225-
"Model architecture %s is already registered, and will be "
226-
"overwritten by the new model class %s.", model_arch,
227-
model_cls.__name__)
228-
229-
_OOT_MODELS[model_arch] = model_cls
230-
231-
@staticmethod
232-
@lru_cache(maxsize=128)
233-
def _check_stateless(
234-
func: Callable[[Type[nn.Module]], bool],
235-
model_arch: str,
236-
*,
237-
default: Optional[bool] = None,
238-
) -> bool:
239-
"""
240-
Run a boolean function against a model and return the result.
241-
242-
If the model is not found, returns the provided default value.
243-
244-
If the model is not already imported, the function is run inside a
245-
subprocess to avoid initializing CUDA for the main program.
246-
"""
247-
model = ModelRegistry._try_get_model_stateless(model_arch)
248-
if model is not None:
249-
return func(model)
250-
251-
if model_arch not in _MODELS and default is not None:
252-
return default
253-
254-
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
255-
256-
valid_name_characters = string.ascii_letters + string.digits + "._"
257-
if any(s not in valid_name_characters for s in module_name):
258-
raise ValueError(f"Unsafe module name detected for {model_arch}")
259-
if any(s not in valid_name_characters for s in cls_name):
260-
raise ValueError(f"Unsafe class name detected for {model_arch}")
261-
if any(s not in valid_name_characters for s in func.__module__):
262-
raise ValueError(f"Unsafe module name detected for {func}")
263-
if any(s not in valid_name_characters for s in func.__name__):
264-
raise ValueError(f"Unsafe class name detected for {func}")
265-
266-
err_id = uuid.uuid4()
267-
268-
stmts = ";".join([
269-
f"from {module_name} import {cls_name}",
270-
f"from {func.__module__} import {func.__name__}",
271-
f"assert {func.__name__}({cls_name}), '{err_id}'",
272-
])
273-
274-
result = subprocess.run([sys.executable, "-c", stmts],
275-
capture_output=True)
276-
277-
if result.returncode != 0:
278-
err_lines = [line.decode() for line in result.stderr.splitlines()]
279-
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
280-
err_str = "\n".join(err_lines)
281-
raise RuntimeError(
282-
"An unexpected error occurred while importing the model in "
283-
f"another process. Error log:\n{err_str}")
284-
285-
return result.returncode == 0
286-
287-
@staticmethod
288-
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
289-
if isinstance(architectures, str):
290-
architectures = [architectures]
291-
if not architectures:
292-
logger.warning("No model architectures are specified")
293-
294-
return any(arch in _EMBEDDING_MODELS for arch in architectures)
295-
296-
@staticmethod
297-
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
298-
if isinstance(architectures, str):
299-
architectures = [architectures]
300-
if not architectures:
301-
logger.warning("No model architectures are specified")
302-
303-
is_mm = partial(ModelRegistry._check_stateless,
304-
supports_multimodal,
305-
default=False)
306-
307-
return any(is_mm(arch) for arch in architectures)
308-
309-
@staticmethod
310-
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
311-
if isinstance(architectures, str):
312-
architectures = [architectures]
313-
if not architectures:
314-
logger.warning("No model architectures are specified")
315-
316-
is_pp = partial(ModelRegistry._check_stateless,
317-
supports_pp,
318-
default=False)
319-
320-
return any(is_pp(arch) for arch in architectures)
321-
1+
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
2+
SupportsPP, has_inner_state, supports_lora,
3+
supports_multimodal, supports_pp)
4+
from .registry import ModelRegistry
3225

3236
__all__ = [
3247
"ModelRegistry",
8+
"HasInnerState",
9+
"has_inner_state",
10+
"SupportsLoRA",
11+
"supports_lora",
12+
"SupportsMultiModal",
13+
"supports_multimodal",
14+
"SupportsPP",
15+
"supports_pp",
32516
]

vllm/model_executor/models/jamba.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,18 @@
2525
causal_conv1d_fn, causal_conv1d_update)
2626
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
2727
selective_scan_fn, selective_state_update)
28-
from vllm.model_executor.layers.quantization.base_config import (
29-
QuantizationConfig)
28+
from vllm.model_executor.layers.quantization import QuantizationConfig
3029
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
3130
from vllm.model_executor.layers.vocab_parallel_embedding import (
3231
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
3332
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34-
from vllm.model_executor.models.interfaces import HasInnerState
3533
from vllm.model_executor.sampling_metadata import SamplingMetadata
3634
from vllm.model_executor.utils import set_weight_attrs
3735
from vllm.sequence import IntermediateTensors
3836
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
3937
_get_graph_batch_size)
4038

41-
from .interfaces import SupportsLoRA
39+
from .interfaces import HasInnerState, SupportsLoRA
4240

4341
KVCache = Tuple[torch.Tensor, torch.Tensor]
4442

0 commit comments

Comments
 (0)