Skip to content

Commit

Permalink
support for auto_find_batch_size when packing (#1885)
Browse files Browse the repository at this point in the history
* support for auto_find_batch_size when packing

* make sure to return data from validation

* make sure to return data from validation

* actually expose multipack_real_batches in the config

* calculate gathered efficiency in sampler

* tweak to fix auto find and use actual sampler len for multipack

* uncomment

* use args for bsz when not available from auto find
  • Loading branch information
winglian authored Sep 4, 2024
1 parent 0aeb277 commit 4e5400c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,10 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
Expand Down Expand Up @@ -1379,6 +1380,10 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
Expand Down Expand Up @@ -1461,9 +1466,9 @@ def build(self, total_num_steps):
)

training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs[
"multipack_real_batches"
] = not self.cfg.flash_attention
training_arguments_kwargs["multipack_real_batches"] = (
not self.cfg.flash_attention or self.cfg.multipack_real_batches
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel):
},
)

auto_find_batch_size: Optional[bool] = None

train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None

Expand Down Expand Up @@ -592,6 +594,7 @@ class Config:
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None

# for PoSE context length extension
use_pose: Optional[bool] = None
Expand Down
40 changes: 36 additions & 4 deletions src/axolotl/utils/samplers/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
from torch.utils.data import BatchSampler, Sampler

from axolotl.utils.distributed import reduce_and_broadcast

LOG = logging.getLogger("axolotl.utils.samplers.multipack")


Expand Down Expand Up @@ -174,16 +176,46 @@ def num_batches(self):
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

def gather_efficiency(self):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))

sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est

def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))

min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches

def __len__(self):
self.num_batches()
return self._len_est()
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)

def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)

Expand All @@ -195,7 +227,7 @@ def _len_est(self):
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
/ efficiency
// (self.batch_max_len * self.batch_size)
)
- 1
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
if cfg.flash_attention:
if cfg.flash_attention and not cfg.multipack_real_batches:
sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:
Expand Down

0 comments on commit 4e5400c

Please sign in to comment.