Skip to content

Commit d6bd71d

Browse files
committed
ExLlamaV2: fix loading when autosplit is not set
1 parent af0bbf5 commit d6bd71d

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

modules/exllamav2.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,21 @@ def from_pretrained(self, path_to_model):
5151

5252
model = ExLlamaV2(config)
5353

54-
if shared.args.cache_8bit:
55-
cache = ExLlamaV2Cache_8bit(model, lazy=True)
56-
else:
57-
cache = ExLlamaV2Cache(model, lazy=True)
58-
59-
if shared.args.autosplit:
60-
model.load_autosplit(cache)
61-
else:
54+
if not shared.args.autosplit:
6255
split = None
6356
if shared.args.gpu_split:
6457
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
6558

6659
model.load(split)
6760

61+
if shared.args.cache_8bit:
62+
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
63+
else:
64+
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
65+
66+
if shared.args.autosplit:
67+
model.load_autosplit(cache)
68+
6869
tokenizer = ExLlamaV2Tokenizer(config)
6970
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
7071

modules/exllamav2_hf.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,26 @@ class Exllamav2HF(PreTrainedModel):
3636
def __init__(self, config: ExLlamaV2Config):
3737
super().__init__(PretrainedConfig())
3838
self.ex_config = config
39-
self.ex_model = ExLlamaV2(config)
4039
self.loras = None
4140
self.generation_config = GenerationConfig()
4241

43-
if shared.args.cache_8bit:
44-
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
45-
else:
46-
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
42+
self.ex_model = ExLlamaV2(config)
4743

48-
if shared.args.autosplit:
49-
self.ex_model.load_autosplit(self.ex_cache)
50-
else:
44+
if not shared.args.autosplit:
5145
split = None
5246
if shared.args.gpu_split:
5347
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
5448

5549
self.ex_model.load(split)
5650

51+
if shared.args.cache_8bit:
52+
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
53+
else:
54+
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
55+
56+
if shared.args.autosplit:
57+
self.ex_model.load_autosplit(self.ex_cache)
58+
5759
self.past_seq = None
5860
if shared.args.cfg_cache:
5961
if shared.args.cache_8bit:

0 commit comments

Comments
 (0)