Skip to content

Commit dafb508

Browse files
authored
Merge pull request #403 from turboderp/dev
Merge dev branch
2 parents f6b7faa + 3e8e306 commit dafb508

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1606
-576
lines changed

conversion/adaptivegptq.py

+172-125
Large diffs are not rendered by default.

conversion/compile.py

+32-25
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,17 @@ def compile_model(job, save_fn, model):
9797
if isinstance(module, ExLlamaV2ParallelDecoder):
9898

9999
has_gate = model.config.arch.mlp_gate
100+
has_qk_norm = model.config.use_qk_norm
100101
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
101102
d = get_q_module(job, module.attn.q_proj); out_dict.update(d); current_size += _dsize(d)
102103
d = get_q_module(job, module.attn.k_proj); out_dict.update(d); current_size += _dsize(d)
103104
d = get_q_module(job, module.attn.v_proj); out_dict.update(d); current_size += _dsize(d)
104105
d = get_q_module(job, module.attn.o_proj); out_dict.update(d); current_size += _dsize(d)
105-
if has_gate: d = get_q_module(job, module.mlp.gate_proj); out_dict.update(d); current_size += _dsize(d)
106+
if has_qk_norm:
107+
d = get_f_module(job, module.attn.q_norm); out_dict.update(d); current_size += _dsize(d)
108+
d = get_f_module(job, module.attn.k_norm); out_dict.update(d); current_size += _dsize(d)
109+
if has_gate:
110+
d = get_q_module(job, module.mlp.gate_proj); out_dict.update(d); current_size += _dsize(d)
106111
d = get_q_module(job, module.mlp.up_proj); out_dict.update(d); current_size += _dsize(d)
107112
d = get_q_module(job, module.mlp.down_proj); out_dict.update(d); current_size += _dsize(d)
108113

@@ -206,27 +211,29 @@ def compile_model(job, save_fn, model):
206211

207212
# Add signature to config.json
208213

209-
ds = job["cal_dataset"]
210-
if ds is not None: qcfg_ds = os.path.split(ds)[1]
211-
else: qcfg_ds = "(default)"
212-
213-
qcfg = {
214-
"quant_method": "exl2",
215-
"version": __version__,
216-
"bits": job["bits"],
217-
"head_bits": job["head_bits"],
218-
"calibration": {
219-
"rows": job["dataset_rows"],
220-
"length": job["length"],
221-
"dataset": qcfg_ds
222-
},
223-
}
224-
225-
config_json = os.path.join(out_dir, "config.json")
226-
with open(config_json, "r") as f:
227-
config_dict = json.load(f)
228-
229-
config_dict["quantization_config"] = qcfg
230-
231-
with open(config_json, "w") as f:
232-
f.write(json.dumps(config_dict, indent = 4))
214+
if job["compile_full"] is not None:
215+
216+
ds = job["cal_dataset"]
217+
if ds is not None: qcfg_ds = os.path.split(ds)[1]
218+
else: qcfg_ds = "(default)"
219+
220+
qcfg = {
221+
"quant_method": "exl2",
222+
"version": __version__,
223+
"bits": job["bits"],
224+
"head_bits": job["head_bits"],
225+
"calibration": {
226+
"rows": job["dataset_rows"],
227+
"length": job["length"],
228+
"dataset": qcfg_ds
229+
},
230+
}
231+
232+
config_json = os.path.join(out_dir, "config.json")
233+
with open(config_json, "r") as f:
234+
config_dict = json.load(f)
235+
236+
config_dict["quantization_config"] = qcfg
237+
238+
with open(config_json, "w") as f:
239+
f.write(json.dumps(config_dict, indent = 4))

conversion/measure.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_error(module, hidden_states, target_states, cache, attn_params):
133133
return max(1e-6, 1 - (rfn_sum / rfn_count))
134134

135135

136-
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params):
136+
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params, keep_q = False):
137137

138138
qjobs, qmaps = get_qparams_reduced(qparams_attn)
139139
results = []
@@ -181,6 +181,10 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p
181181
"o_proj": qjobs[3][o].get_dict() }
182182
results.append(r)
183183

184+
for x in ["k_proj", "v_proj", "o_proj"] + (["q_proj"] if not keep_q else []):
185+
if x in quantizers:
186+
del quantizers[x]
187+
184188
return results
185189

186190

@@ -257,6 +261,9 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa
257261
"down_proj": qjobs[2][d].get_dict() }
258262
results.append(r)
259263

264+
for x in ["up_proj", "down_proj", "gate_proj"]:
265+
if x in quantizers:
266+
del quantizers[x]
260267

261268
return results
262269

@@ -329,10 +336,22 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att
329336

330337
def measure_parallel_decoder(module, hidden_states, target_states_attn, target_states_mlp, quantizers, cache, attn_params):
331338

339+
for i in range(len(hidden_states)):
340+
hidden_states[i] = hidden_states[i].cpu()
341+
332342
print(f" -- Sublayer: {module.key}.self_attn")
333-
results_attn = measure_attn(module.attn, hidden_states, target_states_attn, quantizers, cache, attn_params)
343+
results_attn = measure_attn(module.attn, hidden_states, target_states_attn, quantizers, cache, attn_params, keep_q = True)
344+
345+
module.attn.unload()
346+
gc.collect()
347+
torch.cuda.empty_cache()
348+
334349
print(f" -- Sublayer: {module.key}.mlp")
335350
results_mlp = measure_mlp(module.mlp, hidden_states, target_states_mlp, quantizers, cache, attn_params, "q_proj")
351+
352+
for i in range(len(hidden_states)):
353+
hidden_states[i] = hidden_states[i].to("cuda:0")
354+
336355
r = { "attn": results_attn,
337356
"mlp": results_mlp }
338357
return r
@@ -367,7 +386,9 @@ def measure_quant(job, save_fn, model):
367386
accuracy_count = 0
368387
overall_rolling_accuracy = 0
369388

370-
snapshot_interval = 10
389+
last_snapshot_time = time.time()
390+
snapshot_interval_s = 90
391+
371392
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
372393
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
373394
measurement = job.get("measurement", {})
@@ -602,7 +623,10 @@ def measure_quant(job, save_fn, model):
602623

603624
# Checkpoint
604625

605-
if index % snapshot_interval == 0 or index == len(model.modules) - 1:
626+
time_since_snapshot = time.time() - last_snapshot_time
627+
if time_since_snapshot > snapshot_interval_s or index == len(model.modules) - 1:
628+
629+
print(" -- Saving checkpoint...")
606630

607631
save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
608632
save_file(save_dict, temp_filename)
@@ -621,6 +645,8 @@ def measure_quant(job, save_fn, model):
621645
del job["invalid"]
622646
save_fn()
623647

648+
last_snapshot_time = time.time()
649+
624650
# Export measurement
625651

626652
exp_measurement = { "measurement": job["measurement"],

conversion/optimize.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ def optimize(job, save_fn, model):
1111

1212
error_norm = 2.4
1313
max_step_size = 2
14+
first_layer_bias = 10
15+
bias_layers = 2
16+
bias_iter = 10
1417

1518
key = "model.layers.0"
1619
key_q = key + ".self_attn.q_proj"
@@ -60,8 +63,11 @@ def optimize(job, save_fn, model):
6063

6164
measurement = job["measurement"]
6265

63-
def fn(x):
64-
return 1 - ((1 - x) ** error_norm)
66+
def fn(x, idx):
67+
if idx < bias_layers:
68+
return 1 - ((1 - x) ** error_norm) * first_layer_bias
69+
else:
70+
return 1 - ((1 - x) ** error_norm)
6571

6672
weights = []
6773
values = []
@@ -74,7 +80,7 @@ def fn(x):
7480
m1 = measurement["model.layers." + str(i) + ".self_attn"]
7581
m2 = measurement["model.layers." + str(i) + "." + mlp_mode]
7682
for m in [m1, m2]:
77-
v = [fn(e["accuracy"]) for e in m]
83+
v = [fn(e["accuracy"], i) for e in m]
7884
w = [e["total_bits"] for e in m]
7985
weights.append(w)
8086
values.append(v)
@@ -111,10 +117,13 @@ def fn(x):
111117
value = 1
112118
for i in range(num_layers * 2): value *= values[i][0]
113119

120+
iteration = 0
121+
114122
while True:
115123
min_idx = -1
116124
min_value = float("inf")
117-
for i in range(num_layers * 2):
125+
iteration += 1
126+
for i in range(bias_layers if iteration < bias_iter else num_layers * 2):
118127
s = f_solution[i]
119128
if values[i][s] < min_value:
120129
if s < len(weights[i]) - 1:
@@ -211,6 +220,7 @@ def improve(solution, s_weight, hold = None):
211220

212221
print(" -- Quantization strategy:")
213222

223+
errp = 1
214224
job["strategy"] = {}
215225
for layer_ in range(num_layers):
216226

@@ -224,5 +234,8 @@ def improve(solution, s_weight, hold = None):
224234
bpw = p["total_bits"] / n
225235
err = 1 - p["accuracy"]
226236
print(f" -- {k:50} {bpw:1.4f} bpw - exp. error: {err:1.8f}")
237+
errp *= (1 - err)
238+
239+
print(f" -- Total exp. error: {1 - errp:1.12f}")
227240

228241
xx = 0

conversion/quantize.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -83,24 +83,30 @@ def quant_linear(job: dict,
8383
recons_dict = {}
8484
recons_keys = ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"]
8585
if source.has_bias: recons_keys += ["bias"]
86+
r_device = packed_dict[source.key + ".q_weight"].device
87+
recons_linear.set_device_idx(r_device.index)
8688
for k in recons_keys:
87-
recons_dict[k] = packed_dict[source.key + "." + k]
89+
recons_dict[k] = packed_dict[source.key + "." + k].to(r_device)
8890
recons_dict["q_perm"] = torch.argsort(recons_dict["q_invperm"]).to(torch.int)
89-
recons_linear.load(recons_dict)
91+
recons_linear.load(recons_dict, device_tensors = False)
9092

9193
# Sanity test to ensure reconstructed matrix matches unpacked matrix
9294

9395
quant_w = source.linear.weight.T
9496
recons_w = recons_linear.get_weight_tensor_dq()
9597

96-
if quant_w.numel() <= 1e9:
97-
ident = torch.eye(recons_linear.in_features, dtype = torch.half).cuda()
98-
recons_w2 = recons_linear.forward(ident, force_cuda = True)
99-
recons_w2.sub_(quant_w)
100-
if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"])
101-
recons_w2.abs_()
102-
diff2 = torch.max(recons_w2)
103-
else:
98+
try:
99+
if quant_w.numel() <= 1e9:
100+
ident = torch.eye(recons_linear.in_features, dtype = torch.half, device = r_device)
101+
recons_w2 = recons_linear.forward(ident, force_cuda = True)
102+
recons_w2.sub_(quant_w)
103+
if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"])
104+
recons_w2.abs_()
105+
diff2 = torch.max(recons_w2)
106+
else:
107+
diff2 = 0
108+
except torch.cuda.OutOfMemoryError as e:
109+
print(f" !! Warning, not enough VRAM for second sanity check of {source.key}")
104110
diff2 = 0
105111

106112
quant_w.sub_(recons_w)
@@ -120,7 +126,7 @@ def quant_linear(job: dict,
120126

121127
# Apply reconstructed matrix to source layer
122128

123-
source.linear.weight.data = recons_w.T
129+
source.linear.weight.data = recons_w.T.to("cuda:0")
124130

125131

126132
def quant_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat):
@@ -246,7 +252,9 @@ def quant_parallel_decoder(job, module, hidden_states, target_states, quantizers
246252
@torch.inference_mode()
247253
def quant(job, save_fn, model):
248254

249-
snapshot_interval = 10
255+
last_snapshot_time = time.time()
256+
snapshot_interval_s = 90
257+
250258
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
251259
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
252260
strategy = job["strategy"]
@@ -412,6 +420,7 @@ def quant(job, save_fn, model):
412420
strat_mlp = strategy[module.key + ".mlp"]
413421
quant_parallel_decoder(job, module, hidden_states, target_states, quantizers, attn_params, strat_attn, strat_mlp)
414422

423+
torch.cuda.synchronize()
415424
quantizers.clear()
416425
gc.collect()
417426
torch.cuda.empty_cache()
@@ -421,6 +430,7 @@ def quant(job, save_fn, model):
421430
if mode == "linear":
422431
with safe_open(job["cal_filename"], framework = "pt", device = "cpu") as f:
423432
cal_ids = f.get_tensor("input_ids")
433+
module.linear.weight.data = module.linear.weight.data.to("cuda:0")
424434

425435
rfn_sum = 0
426436
rfn_count = 0
@@ -494,7 +504,10 @@ def quant(job, save_fn, model):
494504

495505
# Checkpoint
496506

497-
if index % snapshot_interval == 0 or index == len(model.modules) - 1:
507+
time_since_snapshot = time.time() - last_snapshot_time
508+
if time_since_snapshot > snapshot_interval_s or index == len(model.modules) - 1:
509+
510+
print(" -- Saving checkpoint...")
498511

499512
if mode != "linear":
500513
save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
@@ -512,3 +525,5 @@ def quant(job, save_fn, model):
512525

513526
del job["invalid"]
514527
save_fn()
528+
529+
time_since_snapshot = time.time()

convert.py

+12
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,12 @@ def save_job():
216216

217217
if progress == "measure_quant":
218218
print(f" -- Measuring quantization impact...")
219+
220+
model.unload()
221+
config.max_output_len = 16
222+
model = ExLlamaV2(config)
223+
model.load(lazy = True)
224+
219225
status = measure_quant(job, save_job, model) # capturing the graceful exits
220226
if status == "interrupted":
221227
print("Process interrupted. Exiting gracefully.")
@@ -227,6 +233,12 @@ def save_job():
227233
job["progress"] = "finished"
228234
save_job()
229235

236+
model.unload()
237+
config.max_output_len = None
238+
model = ExLlamaV2(config)
239+
model.load(lazy = True)
240+
241+
230242
if progress == "optimize":
231243

232244
print(f" -- Optimizing...")

0 commit comments

Comments
 (0)