Skip to content

Commit a16ec5d

Browse files
committed
Add test for loading lora from huggingface
1 parent a6a30b9 commit a16ec5d

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

tests/lora/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ def sql_lora_files():
163163
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
164164

165165

166+
@pytest.fixture(scope="session")
167+
def sql_lora_huggingface_id():
168+
# huggingface repo id is used to test lora runtime downloading.
169+
return "yard1/llama-2-7b-sql-lora-test"
170+
171+
166172
@pytest.fixture(scope="session")
167173
def mixtral_lora_files():
168174
# Note: this module has incorrect adapter_config.json to test

tests/lora/test_lora_huggingface.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from vllm.lora.models import LoRAModel
6+
from vllm.lora.utils import get_adapter_absolute_path
7+
from vllm.model_executor.models.llama import LlamaForCausalLM
8+
9+
# Provide absolute path and huggingface lora ids
10+
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
11+
12+
13+
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
14+
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
15+
lora_name = request.getfixturevalue(lora_fixture_name)
16+
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
17+
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
18+
embedding_modules = LlamaForCausalLM.embedding_modules
19+
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
20+
expected_lora_modules: List[str] = []
21+
for module in supported_lora_modules:
22+
if module in packed_modules_mapping:
23+
expected_lora_modules.extend(packed_modules_mapping[module])
24+
else:
25+
expected_lora_modules.append(module)
26+
27+
lora_path = get_adapter_absolute_path(lora_name)
28+
29+
# lora loading should work for either absolute path and hugggingface id.
30+
lora_model = LoRAModel.from_local_checkpoint(
31+
lora_path,
32+
expected_lora_modules,
33+
lora_model_id=1,
34+
device="cpu",
35+
embedding_modules=embedding_modules,
36+
embedding_padding_modules=embed_padding_modules)
37+
38+
# Assertions to ensure the model is loaded correctly
39+
assert lora_model is not None, "LoRAModel is not loaded correctly"

0 commit comments

Comments
 (0)