Skip to content

Commit 9ade8bb

Browse files
[Model] add a bunch of supported lora modules for mixtral (#9008)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
1 parent 22482e4 commit 9ade8bb

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

tests/lora/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def mixtral_lora_files():
173173
return snapshot_download(repo_id="SangBinCho/mixtral-lora")
174174

175175

176+
@pytest.fixture(scope="session")
177+
def mixtral_lora_files_all_target_modules():
178+
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
179+
180+
176181
@pytest.fixture(scope="session")
177182
def gemma_lora_files():
178183
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")

tests/lora/test_mixtral.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
1010

1111

12-
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
13-
prompts = [
14-
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
15-
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
16-
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
17-
]
12+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
13+
prompts: List[str]) -> List[str]:
14+
1815
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
1916
outputs = llm.generate(
2017
prompts,
@@ -33,22 +30,71 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3330

3431
@pytest.mark.parametrize("tp_size", [4])
3532
def test_mixtral_lora(mixtral_lora_files, tp_size):
33+
"""Original test, the LoRA model has the common target modules, not all"""
3634
if torch.cuda.device_count() < tp_size:
3735
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
3836

39-
llm = vllm.LLM(MODEL_PATH,
40-
enable_lora=True,
41-
max_num_seqs=16,
42-
max_loras=4,
43-
distributed_executor_backend="ray",
44-
tensor_parallel_size=tp_size)
37+
prompts = [
38+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
39+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
40+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
41+
]
42+
43+
llm = vllm.LLM(
44+
MODEL_PATH,
45+
enable_lora=True,
46+
max_num_seqs=16,
47+
max_loras=4,
48+
distributed_executor_backend="ray",
49+
tensor_parallel_size=tp_size,
50+
)
4551

4652
expected_lora_output = [
4753
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
4854
"give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501
4955
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
5056
]
51-
assert do_sample(llm, mixtral_lora_files,
52-
lora_id=1) == expected_lora_output
53-
assert do_sample(llm, mixtral_lora_files,
54-
lora_id=2) == expected_lora_output
57+
assert do_sample(llm, mixtral_lora_files, lora_id=1,
58+
prompts=prompts) == expected_lora_output
59+
assert do_sample(llm, mixtral_lora_files, lora_id=2,
60+
prompts=prompts) == expected_lora_output
61+
62+
63+
@pytest.mark.parametrize("tp_size", [4])
64+
def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
65+
tp_size):
66+
"""This LoRA model has all supported Mixtral target modules"""
67+
68+
if torch.cuda.device_count() < tp_size:
69+
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
70+
71+
prompts = [
72+
"Question: What happens if you touch the eyes of a blind man?\nChoices:\nB: The man's sight will be restored.\nA: Nothing happens if you touch the eyes of a blind man.\nC: The man will get his eyesight back.\nD: The man will no longer be blind.\nAnswer:", # noqa: E501
73+
"Question: Which will increase the energy of motion of water molecules?\nChoices:\nA: add heat\nB: add ice\nC: remove some of the water\nD: place the water in the freezer\nAnswer:", # noqa: E501
74+
"Since Craig threw aluminum cans in the trash and Benjamin recycled, _ was environmentally irresponsible.\nChoices:\n1: Craig\n2: Benjamin\nAnswer:", # noqa: E501
75+
]
76+
77+
llm = vllm.LLM(
78+
MODEL_PATH,
79+
enable_lora=True,
80+
max_num_seqs=16,
81+
max_loras=4,
82+
distributed_executor_backend="ray",
83+
tensor_parallel_size=tp_size,
84+
max_lora_rank=32,
85+
)
86+
87+
expected_lora_output = [
88+
"A: Nothing happens if you touch the eyes of a blind man.",
89+
"A: add heat",
90+
"1: Craig",
91+
]
92+
93+
assert do_sample(llm,
94+
mixtral_lora_files_all_target_modules,
95+
lora_id=1,
96+
prompts=prompts) == expected_lora_output
97+
assert do_sample(llm,
98+
mixtral_lora_files_all_target_modules,
99+
lora_id=2,
100+
prompts=prompts) == expected_lora_output

vllm/model_executor/models/mixtral.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
322322

323323
# LoRA specific attributes
324324
supported_lora_modules = [
325-
"qkv_proj",
326-
"o_proj",
327-
"embed_tokens",
328-
"lm_head",
325+
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3",
326+
"gate"
329327
]
330328
embedding_modules = {
331329
"embed_tokens": "input_embeddings",

0 commit comments

Comments
 (0)