Skip to content

Commit 367b2e8

Browse files
wingliandsesclei
andauthored
Switch to parallel FFD bin packing algorithm. (axolotl-ai-cloud#1619)
* Switch to parallel FFD bin packing algorithm. Add support for packing in a distributed context. Add packing efficiency estimate back. * revert changes to distributed code * chore: lint * fix config w new params for packing test * add sample_packing_group_size and sample_packing_bin_size to cfg schema * fix lamdbda function * fix sampler/dataloader calculations for packing --------- Co-authored-by: dsesclei <dave@sescleifer.com>
1 parent bbfed31 commit 367b2e8

File tree

8 files changed

+169
-219
lines changed

8 files changed

+169
-219
lines changed

docs/config.qmd

+5
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ eval_sample_packing:
186186
# The trainer will provide recommended values for these values.
187187
sample_packing_eff_est:
188188
total_num_tokens:
189+
# Increasing the following values helps with packing, but usually only slightly (<%1.)
190+
# The number of samples packed at a time.
191+
sample_packing_group_size: 100000
192+
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
193+
sample_packing_bin_size: 200
189194

190195
# Passed through to transformers when loading the model when launched without accelerate
191196
# Use `sequential` when training w/ model parallelism to limit memory

src/axolotl/core/trainer_builder.py

+40-30
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,22 @@ class AxolotlTrainingArguments(TrainingArguments):
125125
default=1.0,
126126
metadata={"help": "Sample packing efficiency for calculating batch length."},
127127
)
128+
sample_packing_bin_size: int = field(
129+
default=200,
130+
metadata={
131+
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
132+
},
133+
)
134+
sample_packing_group_size: int = field(
135+
default=100000,
136+
metadata={
137+
"help": "The number of samples to group together for packing. Increase for better packing."
138+
},
139+
)
128140
max_seq_length: int = field(
129141
default=2048,
130142
metadata={"help": "The maximum sequence length the model can handle"},
131143
)
132-
sample_packing_seq_len_multiplier: int = field(
133-
default=1,
134-
metadata={"help": "the multiplier for the max len for packed sequences"},
135-
)
136144
relora_steps: Optional[int] = field(
137145
default=None,
138146
metadata={"help": "how often to reset for ReLoRA"},
@@ -346,11 +354,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
346354
)
347355
return MultipackBatchSampler(
348356
RandomSampler(self.train_dataset),
349-
batch_size=batch_size,
350-
drop_last=True,
351-
batch_max_len=batch_max_len,
352357
lengths=get_dataset_lengths(self.train_dataset),
353-
packing_efficiency_estimate=self.args.sample_packing_efficiency,
358+
batch_max_len=batch_max_len,
359+
batch_size=batch_size,
360+
group_size=self.args.sample_packing_group_size,
361+
bin_size=self.args.sample_packing_bin_size,
354362
)
355363
if self.args.curriculum_sampling:
356364
return SequentialSampler(self.train_dataset)
@@ -370,11 +378,11 @@ def _get_eval_sampler(
370378
)
371379
return MultipackBatchSampler(
372380
SequentialSampler(eval_dataset),
373-
batch_size=batch_size,
374-
drop_last=True,
381+
lengths=get_dataset_lengths(self.eval_dataset),
375382
batch_max_len=batch_max_len,
376-
lengths=get_dataset_lengths(eval_dataset),
377-
packing_efficiency_estimate=self.args.sample_packing_efficiency,
383+
batch_size=batch_size,
384+
group_size=self.args.sample_packing_group_size,
385+
bin_size=self.args.sample_packing_bin_size,
378386
)
379387
return super()._get_eval_sampler(eval_dataset)
380388

@@ -1113,11 +1121,6 @@ def build(self, total_num_steps):
11131121
if self.cfg.save_safetensors is not None:
11141122
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
11151123

1116-
if self.cfg.sample_packing_eff_est:
1117-
training_arguments_kwargs[
1118-
"sample_packing_efficiency"
1119-
] = self.cfg.sample_packing_eff_est
1120-
11211124
if self.cfg.dataloader_pin_memory is not None:
11221125
training_arguments_kwargs[
11231126
"dataloader_pin_memory"
@@ -1293,20 +1296,27 @@ def build(self, total_num_steps):
12931296
training_arguments_kwargs["weight_decay"] = (
12941297
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
12951298
)
1296-
training_arguments_kwargs["sample_packing"] = (
1297-
self.cfg.sample_packing if self.cfg.sample_packing else False
1298-
)
1299-
training_arguments_kwargs["multipack_real_batches"] = (
1300-
self.cfg.flash_attention is not True
1301-
)
1302-
training_arguments_kwargs["eval_sample_packing"] = (
1303-
self.cfg.sample_packing
1304-
if self.cfg.eval_sample_packing is not False
1305-
else False
1306-
)
1299+
1300+
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
13071301
training_arguments_kwargs[
1308-
"sample_packing_seq_len_multiplier"
1309-
] = self.cfg.micro_batch_size
1302+
"multipack_real_batches"
1303+
] = not self.cfg.flash_attention
1304+
training_arguments_kwargs["eval_sample_packing"] = bool(
1305+
self.cfg.eval_sample_packing
1306+
)
1307+
if self.cfg.sample_packing_bin_size is not None:
1308+
training_arguments_kwargs[
1309+
"sample_packing_bin_size"
1310+
] = self.cfg.sample_packing_bin_size
1311+
if self.cfg.sample_packing_group_size is not None:
1312+
training_arguments_kwargs[
1313+
"sample_packing_group_size"
1314+
] = self.cfg.sample_packing_group_size
1315+
if self.cfg.sample_packing_eff_est:
1316+
training_arguments_kwargs[
1317+
"sample_packing_efficiency"
1318+
] = self.cfg.sample_packing_eff_est
1319+
13101320
if self.cfg.relora_steps:
13111321
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
13121322
training_arguments_kwargs[

src/axolotl/utils/config/models/input/v0_4_1/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,8 @@ class Config:
551551
default=512, metadata={"help": "maximum prompt length for RL training"}
552552
)
553553
sample_packing: Optional[bool] = None
554+
sample_packing_group_size: Optional[int] = 100_000
555+
sample_packing_bin_size: Optional[int] = 200
554556
eval_sample_packing: Optional[bool] = None
555557
pad_to_sequence_len: Optional[bool] = None
556558
curriculum_sampling: Optional[bool] = None

src/axolotl/utils/data/pretraining.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def wrap_pretraining_dataset(
150150
max_seq_length=max_tokens,
151151
batch_size=batch_size,
152152
multipack_attn=cfg.pretrain_multipack_attn,
153+
group_size=cfg.sample_packing_group_size,
154+
bin_size=cfg.sample_packing_bin_size,
153155
)
154156
# set this to 1 so downstream data_loader doesn't try to increase the batch again
155157
cfg.micro_batch_size = 1
@@ -189,6 +191,8 @@ def encode_packed_pretraining(
189191
max_seq_length: int = 2048,
190192
batch_size: int = 4,
191193
multipack_attn: Optional[bool] = False,
194+
group_size: int = 100000,
195+
bin_size: int = 200,
192196
) -> Dict[str, List]:
193197
# pylint: disable=duplicate-code
194198
# tokenize all the examples
@@ -202,11 +206,13 @@ def encode_packed_pretraining(
202206
)
203207

204208
sampler = MultipackBatchSampler(
205-
RandomSampler(train_dataset),
209+
sampler=RandomSampler(train_dataset),
210+
lengths=get_dataset_lengths(train_dataset),
206211
batch_size=1,
207-
drop_last=True,
208212
batch_max_len=batch_size * max_seq_length,
209-
lengths=get_dataset_lengths(train_dataset),
213+
group_size=group_size,
214+
bin_size=bin_size,
215+
drop_last=True,
210216
)
211217

212218
chunked_data = defaultdict(list)

0 commit comments

Comments
 (0)