@@ -43,9 +43,7 @@ class DataLoaderGenesys:
43
43
Each dataset that is pass must have a "train" split and the content must be a list of dict with at least a "problem" and a "ground_truth" key.
44
44
"""
45
45
46
- def __init__ (
47
- self , config : DataConfig , tokenizer : AutoTokenizer , prime_metric : PrimeMetric , do_tokenization : bool = False
48
- ):
46
+ def __init__ (self , config : DataConfig , tokenizer : AutoTokenizer , prime_metric : PrimeMetric ):
49
47
self .config = config
50
48
51
49
self .paths = list (config .path .split ("," ))
@@ -74,7 +72,6 @@ def _add_column(dataset, path):
74
72
75
73
self .total_samples = min (max_samples , total_samples )
76
74
77
- self .do_tokenization = do_tokenization
78
75
self .tokenizer = tokenizer
79
76
80
77
self .dataset_lengths = [len (dataset ) for dataset in self .datasets ]
@@ -114,9 +111,7 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple:
114
111
[{"role" : "user" , "content" : b ["prompt" ]}, {"role" : "assistant" , "content" : "<think>/n" }] for b in batch
115
112
]
116
113
117
- batch_inputs = self .tokenizer .apply_chat_template (
118
- batch_messages , tokenize = self .do_tokenization , continue_final_message = True
119
- )
114
+ batch_inputs = self .tokenizer .apply_chat_template (batch_messages , tokenize = True , continue_final_message = True )
120
115
121
116
return batch_inputs , batch
122
117
0 commit comments