Skip to content

Commit 6a6f477

Browse files
committed
Merge branch 'main' into nvlm_d
2 parents 2ec7fc1 + 0e36fd4 commit 6a6f477

File tree

8 files changed

+333
-327
lines changed

8 files changed

+333
-327
lines changed

docs/source/models/adding_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
9999
5. Register your model
100100
----------------------
101101

102-
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_.
102+
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
103103

104104
6. Out-of-Tree Model Integration
105105
--------------------------------------------

tests/models/test_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import pytest
44
import torch.cuda
55

6-
from vllm.model_executor.models import _MODELS, ModelRegistry
6+
from vllm.model_executor.models import ModelRegistry
77
from vllm.platforms import current_platform
88

99
from ..utils import fork_new_process_for_each_test
1010

1111

12-
@pytest.mark.parametrize("model_arch", _MODELS)
12+
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
1313
def test_registry_imports(model_arch):
1414
# Ensure all model classes can be imported successfully
1515
ModelRegistry.resolve_model_cls(model_arch)

vllm/lora/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from vllm.lora.punica import PunicaWrapper
2525
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
2626
parse_fine_tuned_lora_name, replace_submodule)
27-
from vllm.model_executor.models.interfaces import (SupportsLoRA,
28-
supports_multimodal)
27+
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
2928
from vllm.model_executor.models.module_mapping import MultiModelKeys
3029
from vllm.model_executor.models.utils import PPMissingLayer
3130
from vllm.utils import is_pin_memory_available

vllm/model_executor/model_loader/loader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@
4141
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
4242
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
4343
safetensors_weights_iterator)
44-
from vllm.model_executor.models.interfaces import (has_inner_state,
45-
supports_lora,
46-
supports_multimodal)
44+
from vllm.model_executor.models import (has_inner_state, supports_lora,
45+
supports_multimodal)
4746
from vllm.model_executor.utils import set_weight_attrs
4847
from vllm.platforms import current_platform
4948
from vllm.utils import is_pin_memory_available
Lines changed: 12 additions & 313 deletions
Original file line numberDiff line numberDiff line change
@@ -1,317 +1,16 @@
1-
import importlib
2-
import string
3-
import subprocess
4-
import sys
5-
import uuid
6-
from functools import lru_cache, partial
7-
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
8-
9-
import torch.nn as nn
10-
11-
from vllm.logger import init_logger
12-
from vllm.utils import is_hip
13-
14-
from .interfaces import supports_multimodal, supports_pp
15-
16-
logger = init_logger(__name__)
17-
18-
_GENERATION_MODELS = {
19-
"AquilaModel": ("llama", "LlamaForCausalLM"),
20-
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
21-
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
22-
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
23-
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
24-
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
25-
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
26-
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
27-
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
28-
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
29-
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
30-
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
31-
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
32-
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
33-
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
34-
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
35-
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
36-
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
37-
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
38-
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
39-
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
40-
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
41-
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
42-
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
43-
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
44-
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
45-
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
46-
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
47-
# For decapoda-research/llama-*
48-
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
49-
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
50-
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
51-
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
52-
# transformers's mpt class has lower case
53-
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
54-
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
55-
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
56-
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
57-
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
58-
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
59-
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
60-
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
61-
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
62-
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
63-
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
64-
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
65-
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
66-
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
67-
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
68-
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
69-
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
70-
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
71-
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
72-
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
73-
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
74-
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
75-
# NOTE: The below models are for speculative decoding only
76-
"MedusaModel": ("medusa", "Medusa"),
77-
"EAGLEModel": ("eagle", "EAGLE"),
78-
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
79-
}
80-
81-
_EMBEDDING_MODELS = {
82-
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
83-
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
84-
}
85-
86-
# yapf: disable
87-
_MULTIMODAL_MODELS = {
88-
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
89-
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
90-
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
91-
"InternVLChatModel": ("internvl", "InternVLChatModel"),
92-
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
93-
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
94-
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
95-
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
96-
"MiniCPMV": ("minicpmv", "MiniCPMV"),
97-
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
98-
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
99-
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
100-
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
101-
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
102-
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
103-
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
104-
"UltravoxModel": ("ultravox", "UltravoxModel"),
105-
}
106-
# yapf: enable
107-
108-
_CONDITIONAL_GENERATION_MODELS = {
109-
"BartModel": ("bart", "BartForConditionalGeneration"),
110-
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
111-
}
112-
113-
_MODELS = {
114-
**_GENERATION_MODELS,
115-
**_EMBEDDING_MODELS,
116-
**_MULTIMODAL_MODELS,
117-
**_CONDITIONAL_GENERATION_MODELS,
118-
}
119-
120-
# Architecture -> type.
121-
# out of tree models
122-
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
123-
124-
# Models not supported by ROCm.
125-
_ROCM_UNSUPPORTED_MODELS: List[str] = []
126-
127-
# Models partially supported by ROCm.
128-
# Architecture -> Reason.
129-
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
130-
"Triton flash attention. For half-precision SWA support, "
131-
"please use CK flash attention by setting "
132-
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
133-
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
134-
"Qwen2ForCausalLM":
135-
_ROCM_SWA_REASON,
136-
"MistralForCausalLM":
137-
_ROCM_SWA_REASON,
138-
"MixtralForCausalLM":
139-
_ROCM_SWA_REASON,
140-
"PaliGemmaForConditionalGeneration":
141-
("ROCm flash attention does not yet "
142-
"fully support 32-bit precision on PaliGemma"),
143-
"Phi3VForCausalLM":
144-
("ROCm Triton flash attention may run into compilation errors due to "
145-
"excessive use of shared memory. If this happens, disable Triton FA "
146-
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
147-
}
148-
149-
150-
class ModelRegistry:
151-
152-
@staticmethod
153-
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
154-
module_relname, cls_name = _MODELS[model_arch]
155-
return f"vllm.model_executor.models.{module_relname}", cls_name
156-
157-
@staticmethod
158-
@lru_cache(maxsize=128)
159-
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
160-
if model_arch not in _MODELS:
161-
return None
162-
163-
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
164-
module = importlib.import_module(module_name)
165-
return getattr(module, cls_name, None)
166-
167-
@staticmethod
168-
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
169-
if model_arch in _OOT_MODELS:
170-
return _OOT_MODELS[model_arch]
171-
172-
if is_hip():
173-
if model_arch in _ROCM_UNSUPPORTED_MODELS:
174-
raise ValueError(
175-
f"Model architecture {model_arch} is not supported by "
176-
"ROCm for now.")
177-
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
178-
logger.warning(
179-
"Model architecture %s is partially supported by ROCm: %s",
180-
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
181-
182-
return None
183-
184-
@staticmethod
185-
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
186-
model = ModelRegistry._try_get_model_stateless(model_arch)
187-
if model is not None:
188-
return model
189-
190-
return ModelRegistry._try_get_model_stateful(model_arch)
191-
192-
@staticmethod
193-
def resolve_model_cls(
194-
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
195-
if isinstance(architectures, str):
196-
architectures = [architectures]
197-
if not architectures:
198-
logger.warning("No model architectures are specified")
199-
200-
for arch in architectures:
201-
model_cls = ModelRegistry._try_load_model_cls(arch)
202-
if model_cls is not None:
203-
return (model_cls, arch)
204-
205-
raise ValueError(
206-
f"Model architectures {architectures} are not supported for now. "
207-
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
208-
209-
@staticmethod
210-
def get_supported_archs() -> List[str]:
211-
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
212-
213-
@staticmethod
214-
def register_model(model_arch: str, model_cls: Type[nn.Module]):
215-
if model_arch in _MODELS:
216-
logger.warning(
217-
"Model architecture %s is already registered, and will be "
218-
"overwritten by the new model class %s.", model_arch,
219-
model_cls.__name__)
220-
221-
_OOT_MODELS[model_arch] = model_cls
222-
223-
@staticmethod
224-
@lru_cache(maxsize=128)
225-
def _check_stateless(
226-
func: Callable[[Type[nn.Module]], bool],
227-
model_arch: str,
228-
*,
229-
default: Optional[bool] = None,
230-
) -> bool:
231-
"""
232-
Run a boolean function against a model and return the result.
233-
234-
If the model is not found, returns the provided default value.
235-
236-
If the model is not already imported, the function is run inside a
237-
subprocess to avoid initializing CUDA for the main program.
238-
"""
239-
model = ModelRegistry._try_get_model_stateless(model_arch)
240-
if model is not None:
241-
return func(model)
242-
243-
if model_arch not in _MODELS and default is not None:
244-
return default
245-
246-
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
247-
248-
valid_name_characters = string.ascii_letters + string.digits + "._"
249-
if any(s not in valid_name_characters for s in module_name):
250-
raise ValueError(f"Unsafe module name detected for {model_arch}")
251-
if any(s not in valid_name_characters for s in cls_name):
252-
raise ValueError(f"Unsafe class name detected for {model_arch}")
253-
if any(s not in valid_name_characters for s in func.__module__):
254-
raise ValueError(f"Unsafe module name detected for {func}")
255-
if any(s not in valid_name_characters for s in func.__name__):
256-
raise ValueError(f"Unsafe class name detected for {func}")
257-
258-
err_id = uuid.uuid4()
259-
260-
stmts = ";".join([
261-
f"from {module_name} import {cls_name}",
262-
f"from {func.__module__} import {func.__name__}",
263-
f"assert {func.__name__}({cls_name}), '{err_id}'",
264-
])
265-
266-
result = subprocess.run([sys.executable, "-c", stmts],
267-
capture_output=True)
268-
269-
if result.returncode != 0:
270-
err_lines = [line.decode() for line in result.stderr.splitlines()]
271-
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
272-
err_str = "\n".join(err_lines)
273-
raise RuntimeError(
274-
"An unexpected error occurred while importing the model in "
275-
f"another process. Error log:\n{err_str}")
276-
277-
return result.returncode == 0
278-
279-
@staticmethod
280-
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
281-
if isinstance(architectures, str):
282-
architectures = [architectures]
283-
if not architectures:
284-
logger.warning("No model architectures are specified")
285-
286-
return any(arch in _EMBEDDING_MODELS for arch in architectures)
287-
288-
@staticmethod
289-
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
290-
if isinstance(architectures, str):
291-
architectures = [architectures]
292-
if not architectures:
293-
logger.warning("No model architectures are specified")
294-
295-
is_mm = partial(ModelRegistry._check_stateless,
296-
supports_multimodal,
297-
default=False)
298-
299-
return any(is_mm(arch) for arch in architectures)
300-
301-
@staticmethod
302-
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
303-
if isinstance(architectures, str):
304-
architectures = [architectures]
305-
if not architectures:
306-
logger.warning("No model architectures are specified")
307-
308-
is_pp = partial(ModelRegistry._check_stateless,
309-
supports_pp,
310-
default=False)
311-
312-
return any(is_pp(arch) for arch in architectures)
313-
1+
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
2+
SupportsPP, has_inner_state, supports_lora,
3+
supports_multimodal, supports_pp)
4+
from .registry import ModelRegistry
3145

3156
__all__ = [
3167
"ModelRegistry",
8+
"HasInnerState",
9+
"has_inner_state",
10+
"SupportsLoRA",
11+
"supports_lora",
12+
"SupportsMultiModal",
13+
"supports_multimodal",
14+
"SupportsPP",
15+
"supports_pp",
31716
]

vllm/model_executor/models/jamba.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,18 @@
2525
causal_conv1d_fn, causal_conv1d_update)
2626
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
2727
selective_scan_fn, selective_state_update)
28-
from vllm.model_executor.layers.quantization.base_config import (
29-
QuantizationConfig)
28+
from vllm.model_executor.layers.quantization import QuantizationConfig
3029
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
3130
from vllm.model_executor.layers.vocab_parallel_embedding import (
3231
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
3332
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34-
from vllm.model_executor.models.interfaces import HasInnerState
3533
from vllm.model_executor.sampling_metadata import SamplingMetadata
3634
from vllm.model_executor.utils import set_weight_attrs
3735
from vllm.sequence import IntermediateTensors
3836
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
3937
_get_graph_batch_size)
4038

41-
from .interfaces import SupportsLoRA
39+
from .interfaces import HasInnerState, SupportsLoRA
4240

4341
KVCache = Tuple[torch.Tensor, torch.Tensor]
4442

0 commit comments

Comments
 (0)