Skip to content

Commit baee950

Browse files
Jeffwankwang1012
authored andcommitted
[Core] Support Lora lineage and base model metadata management (vllm-project#6315)
1 parent 701f668 commit baee950

15 files changed

+337
-45
lines changed

docs/source/models/lora.rst

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
159159
-d '{
160160
"lora_name": "sql_adapter"
161161
}'
162+
163+
164+
New format for `--lora-modules`
165+
-------------------------------
166+
167+
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
168+
169+
.. code-block:: bash
170+
171+
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
172+
173+
This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
174+
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:
175+
176+
.. code-block:: bash
177+
178+
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
179+
180+
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
181+
182+
183+
Lora model lineage in model card
184+
--------------------------------
185+
186+
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
187+
188+
- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
189+
- The `root` field points to the artifact location of the lora adapter.
190+
191+
.. code-block:: bash
192+
193+
$ curl http://localhost:8000/v1/models
194+
195+
{
196+
"object": "list",
197+
"data": [
198+
{
199+
"id": "meta-llama/Llama-2-7b-hf",
200+
"object": "model",
201+
"created": 1715644056,
202+
"owned_by": "vllm",
203+
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
204+
"parent": null,
205+
"permission": [
206+
{
207+
.....
208+
}
209+
]
210+
},
211+
{
212+
"id": "sql-lora",
213+
"object": "model",
214+
"created": 1715644056,
215+
"owned_by": "vllm",
216+
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
217+
"parent": meta-llama/Llama-2-7b-hf,
218+
"permission": [
219+
{
220+
....
221+
}
222+
]
223+
}
224+
]
225+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import json
2+
import unittest
3+
4+
from vllm.entrypoints.openai.cli_args import make_arg_parser
5+
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
6+
from vllm.utils import FlexibleArgumentParser
7+
8+
LORA_MODULE = {
9+
"name": "module2",
10+
"path": "/path/to/module2",
11+
"base_model_name": "llama"
12+
}
13+
14+
15+
class TestLoraParserAction(unittest.TestCase):
16+
17+
def setUp(self):
18+
# Setting up argparse parser for tests
19+
parser = FlexibleArgumentParser(
20+
description="vLLM's remote OpenAI server.")
21+
self.parser = make_arg_parser(parser)
22+
23+
def test_valid_key_value_format(self):
24+
# Test old format: name=path
25+
args = self.parser.parse_args([
26+
'--lora-modules',
27+
'module1=/path/to/module1',
28+
])
29+
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
30+
self.assertEqual(args.lora_modules, expected)
31+
32+
def test_valid_json_format(self):
33+
# Test valid JSON format input
34+
args = self.parser.parse_args([
35+
'--lora-modules',
36+
json.dumps(LORA_MODULE),
37+
])
38+
expected = [
39+
LoRAModulePath(name='module2',
40+
path='/path/to/module2',
41+
base_model_name='llama')
42+
]
43+
self.assertEqual(args.lora_modules, expected)
44+
45+
def test_invalid_json_format(self):
46+
# Test invalid JSON format input, missing closing brace
47+
with self.assertRaises(SystemExit):
48+
self.parser.parse_args([
49+
'--lora-modules',
50+
'{"name": "module3", "path": "/path/to/module3"'
51+
])
52+
53+
def test_invalid_type_error(self):
54+
# Test type error when values are not JSON or key=value
55+
with self.assertRaises(SystemExit):
56+
self.parser.parse_args([
57+
'--lora-modules',
58+
'invalid_format' # This is not JSON or key=value format
59+
])
60+
61+
def test_invalid_json_field(self):
62+
# Test valid JSON format but missing required fields
63+
with self.assertRaises(SystemExit):
64+
self.parser.parse_args([
65+
'--lora-modules',
66+
'{"name": "module4"}' # Missing required 'path' field
67+
])
68+
69+
def test_empty_values(self):
70+
# Test when no LoRA modules are provided
71+
args = self.parser.parse_args(['--lora-modules', ''])
72+
self.assertEqual(args.lora_modules, [])
73+
74+
def test_multiple_valid_inputs(self):
75+
# Test multiple valid inputs (both old and JSON format)
76+
args = self.parser.parse_args([
77+
'--lora-modules',
78+
'module1=/path/to/module1',
79+
json.dumps(LORA_MODULE),
80+
])
81+
expected = [
82+
LoRAModulePath(name='module1', path='/path/to/module1'),
83+
LoRAModulePath(name='module2',
84+
path='/path/to/module2',
85+
base_model_name='llama')
86+
]
87+
self.assertEqual(args.lora_modules, expected)
88+
89+
90+
if __name__ == '__main__':
91+
unittest.main()
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import json
2+
3+
import openai # use the official client for correctness check
4+
import pytest
5+
import pytest_asyncio
6+
# downloading lora to test lora requests
7+
from huggingface_hub import snapshot_download
8+
9+
from ...utils import RemoteOpenAIServer
10+
11+
# any model with a chat template should work here
12+
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
13+
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
14+
# generation quality here
15+
LORA_NAME = "typeof/zephyr-7b-beta-lora"
16+
17+
18+
@pytest.fixture(scope="module")
19+
def zephyr_lora_files():
20+
return snapshot_download(repo_id=LORA_NAME)
21+
22+
23+
@pytest.fixture(scope="module")
24+
def server_with_lora_modules_json(zephyr_lora_files):
25+
# Define the json format LoRA module configurations
26+
lora_module_1 = {
27+
"name": "zephyr-lora",
28+
"path": zephyr_lora_files,
29+
"base_model_name": MODEL_NAME
30+
}
31+
32+
lora_module_2 = {
33+
"name": "zephyr-lora2",
34+
"path": zephyr_lora_files,
35+
"base_model_name": MODEL_NAME
36+
}
37+
38+
args = [
39+
# use half precision for speed and memory savings in CI environment
40+
"--dtype",
41+
"bfloat16",
42+
"--max-model-len",
43+
"8192",
44+
"--enforce-eager",
45+
# lora config below
46+
"--enable-lora",
47+
"--lora-modules",
48+
json.dumps(lora_module_1),
49+
json.dumps(lora_module_2),
50+
"--max-lora-rank",
51+
"64",
52+
"--max-cpu-loras",
53+
"2",
54+
"--max-num-seqs",
55+
"64",
56+
]
57+
58+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
59+
yield remote_server
60+
61+
62+
@pytest_asyncio.fixture
63+
async def client_for_lora_lineage(server_with_lora_modules_json):
64+
async with server_with_lora_modules_json.get_async_client(
65+
) as async_client:
66+
yield async_client
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
71+
zephyr_lora_files):
72+
models = await client_for_lora_lineage.models.list()
73+
models = models.data
74+
served_model = models[0]
75+
lora_models = models[1:]
76+
assert served_model.id == MODEL_NAME
77+
assert served_model.root == MODEL_NAME
78+
assert served_model.parent is None
79+
assert all(lora_model.root == zephyr_lora_files
80+
for lora_model in lora_models)
81+
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
82+
assert lora_models[0].id == "zephyr-lora"
83+
assert lora_models[1].id == "zephyr-lora2"

tests/entrypoints/openai/test_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@ async def client(server):
5151

5252

5353
@pytest.mark.asyncio
54-
async def test_check_models(client: openai.AsyncOpenAI):
54+
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
5555
models = await client.models.list()
5656
models = models.data
5757
served_model = models[0]
5858
lora_models = models[1:]
5959
assert served_model.id == MODEL_NAME
60-
assert all(model.root == MODEL_NAME for model in models)
60+
assert served_model.root == MODEL_NAME
61+
assert all(lora_model.root == zephyr_lora_files
62+
for lora_model in lora_models)
6163
assert lora_models[0].id == "zephyr-lora"
6264
assert lora_models[1].id == "zephyr-lora2"

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from vllm.engine.multiprocessing.client import MQLLMEngineClient
88
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
99
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
10+
from vllm.entrypoints.openai.serving_engine import BaseModelPath
1011
from vllm.transformers_utils.tokenizer import get_tokenizer
1112

1213
MODEL_NAME = "openai-community/gpt2"
1314
CHAT_TEMPLATE = "Dummy chat template for testing {}"
15+
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
1416

1517

1618
@dataclass
@@ -37,7 +39,7 @@ async def _async_serving_chat_init():
3739

3840
serving_completion = OpenAIServingChat(engine,
3941
model_config,
40-
served_model_names=[MODEL_NAME],
42+
BASE_MODEL_PATHS,
4143
response_role="assistant",
4244
chat_template=CHAT_TEMPLATE,
4345
lora_modules=None,
@@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():
5860

5961
serving_chat = OpenAIServingChat(mock_engine,
6062
MockModelConfig(),
61-
served_model_names=[MODEL_NAME],
63+
BASE_MODEL_PATHS,
6264
response_role="assistant",
6365
chat_template=CHAT_TEMPLATE,
6466
lora_modules=None,

tests/entrypoints/openai/test_serving_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from vllm.entrypoints.openai.protocol import (ErrorResponse,
99
LoadLoraAdapterRequest,
1010
UnloadLoraAdapterRequest)
11-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
11+
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
1212

1313
MODEL_NAME = "meta-llama/Llama-2-7b"
14+
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
1415
LORA_LOADING_SUCCESS_MESSAGE = (
1516
"Success: LoRA adapter '{lora_name}' added successfully.")
1617
LORA_UNLOADING_SUCCESS_MESSAGE = (
@@ -25,7 +26,7 @@ async def _async_serving_engine_init():
2526

2627
serving_engine = OpenAIServing(mock_engine_client,
2728
mock_model_config,
28-
served_model_names=[MODEL_NAME],
29+
BASE_MODEL_PATHS,
2930
lora_modules=None,
3031
prompt_adapters=None,
3132
request_logger=None)

vllm/entrypoints/openai/api_server.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
5151
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
5252
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
53+
from vllm.entrypoints.openai.serving_engine import BaseModelPath
5354
from vllm.entrypoints.openai.serving_tokenization import (
5455
OpenAIServingTokenization)
5556
from vllm.logger import init_logger
@@ -476,13 +477,18 @@ def init_app_state(
476477
else:
477478
request_logger = RequestLogger(max_log_len=args.max_log_len)
478479

480+
base_model_paths = [
481+
BaseModelPath(name=name, model_path=args.model)
482+
for name in served_model_names
483+
]
484+
479485
state.engine_client = engine_client
480486
state.log_stats = not args.disable_log_stats
481487

482488
state.openai_serving_chat = OpenAIServingChat(
483489
engine_client,
484490
model_config,
485-
served_model_names,
491+
base_model_paths,
486492
args.response_role,
487493
lora_modules=args.lora_modules,
488494
prompt_adapters=args.prompt_adapters,
@@ -494,7 +500,7 @@ def init_app_state(
494500
state.openai_serving_completion = OpenAIServingCompletion(
495501
engine_client,
496502
model_config,
497-
served_model_names,
503+
base_model_paths,
498504
lora_modules=args.lora_modules,
499505
prompt_adapters=args.prompt_adapters,
500506
request_logger=request_logger,
@@ -503,13 +509,13 @@ def init_app_state(
503509
state.openai_serving_embedding = OpenAIServingEmbedding(
504510
engine_client,
505511
model_config,
506-
served_model_names,
512+
base_model_paths,
507513
request_logger=request_logger,
508514
)
509515
state.openai_serving_tokenization = OpenAIServingTokenization(
510516
engine_client,
511517
model_config,
512-
served_model_names,
518+
base_model_paths,
513519
lora_modules=args.lora_modules,
514520
request_logger=request_logger,
515521
chat_template=args.chat_template,

0 commit comments

Comments
 (0)