Skip to content

Commit 509f9aa

Browse files
committed
Merge branch 'refs/heads/dev'
2 parents 61e9ae8 + 4bbd969 commit 509f9aa

File tree

10 files changed

+304
-274
lines changed

10 files changed

+304
-274
lines changed

exllamav2/architecture.py

+142-254
Large diffs are not rendered by default.

exllamav2/config.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
T = TypeVar('T')
1111
no_default = object()
1212

13-
def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str], default = no_default) -> T:
13+
def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T:
14+
15+
expected_types = expected_type if isinstance(expected_type, list) else [expected_type]
1416

1517
if isinstance(keys, str): keys = [keys]
1618

@@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str],
3436
if expected_type == int and isinstance(x, float) and x == int(x):
3537
x = int(x)
3638

37-
if isinstance(x, expected_type):
38-
return cast(T, x)
39-
else:
40-
raise TypeError(f"Value for {key} is not of expected type {expected_type}")
39+
for t in expected_types:
40+
if isinstance(x, t):
41+
return cast(T, x)
42+
raise TypeError(f"Value for {key} is not of expected type {expected_type}")
4143

4244
if default != no_default: return default
4345
raise ValueError(f"Missing any of the following keys: {keys}")
@@ -104,8 +106,13 @@ class ExLlamaV2Config:
104106
final_logit_softcapping: float | None
105107
attn_logit_softcapping: float | None
106108
sliding_window: int
107-
109+
norm_head: int | None
110+
l3_rope_factor: float | None
111+
l3_rope_low_freq_factor: float | None
112+
l3_rope_high_freq_factor: float | None
113+
l3_rope_original_max_position_embeddings: int | None
108114
checkpoint_fused_mlp: bool
115+
checkpoint_offset_qzeros: bool
109116

110117

111118
def __init__(self,
@@ -189,10 +196,13 @@ def prepare(self, no_tensors: bool = False):
189196
# Vocab params
190197

191198
self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
192-
self.eos_token_id = read(read_config, int, "eos_token_id", None) # 2
199+
self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2
193200
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
194201
self.vocab_size = read(read_config, int, "vocab_size")
195202

203+
if isinstance(self.eos_token_id, list):
204+
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow
205+
196206
# Standard params
197207

198208
self.initializer_range = read(read_config, float, ["initializer_range"])
@@ -251,6 +261,10 @@ def prepare(self, no_tensors: bool = False):
251261
self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None)
252262
self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None)
253263

264+
# Normalize weights in head layer
265+
266+
self.norm_head = read(read_config, int, "norm_head", None)
267+
254268
# Positional embeddings
255269

256270
self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0)
@@ -281,6 +295,18 @@ def prepare(self, no_tensors: bool = False):
281295
self.alt_rope_method = "su"
282296
# if scaling_type == "yarn":
283297
# self.scale_alpha_value = factor
298+
rope_type = rs.get("rope_type", None)
299+
if rope_type == "llama3":
300+
self.alt_rope_method = "llama3"
301+
self.l3_rope_factor = rs["factor"]
302+
self.l3_rope_low_freq_factor = rs["low_freq_factor"]
303+
self.l3_rope_high_freq_factor = rs["high_freq_factor"]
304+
self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
305+
306+
# Checkpoint format (for GPTQ models)
307+
308+
checkpoint_format = read(read_config, str, ["quantization_config->checkpoint_format"], None)
309+
self.checkpoint_offset_qzeros = (checkpoint_format == "gptq_v2")
284310

285311
# Create map of model tensors
286312

exllamav2/conversion/convert_exl2.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
parser.add_argument("-ml", "--measurement_length", type = int, default = 2048, help = "Max no. tokens per sample when measuring")
3232
parser.add_argument("-so", "--status_output", action = "store_true", help = "Include machine-parseable status updates in console output")
3333
parser.add_argument("-hsol", "--hidden_state_offload_layers", type = int, default = 0, help = "Number of hidden/target states to keep in VRAM. Speed-up but increases VRAM usage")
34+
parser.add_argument("-fst", "--fast_safetensors", action = "store_true", help = "Use fast-safetensors to load layers of the unquantized model. This can help alleviate some out-of-memory issues, especially on Windows.")
3435

3536
args = parser.parse_args()
3637

@@ -112,6 +113,7 @@ def save_job():
112113
"rope_scale": args.rope_scale,
113114
"rope_alpha": args.rope_alpha,
114115
"output_measurement": output_measurement,
116+
"fast_safetensors": args.fast_safetensors,
115117
"progress": "begin"}
116118

117119
if args.measurement is not None:
@@ -160,6 +162,8 @@ def save_job():
160162
else:
161163
print(f" -- Measurement will be saved to {job['output_measurement']}")
162164
print(f" !! Conversion script will end after measurement pass")
165+
if job.get("fast_safetensors"):
166+
print(f" -- Enabled fast_safetensors option.")
163167

164168
if job['rope_scale']: print(f" -- RoPE scale: {job['rope_scale']:.2f}")
165169
if job['rope_alpha']: print(f" -- RoPE alpha: {job['rope_alpha']:.2f}")
@@ -190,6 +194,10 @@ def save_job():
190194

191195
tokenizer = ExLlamaV2Tokenizer(config)
192196

197+
# Set fast_safetensors in config
198+
199+
if job.get("fast_safetensors"): config.fasttensors = True
200+
193201
# Set scaling for input model
194202

195203
if job["rope_scale"] is not None: config.scale_pos_emb = job["rope_scale"]

exllamav2/ext.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def make_q_matrix(w: dict,
320320
temp_dq: torch.Tensor,
321321
key: str = None,
322322
prescale: float = 1,
323-
max_dq_rows = 0):
323+
max_dq_rows = 0,
324+
offset_qzeros: bool = False):
324325

325326
# EXL2
326327

@@ -354,6 +355,9 @@ def make_q_matrix(w: dict,
354355
if prescale != 1: w["scales"] *= prescale
355356
if w["scales"].dtype == torch.float: w["scales"] = w["scales"].half()
356357

358+
if offset_qzeros:
359+
w["qzeros"] -= 0b00010001000100010001000100010001
360+
357361
# GPTQ with g_idx (act_order)
358362

359363
if "g_idx" in w and not (w["g_idx"] == 0).all().item():

exllamav2/generator/dynamic.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,11 @@ def iterate(self) -> list[dict]:
893893
"stop_string"
894894
"max_new_tokens"
895895
"end_filter"
896+
optional, if "eos_reason" == "stop_token":
897+
"eos_triggering_token_id": int
898+
"eos_triggering_token_str": str
899+
optional, if "eos_reason" == "stop_string":
900+
"eos_triggering_string": str
896901
"full_completion": str - full text completion
897902
"new_tokens": int - number of tokens generated
898903
"time_enqueued": float - time from job was enqueued until it started, in seconds
@@ -1849,7 +1854,10 @@ def emit(
18491854
eos_reason: str = None,
18501855
emit_held = False,
18511856
suppressed_text = None,
1852-
suppressed_tokens = None
1857+
suppressed_tokens = None,
1858+
stop_token: int = None,
1859+
stop_string: str = None,
1860+
rem_held_text: str = None
18531861
):
18541862
r = {
18551863
"job": self,
@@ -1860,6 +1868,15 @@ def emit(
18601868

18611869
if eos_reason is not None:
18621870
r.update({ "eos_reason": eos_reason })
1871+
if eos_reason == "stop_token":
1872+
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(True)
1873+
r.update({
1874+
"eos_triggering_token_id": stop_token,
1875+
"eos_triggering_token_str": id_to_piece[stop_token]
1876+
})
1877+
pass
1878+
if eos_reason == "stop_string":
1879+
r.update({ "eos_triggering_string": stop_string })
18631880

18641881
if emit_held:
18651882
if self.held_text != "":
@@ -1903,18 +1920,29 @@ def emit(
19031920
"accepted_draft_tokens": self.accepted_draft_tokens,
19041921
"rejected_draft_tokens": self.rejected_draft_tokens
19051922
})
1923+
if eos_reason == "stop_string":
1924+
self.held_text = rem_held_text
1925+
rh = {}
1926+
if self.held_text:
1927+
rh.update({ "text": self.held_text })
1928+
if self.held_tokens:
1929+
rh.update({ "token_ids": self.held_tokens.torch().clone() })
1930+
if self.held_probs:
1931+
rh.update({ "token_probs": self.held_probs.torch().clone() })
1932+
if self.held_k_tokens:
1933+
rh.update({ "top_k_tokens": self.held_k_tokens.torch().clone() })
1934+
rh.update({ "top_k_probs": self.held_k_probs.torch().clone() })
1935+
if self.held_logits:
1936+
rh.update({ "logits": self.held_logits.torch().clone() })
1937+
if rh:
1938+
r.update({ "held": rh })
19061939

19071940
if self.identifier is not None:
19081941
r.update({ "identifier": self.identifier })
19091942

19101943
results.append(r)
19111944
return emit_eos, next_token
19121945

1913-
# End on stop tokens
1914-
1915-
if next_token.item() in self.stop_tokens:
1916-
return emit(results, emit_eos = True, eos_reason = "stop_token")
1917-
19181946
# Decode and buffer output
19191947

19201948
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens)
@@ -1934,6 +1962,11 @@ def emit(
19341962
if self.return_logits:
19351963
self.held_logits.append(logits[:1, :, :])
19361964

1965+
# End on stop tokens
1966+
1967+
if next_token.item() in self.stop_tokens:
1968+
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
1969+
19371970
# Stop if we reach max_new_tokens
19381971

19391972
if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
@@ -2032,8 +2065,19 @@ def rewind_checkpoint():
20322065
self.stop_strings_utf32_buffer
20332066
)
20342067
if match >= 0:
2068+
held = self.held_text[match:]
20352069
self.held_text = self.held_text[:match]
2036-
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string")
2070+
for s in self.stop_strings:
2071+
if held.startswith(s):
2072+
return emit(
2073+
results,
2074+
emit_eos = True,
2075+
emit_held = True,
2076+
eos_reason = "stop_string",
2077+
stop_string = s,
2078+
rem_held_text = held
2079+
)
2080+
assert False, "Detected stop string but couldn't identify it (logic error)"
20372081
if match == -2:
20382082
return emit(results)
20392083

exllamav2/generator/dynamic_async.py

+6
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class ExLlamaV2DynamicJobAsync:
7575
job: ExLlamaV2DynamicJob
7676
queue: asyncio.Queue
7777
generator: ExLlamaV2DynamicGeneratorAsync
78+
cancelled: bool = False
7879

7980
def __init__(self, generator: ExLlamaV2DynamicGeneratorAsync, *args: object, **kwargs: object):
8081
self.generator = generator
@@ -87,6 +88,10 @@ async def put_result(self, result):
8788

8889
async def __aiter__(self):
8990
while True:
91+
# Get out if the job is cancelled
92+
if self.cancelled:
93+
break
94+
9095
result = await self.queue.get()
9196
if isinstance(result, Exception):
9297
raise result
@@ -96,3 +101,4 @@ async def __aiter__(self):
96101

97102
async def cancel(self):
98103
await self.generator.cancel(self)
104+
self.cancelled = True

exllamav2/linear.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def __init__(self,
5454
f_beg: int = None,
5555
f_end: int = None,
5656
is_sub_module: bool = True,
57-
altpack_qkv: bool = False):
57+
altpack_qkv: bool = False,
58+
normalize_unq: bool = False):
5859
super().__init__(model, key)
5960

6061
self.is_sub_module = is_sub_module
@@ -89,20 +90,23 @@ def __init__(self,
8990
self.altpack_qkv = altpack_qkv
9091

9192
self.assumed_footprint = in_features * (out_features + self.padding) * 2 + 128
93+
self.normalize_unq = normalize_unq
9294

9395

9496
@torch.inference_mode
9597
def load(self,
9698
w: dict | nn.Parameter | tuple | None = None,
9799
device_tensors: bool = True):
98100

101+
cfg = self.model.config
102+
99103
if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
100104
if w is None: w = self.load_weight()
101105

102106
# Load quantized linear layer from dictionary
103107

104108
if isinstance(w, dict):
105-
assert not self.model.config.load_in_q4, "Can't load quantized layer in Q4 mode"
109+
assert not cfg.load_in_q4, "Can't load quantized layer in Q4 mode"
106110
if self.has_bias:
107111
assert "bias" in w, self.key + " has no bias but bias expected"
108112
else:
@@ -117,14 +121,17 @@ def load(self,
117121
self.q_handle = ext.make_q_matrix(w,
118122
self.temp_dq,
119123
prescale = self.prescale,
120-
max_dq_rows = self.model.config.max_dq_size // self.out_features)
124+
max_dq_rows = cfg.max_dq_size // self.out_features,
125+
offset_qzeros = cfg.checkpoint_offset_qzeros)
121126
self.prev_prescale = self.prescale
122127
self.prescale = 1
123128

124129
# Load FP16 linear layer without bias, optionally quantize to Q4
125130

126131
elif isinstance(w, nn.Parameter):
127132
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
133+
if self.normalize_unq:
134+
w = self.normalize(w)
128135
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
129136
if not self.model.config.load_in_q4 or not ".layers." in self.key:
130137
self.linear = nn.Linear(self.in_features, self.out_features, self.has_bias, device = "meta", dtype = torch.float16)
@@ -138,6 +145,8 @@ def load(self,
138145

139146
elif isinstance(w, tuple):
140147
assert self.has_bias, self.key + " has bias tensor but bias is not expected"
148+
if self.normalize_unq:
149+
w = self.normalize(w[0]), w[1]
141150
ww = w[0]
142151
wb = w[1]
143152
if self.padding > 0:
@@ -154,6 +163,10 @@ def load(self,
154163
self.fp16_bias = wb
155164

156165

166+
def normalize(self, w: torch.Tensor):
167+
return nn.functional.normalize(w)
168+
169+
157170
def matrix_shape(self):
158171

159172
return self.in_features, self.out_features

exllamav2/lora.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def __init__(self,
8181
f = load_file(self.lora_path, map_location = "cpu")
8282

8383
for key in f.keys():
84+
if any(key.endswith(x) for x in [".original_module.weight", ".modules_to_save.weight"]):
85+
continue
8486
tensor = f[key]
8587

8688
# Find target

0 commit comments

Comments
 (0)