Skip to content

Commit a6730f8

Browse files
authored
Add --autosplit flag for ExLlamaV2 (oobabooga#5524)
1 parent 4039999 commit a6730f8

File tree

6 files changed

+29
-17
lines changed

6 files changed

+29
-17
lines changed

modules/exllamav2.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,21 @@ def from_pretrained(self, path_to_model):
5151

5252
model = ExLlamaV2(config)
5353

54-
split = None
55-
if shared.args.gpu_split:
56-
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
57-
58-
model.load(split)
59-
60-
tokenizer = ExLlamaV2Tokenizer(config)
6154
if shared.args.cache_8bit:
62-
cache = ExLlamaV2Cache_8bit(model)
55+
cache = ExLlamaV2Cache_8bit(model, lazy=True)
6356
else:
64-
cache = ExLlamaV2Cache(model)
57+
cache = ExLlamaV2Cache(model, lazy=True)
6558

59+
if shared.args.autosplit:
60+
model.load_autosplit(cache)
61+
else:
62+
split = None
63+
if shared.args.gpu_split:
64+
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
65+
66+
model.load(split)
67+
68+
tokenizer = ExLlamaV2Tokenizer(config)
6669
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
6770

6871
result = self()

modules/exllamav2_hf.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,22 @@ def __init__(self, config: ExLlamaV2Config):
3737
super().__init__(PretrainedConfig())
3838
self.ex_config = config
3939
self.ex_model = ExLlamaV2(config)
40-
split = None
41-
if shared.args.gpu_split:
42-
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
43-
44-
self.ex_model.load(split)
45-
self.generation_config = GenerationConfig()
4640
self.loras = None
41+
self.generation_config = GenerationConfig()
4742

4843
if shared.args.cache_8bit:
49-
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
44+
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
5045
else:
51-
self.ex_cache = ExLlamaV2Cache(self.ex_model)
46+
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
47+
48+
if shared.args.autosplit:
49+
self.ex_model.load_autosplit(self.ex_cache)
50+
else:
51+
split = None
52+
if shared.args.gpu_split:
53+
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
54+
55+
self.ex_model.load(split)
5256

5357
self.past_seq = None
5458
if shared.args.cfg_cache:

modules/loaders.py

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
'no_flash_attn',
7979
'num_experts_per_token',
8080
'cache_8bit',
81+
'autosplit',
8182
'alpha_value',
8283
'compress_pos_emb',
8384
'trust_remote_code',
@@ -89,6 +90,7 @@
8990
'no_flash_attn',
9091
'num_experts_per_token',
9192
'cache_8bit',
93+
'autosplit',
9294
'alpha_value',
9395
'compress_pos_emb',
9496
'exllamav2_info',

modules/shared.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
# ExLlamaV2
135135
group = parser.add_argument_group('ExLlamaV2')
136136
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
137+
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
137138
group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
138139
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
139140
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')

modules/ui.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def list_model_elements():
7676
'no_flash_attn',
7777
'num_experts_per_token',
7878
'cache_8bit',
79+
'autosplit',
7980
'threads',
8081
'threads_batch',
8182
'n_batch',

modules/ui_model_menu.py

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def create_ui():
132132
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
133133
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
134134
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
135+
shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
135136
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
136137
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
137138
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')

0 commit comments

Comments
 (0)