Skip to content

Commit 6084145

Browse files
committed
add gcp push
1 parent c7f97b5 commit 6084145

File tree

4 files changed

+319
-88
lines changed

4 files changed

+319
-88
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"tqdm",
1212
"antlr4-python3-runtime==4.11",
1313
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@74c94ee",
14+
"google-cloud-storage",
1415
]
1516

1617
[project.optional-dependencies]

src/genesys/generate.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,51 @@
11
import itertools
2+
from pydantic import model_validator
23
from pydantic_config import BaseConfig, parse_argv
34
import sglang as sgl
45
from datasets import load_dataset
56
from tqdm import tqdm
67
from transformers import AutoTokenizer
7-
8-
from genesys.utils import repeat_elements, save_batch_results
8+
from genesys.utils import GcpBucket, repeat_elements, save_batch_results
99

1010
SYSTEM_PROMPT = "Solve the following math problem efficiently and clearly. Think carefully and step by step about your response and reason before providing a final response. Conclude your response with: \n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem. If the question is a multiple choice question, [answer] should be the letter indicating your correct response (e.g. \\text{A} or \\text{B})."
1111

1212

1313
class Config(BaseConfig):
1414
name_model: str = "Qwen/QwQ-32B-Preview"
15-
out_file_name: str = "out.jsonl"
1615
num_responses_per_question: int = 1
1716
num_gpus: int = 8
1817
temperature: float = 0.9
19-
batch_size: int = 10000
18+
batch_size: int = 10_000
2019
max_samples: int | None = None
20+
gcp_bucket: str | None = None # optional, if provided, will save the each file with sample_per_file to GCP
21+
sample_per_file: int = 10_000 # how much sample each file contains
22+
23+
@model_validator(mode="after")
24+
def check_batch_size(self):
25+
if self.sample_per_file < self.batch_size:
26+
raise ValueError("sample_per_file must be greater than or equal to batch_size")
27+
if self.max_samples is not None and self.max_samples < self.sample_per_file:
28+
raise ValueError("max_samples must be greater than or equal to sample_per_file")
29+
return self
2130

2231

2332
def main(config: Config):
33+
if config.gcp_bucket is not None:
34+
gcp_bucket = GcpBucket(config.gcp_bucket)
35+
2436
llm = sgl.Engine(model_path=config.name_model, tp_size=config.num_gpus)
2537
tokenizer = AutoTokenizer.from_pretrained(config.name_model)
2638

27-
math_dataset = load_dataset("Primegenesys/NuminaMath-groundtruth")["train"]
39+
math_dataset = load_dataset("PrimeIntellect/NuminaMath-groundtruth")["train"]
2840
math_dataset = math_dataset.add_column("problem_id", range(len(math_dataset)))
2941

3042
sampling_params = dict(temperature=config.temperature, max_new_tokens=8192, stop=["<|eot_id|>"])
3143

32-
open(config.out_file_name, "w").close()
33-
3444
max_samples = config.max_samples if config.max_samples is not None else len(math_dataset)
3545

46+
all_results = []
47+
file_counter = 0
48+
3649
for i in tqdm(range(0, min(max_samples, len(math_dataset)), config.batch_size), desc="Generating data"):
3750
batch = math_dataset[i : min(i + config.batch_size, len(math_dataset))]
3851
batch_ids = list(
@@ -50,7 +63,6 @@ def main(config: Config):
5063
batch_inputs = tokenizer.apply_chat_template(batch_messages, tokenize=False, add_generation_prompt=True)
5164
batch_output = llm.generate(batch_inputs, sampling_params)
5265

53-
all_results = []
5466
for j, out in enumerate(batch_output):
5567
result = dict()
5668
result["prompt"] = batch_messages[j][1]["content"]
@@ -60,7 +72,11 @@ def main(config: Config):
6072

6173
all_results.append(result)
6274

63-
save_batch_results(all_results, config.out_file_name)
75+
if len(all_results) >= config.sample_per_file:
76+
file_name = f"out_{file_counter}.jsonl"
77+
save_batch_results(all_results, file_name, gcp_bucket)
78+
all_results = []
79+
file_counter += 1
6480

6581

6682
if __name__ == "__main__":

src/genesys/utils.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,45 @@
11
import json
2+
import os
3+
from google.cloud import storage
24

35

46
def repeat_elements(lst, n):
57
return [item for item in lst for _ in range(n)]
68

79

8-
def save_batch_results(batch_results, results_file):
10+
class GcpBucket:
11+
def __init__(self, gcp_path: str):
12+
# Parse GCS path (e.g., "gs://bucket-name/folder/path")
13+
path = gcp_path.replace("gs://", "")
14+
self.bucket_name = path.split("/")[0]
15+
self.destination_folder = "/".join(path.split("/")[1:])
16+
17+
# Initialize client
18+
self.client = storage.Client()
19+
self.bucket = self.client.bucket(self.bucket_name)
20+
print(f"Initialized GCP bucket: {self.bucket_name}, folder: {self.destination_folder}")
21+
22+
def push(self, file_name: str):
23+
# Create the full destination path including folder
24+
destination_blob_name = os.path.join(self.destination_folder, os.path.basename(file_name))
25+
print(f"Uploading {file_name} to gs://{self.bucket_name}/{destination_blob_name}")
26+
27+
# Upload the file
28+
blob = self.bucket.blob(destination_blob_name)
29+
blob.upload_from_filename(file_name)
30+
31+
32+
def save_batch_results(batch_results, results_file, gcp_bucket: GcpBucket | None = None):
33+
# Save locally first
934
with open(results_file, "a") as f:
1035
for result in batch_results:
1136
json.dump(result, f)
1237
f.write("\n")
38+
39+
# Upload to GCP if bucket is configured
40+
if gcp_bucket is not None:
41+
try:
42+
gcp_bucket.push(results_file)
43+
print(f"Successfully uploaded {results_file} to GCP bucket")
44+
except Exception as e:
45+
print(f"Error uploading to GCP: {str(e)}")

0 commit comments

Comments
 (0)