Skip to content

Commit d3e46da

Browse files
committed
Add batch processing to prevent OOM errors
1 parent 57ee846 commit d3e46da

File tree

2 files changed

+241
-261
lines changed

2 files changed

+241
-261
lines changed

eval/humaneval.py

+92-102
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
from __future__ import annotations
2-
import sys, os
2+
import sys
3+
import os
4+
import argparse
5+
import subprocess
6+
import torch
7+
38
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
49
from human_eval.data import write_jsonl, read_problems
510
from exllamav2 import model_init
611
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
712
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
8-
import argparse, contextlib, subprocess
913
import util
1014

1115
# Args
1216

13-
parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
14-
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
15-
parser.add_argument("-cs", "--cache_size", type = int, default = None)
16-
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
17-
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
18-
parser.add_argument("-cq6", "--cache_q6", action = "store_true", help = "Use Q6 cache")
19-
parser.add_argument("-cq8", "--cache_q8", action = "store_true", help = "Use Q8 cache")
20-
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
21-
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
22-
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
23-
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
24-
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6")
17+
parser = argparse.ArgumentParser(description="Run HumanEval evaluation on EXL2 model")
18+
parser.add_argument("-o", "--output", type=str, help="Output .jsonl filename", required=True)
19+
parser.add_argument("-cs", "--cache_size", type=int, default=None)
20+
parser.add_argument("-spt", "--samples_per_task", type=int, default=200)
21+
parser.add_argument("-cq4", "--cache_q4", action="store_true", help="Use Q4 cache")
22+
parser.add_argument("-cq6", "--cache_q6", action="store_true", help="Use Q6 cache")
23+
parser.add_argument("-cq8", "--cache_q8", action="store_true", help="Use Q8 cache")
24+
parser.add_argument("--max_tokens", type=int, default=768, help="Max number of tokens for each completion")
25+
parser.add_argument("-pf", "--prompt_format", type=str,
26+
help="Instruct format to apply. Default is raw completion (for base models)")
27+
parser.add_argument("-v", "--verbose", action="store_true", help="Spam completions to console while generating")
28+
parser.add_argument("-e", "--eval", action="store_true", help="Run evaluation script on output file after sampling")
29+
parser.add_argument("-temp", "--temperature", type=float, default=0.6, help="Sampling temperature (0 for greedy)")
30+
parser.add_argument("-bs", "--batch_size", type=int, default=50, help="Number of problems to process in each batch")
2531
model_init.add_args(parser)
2632
args = parser.parse_args()
2733

@@ -37,38 +43,27 @@
3743
# Prompt formats
3844

3945
prompt_formats = {
40-
"raw": (
41-
"```python\n{{problem}} ",
42-
" "
43-
),
46+
"raw": ("```python\n{{problem}} ", " "),
4447
"granite": (
4548
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
4649
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
4750
" "
4851
),
4952
"llama": (
50-
"[INST] <<SYS>>\n"
51-
"You are a helpful AI coding assistant.\n"
52-
"<</SYS>>\n\n"
53-
"Complete the following Python function:\n\n"
54-
"{{problem}} [/INST] "
53+
"[INST] <<SYS>>\nYou are a helpful AI coding assistant.\n<</SYS>>\n\n"
54+
"Complete the following Python function:\n\n{{problem}} [/INST] "
5555
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
5656
" "
5757
),
5858
"llama3": (
59-
"<|start_header_id|>system<|end_header_id|>\n\n"
60-
"You are a helpful AI coding assistant.<|eot_id|>"
61-
"<|start_header_id|>user<|end_header_id|>\n\n"
62-
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
63-
"<|start_header_id|>assistant<|end_header_id|>\n\n"
64-
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
59+
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI coding assistant.<|eot_id|>"
60+
"<|start_header_id|>user<|end_header_id|>\n\nComplete the following Python function:\n\n{{problem}}<|eot_id|>"
61+
"<|start_header_id|>assistant<|end_header_id|>\n\nSure! Here is how you might implement the function:\n\n```python\n{{problem}}",
6562
" "
6663
),
6764
"gemma": (
68-
"<bos><start_of_turn>user\n"
69-
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
70-
"<start_of_turn>model\n"
71-
"```python\n{{problem}}",
65+
"<bos><start_of_turn>user\nComplete the following Python function:\n\n{{problem}}<|eot_id|>"
66+
"<start_of_turn>model\n```python\n{{problem}}",
7267
" "
7368
)
7469
}
@@ -88,93 +83,71 @@
8883
model_init.print_options(args)
8984
model, tokenizer = model_init.init(
9085
args,
91-
allow_auto_split = True,
92-
progress = True,
93-
max_output_len = 4,
94-
max_input_len = 2048
86+
allow_auto_split=True,
87+
progress=True,
88+
max_output_len=4,
89+
max_input_len=2048
9590
)
9691

97-
if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
98-
elif args.cache_q6: cache_type = ExLlamaV2Cache_Q6
99-
elif args.cache_q8: cache_type = ExLlamaV2Cache_Q8
100-
else: cache_type = ExLlamaV2Cache
92+
if args.cache_q4:
93+
cache_type = ExLlamaV2Cache_Q4
94+
elif args.cache_q6:
95+
cache_type = ExLlamaV2Cache_Q6
96+
elif args.cache_q8:
97+
cache_type = ExLlamaV2Cache_Q8
98+
else:
99+
cache_type = ExLlamaV2Cache
101100
cache = cache_type(
102101
model,
103-
lazy = not model.loaded,
104-
max_seq_len = args.cache_size or model.config.max_seq_len
102+
lazy=not model.loaded,
103+
max_seq_len=args.cache_size or model.config.max_seq_len
105104
)
106105

107106
if not model.loaded:
108-
model.load_autosplit(cache, progress = True)
107+
model.load_autosplit(cache, progress=True)
109108

110109
# Generator
111110

112111
generator = ExLlamaV2DynamicGenerator(
113-
model = model,
114-
cache = cache,
115-
tokenizer = tokenizer,
116-
max_batch_size = 256,
117-
max_q_size = 4
112+
model=model,
113+
cache=cache,
114+
tokenizer=tokenizer,
115+
max_batch_size=256,
116+
max_q_size=4
118117
)
119118

120119
gen_settings = ExLlamaV2Sampler.Settings(
121-
token_repetition_penalty = 1.0,
122-
temperature = 0.6,
123-
top_k = 50,
124-
top_p = 0.6
120+
token_repetition_penalty=1.0,
121+
temperature=args.temperature,
122+
top_k=50,
123+
top_p=0.6
125124
)
126125

127-
# Get problems
128-
129-
problems = read_problems()
130-
num_samples_per_task = args.samples_per_task
131-
132-
# Create jobs
133126

134-
with util.get_progress() as progress:
135-
136-
task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
137-
for problem_id, problem in problems.items():
127+
def process_batch(batch_problems, batch_size, progress, sample_task, generate_task):
128+
samples = []
138129

130+
for problem_id, problem in batch_problems.items():
139131
b_problem = problem["prompt"]
140132
f_problem = prompt_format.replace("{{problem}}", b_problem)
141133
input_ids = tokenizer.encode(f_problem, encode_special_tokens=True, add_bos=True)
142134

143-
for s in range(num_samples_per_task):
144-
135+
for s in range(batch_size):
145136
job = ExLlamaV2DynamicJob(
146-
input_ids = input_ids,
147-
gen_settings = gen_settings,
148-
max_new_tokens = args.max_tokens,
149-
stop_conditions = [tokenizer.eos_token_id],
150-
token_healing = True,
151-
identifier = (problem_id, s),
152-
min_new_tokens = 6
137+
input_ids=input_ids,
138+
gen_settings=gen_settings,
139+
max_new_tokens=args.max_tokens,
140+
stop_conditions=[tokenizer.eos_token_id],
141+
token_healing=True,
142+
identifier=(problem_id, s),
143+
min_new_tokens=6
153144
)
154-
155145
generator.enqueue(job)
156-
progress.update(task1, advance = 1)
157-
158-
# Collect samples here
159-
160-
samples = []
161-
162-
# Work
163-
164-
total_jobs = generator.num_remaining_jobs()
165-
cm = contextlib.nullcontext() if args.verbose else util.get_progress()
166-
with cm as progress:
167-
168-
if not args.verbose:
169-
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Generating samples")
146+
progress.update(sample_task, advance=1)
170147

171148
while generator.num_remaining_jobs():
172-
173149
results = generator.iterate()
174150
for result in results:
175-
176-
# End sample if generator says EOS or if there is a non-indented line at the end of the output
177-
178151
job = result["job"]
179152
eos = False
180153
completion = job.full_completion
@@ -186,32 +159,49 @@
186159
eos = True
187160
eos = eos or result["eos"]
188161

189-
# Collect completed sample
190-
191162
if eos:
192163
identifier = result["identifier"]
193-
sample = problems[identifier[0]]["prompt"] + prefix + completion.strip()
164+
sample = batch_problems[identifier[0]]["prompt"] + prefix + completion.strip()
194165
if not result["eos"]:
195166
generator.cancel(job)
196167

197168
if args.verbose:
198169
print("----------------------------------------------------------------------")
199-
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {num_samples_per_task}")
170+
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {batch_size}")
200171
print("----------------------------------------------------------------------")
201172
print(sample)
202173
print()
203174
else:
204-
progress.update(task1, advance = 1)
175+
progress.update(generate_task, advance=1)
205176

206-
samples.append(dict(task_id = identifier[0], completion = prefix + completion.strip()))
177+
samples.append(dict(task_id=identifier[0], completion=prefix + completion.strip()))
207178

208-
# Save output
179+
return samples
180+
181+
182+
# Main execution
183+
problems = read_problems()
184+
all_samples = []
185+
batch_size = args.batch_size
186+
total_samples = len(problems) * args.samples_per_task
187+
188+
with util.get_progress() as progress:
189+
sample_task = progress.add_task("[red]Sample", total=total_samples, name="Creating sample jobs")
190+
generate_task = progress.add_task("[green]Sample", total=total_samples, name="Generating samples")
191+
192+
for i in range(0, len(problems), batch_size):
193+
batch_problems = dict(list(problems.items())[i:i + batch_size])
194+
batch_samples = process_batch(batch_problems, args.samples_per_task, progress, sample_task, generate_task)
195+
all_samples.extend(batch_samples)
209196

197+
# Optional: Clear CUDA cache to free up memory
198+
if torch.cuda.is_available():
199+
torch.cuda.empty_cache()
200+
201+
# Save output
210202
print(f" -- Saving: {args.output}")
211-
write_jsonl(args.output, samples)
203+
write_jsonl(args.output, all_samples)
212204

213205
# Optionally launch eval script
214-
215206
if args.eval:
216-
subprocess.run(["evaluate_functional_correctness", args.output])
217-
207+
subprocess.run(["evaluate_functional_correctness", args.output])

0 commit comments

Comments
 (0)