diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 656ded2559..f4cd257838 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -506,9 +506,10 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: batch_max_len = self.args.max_seq_length else: batch_size = 1 - batch_max_len = ( - self.args.per_device_train_batch_size * self.args.max_seq_length + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size ) + batch_max_len = train_batch_size * self.args.max_seq_length return MultipackBatchSampler( RandomSampler(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset), @@ -1379,6 +1380,10 @@ def build(self, total_num_steps): training_arguments_kwargs[ "per_device_eval_batch_size" ] = self.cfg.eval_batch_size + if self.cfg.auto_find_batch_size is not None: + training_arguments_kwargs[ + "auto_find_batch_size" + ] = self.cfg.auto_find_batch_size training_arguments_kwargs[ "gradient_accumulation_steps" ] = self.cfg.gradient_accumulation_steps @@ -1461,9 +1466,9 @@ def build(self, total_num_steps): ) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) - training_arguments_kwargs[ - "multipack_real_batches" - ] = not self.cfg.flash_attention + training_arguments_kwargs["multipack_real_batches"] = ( + not self.cfg.flash_attention or self.cfg.multipack_real_batches + ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 65a2c5409a..9044047cce 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel): }, ) + auto_find_batch_size: Optional[bool] = None + train_on_inputs: Optional[bool] = False group_by_length: Optional[bool] = None @@ -592,6 +594,7 @@ class Config: eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None + multipack_real_batches: Optional[bool] = None # for PoSE context length extension use_pose: Optional[bool] = None diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 957ca57464..205c2894d1 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -11,6 +11,8 @@ import numpy as np from torch.utils.data import BatchSampler, Sampler +from axolotl.utils.distributed import reduce_and_broadcast + LOG = logging.getLogger("axolotl.utils.samplers.multipack") @@ -174,16 +176,46 @@ def num_batches(self): def efficiency(self): return self.eff_total_used / self.eff_total_slots + def gather_efficiency(self): + def calc_sample_packing_eff_est(estimates: List[float]): + LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") + return math.floor(0.997 * max(estimates)) + + sample_packing_actual_eff_all = reduce_and_broadcast( + lambda: self.efficiency(), # pylint: disable=unnecessary-lambda + calc_sample_packing_eff_est, + ) + sample_packing_eff_est = ( + math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0 + ) + return sample_packing_eff_est + + def gather_len_batches(self, num): + def calc_min_len(estimates: list[(int, float)]): + LOG.info(f"gather_len_batches: {repr(estimates)}") + return math.floor(0.998 * min(estimates)) + + min_len_batches = reduce_and_broadcast( + lambda: num, + calc_min_len, + ) + return min_len_batches + def __len__(self): - self.num_batches() - return self._len_est() + len_batches = self.num_batches() + return self.gather_len_batches(len_batches) def _len_est(self): + efficiency = ( + self.packing_efficiency_estimate + if self.packing_efficiency_estimate + else self.gather_efficiency() + ) world_size = int(os.getenv("WORLD_SIZE", "1")) lengths_sum = np.sum(self.lengths) lengths_sum_per_device = lengths_sum // world_size LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"packing_efficiency_estimate: {efficiency} " f"total_num_tokens per device: {lengths_sum_per_device}" ) @@ -195,7 +227,7 @@ def _len_est(self): * math.floor( 0.99 * lengths_sum_per_device - / self.packing_efficiency_estimate + / efficiency // (self.batch_max_len * self.batch_size) ) - 1 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f4e1fc6cb8..1029fff13d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): main_process_only=True, ) else: - if cfg.flash_attention: + if cfg.flash_attention and not cfg.multipack_real_batches: sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: