@@ -125,14 +125,22 @@ class AxolotlTrainingArguments(TrainingArguments):
125
125
default = 1.0 ,
126
126
metadata = {"help" : "Sample packing efficiency for calculating batch length." },
127
127
)
128
+ sample_packing_bin_size : int = field (
129
+ default = 200 ,
130
+ metadata = {
131
+ "help" : "The max number of samples that packed sample can contain after packing. Increase for better packing."
132
+ },
133
+ )
134
+ sample_packing_group_size : int = field (
135
+ default = 100000 ,
136
+ metadata = {
137
+ "help" : "The number of samples to group together for packing. Increase for better packing."
138
+ },
139
+ )
128
140
max_seq_length : int = field (
129
141
default = 2048 ,
130
142
metadata = {"help" : "The maximum sequence length the model can handle" },
131
143
)
132
- sample_packing_seq_len_multiplier : int = field (
133
- default = 1 ,
134
- metadata = {"help" : "the multiplier for the max len for packed sequences" },
135
- )
136
144
relora_steps : Optional [int ] = field (
137
145
default = None ,
138
146
metadata = {"help" : "how often to reset for ReLoRA" },
@@ -346,11 +354,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
346
354
)
347
355
return MultipackBatchSampler (
348
356
RandomSampler (self .train_dataset ),
349
- batch_size = batch_size ,
350
- drop_last = True ,
351
- batch_max_len = batch_max_len ,
352
357
lengths = get_dataset_lengths (self .train_dataset ),
353
- packing_efficiency_estimate = self .args .sample_packing_efficiency ,
358
+ batch_max_len = batch_max_len ,
359
+ batch_size = batch_size ,
360
+ group_size = self .args .sample_packing_group_size ,
361
+ bin_size = self .args .sample_packing_bin_size ,
354
362
)
355
363
if self .args .curriculum_sampling :
356
364
return SequentialSampler (self .train_dataset )
@@ -370,11 +378,11 @@ def _get_eval_sampler(
370
378
)
371
379
return MultipackBatchSampler (
372
380
SequentialSampler (eval_dataset ),
373
- batch_size = batch_size ,
374
- drop_last = True ,
381
+ lengths = get_dataset_lengths (self .eval_dataset ),
375
382
batch_max_len = batch_max_len ,
376
- lengths = get_dataset_lengths (eval_dataset ),
377
- packing_efficiency_estimate = self .args .sample_packing_efficiency ,
383
+ batch_size = batch_size ,
384
+ group_size = self .args .sample_packing_group_size ,
385
+ bin_size = self .args .sample_packing_bin_size ,
378
386
)
379
387
return super ()._get_eval_sampler (eval_dataset )
380
388
@@ -1113,11 +1121,6 @@ def build(self, total_num_steps):
1113
1121
if self .cfg .save_safetensors is not None :
1114
1122
training_arguments_kwargs ["save_safetensors" ] = self .cfg .save_safetensors
1115
1123
1116
- if self .cfg .sample_packing_eff_est :
1117
- training_arguments_kwargs [
1118
- "sample_packing_efficiency"
1119
- ] = self .cfg .sample_packing_eff_est
1120
-
1121
1124
if self .cfg .dataloader_pin_memory is not None :
1122
1125
training_arguments_kwargs [
1123
1126
"dataloader_pin_memory"
@@ -1293,20 +1296,27 @@ def build(self, total_num_steps):
1293
1296
training_arguments_kwargs ["weight_decay" ] = (
1294
1297
self .cfg .weight_decay if self .cfg .weight_decay is not None else 0.0
1295
1298
)
1296
- training_arguments_kwargs ["sample_packing" ] = (
1297
- self .cfg .sample_packing if self .cfg .sample_packing else False
1298
- )
1299
- training_arguments_kwargs ["multipack_real_batches" ] = (
1300
- self .cfg .flash_attention is not True
1301
- )
1302
- training_arguments_kwargs ["eval_sample_packing" ] = (
1303
- self .cfg .sample_packing
1304
- if self .cfg .eval_sample_packing is not False
1305
- else False
1306
- )
1299
+
1300
+ training_arguments_kwargs ["sample_packing" ] = bool (self .cfg .sample_packing )
1307
1301
training_arguments_kwargs [
1308
- "sample_packing_seq_len_multiplier"
1309
- ] = self .cfg .micro_batch_size
1302
+ "multipack_real_batches"
1303
+ ] = not self .cfg .flash_attention
1304
+ training_arguments_kwargs ["eval_sample_packing" ] = bool (
1305
+ self .cfg .eval_sample_packing
1306
+ )
1307
+ if self .cfg .sample_packing_bin_size is not None :
1308
+ training_arguments_kwargs [
1309
+ "sample_packing_bin_size"
1310
+ ] = self .cfg .sample_packing_bin_size
1311
+ if self .cfg .sample_packing_group_size is not None :
1312
+ training_arguments_kwargs [
1313
+ "sample_packing_group_size"
1314
+ ] = self .cfg .sample_packing_group_size
1315
+ if self .cfg .sample_packing_eff_est :
1316
+ training_arguments_kwargs [
1317
+ "sample_packing_efficiency"
1318
+ ] = self .cfg .sample_packing_eff_est
1319
+
1310
1320
if self .cfg .relora_steps :
1311
1321
training_arguments_kwargs ["relora_steps" ] = self .cfg .relora_steps
1312
1322
training_arguments_kwargs [
0 commit comments