Skip to content

Commit acc836a

Browse files
committed
Fix lora bug and modify minicpmv lora tests
1 parent bbfd3e0 commit acc836a

File tree

3 files changed

+108
-10
lines changed

3 files changed

+108
-10
lines changed

tests/lora/test_minicpmv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
2727
sampling_params = vllm.SamplingParams(
2828
temperature=0,
29-
max_tokens=256,
29+
max_tokens=5,
3030
stop_token_ids=[128001, 128009], # eos_id, eot_id
3131
)
3232

@@ -65,7 +65,7 @@ def test_minicpmv_lora(minicpmv_lora_files):
6565

6666
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
6767
for i in range(len(EXPECTED_OUTPUT)):
68-
assert output1[i] == EXPECTED_OUTPUT[i]
68+
assert EXPECTED_OUTPUT[i].startswith(output1[i])
6969
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
7070
for i in range(len(EXPECTED_OUTPUT)):
71-
assert output2[i] == EXPECTED_OUTPUT[i]
71+
assert EXPECTED_OUTPUT[i].startswith(output2[i])

tests/lora/test_minicpmv_tp.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
import vllm
6+
from vllm.assets.image import ImageAsset
7+
from vllm.lora.request import LoRARequest
8+
9+
from ..utils import multi_gpu_test
10+
11+
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
12+
13+
PROMPT_TEMPLATE = (
14+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
15+
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
16+
"<|start_header_id|>assistant<|end_header_id|>\n\n")
17+
18+
IMAGE_ASSETS = [
19+
ImageAsset("stop_sign"),
20+
ImageAsset("cherry_blossom"),
21+
]
22+
23+
# After fine-tuning with LoRA, all generated content should start begin `A`.
24+
EXPECTED_OUTPUT = [
25+
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
26+
"A pink cherry blossom tree with a blue sky in the background.",
27+
]
28+
29+
30+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
31+
sampling_params = vllm.SamplingParams(
32+
temperature=0,
33+
max_tokens=5,
34+
stop_token_ids=[128001, 128009], # eos_id, eot_id
35+
)
36+
37+
inputs = [{
38+
"prompt": PROMPT_TEMPLATE,
39+
"multi_modal_data": {
40+
"image": asset.pil_image
41+
},
42+
} for asset in IMAGE_ASSETS]
43+
44+
outputs = llm.generate(
45+
inputs,
46+
sampling_params,
47+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
48+
if lora_id else None,
49+
)
50+
# Print the outputs.
51+
generated_texts: List[str] = []
52+
for output in outputs:
53+
prompt = output.prompt
54+
generated_text = output.outputs[0].text.strip()
55+
generated_texts.append(generated_text)
56+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
57+
return generated_texts
58+
59+
60+
@multi_gpu_test(num_gpus=2)
61+
@pytest.mark.parametrize("fully_sharded", [True, False])
62+
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
63+
llm = vllm.LLM(
64+
MODEL_PATH,
65+
enable_lora=True,
66+
max_num_seqs=2,
67+
max_loras=4,
68+
max_lora_rank=64,
69+
tensor_parallel_size=2,
70+
trust_remote_code=True,
71+
fully_sharded_loras=fully_sharded,
72+
)
73+
74+
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
75+
76+
for i in range(len(EXPECTED_OUTPUT)):
77+
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
78+
79+
80+
@multi_gpu_test(num_gpus=4)
81+
@pytest.mark.parametrize("fully_sharded", [True, False])
82+
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
83+
llm = vllm.LLM(
84+
MODEL_PATH,
85+
enable_lora=True,
86+
max_num_seqs=2,
87+
max_loras=4,
88+
max_lora_rank=64,
89+
tensor_parallel_size=4,
90+
trust_remote_code=True,
91+
fully_sharded_loras=fully_sharded,
92+
)
93+
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
94+
for i in range(len(EXPECTED_OUTPUT)):
95+
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])

vllm/lora/models.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -456,13 +456,7 @@ def _create_lora_modules(self):
456456
self.model, module_name,
457457
from_layer(module, self.lora_slots, self.lora_config,
458458
packed_moduled_lst, self.model.config))
459-
# In some models, especially multimodal ones, layers with the same
460-
# name may have different types, such as nn.Linear and
461-
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
462-
# LoRA layers, leading to assertion error. The following check
463-
# aims to prevent this error
464-
if not isinstance(new_module, BaseLayerWithLoRA):
465-
continue
459+
466460
# LinearScalingRotaryEmbeddingWithLora is used to handle
467461
# long context lora. Register relevant metadata.
468462
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
@@ -480,6 +474,15 @@ def _create_lora_modules(self):
480474
module, self.lora_slots,
481475
self.lora_config,
482476
self.model.config))
477+
478+
# In some models, especially multimodal ones, layers with the same
479+
# name may have different types, such as nn.Linear and
480+
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
481+
# LoRA layers, leading to assertion error. The following check
482+
# aims to prevent this error
483+
if self.supports_mm and not isinstance(new_module,
484+
BaseLayerWithLoRA):
485+
continue
483486
self.register_module(module_name, new_module)
484487
self._register_packed_modules(module_name)
485488
# All lora layers share the same punica_wrapper based on reference.

0 commit comments

Comments
 (0)