Skip to content

Commit c0fe1c0

Browse files
committed
remove do_tokenization
1 parent b5b63c3 commit c0fe1c0

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

src/genesys/data.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ class DataLoaderGenesys:
4343
Each dataset that is pass must have a "train" split and the content must be a list of dict with at least a "problem" and a "ground_truth" key.
4444
"""
4545

46-
def __init__(
47-
self, config: DataConfig, tokenizer: AutoTokenizer, prime_metric: PrimeMetric, do_tokenization: bool = False
48-
):
46+
def __init__(self, config: DataConfig, tokenizer: AutoTokenizer, prime_metric: PrimeMetric):
4947
self.config = config
5048

5149
self.paths = list(config.path.split(","))
@@ -74,7 +72,6 @@ def _add_column(dataset, path):
7472

7573
self.total_samples = min(max_samples, total_samples)
7674

77-
self.do_tokenization = do_tokenization
7875
self.tokenizer = tokenizer
7976

8077
self.dataset_lengths = [len(dataset) for dataset in self.datasets]
@@ -114,9 +111,7 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple:
114111
[{"role": "user", "content": b["prompt"]}, {"role": "assistant", "content": "<think>/n"}] for b in batch
115112
]
116113

117-
batch_inputs = self.tokenizer.apply_chat_template(
118-
batch_messages, tokenize=self.do_tokenization, continue_final_message=True
119-
)
114+
batch_inputs = self.tokenizer.apply_chat_template(batch_messages, tokenize=True, continue_final_message=True)
120115

121116
return batch_inputs, batch
122117

src/genesys/generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def main(config: GenerateConfig):
7676
tokenizer = AutoTokenizer.from_pretrained(config.name_model)
7777

7878
log("[cyan] Loading dataloader...[/]")
79-
dataloader = DataLoaderGenesys(config.data, tokenizer=tokenizer, prime_metric=prime_metric, do_tokenization=True)
79+
dataloader = DataLoaderGenesys(config.data, tokenizer=tokenizer, prime_metric=prime_metric)
8080
machine_info = get_machine_info()
8181

8282
log("[bold green]✨ Setup complete! Starting generation...[/]")

0 commit comments

Comments
 (0)