Skip to content

Commit 0ece2f3

Browse files
committed
add layer GPU offloading for hidden/target states
1 parent b428e23 commit 0ece2f3

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

conversion/measure.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def print_status_box(*content_lines):
382382
print('-' * box_width)
383383

384384
@torch.inference_mode()
385-
def measure_quant(job, save_fn, model):
385+
def measure_quant(job, save_fn, model, hidden_state_offload_layers):
386386

387387
# vars for status box
388388
time_spent_list = []
@@ -418,8 +418,9 @@ def measure_quant(job, save_fn, model):
418418

419419
hidden_states = []
420420
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
421-
for k in sorted(f.keys()):
422-
hidden_states.append(f.get_tensor(k))
421+
for i, k in enumerate(sorted(f.keys())):
422+
t = f.get_tensor(k)
423+
hidden_states.append(t.to("cuda:0") if i < hidden_state_offload_layers else t)
423424

424425
index = job["last_module_idx"]
425426
while True:
@@ -515,18 +516,19 @@ def measure_quant(job, save_fn, model):
515516

516517
x = hidden_states[i].to("cuda:0")
517518
outputs = module.forward(x, cache, attn_params, intermediates = True)
519+
target_device = "cuda:0" if i < hidden_state_offload_layers else "cpu"
518520

519521
# Hessians
520522

521523
if mode == "self_attn":
522524
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
523525
quantizers["o_proj"].add_batch(outputs["attn_output"])
524-
target_states.append(outputs["hidden_states"].to("cpu"))
526+
target_states.append(outputs["hidden_states"].to(target_device))
525527

526528
if mode == "mlp":
527529
quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj
528530
quantizers["down_proj"].add_batch(outputs["pre_down"])
529-
target_states.append(outputs["hidden_states"].to("cpu"))
531+
target_states.append(outputs["hidden_states"].to(target_device))
530532

531533
if mode == "block_sparse_moe":
532534
for j in range(model.config.num_experts):
@@ -537,19 +539,19 @@ def measure_quant(job, save_fn, model):
537539
uncalibrated_experts[j] += 1
538540
else:
539541
uncalibrated_experts[j] += 1
540-
target_states.append(outputs["hidden_states"].to("cpu"))
542+
target_states.append(outputs["hidden_states"].to(target_device))
541543

542544
if mode == "parallel_decoder":
543545
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K, V, up_proj and gate_proj
544546
quantizers["o_proj"].add_batch(outputs["attn_output"])
545547
quantizers["down_proj"].add_batch(outputs["pre_down"])
546548
hidden_states[i] = outputs["post_norm"]
547-
target_states_attn.append(outputs["hidden_states_attn"].to("cpu"))
548-
target_states_mlp.append(outputs["hidden_states_mlp"].to("cpu"))
549-
target_states.append(outputs["hidden_states"].to("cpu"))
549+
target_states_attn.append(outputs["hidden_states_attn"].to(target_device))
550+
target_states_mlp.append(outputs["hidden_states_mlp"].to(target_device))
551+
target_states.append(outputs["hidden_states"].to(target_device))
550552

551553
if mode == "pos_emb":
552-
target_states.append(outputs["hidden_states"].to("cpu"))
554+
target_states.append(outputs["hidden_states"].to(target_device))
553555

554556
# For MoE layers, warn if any layer received less than 10% of a calibration batch
555557

convert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
parser.add_argument("-l", "--length", type = int, default = 2048, help = "Max no. tokens per sample")
3030
parser.add_argument("-ml", "--measurement_length", type = int, default = 2048, help = "Max no. tokens per sample when measuring")
3131
parser.add_argument("-so", "--status_output", action = "store_true", help = "Include machine-parseable status updates in console output")
32+
parser.add_argument("-hsol", "--hidden_state_offload_layers", type = int, default = 0, help = "Number of hidden/target states to keep in VRAM. Speed-up but increases VRAM usage")
3233

3334
args = parser.parse_args()
3435

@@ -242,7 +243,7 @@ def save_job():
242243
model = ExLlamaV2(config)
243244
model.load(lazy = True)
244245

245-
status = measure_quant(job, save_job, model) # capturing the graceful exits
246+
status = measure_quant(job, save_job, model, args.hidden_state_offload_layers) # capturing the graceful exits
246247
if status == "interrupted":
247248
print("Process interrupted. Exiting gracefully.")
248249
save_job()

0 commit comments

Comments
 (0)