Skip to content

Commit de29076

Browse files
committed
add sample_per_file
1 parent 780551a commit de29076

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/genesys/generate.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
from pydantic import model_validator
23
from pydantic_config import BaseConfig, parse_argv
34
import sglang as sgl
45
from datasets import load_dataset
@@ -12,13 +13,21 @@
1213

1314
class Config(BaseConfig):
1415
name_model: str = "Qwen/QwQ-32B-Preview"
15-
out_file_name: str = "out.jsonl"
1616
num_responses_per_question: int = 1
1717
num_gpus: int = 8
1818
temperature: float = 0.9
19-
batch_size: int = 10000
19+
batch_size: int = 10_000
2020
max_samples: int | None = None
21-
gcp_bucket: str | None = None
21+
gcp_bucket: str | None = None # optional, if provided, will save the each file with sample_per_file to GCP
22+
sample_per_file: int = 10_000 # how much sample each file contains
23+
24+
@model_validator(mode="after")
25+
def check_batch_size(self):
26+
if self.sample_per_file < self.batch_size:
27+
raise ValueError("sample_per_file must be greater than or equal to batch_size")
28+
if self.max_samples is not None and self.max_samples < self.sample_per_file:
29+
raise ValueError("max_samples must be greater than or equal to sample_per_file")
30+
return self
2231

2332

2433
def main(config: Config):
@@ -33,10 +42,11 @@ def main(config: Config):
3342

3443
sampling_params = dict(temperature=config.temperature, max_new_tokens=8192, stop=["<|eot_id|>"])
3544

36-
open(config.out_file_name, "w").close()
37-
3845
max_samples = config.max_samples if config.max_samples is not None else len(math_dataset)
3946

47+
all_results = []
48+
file_counter = 0
49+
4050
for i in tqdm(range(0, min(max_samples, len(math_dataset)), config.batch_size), desc="Generating data"):
4151
batch = math_dataset[i : min(i + config.batch_size, len(math_dataset))]
4252
batch_ids = list(
@@ -54,7 +64,6 @@ def main(config: Config):
5464
batch_inputs = tokenizer.apply_chat_template(batch_messages, tokenize=False, add_generation_prompt=True)
5565
batch_output = llm.generate(batch_inputs, sampling_params)
5666

57-
all_results = []
5867
for j, out in enumerate(batch_output):
5968
result = dict()
6069
result["prompt"] = batch_messages[j][1]["content"]
@@ -64,7 +73,11 @@ def main(config: Config):
6473

6574
all_results.append(result)
6675

67-
save_batch_results(all_results, config.out_file_name, gcp_bucket)
76+
if len(all_results) >= config.sample_per_file:
77+
file_name = f"out_{file_counter}.jsonl"
78+
save_batch_results(all_results, file_name, gcp_bucket)
79+
all_results = []
80+
file_counter += 1
6881

6982

7083
if __name__ == "__main__":

0 commit comments

Comments
 (0)