1
1
from __future__ import annotations
2
- import sys , os
2
+ import sys
3
+ import os
4
+ import argparse
5
+ import subprocess
6
+ import torch
7
+
3
8
sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
4
9
from human_eval .data import write_jsonl , read_problems
5
10
from exllamav2 import model_init
6
11
from exllamav2 import ExLlamaV2Cache , ExLlamaV2Cache_Q4 , ExLlamaV2Cache_Q6 , ExLlamaV2Cache_Q8
7
12
from exllamav2 .generator import ExLlamaV2DynamicGenerator , ExLlamaV2DynamicJob , ExLlamaV2Sampler
8
- import argparse , contextlib , subprocess
9
13
import util
10
14
11
15
# Args
12
16
13
- parser = argparse .ArgumentParser (description = "Run HumanEval evaluation on EXL2 model" )
14
- parser .add_argument ("-o" , "--output" , type = str , help = "Output .jsonl filename" , required = True )
15
- parser .add_argument ("-cs" , "--cache_size" , type = int , default = None )
16
- parser .add_argument ("-spt" , "--samples_per_task" , type = int , default = 200 )
17
- parser .add_argument ("-cq4" , "--cache_q4" , action = "store_true" , help = "Use Q4 cache" )
18
- parser .add_argument ("-cq6" , "--cache_q6" , action = "store_true" , help = "Use Q6 cache" )
19
- parser .add_argument ("-cq8" , "--cache_q8" , action = "store_true" , help = "Use Q8 cache" )
20
- parser .add_argument ("--max_tokens" , type = int , default = 768 , help = "Max number of tokens for each completion" )
21
- parser .add_argument ("-pf" , "--prompt_format" , type = str , help = "Instruct format to apply. Default is raw completion (for base models) " )
22
- parser .add_argument ("-v" , "--verbose" , action = "store_true" , help = "Spam completions to console while generating" )
23
- parser .add_argument ("-e" , "--eval" , action = "store_true" , help = "Run evaluation script on output file after sampling" )
24
- parser .add_argument ("-temp" , "--temperature" , type = float , help = "Sampling temperature (0 for greedy), default: 0.6" )
17
+ parser = argparse .ArgumentParser (description = "Run HumanEval evaluation on EXL2 model" )
18
+ parser .add_argument ("-o" , "--output" , type = str , help = "Output .jsonl filename" , required = True )
19
+ parser .add_argument ("-cs" , "--cache_size" , type = int , default = None )
20
+ parser .add_argument ("-spt" , "--samples_per_task" , type = int , default = 200 )
21
+ parser .add_argument ("-cq4" , "--cache_q4" , action = "store_true" , help = "Use Q4 cache" )
22
+ parser .add_argument ("-cq6" , "--cache_q6" , action = "store_true" , help = "Use Q6 cache" )
23
+ parser .add_argument ("-cq8" , "--cache_q8" , action = "store_true" , help = "Use Q8 cache" )
24
+ parser .add_argument ("--max_tokens" , type = int , default = 768 , help = "Max number of tokens for each completion" )
25
+ parser .add_argument ("-pf" , "--prompt_format" , type = str ,
26
+ help = "Instruct format to apply. Default is raw completion (for base models)" )
27
+ parser .add_argument ("-v" , "--verbose" , action = "store_true" , help = "Spam completions to console while generating" )
28
+ parser .add_argument ("-e" , "--eval" , action = "store_true" , help = "Run evaluation script on output file after sampling" )
29
+ parser .add_argument ("-temp" , "--temperature" , type = float , default = 0.6 , help = "Sampling temperature (0 for greedy)" )
30
+ parser .add_argument ("-bs" , "--batch_size" , type = int , default = 50 , help = "Number of problems to process in each batch" )
25
31
model_init .add_args (parser )
26
32
args = parser .parse_args ()
27
33
37
43
# Prompt formats
38
44
39
45
prompt_formats = {
40
- "raw" : (
41
- "```python\n {{problem}} " ,
42
- " "
43
- ),
46
+ "raw" : ("```python\n {{problem}} " , " " ),
44
47
"granite" : (
45
48
"Question:\n Complete the following Python function:\n \n {{problem}}\n \n Answer:\n "
46
49
"Sure! Here is how you might implement the function:\n \n ```python\n {{problem}}" ,
47
50
" "
48
51
),
49
52
"llama" : (
50
- "[INST] <<SYS>>\n "
51
- "You are a helpful AI coding assistant.\n "
52
- "<</SYS>>\n \n "
53
- "Complete the following Python function:\n \n "
54
- "{{problem}} [/INST] "
53
+ "[INST] <<SYS>>\n You are a helpful AI coding assistant.\n <</SYS>>\n \n "
54
+ "Complete the following Python function:\n \n {{problem}} [/INST] "
55
55
"Sure! Here is how you might implement the function:\n \n ```python\n {{problem}}" ,
56
56
" "
57
57
),
58
58
"llama3" : (
59
- "<|start_header_id|>system<|end_header_id|>\n \n "
60
- "You are a helpful AI coding assistant.<|eot_id|>"
61
- "<|start_header_id|>user<|end_header_id|>\n \n "
62
- "Complete the following Python function:\n \n {{problem}}<|eot_id|>"
63
- "<|start_header_id|>assistant<|end_header_id|>\n \n "
64
- "Sure! Here is how you might implement the function:\n \n ```python\n {{problem}}" ,
59
+ "<|start_header_id|>system<|end_header_id|>\n \n You are a helpful AI coding assistant.<|eot_id|>"
60
+ "<|start_header_id|>user<|end_header_id|>\n \n Complete the following Python function:\n \n {{problem}}<|eot_id|>"
61
+ "<|start_header_id|>assistant<|end_header_id|>\n \n Sure! Here is how you might implement the function:\n \n ```python\n {{problem}}" ,
65
62
" "
66
63
),
67
64
"gemma" : (
68
- "<bos><start_of_turn>user\n "
69
- "Complete the following Python function:\n \n {{problem}}<|eot_id|>"
70
- "<start_of_turn>model\n "
71
- "```python\n {{problem}}" ,
65
+ "<bos><start_of_turn>user\n Complete the following Python function:\n \n {{problem}}<|eot_id|>"
66
+ "<start_of_turn>model\n ```python\n {{problem}}" ,
72
67
" "
73
68
)
74
69
}
88
83
model_init .print_options (args )
89
84
model , tokenizer = model_init .init (
90
85
args ,
91
- allow_auto_split = True ,
92
- progress = True ,
93
- max_output_len = 4 ,
94
- max_input_len = 2048
86
+ allow_auto_split = True ,
87
+ progress = True ,
88
+ max_output_len = 4 ,
89
+ max_input_len = 2048
95
90
)
96
91
97
- if args .cache_q4 : cache_type = ExLlamaV2Cache_Q4
98
- elif args .cache_q6 : cache_type = ExLlamaV2Cache_Q6
99
- elif args .cache_q8 : cache_type = ExLlamaV2Cache_Q8
100
- else : cache_type = ExLlamaV2Cache
92
+ if args .cache_q4 :
93
+ cache_type = ExLlamaV2Cache_Q4
94
+ elif args .cache_q6 :
95
+ cache_type = ExLlamaV2Cache_Q6
96
+ elif args .cache_q8 :
97
+ cache_type = ExLlamaV2Cache_Q8
98
+ else :
99
+ cache_type = ExLlamaV2Cache
101
100
cache = cache_type (
102
101
model ,
103
- lazy = not model .loaded ,
104
- max_seq_len = args .cache_size or model .config .max_seq_len
102
+ lazy = not model .loaded ,
103
+ max_seq_len = args .cache_size or model .config .max_seq_len
105
104
)
106
105
107
106
if not model .loaded :
108
- model .load_autosplit (cache , progress = True )
107
+ model .load_autosplit (cache , progress = True )
109
108
110
109
# Generator
111
110
112
111
generator = ExLlamaV2DynamicGenerator (
113
- model = model ,
114
- cache = cache ,
115
- tokenizer = tokenizer ,
116
- max_batch_size = 256 ,
117
- max_q_size = 4
112
+ model = model ,
113
+ cache = cache ,
114
+ tokenizer = tokenizer ,
115
+ max_batch_size = 256 ,
116
+ max_q_size = 4
118
117
)
119
118
120
119
gen_settings = ExLlamaV2Sampler .Settings (
121
- token_repetition_penalty = 1.0 ,
122
- temperature = 0.6 ,
123
- top_k = 50 ,
124
- top_p = 0.6
120
+ token_repetition_penalty = 1.0 ,
121
+ temperature = args . temperature ,
122
+ top_k = 50 ,
123
+ top_p = 0.6
125
124
)
126
125
127
- # Get problems
128
-
129
- problems = read_problems ()
130
- num_samples_per_task = args .samples_per_task
131
-
132
- # Create jobs
133
126
134
- with util .get_progress () as progress :
135
-
136
- task1 = progress .add_task ("[red]Sample" , total = len (problems ) * num_samples_per_task , name = "Creating sample jobs" )
137
- for problem_id , problem in problems .items ():
127
+ def process_batch (batch_problems , batch_size , progress , sample_task , generate_task ):
128
+ samples = []
138
129
130
+ for problem_id , problem in batch_problems .items ():
139
131
b_problem = problem ["prompt" ]
140
132
f_problem = prompt_format .replace ("{{problem}}" , b_problem )
141
133
input_ids = tokenizer .encode (f_problem , encode_special_tokens = True , add_bos = True )
142
134
143
- for s in range (num_samples_per_task ):
144
-
135
+ for s in range (batch_size ):
145
136
job = ExLlamaV2DynamicJob (
146
- input_ids = input_ids ,
147
- gen_settings = gen_settings ,
148
- max_new_tokens = args .max_tokens ,
149
- stop_conditions = [tokenizer .eos_token_id ],
150
- token_healing = True ,
151
- identifier = (problem_id , s ),
152
- min_new_tokens = 6
137
+ input_ids = input_ids ,
138
+ gen_settings = gen_settings ,
139
+ max_new_tokens = args .max_tokens ,
140
+ stop_conditions = [tokenizer .eos_token_id ],
141
+ token_healing = True ,
142
+ identifier = (problem_id , s ),
143
+ min_new_tokens = 6
153
144
)
154
-
155
145
generator .enqueue (job )
156
- progress .update (task1 , advance = 1 )
157
-
158
- # Collect samples here
159
-
160
- samples = []
161
-
162
- # Work
163
-
164
- total_jobs = generator .num_remaining_jobs ()
165
- cm = contextlib .nullcontext () if args .verbose else util .get_progress ()
166
- with cm as progress :
167
-
168
- if not args .verbose :
169
- task1 = progress .add_task ("[red]Sample" , total = total_jobs , name = "Generating samples" )
146
+ progress .update (sample_task , advance = 1 )
170
147
171
148
while generator .num_remaining_jobs ():
172
-
173
149
results = generator .iterate ()
174
150
for result in results :
175
-
176
- # End sample if generator says EOS or if there is a non-indented line at the end of the output
177
-
178
151
job = result ["job" ]
179
152
eos = False
180
153
completion = job .full_completion
186
159
eos = True
187
160
eos = eos or result ["eos" ]
188
161
189
- # Collect completed sample
190
-
191
162
if eos :
192
163
identifier = result ["identifier" ]
193
- sample = problems [identifier [0 ]]["prompt" ] + prefix + completion .strip ()
164
+ sample = batch_problems [identifier [0 ]]["prompt" ] + prefix + completion .strip ()
194
165
if not result ["eos" ]:
195
166
generator .cancel (job )
196
167
197
168
if args .verbose :
198
169
print ("----------------------------------------------------------------------" )
199
- print (f" ** Problem { identifier [0 ]} , sample { identifier [1 ] + 1 } / { num_samples_per_task } " )
170
+ print (f" ** Problem { identifier [0 ]} , sample { identifier [1 ] + 1 } / { batch_size } " )
200
171
print ("----------------------------------------------------------------------" )
201
172
print (sample )
202
173
print ()
203
174
else :
204
- progress .update (task1 , advance = 1 )
175
+ progress .update (generate_task , advance = 1 )
205
176
206
- samples .append (dict (task_id = identifier [0 ], completion = prefix + completion .strip ()))
177
+ samples .append (dict (task_id = identifier [0 ], completion = prefix + completion .strip ()))
207
178
208
- # Save output
179
+ return samples
180
+
181
+
182
+ # Main execution
183
+ problems = read_problems ()
184
+ all_samples = []
185
+ batch_size = args .batch_size
186
+ total_samples = len (problems ) * args .samples_per_task
187
+
188
+ with util .get_progress () as progress :
189
+ sample_task = progress .add_task ("[red]Sample" , total = total_samples , name = "Creating sample jobs" )
190
+ generate_task = progress .add_task ("[green]Sample" , total = total_samples , name = "Generating samples" )
191
+
192
+ for i in range (0 , len (problems ), batch_size ):
193
+ batch_problems = dict (list (problems .items ())[i :i + batch_size ])
194
+ batch_samples = process_batch (batch_problems , args .samples_per_task , progress , sample_task , generate_task )
195
+ all_samples .extend (batch_samples )
209
196
197
+ # Optional: Clear CUDA cache to free up memory
198
+ if torch .cuda .is_available ():
199
+ torch .cuda .empty_cache ()
200
+
201
+ # Save output
210
202
print (f" -- Saving: { args .output } " )
211
- write_jsonl (args .output , samples )
203
+ write_jsonl (args .output , all_samples )
212
204
213
205
# Optionally launch eval script
214
-
215
206
if args .eval :
216
- subprocess .run (["evaluate_functional_correctness" , args .output ])
217
-
207
+ subprocess .run (["evaluate_functional_correctness" , args .output ])
0 commit comments