Skip to content

Commit 56ae879

Browse files
committed
Merge branch 'refs/heads/202405-cached-states' into dev
2 parents f0b81bb + 0ece2f3 commit 56ae879

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

conversion/measure.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,18 @@ def test_quant(source: ExLlamaV2Linear,
125125

126126
def test_error(module, hidden_states, target_states, cache, attn_params):
127127

128-
rfn_sum = 0
128+
rfn_sum = torch.tensor(0.0).cuda()
129129
rfn_count = 0
130130
for x, xref in zip(hidden_states, target_states):
131131
x = x.cuda()
132132
xref = xref.cuda()
133133
xtest = module.forward(x, cache, attn_params)
134134
xtest = xtest[0].float()
135135
xref = xref[0].float()
136-
rfn_sum += (torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')).item()
136+
rfn_sum += torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')
137137
rfn_count += 1
138138

139-
return max(1e-6, 1 - (rfn_sum / rfn_count))
139+
return max(1e-6, 1 - (rfn_sum.item() / rfn_count))
140140

141141

142142
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params, keep_q = False):
@@ -382,7 +382,7 @@ def print_status_box(*content_lines):
382382
print('-' * box_width)
383383

384384
@torch.inference_mode()
385-
def measure_quant(job, save_fn, model):
385+
def measure_quant(job, save_fn, model, hidden_state_offload_layers):
386386

387387
# vars for status box
388388
time_spent_list = []
@@ -418,8 +418,9 @@ def measure_quant(job, save_fn, model):
418418

419419
hidden_states = []
420420
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
421-
for k in sorted(f.keys()):
422-
hidden_states.append(f.get_tensor(k))
421+
for i, k in enumerate(sorted(f.keys())):
422+
t = f.get_tensor(k)
423+
hidden_states.append(t.to("cuda:0") if i < hidden_state_offload_layers else t)
423424

424425
index = job["last_module_idx"]
425426
while True:
@@ -515,18 +516,19 @@ def measure_quant(job, save_fn, model):
515516

516517
x = hidden_states[i].to("cuda:0")
517518
outputs = module.forward(x, cache, attn_params, intermediates = True)
519+
target_device = "cuda:0" if i < hidden_state_offload_layers else "cpu"
518520

519521
# Hessians
520522

521523
if mode == "self_attn":
522524
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
523525
quantizers["o_proj"].add_batch(outputs["attn_output"])
524-
target_states.append(outputs["hidden_states"].to("cpu"))
526+
target_states.append(outputs["hidden_states"].to(target_device))
525527

526528
if mode == "mlp":
527529
quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj
528530
quantizers["down_proj"].add_batch(outputs["pre_down"])
529-
target_states.append(outputs["hidden_states"].to("cpu"))
531+
target_states.append(outputs["hidden_states"].to(target_device))
530532

531533
if mode == "block_sparse_moe":
532534
for j in range(model.config.num_experts):
@@ -537,19 +539,19 @@ def measure_quant(job, save_fn, model):
537539
uncalibrated_experts[j] += 1
538540
else:
539541
uncalibrated_experts[j] += 1
540-
target_states.append(outputs["hidden_states"].to("cpu"))
542+
target_states.append(outputs["hidden_states"].to(target_device))
541543

542544
if mode == "parallel_decoder":
543545
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K, V, up_proj and gate_proj
544546
quantizers["o_proj"].add_batch(outputs["attn_output"])
545547
quantizers["down_proj"].add_batch(outputs["pre_down"])
546548
hidden_states[i] = outputs["post_norm"]
547-
target_states_attn.append(outputs["hidden_states_attn"].to("cpu"))
548-
target_states_mlp.append(outputs["hidden_states_mlp"].to("cpu"))
549-
target_states.append(outputs["hidden_states"].to("cpu"))
549+
target_states_attn.append(outputs["hidden_states_attn"].to(target_device))
550+
target_states_mlp.append(outputs["hidden_states_mlp"].to(target_device))
551+
target_states.append(outputs["hidden_states"].to(target_device))
550552

551553
if mode == "pos_emb":
552-
target_states.append(outputs["hidden_states"].to("cpu"))
554+
target_states.append(outputs["hidden_states"].to(target_device))
553555

554556
# For MoE layers, warn if any layer received less than 10% of a calibration batch
555557

conversion/quantize.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def quant(job, save_fn, model):
439439
cal_ids = f.get_tensor("input_ids")
440440
module.linear.weight.data = module.linear.weight.data.to("cuda:0")
441441

442-
rfn_sum = 0
442+
rfn_sum = torch.tensor(0.0).cuda()
443443
rfn_count = 0
444444
logprob_sum = 0.0
445445
logprob_count = 0
@@ -458,7 +458,7 @@ def quant(job, save_fn, model):
458458
output_ref = target_states[i].to("cuda:0")
459459
output_ref = output_ref[0].float()
460460

461-
rfn_sum += (torch.linalg.norm(output - output_ref, 'fro') / torch.linalg.norm(output_ref, 'fro')).item()
461+
rfn_sum += torch.linalg.norm(output - output_ref, 'fro') / torch.linalg.norm(output_ref, 'fro')
462462
rfn_count += 1
463463

464464
output_ref = None
@@ -485,7 +485,7 @@ def quant(job, save_fn, model):
485485

486486
if mode != "linear":
487487

488-
err = rfn_sum / rfn_count
488+
err = rfn_sum.item() / rfn_count
489489
print(f" -- Module quantized, rfn_error: {err:1.6f}")
490490

491491
else:

convert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
parser.add_argument("-l", "--length", type = int, default = 2048, help = "Max no. tokens per sample")
3030
parser.add_argument("-ml", "--measurement_length", type = int, default = 2048, help = "Max no. tokens per sample when measuring")
3131
parser.add_argument("-so", "--status_output", action = "store_true", help = "Include machine-parseable status updates in console output")
32+
parser.add_argument("-hsol", "--hidden_state_offload_layers", type = int, default = 0, help = "Number of hidden/target states to keep in VRAM. Speed-up but increases VRAM usage")
3233

3334
args = parser.parse_args()
3435

@@ -242,7 +243,7 @@ def save_job():
242243
model = ExLlamaV2(config)
243244
model.load(lazy = True)
244245

245-
status = measure_quant(job, save_job, model) # capturing the graceful exits
246+
status = measure_quant(job, save_job, model, args.hidden_state_offload_layers) # capturing the graceful exits
246247
if status == "interrupted":
247248
print("Process interrupted. Exiting gracefully.")
248249
save_job()

0 commit comments

Comments
 (0)