Skip to content

Commit 0f53b13

Browse files
paulyu12jesse996paulyu
authored
[V1][LoRA][Test] V1 Engine LoRA support & e2e test (#893)
### What this PR does / why we need it? Add V1Engine LoRA support. Add LoRA e2e test on single card and multiple cards. ### Does this PR introduce _any_ user-facing change? support lora for V1 ### How was this patch tested? CI passed with new added test --------- Signed-off-by: jesse <szxfml@gmail.com> Signed-off-by: paulyu <paulyu0307@gmail.com> Signed-off-by: paulyu12 <507435917@qq.com> Co-authored-by: jesse <szxfml@gmail.com> Co-authored-by: paulyu <paulyu0307@gmail.com>
1 parent 7aa4f85 commit 0f53b13

File tree

6 files changed

+168
-39
lines changed

6 files changed

+168
-39
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ jobs:
5151
vllm_verison: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
54-
${{
55-
matrix.os == 'linux-arm64-npu-4'
56-
&& github.event.pull_request.number
57-
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
54+
${{
55+
matrix.os == 'linux-arm64-npu-4'
56+
&& github.event.pull_request.number
57+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -112,10 +112,12 @@ jobs:
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114114
pytest -sv tests/singlecard/test_offline_inference.py
115+
pytest -sv tests/singlecard/test_ilama_lora.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
118119
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
120+
pytest -sv tests/multicard/test_ilama_lora_tp2.py
119121
pytest -sv tests/ops
120122
pytest -sv tests/compile
121123
fi
@@ -125,9 +127,11 @@ jobs:
125127
VLLM_USE_V1: 0
126128
run: |
127129
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
130+
pytest -sv tests/singlecard/test_ilama_lora.py
128131
pytest -sv tests/singlecard/test_offline_inference.py
129132
pytest -sv tests/ops
130133
else
134+
pytest -sv tests/multicard/test_ilama_lora_tp2.py
131135
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
132136
pytest -sv -k "DeepSeek" tests/multicard/test_offline_inference_distributed.py
133137
pytest -sv tests/ops

tests/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import pytest
2525
import torch
26+
from huggingface_hub import snapshot_download
2627
from PIL import Image
2728
from vllm import LLM, SamplingParams
2829
from vllm.config import TaskOption
@@ -348,4 +349,9 @@ def vllm_runner():
348349

349350
@pytest.fixture(params=list(PROMPT_TEMPLATES.keys()))
350351
def prompt_template(request):
351-
return PROMPT_TEMPLATES[request.param]
352+
return PROMPT_TEMPLATES[request.param]
353+
354+
355+
@pytest.fixture(scope="session")
356+
def ilama_lora_files():
357+
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from tests.conftest import VllmRunner
4+
from tests.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, MODEL_PATH,
5+
do_sample)
6+
7+
8+
@pytest.mark.parametrize("distributed_executor_backend", ["mp"])
9+
def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files):
10+
with VllmRunner(model_name=MODEL_PATH,
11+
enable_lora=True,
12+
max_loras=4,
13+
max_model_len=1024,
14+
max_num_seqs=16,
15+
tensor_parallel_size=2,
16+
distributed_executor_backend=distributed_executor_backend
17+
) as vllm_model:
18+
output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
19+
20+
for i in range(len(EXPECTED_LORA_OUTPUT)):
21+
assert output[i] == EXPECTED_LORA_OUTPUT[i]

tests/singlecard/test_ilama_lora.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import vllm
4+
from vllm.lora.request import LoRARequest
5+
6+
from tests.conftest import VllmRunner
7+
8+
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
9+
10+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
11+
12+
EXPECTED_LORA_OUTPUT = [
13+
"SELECT count(*) FROM singer",
14+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
15+
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
16+
]
17+
18+
19+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
20+
prompts = [
21+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
22+
PROMPT_TEMPLATE.format(
23+
query=
24+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
25+
),
26+
PROMPT_TEMPLATE.format(
27+
query=
28+
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
29+
),
30+
]
31+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
32+
outputs = llm.generate(
33+
prompts,
34+
sampling_params,
35+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
36+
if lora_id else None)
37+
# Print the outputs.
38+
generated_texts: list[str] = []
39+
for output in outputs:
40+
prompt = output.prompt
41+
generated_text = output.outputs[0].text.strip()
42+
generated_texts.append(generated_text)
43+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
44+
return generated_texts
45+
46+
47+
def test_ilama_lora(ilama_lora_files):
48+
with VllmRunner(model_name=MODEL_PATH,
49+
enable_lora=True,
50+
max_loras=4,
51+
max_model_len=1024,
52+
max_num_seqs=16) as vllm_model:
53+
54+
output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1)
55+
for i in range(len(EXPECTED_LORA_OUTPUT)):
56+
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
57+
58+
output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
59+
for i in range(len(EXPECTED_LORA_OUTPUT)):
60+
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from vllm.v1.sample.sampler import Sampler
5151
from vllm.v1.utils import bind_kv_cache
5252
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
53+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
5354

5455
from vllm_ascend.attention.attention import AttentionMaskBuilder
5556
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
102103
yield graph_capture_context
103104

104105

105-
class NPUModelRunner:
106+
class NPUModelRunner(LoRAModelRunnerMixin):
106107

107108
def __init__(self, vllm_config: VllmConfig, device: torch.device):
108109
self.vllm_config = vllm_config
@@ -543,6 +544,10 @@ def _process_reqs(
543544
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
544545
num_tokens)
545546

547+
# Hot-Swap lora model
548+
if self.lora_config:
549+
self.set_active_loras(self.input_batch, num_scheduled_tokens)
550+
546551
# Prepare positions
547552
req_indices = np.repeat(self.arange_np[:num_reqs],
548553
num_scheduled_tokens)
@@ -867,39 +872,55 @@ def _profile_multimodal(self) -> None:
867872

868873
@torch.inference_mode()
869874
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
870-
model = self.model
871-
if self.is_multimodal_model:
872-
input_ids = None
873-
inputs_embeds = self.inputs_embeds[:num_tokens]
874-
else:
875-
input_ids = self.input_ids[:num_tokens]
876-
inputs_embeds = None
875+
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
876+
# for dummy run with LoRA so that the num_reqs collectively
877+
# has num_tokens in total.
878+
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
879+
max_num_reqs = self.scheduler_config.max_num_seqs
880+
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
881+
min_tokens_per_req = num_tokens // num_reqs
882+
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
883+
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
884+
assert sum(num_scheduled_tokens_list) == num_tokens
885+
assert len(num_scheduled_tokens_list) == num_reqs
886+
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
887+
dtype=np.int32)
888+
with self.maybe_dummy_run_with_lora(self.lora_config,
889+
num_scheduled_tokens):
890+
model = self.model
891+
if self.is_multimodal_model:
892+
input_ids = None
893+
inputs_embeds = self.inputs_embeds[:num_tokens]
894+
else:
895+
input_ids = self.input_ids[:num_tokens]
896+
inputs_embeds = None
877897

878-
if self.uses_mrope:
879-
positions = self.mrope_positions[:, :num_tokens]
880-
else:
881-
positions = self.positions[:num_tokens]
898+
if self.uses_mrope:
899+
positions = self.mrope_positions[:, :num_tokens]
900+
else:
901+
positions = self.positions[:num_tokens]
882902

883-
if get_pp_group().is_first_rank:
884-
intermediate_tensors = None
885-
else:
886-
if self.intermediate_tensors is None:
887-
self.intermediate_tensors = (
888-
self.model.make_empty_intermediate_tensors(
889-
batch_size=num_tokens,
890-
dtype=self.dtype,
891-
device=self.device))
892-
intermediate_tensors = IntermediateTensors({
893-
k: v[:num_tokens]
894-
for k, v in self.intermediate_tensors.items()
895-
})
896-
897-
with set_forward_context(None, self.vllm_config):
898-
hidden_states = model(input_ids=input_ids,
899-
positions=positions,
900-
intermediate_tensors=intermediate_tensors,
901-
inputs_embeds=inputs_embeds)
902-
return hidden_states
903+
if get_pp_group().is_first_rank:
904+
intermediate_tensors = None
905+
else:
906+
if self.intermediate_tensors is None:
907+
self.intermediate_tensors = (
908+
self.model.make_empty_intermediate_tensors(
909+
batch_size=num_tokens,
910+
dtype=self.dtype,
911+
device=self.device))
912+
intermediate_tensors = IntermediateTensors({
913+
k: v[:num_tokens]
914+
for k, v in self.intermediate_tensors.items()
915+
})
916+
917+
with set_forward_context(None, self.vllm_config):
918+
hidden_states = model(
919+
input_ids=input_ids,
920+
positions=positions,
921+
intermediate_tensors=intermediate_tensors,
922+
inputs_embeds=inputs_embeds)
923+
return hidden_states
903924

904925
def profile_run(self) -> None:
905926
# Profile with multimodal encoder & encoder cache.
@@ -948,7 +969,11 @@ def load_model(self) -> None:
948969
with DeviceMemoryProfiler() as m: # noqa: SIM117
949970
self.model = get_model(vllm_config=self.vllm_config)
950971
if self.lora_config:
951-
raise ValueError("LoRA model is not supported on NPU now.")
972+
self.model = self.load_lora_model(self.model,
973+
self.model_config,
974+
self.scheduler_config,
975+
self.lora_config,
976+
self.device)
952977
logger.info("Loading model weights took %.4f GB",
953978
m.consumed_memory / float(2**30))
954979

vllm_ascend/worker/worker_v1.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
set_custom_all_reduce)
3232
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3333
from vllm.logger import logger
34+
from vllm.lora.request import LoRARequest
3435
from vllm.model_executor import set_random_seed
3536
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
3637
from vllm.v1.core.sched.output import SchedulerOutput
@@ -216,6 +217,18 @@ def profile(self, is_start: bool = True):
216217
else:
217218
self.profiler.stop()
218219

220+
def add_lora(self, lora_request: LoRARequest) -> bool:
221+
return self.model_runner.add_lora(lora_request)
222+
223+
def remove_lora(self, lora_id: int) -> bool:
224+
return self.model_runner.remove_lora(lora_id)
225+
226+
def list_loras(self) -> set[int]:
227+
return self.model_runner.list_loras()
228+
229+
def pin_lora(self, lora_id: int) -> bool:
230+
return self.model_runner.pin_lora(lora_id)
231+
219232
def execute_dummy_batch(self) -> None:
220233
self.model_runner._dummy_run(1)
221234

0 commit comments

Comments
 (0)