Skip to content

Commit ba45531

Browse files
authored
fixes to save on fractional save_steps (axolotl-ai-cloud#1643)
1 parent 8a1572a commit ba45531

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/axolotl/core/trainer_builder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
LossWatchDogCallback,
4444
SaveAxolotlConfigtoWandBCallback,
4545
SaveBetterTransformerModelCallback,
46-
SaveModelOnTrainEndCallback,
46+
SaveModelCallback,
4747
bench_eval_callback_factory,
4848
causal_lm_bench_eval_callback_factory,
4949
log_prediction_callback_factory,
@@ -945,7 +945,7 @@ def get_callbacks(self):
945945
if self.cfg.loss_watchdog_threshold is not None:
946946
callbacks.append(LossWatchDogCallback(self.cfg))
947947

948-
callbacks.append(SaveModelOnTrainEndCallback())
948+
callbacks.append(SaveModelCallback())
949949

950950
return callbacks
951951

@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
14311431

14321432
def get_callbacks(self):
14331433
callbacks = super().get_callbacks()
1434-
callbacks.append(SaveModelOnTrainEndCallback())
1434+
callbacks.append(SaveModelCallback())
14351435

14361436
return callbacks
14371437

src/axolotl/utils/callbacks/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
import math
67
import os
78
from shutil import copyfile
89
from tempfile import NamedTemporaryFile
@@ -775,7 +776,7 @@ def on_train_begin(
775776
return control
776777

777778

778-
class SaveModelOnTrainEndCallback(TrainerCallback):
779+
class SaveModelCallback(TrainerCallback):
779780
"""Callback to save model on train end"""
780781

781782
def on_step_end( # pylint: disable=unused-argument
@@ -788,6 +789,13 @@ def on_step_end( # pylint: disable=unused-argument
788789
# Save
789790
if state.global_step >= state.max_steps:
790791
control.should_save = True
792+
elif (
793+
args.save_strategy == IntervalStrategy.STEPS
794+
and state.save_steps < 1.0
795+
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
796+
):
797+
# workaround to save model on fractional save_steps
798+
control.should_save = True
791799

792800
def on_train_end( # pylint: disable=unused-argument
793801
self, args, state, control, **kwargs

0 commit comments

Comments
 (0)