Skip to content

Commit 6c628d7

Browse files
update step handling
1 parent 6446e08 commit 6c628d7

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pytorch_accelerated/trainer.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -640,14 +640,17 @@ def _create_run_config(
640640

641641
if self._train_dataloader is not None:
642642
local_batches = len(self._train_dataloader)
643+
total_batches = local_batches * self._accelerator.num_processes
643644
num_update_steps_per_epoch = math.ceil(
644-
local_batches / gradient_accumulation_steps
645+
total_batches / gradient_accumulation_steps
645646
)
647+
646648
else:
647649
num_update_steps_per_epoch = 0
648650

649651
if max_num_train_steps is None:
650-
max_num_train_steps = num_epochs * num_update_steps_per_epoch
652+
# Add 1 to ensure we don't stop early due to rounding
653+
max_num_train_steps = (num_epochs * num_update_steps_per_epoch) + 1
651654
else:
652655
num_epochs = math.ceil(max_num_train_steps / num_update_steps_per_epoch)
653656

@@ -749,11 +752,11 @@ def _run_training(self):
749752
)
750753
break
751754

752-
# if reached_max_steps:
753-
# self.print(
754-
# f"Reached max number of training steps {self.run_config.max_num_train_steps} in epoch {epoch + 1}"
755-
# )
756-
# break
755+
if reached_max_steps:
756+
self.print(
757+
f"Reached max number of training steps {self.run_config.max_num_train_steps} in epoch {epoch + 1}"
758+
)
759+
break
757760

758761
self.training_run_end()
759762
self.callback_handler.call_event(

0 commit comments

Comments
 (0)