From 299e984721236de37c9cd1874f405600e0ec2979 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Sun, 2 Mar 2025 04:45:04 -0800 Subject: [PATCH 1/3] save_model_card function is only called if self.args.push_to_hub is True --- finetrainers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 89324044..e310a097 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -875,7 +875,7 @@ def validate(self, step: int, final_validation: bool = False) -> None: if num_validation_samples == 0: logger.warning("No validation samples found. Skipping validation.") - if accelerator.is_main_process: + if accelerator.is_main_process and self.args.push_to_hub: save_model_card( args=self.args, repo_id=self.state.repo_id, From b1398d18de76b32d48e07bf6c8564e1b722112a4 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Sun, 2 Mar 2025 04:55:07 -0800 Subject: [PATCH 2/3] formatting done --- finetrainers/trainer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index e310a097..c834a0e4 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -875,13 +875,14 @@ def validate(self, step: int, final_validation: bool = False) -> None: if num_validation_samples == 0: logger.warning("No validation samples found. Skipping validation.") - if accelerator.is_main_process and self.args.push_to_hub: - save_model_card( - args=self.args, - repo_id=self.state.repo_id, - videos=None, - validation_prompts=None, - ) + if accelerator.is_main_process: + if self.args.push_to_hub: + save_model_card( + args=self.args, + repo_id=self.state.repo_id, + videos=None, + validation_prompts=None, + ) return self.transformer.eval() From b2a8730b106579bf64cc0528623444f4d2578aa4 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Sun, 2 Mar 2025 05:05:31 -0800 Subject: [PATCH 3/3] saving the model card only when final_validation is True --- finetrainers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index c834a0e4..a3557ca9 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -876,7 +876,7 @@ def validate(self, step: int, final_validation: bool = False) -> None: if num_validation_samples == 0: logger.warning("No validation samples found. Skipping validation.") if accelerator.is_main_process: - if self.args.push_to_hub: + if self.args.push_to_hub and final_validation: save_model_card( args=self.args, repo_id=self.state.repo_id,