|
| 1 | +from typing import List |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from vllm.lora.models import LoRAModel |
| 6 | +from vllm.lora.utils import get_adapter_absolute_path |
| 7 | +from vllm.model_executor.models.llama import LlamaForCausalLM |
| 8 | + |
| 9 | +# Provide absolute path and huggingface lora ids |
| 10 | +lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] |
| 11 | + |
| 12 | + |
| 13 | +@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name) |
| 14 | +def test_load_checkpoints_from_huggingface(lora_fixture_name, request): |
| 15 | + lora_name = request.getfixturevalue(lora_fixture_name) |
| 16 | + supported_lora_modules = LlamaForCausalLM.supported_lora_modules |
| 17 | + packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping |
| 18 | + embedding_modules = LlamaForCausalLM.embedding_modules |
| 19 | + embed_padding_modules = LlamaForCausalLM.embedding_padding_modules |
| 20 | + expected_lora_modules: List[str] = [] |
| 21 | + for module in supported_lora_modules: |
| 22 | + if module in packed_modules_mapping: |
| 23 | + expected_lora_modules.extend(packed_modules_mapping[module]) |
| 24 | + else: |
| 25 | + expected_lora_modules.append(module) |
| 26 | + |
| 27 | + lora_path = get_adapter_absolute_path(lora_name) |
| 28 | + |
| 29 | + # lora loading should work for either absolute path and hugggingface id. |
| 30 | + lora_model = LoRAModel.from_local_checkpoint( |
| 31 | + lora_path, |
| 32 | + expected_lora_modules, |
| 33 | + lora_model_id=1, |
| 34 | + device="cpu", |
| 35 | + embedding_modules=embedding_modules, |
| 36 | + embedding_padding_modules=embed_padding_modules) |
| 37 | + |
| 38 | + # Assertions to ensure the model is loaded correctly |
| 39 | + assert lora_model is not None, "LoRAModel is not loaded correctly" |
0 commit comments