Skip to content

Commit 2ba0e68

Browse files
JeffwanYard1
authored andcommitted
[Core] Support dynamically loading Lora adapter from HuggingFace (vllm-project#6234)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com> Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent b3e41df commit 2ba0e68

File tree

11 files changed

+201
-18
lines changed

11 files changed

+201
-18
lines changed

tests/core/test_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
462462
lora_request=LoRARequest(
463463
lora_name=str(i),
464464
lora_int_id=i + 1,
465-
lora_local_path="abc"))
465+
lora_path="abc"))
466466
waiting.append(seq_group)
467467
# Add two more requests to verify lora is prioritized.
468468
# 0: Lora, 1: Lora, 2: regular, 3: regular
@@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
760760
lora_request=LoRARequest(
761761
lora_name=str(i),
762762
lora_int_id=i + 1,
763-
lora_local_path="abc"))
763+
lora_path="abc"))
764764
scheduler._allocate_and_set_running(seq_group)
765765
append_new_token_seq_group(60, seq_group, 1)
766766
scheduler._swap_out(seq_group, blocks_to_swap_out)

tests/lora/conftest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module:
159159

160160

161161
@pytest.fixture(scope="session")
162-
def sql_lora_files():
163-
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
162+
def sql_lora_huggingface_id():
163+
# huggingface repo id is used to test lora runtime downloading.
164+
return "yard1/llama-2-7b-sql-lora-test"
165+
166+
167+
@pytest.fixture(scope="session")
168+
def sql_lora_files(sql_lora_huggingface_id):
169+
return snapshot_download(repo_id=sql_lora_huggingface_id)
164170

165171

166172
@pytest.fixture(scope="session")

tests/lora/test_long_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _create_lora_request(lora_id, long_context_infos):
2929
context_len = long_context_infos[lora_id]["context_length"]
3030
scaling_factor = context_len_to_scaling_factor[context_len]
3131
return LoRARequest(context_len, lora_id,
32-
long_context_infos[lora_id]["lora"],
32+
long_context_infos[lora_id]["lora"], None,
3333
4096 * scaling_factor)
3434

3535

tests/lora/test_lora_huggingface.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from vllm.lora.models import LoRAModel
6+
from vllm.lora.utils import get_adapter_absolute_path
7+
from vllm.model_executor.models.llama import LlamaForCausalLM
8+
9+
# Provide absolute path and huggingface lora ids
10+
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
11+
12+
13+
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
14+
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
15+
lora_name = request.getfixturevalue(lora_fixture_name)
16+
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
17+
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
18+
embedding_modules = LlamaForCausalLM.embedding_modules
19+
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
20+
expected_lora_modules: List[str] = []
21+
for module in supported_lora_modules:
22+
if module in packed_modules_mapping:
23+
expected_lora_modules.extend(packed_modules_mapping[module])
24+
else:
25+
expected_lora_modules.append(module)
26+
27+
lora_path = get_adapter_absolute_path(lora_name)
28+
29+
# lora loading should work for either absolute path and hugggingface id.
30+
lora_model = LoRAModel.from_local_checkpoint(
31+
lora_path,
32+
expected_lora_modules,
33+
lora_model_id=1,
34+
device="cpu",
35+
embedding_modules=embedding_modules,
36+
embedding_padding_modules=embed_padding_modules)
37+
38+
# Assertions to ensure the model is loaded correctly
39+
assert lora_model is not None, "LoRAModel is not loaded correctly"

tests/lora/test_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from collections import OrderedDict
2+
from unittest.mock import patch
23

34
import pytest
5+
from huggingface_hub.utils import HfHubHTTPError
46
from torch import nn
57

6-
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
8+
from vllm.lora.utils import (get_adapter_absolute_path,
9+
parse_fine_tuned_lora_name, replace_submodule)
710
from vllm.utils import LRUCache
811

912

@@ -182,3 +185,55 @@ def test_lru_cache():
182185
assert 2 in cache
183186
assert 4 in cache
184187
assert 6 in cache
188+
189+
190+
# Unit tests for get_adapter_absolute_path
191+
@patch('os.path.isabs')
192+
def test_get_adapter_absolute_path_absolute(mock_isabs):
193+
path = '/absolute/path/to/lora'
194+
mock_isabs.return_value = True
195+
assert get_adapter_absolute_path(path) == path
196+
197+
198+
@patch('os.path.expanduser')
199+
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
200+
# Path with ~ that needs to be expanded
201+
path = '~/relative/path/to/lora'
202+
absolute_path = '/home/user/relative/path/to/lora'
203+
mock_expanduser.return_value = absolute_path
204+
assert get_adapter_absolute_path(path) == absolute_path
205+
206+
207+
@patch('os.path.exists')
208+
@patch('os.path.abspath')
209+
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
210+
# Relative path that exists locally
211+
path = 'relative/path/to/lora'
212+
absolute_path = '/absolute/path/to/lora'
213+
mock_exist.return_value = True
214+
mock_abspath.return_value = absolute_path
215+
assert get_adapter_absolute_path(path) == absolute_path
216+
217+
218+
@patch('huggingface_hub.snapshot_download')
219+
@patch('os.path.exists')
220+
def test_get_adapter_absolute_path_huggingface(mock_exist,
221+
mock_snapshot_download):
222+
# Hugging Face model identifier
223+
path = 'org/repo'
224+
absolute_path = '/mock/snapshot/path'
225+
mock_exist.return_value = False
226+
mock_snapshot_download.return_value = absolute_path
227+
assert get_adapter_absolute_path(path) == absolute_path
228+
229+
230+
@patch('huggingface_hub.snapshot_download')
231+
@patch('os.path.exists')
232+
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
233+
mock_snapshot_download):
234+
# Hugging Face model identifier with download error
235+
path = 'org/repo'
236+
mock_exist.return_value = False
237+
mock_snapshot_download.side_effect = HfHubHTTPError(
238+
"failed to query model info")
239+
assert get_adapter_absolute_path(path) == path

vllm/entrypoints/openai/serving_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class PromptAdapterPath:
4343
@dataclass
4444
class LoRAModulePath:
4545
name: str
46-
local_path: str
46+
path: str
4747

4848

4949
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
@@ -83,7 +83,7 @@ def __init__(
8383
LoRARequest(
8484
lora_name=lora.name,
8585
lora_int_id=i,
86-
lora_local_path=lora.local_path,
86+
lora_path=lora.path,
8787
) for i, lora in enumerate(lora_modules, start=1)
8888
]
8989

vllm/lora/request.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from dataclasses import dataclass
1+
import warnings
2+
from dataclasses import dataclass, field
23
from typing import Optional
34

45
from vllm.adapter_commons.request import AdapterRequest
@@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
2021

2122
lora_name: str
2223
lora_int_id: int
23-
lora_local_path: str
24+
lora_path: str = ""
25+
lora_local_path: Optional[str] = field(default=None, repr=False)
2426
long_lora_max_len: Optional[int] = None
2527
__hash__ = AdapterRequest.__hash__
2628

29+
def __post_init__(self):
30+
if 'lora_local_path' in self.__dict__:
31+
warnings.warn(
32+
"The 'lora_local_path' attribute is deprecated "
33+
"and will be removed in a future version. "
34+
"Please use 'lora_path' instead.",
35+
DeprecationWarning,
36+
stacklevel=2)
37+
if not self.lora_path:
38+
self.lora_path = self.lora_local_path or ""
39+
40+
# Ensure lora_path is not empty
41+
assert self.lora_path, "lora_path cannot be empty"
42+
2743
@property
2844
def adapter_id(self):
2945
return self.lora_int_id
@@ -32,6 +48,26 @@ def adapter_id(self):
3248
def name(self):
3349
return self.lora_name
3450

51+
@property
52+
def path(self):
53+
return self.lora_path
54+
3555
@property
3656
def local_path(self):
37-
return self.lora_local_path
57+
warnings.warn(
58+
"The 'local_path' attribute is deprecated "
59+
"and will be removed in a future version. "
60+
"Please use 'path' instead.",
61+
DeprecationWarning,
62+
stacklevel=2)
63+
return self.lora_path
64+
65+
@local_path.setter
66+
def local_path(self, value):
67+
warnings.warn(
68+
"The 'local_path' attribute is deprecated "
69+
"and will be removed in a future version. "
70+
"Please use 'path' instead.",
71+
DeprecationWarning,
72+
stacklevel=2)
73+
self.lora_path = value

vllm/lora/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import os
12
from typing import List, Optional, Set, Tuple, Type
23

4+
import huggingface_hub
5+
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
6+
HFValidationError, RepositoryNotFoundError)
37
from torch import nn
48
from transformers import PretrainedConfig
59

@@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
105109
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
106110

107111
raise ValueError(f"{name} is unsupported LoRA weight")
112+
113+
114+
def get_adapter_absolute_path(lora_path: str) -> str:
115+
"""
116+
Resolves the given lora_path to an absolute local path.
117+
118+
If the lora_path is identified as a Hugging Face model identifier,
119+
it will download the model and return the local snapshot path.
120+
Otherwise, it treats the lora_path as a local file path and
121+
converts it to an absolute path.
122+
123+
Parameters:
124+
lora_path (str): The path to the lora model, which can be an absolute path,
125+
a relative path, or a Hugging Face model identifier.
126+
127+
Returns:
128+
str: The resolved absolute local path to the lora model.
129+
"""
130+
131+
# Check if the path is an absolute path. Return it no matter exists or not.
132+
if os.path.isabs(lora_path):
133+
return lora_path
134+
135+
# If the path starts with ~, expand the user home directory.
136+
if lora_path.startswith('~'):
137+
return os.path.expanduser(lora_path)
138+
139+
# Check if the expanded relative path exists locally.
140+
if os.path.exists(lora_path):
141+
return os.path.abspath(lora_path)
142+
143+
# If the path does not exist locally, assume it's a Hugging Face repo.
144+
try:
145+
local_snapshot_path = huggingface_hub.snapshot_download(
146+
repo_id=lora_path)
147+
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
148+
HFValidationError):
149+
# Handle errors that may occur during the download
150+
# Return original path instead instead of throwing error here
151+
logger.exception("Error downloading the HuggingFace model")
152+
return lora_path
153+
154+
return local_snapshot_path

vllm/lora/worker_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.lora.models import (LoRAModel, LoRAModelManager,
1414
LRUCacheLoRAModelManager, create_lora_manager)
1515
from vllm.lora.request import LoRARequest
16+
from vllm.lora.utils import get_adapter_absolute_path
1617

1718
logger = init_logger(__name__)
1819

@@ -89,8 +90,9 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
8990
packed_modules_mapping[module])
9091
else:
9192
expected_lora_modules.append(module)
93+
lora_path = get_adapter_absolute_path(lora_request.lora_path)
9294
lora = self._lora_model_cls.from_local_checkpoint(
93-
lora_request.lora_local_path,
95+
lora_path,
9496
expected_lora_modules,
9597
max_position_embeddings=self.max_position_embeddings,
9698
lora_model_id=lora_request.lora_int_id,
@@ -102,8 +104,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
102104
embedding_padding_modules=self.embedding_padding_modules,
103105
)
104106
except Exception as e:
105-
raise RuntimeError(
106-
f"Loading lora {lora_request.lora_local_path} failed") from e
107+
raise RuntimeError(f"Loading lora {lora_path} failed") from e
107108
if lora.rank > self.lora_config.max_lora_rank:
108109
raise ValueError(
109110
f"LoRA rank {lora.rank} is greater than max_lora_rank "

vllm/transformers_utils/tokenizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
137137
if lora_request is None:
138138
return None
139139
try:
140-
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
141-
**kwargs)
140+
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
142141
except OSError as e:
143142
# No tokenizer was found in the LoRA folder,
144143
# use base model tokenizer
145144
logger.warning(
146145
"No tokenizer found in %s, using base model tokenizer instead. "
147-
"(Exception: %s)", lora_request.lora_local_path, e)
146+
"(Exception: %s)", lora_request.lora_path, e)
148147
tokenizer = None
149148
return tokenizer
150149

vllm/worker/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def profile_run(self) -> None:
691691
dummy_lora_request = LoRARequest(
692692
lora_name=f"warmup_{lora_id}",
693693
lora_int_id=lora_id,
694-
lora_local_path="/not/a/real/path",
694+
lora_path="/not/a/real/path",
695695
)
696696
self.lora_manager.add_dummy_lora(dummy_lora_request,
697697
rank=LORA_WARMUP_RANK)

0 commit comments

Comments
 (0)