1
1
import itertools
2
+ from pydantic import model_validator
2
3
from pydantic_config import BaseConfig , parse_argv
3
4
import sglang as sgl
4
5
from datasets import load_dataset
12
13
13
14
class Config (BaseConfig ):
14
15
name_model : str = "Qwen/QwQ-32B-Preview"
15
- out_file_name : str = "out.jsonl"
16
16
num_responses_per_question : int = 1
17
17
num_gpus : int = 8
18
18
temperature : float = 0.9
19
- batch_size : int = 10000
19
+ batch_size : int = 10_000
20
20
max_samples : int | None = None
21
- gcp_bucket : str | None = None
21
+ gcp_bucket : str | None = None # optional, if provided, will save the each file with sample_per_file to GCP
22
+ sample_per_file : int = 10_000 # how much sample each file contains
23
+
24
+ @model_validator (mode = "after" )
25
+ def check_batch_size (self ):
26
+ if self .sample_per_file < self .batch_size :
27
+ raise ValueError ("sample_per_file must be greater than or equal to batch_size" )
28
+ if self .max_samples is not None and self .max_samples < self .sample_per_file :
29
+ raise ValueError ("max_samples must be greater than or equal to sample_per_file" )
30
+ return self
22
31
23
32
24
33
def main (config : Config ):
@@ -33,10 +42,11 @@ def main(config: Config):
33
42
34
43
sampling_params = dict (temperature = config .temperature , max_new_tokens = 8192 , stop = ["<|eot_id|>" ])
35
44
36
- open (config .out_file_name , "w" ).close ()
37
-
38
45
max_samples = config .max_samples if config .max_samples is not None else len (math_dataset )
39
46
47
+ all_results = []
48
+ file_counter = 0
49
+
40
50
for i in tqdm (range (0 , min (max_samples , len (math_dataset )), config .batch_size ), desc = "Generating data" ):
41
51
batch = math_dataset [i : min (i + config .batch_size , len (math_dataset ))]
42
52
batch_ids = list (
@@ -54,7 +64,6 @@ def main(config: Config):
54
64
batch_inputs = tokenizer .apply_chat_template (batch_messages , tokenize = False , add_generation_prompt = True )
55
65
batch_output = llm .generate (batch_inputs , sampling_params )
56
66
57
- all_results = []
58
67
for j , out in enumerate (batch_output ):
59
68
result = dict ()
60
69
result ["prompt" ] = batch_messages [j ][1 ]["content" ]
@@ -64,7 +73,11 @@ def main(config: Config):
64
73
65
74
all_results .append (result )
66
75
67
- save_batch_results (all_results , config .out_file_name , gcp_bucket )
76
+ if len (all_results ) >= config .sample_per_file :
77
+ file_name = f"out_{ file_counter } .jsonl"
78
+ save_batch_results (all_results , file_name , gcp_bucket )
79
+ all_results = []
80
+ file_counter += 1
68
81
69
82
70
83
if __name__ == "__main__" :
0 commit comments