File tree 1 file changed +10
-7
lines changed
1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -640,14 +640,17 @@ def _create_run_config(
640
640
641
641
if self ._train_dataloader is not None :
642
642
local_batches = len (self ._train_dataloader )
643
+ total_batches = local_batches * self ._accelerator .num_processes
643
644
num_update_steps_per_epoch = math .ceil (
644
- local_batches / gradient_accumulation_steps
645
+ total_batches / gradient_accumulation_steps
645
646
)
647
+
646
648
else :
647
649
num_update_steps_per_epoch = 0
648
650
649
651
if max_num_train_steps is None :
650
- max_num_train_steps = num_epochs * num_update_steps_per_epoch
652
+ # Add 1 to ensure we don't stop early due to rounding
653
+ max_num_train_steps = (num_epochs * num_update_steps_per_epoch ) + 1
651
654
else :
652
655
num_epochs = math .ceil (max_num_train_steps / num_update_steps_per_epoch )
653
656
@@ -749,11 +752,11 @@ def _run_training(self):
749
752
)
750
753
break
751
754
752
- # if reached_max_steps:
753
- # self.print(
754
- # f"Reached max number of training steps {self.run_config.max_num_train_steps} in epoch {epoch + 1}"
755
- # )
756
- # break
755
+ if reached_max_steps :
756
+ self .print (
757
+ f"Reached max number of training steps { self .run_config .max_num_train_steps } in epoch { epoch + 1 } "
758
+ )
759
+ break
757
760
758
761
self .training_run_end ()
759
762
self .callback_handler .call_event (
You can’t perform that action at this time.
0 commit comments