Skip to content

Commit 3959b30

Browse files
ywang96DarkLight1337
authored andcommitted
[Core][VLM] Test registration for OOT multimodal models (vllm-project#8717)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 8877bd8 commit 3959b30

File tree

12 files changed

+227
-49
lines changed

12 files changed

+227
-49
lines changed

docs/source/models/adding_model.rst

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ When it comes to the linear layers, we provide the following options to parallel
8585
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
8686
* :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
8787
* :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
88-
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
88+
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple :code:`ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
8989
* :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
9090

91-
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
91+
Note that all the linear layers above take :code:`linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
9292

9393
4. Implement the weight loading logic
9494
-------------------------------------
9595

9696
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
97-
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
97+
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for :code:`MergedColumnParallelLinear` and :code:`QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
9898

9999
5. Register your model
100100
----------------------
@@ -114,6 +114,18 @@ Just add the following lines in your code:
114114
from your_code import YourModelForCausalLM
115115
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
116116
117+
If your model imports modules that initialize CUDA, consider instead lazy-importing it to avoid an error like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`:
118+
119+
.. code-block:: python
120+
121+
from vllm import ModelRegistry
122+
123+
ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM")
124+
125+
.. important::
126+
If your model is a multimodal model, make sure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
127+
Read more about that :ref:`here <enabling_multimodal_inputs>`.
128+
117129
If you are running api server with :code:`vllm serve <args>`, you can wrap the entrypoint with the following code:
118130

119131
.. code-block:: python

find_cuda_init.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import importlib
2+
import traceback
3+
from typing import Callable
4+
from unittest.mock import patch
5+
6+
7+
def find_cuda_init(fn: Callable[[], object]) -> None:
8+
"""
9+
Helper function to debug CUDA re-initialization errors.
10+
11+
If `fn` initializes CUDA, prints the stack trace of how this happens.
12+
"""
13+
from torch.cuda import _lazy_init
14+
15+
stack = None
16+
17+
def wrapper():
18+
nonlocal stack
19+
stack = traceback.extract_stack()
20+
return _lazy_init()
21+
22+
with patch("torch.cuda._lazy_init", wrapper):
23+
fn()
24+
25+
if stack is not None:
26+
print("==== CUDA Initialized ====")
27+
print("".join(traceback.format_list(stack)).strip())
28+
print("==========================")
29+
30+
31+
if __name__ == "__main__":
32+
find_cuda_init(
33+
lambda: importlib.import_module("vllm.model_executor.models.llava"))

tests/conftest.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -879,15 +879,16 @@ def num_gpus_available():
879879

880880

881881
temp_dir = tempfile.gettempdir()
882-
_dummy_path = os.path.join(temp_dir, "dummy_opt")
882+
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
883+
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
883884

884885

885886
@pytest.fixture
886887
def dummy_opt_path():
887-
json_path = os.path.join(_dummy_path, "config.json")
888-
if not os.path.exists(_dummy_path):
888+
json_path = os.path.join(_dummy_opt_path, "config.json")
889+
if not os.path.exists(_dummy_opt_path):
889890
snapshot_download(repo_id="facebook/opt-125m",
890-
local_dir=_dummy_path,
891+
local_dir=_dummy_opt_path,
891892
ignore_patterns=[
892893
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
893894
"*.msgpack"
@@ -898,4 +899,23 @@ def dummy_opt_path():
898899
config["architectures"] = ["MyOPTForCausalLM"]
899900
with open(json_path, "w") as f:
900901
json.dump(config, f)
901-
return _dummy_path
902+
return _dummy_opt_path
903+
904+
905+
@pytest.fixture
906+
def dummy_llava_path():
907+
json_path = os.path.join(_dummy_llava_path, "config.json")
908+
if not os.path.exists(_dummy_llava_path):
909+
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
910+
local_dir=_dummy_llava_path,
911+
ignore_patterns=[
912+
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
913+
"*.msgpack"
914+
])
915+
assert os.path.exists(json_path)
916+
with open(json_path, "r") as f:
917+
config = json.load(f)
918+
config["architectures"] = ["MyLlava"]
919+
with open(json_path, "w") as f:
920+
json.dump(config, f)
921+
return _dummy_llava_path

tests/entrypoints/openai/test_audio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def server():
2121
"--dtype",
2222
"bfloat16",
2323
"--max-model-len",
24-
"4096",
24+
"2048",
25+
"--max-num-seqs",
26+
"5",
2527
"--enforce-eager",
2628
]
2729

tests/entrypoints/openai/test_vision.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@
2323
@pytest.fixture(scope="module")
2424
def server():
2525
args = [
26-
"--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
27-
"5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
28-
f"image={MAXIMUM_IMAGES}"
26+
"--dtype",
27+
"bfloat16",
28+
"--max-model-len",
29+
"2048",
30+
"--max-num-seqs",
31+
"5",
32+
"--enforce-eager",
33+
"--trust-remote-code",
34+
"--limit-mm-per-prompt",
35+
f"image={MAXIMUM_IMAGES}",
2936
]
3037

3138
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/models/test_oot_registration.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from vllm import LLM, SamplingParams
6+
from vllm.assets.image import ImageAsset
67

78
from ..utils import fork_new_process_for_each_test
89

@@ -29,3 +30,40 @@ def test_oot_registration(dummy_opt_path):
2930
# make sure only the first token is generated
3031
rest = generated_text.replace(first_token, "")
3132
assert rest == ""
33+
34+
35+
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
36+
37+
38+
@fork_new_process_for_each_test
39+
def test_oot_multimodal_registration(dummy_llava_path):
40+
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
41+
prompts = [{
42+
"prompt": "What's in the image?<image>",
43+
"multi_modal_data": {
44+
"image": image
45+
},
46+
}, {
47+
"prompt": "Describe the image<image>",
48+
"multi_modal_data": {
49+
"image": image
50+
},
51+
}]
52+
53+
sampling_params = SamplingParams(temperature=0)
54+
llm = LLM(model=dummy_llava_path,
55+
load_format="dummy",
56+
max_num_seqs=1,
57+
trust_remote_code=True,
58+
gpu_memory_utilization=0.98,
59+
max_model_len=4096,
60+
enforce_eager=True,
61+
limit_mm_per_prompt={"image": 1})
62+
first_token = llm.get_tokenizer().decode(0)
63+
outputs = llm.generate(prompts, sampling_params)
64+
65+
for output in outputs:
66+
generated_text = output.outputs[0].text
67+
# make sure only the first token is generated
68+
rest = generated_text.replace(first_token, "")
69+
assert rest == ""
Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,14 @@
1-
from typing import Optional
2-
3-
import torch
4-
51
from vllm import ModelRegistry
6-
from vllm.model_executor.models.opt import OPTForCausalLM
7-
from vllm.model_executor.sampling_metadata import SamplingMetadata
8-
9-
10-
class MyOPTForCausalLM(OPTForCausalLM):
11-
12-
def compute_logits(
13-
self, hidden_states: torch.Tensor,
14-
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
15-
# this dummy model always predicts the first token
16-
logits = super().compute_logits(hidden_states, sampling_metadata)
17-
if logits is not None:
18-
logits.zero_()
19-
logits[:, 0] += 1.0
20-
return logits
212

223

234
def register():
24-
# register our dummy model
5+
# Test directly passing the model
6+
from .my_opt import MyOPTForCausalLM
7+
258
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
269
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
10+
11+
# Test passing lazy model
12+
if "MyLlava" not in ModelRegistry.get_supported_archs():
13+
ModelRegistry.register_model("MyLlava",
14+
"vllm_add_dummy_model.my_llava:MyLlava")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from vllm.inputs import INPUT_REGISTRY
6+
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
7+
dummy_data_for_llava,
8+
get_max_llava_image_tokens,
9+
input_processor_for_llava)
10+
from vllm.model_executor.sampling_metadata import SamplingMetadata
11+
from vllm.multimodal import MULTIMODAL_REGISTRY
12+
13+
14+
@MULTIMODAL_REGISTRY.register_image_input_mapper()
15+
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
16+
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
17+
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
18+
class MyLlava(LlavaForConditionalGeneration):
19+
20+
def compute_logits(
21+
self, hidden_states: torch.Tensor,
22+
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
23+
# this dummy model always predicts the first token
24+
logits = super().compute_logits(hidden_states, sampling_metadata)
25+
if logits is not None:
26+
logits.zero_()
27+
logits[:, 0] += 1.0
28+
return logits
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from vllm.model_executor.models.opt import OPTForCausalLM
6+
from vllm.model_executor.sampling_metadata import SamplingMetadata
7+
8+
9+
class MyOPTForCausalLM(OPTForCausalLM):
10+
11+
def compute_logits(
12+
self, hidden_states: torch.Tensor,
13+
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
14+
# this dummy model always predicts the first token
15+
logits = super().compute_logits(hidden_states, sampling_metadata)
16+
if logits is not None:
17+
logits.zero_()
18+
logits[:, 0] += 1.0
19+
return logits

vllm/engine/arg_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ class EngineArgs:
183183
def __post_init__(self):
184184
if self.tokenizer is None:
185185
self.tokenizer = self.model
186+
from vllm.plugins import load_general_plugins
187+
load_general_plugins()
186188

187189
@staticmethod
188190
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

vllm/engine/llm_engine.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,6 @@ def __init__(
290290
model_config.mm_processor_kwargs,
291291
)
292292
# TODO(woosuk): Print more configs in debug mode.
293-
from vllm.plugins import load_general_plugins
294-
load_general_plugins()
295-
296293
self.model_config = model_config
297294
self.cache_config = cache_config
298295
self.lora_config = lora_config

0 commit comments

Comments
 (0)