Skip to content

Commit d17c847

Browse files
jeejeeleesimon-mo
andauthored
[Bugfix] Fix LoRA loading check (#4138)
Co-authored-by: simon-mo <simon.mo@hey.com>
1 parent a134ef6 commit d17c847

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

tests/lora/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ def baichuan_lora_files():
143143
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
144144

145145

146+
@pytest.fixture(scope="session")
147+
def baichuan_zero_lora_files():
148+
# all the lora_B weights are initialized to zero.
149+
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
150+
151+
146152
@pytest.fixture(scope="session")
147153
def tinyllama_lora_files():
148154
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

tests/lora/test_lora_checkpoints.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33
from vllm.lora.models import LoRAModel
44
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
55

6+
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
67

7-
@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
8-
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
8+
9+
@pytest.mark.parametrize("lora_name", lora_lst)
10+
def test_load_checkpoints(
11+
lora_name,
12+
baichuan_lora_files,
13+
baichuan_zero_lora_files,
14+
chatglm3_lora_files,
15+
):
916
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
1017
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
1118
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
@@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
2633
device="cpu",
2734
embedding_modules=embedding_modules,
2835
embedding_padding_modules=embed_padding_modules)
36+
elif lora_name == "baichuan7B-zero":
37+
#Test that the target_modules contain prefix
38+
# such as "model.layers.0.self_atten.W_pack", and
39+
# the test should pass.
40+
LoRAModel.from_local_checkpoint(
41+
baichuan_zero_lora_files,
42+
expected_lora_modules,
43+
lora_model_id=1,
44+
device="cpu",
45+
embedding_modules=embedding_modules,
46+
embedding_padding_modules=embed_padding_modules)
2947
else:
3048
# For the baichuan7B model, load chatglm3-6b's LoRA,
3149
# and the test should raise the following error.

vllm/lora/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def from_local_checkpoint(
212212
target_modules = config["target_modules"]
213213
unexpected_modules = []
214214
for module in target_modules:
215-
if module not in expected_lora_modules:
215+
# Compatible with more modules, such as:layers.11.self_attn.k_proj
216+
part_name = module.split(".")[-1]
217+
if part_name not in expected_lora_modules:
216218
unexpected_modules.append(module)
217219
# loaded lora's target modules must be a subset of expected_lora_modules
218220
if unexpected_modules:

0 commit comments

Comments
 (0)