Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support model cards #176

Merged
merged 13 commits into from
Jan 5, 2025
16 changes: 15 additions & 1 deletion finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
prepare_target,
)
from .utils.file_utils import string_to_filename
from .utils.hub_utils import save_model_card
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
from .utils.optimizer_utils import get_optimizer
Expand Down Expand Up @@ -923,6 +924,7 @@ def validate(self, step: int, final_validation: bool = False) -> None:
pipeline.load_lora_weights(self.args.output_dir)

all_processes_artifacts = []
prompts_to_filenames = {}
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
Expand Down Expand Up @@ -976,7 +978,10 @@ def validate(self, step: int, final_validation: bool = False) -> None:
continue

extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
filename = "validation-" if not final_validation else "final-"
filename += f"{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
if accelerator.is_main_process and extension == "mp4":
prompts_to_filenames[prompt] = filename
filename = os.path.join(self.args.output_dir, filename)

if artifact_type == "image":
Expand Down Expand Up @@ -1005,6 +1010,15 @@ def validate(self, step: int, final_validation: bool = False) -> None:
},
step=step,
)
if self.args.push_to_hub and final_validation:
video_filenames = list(prompts_to_filenames.values())
prompts = list(prompts_to_filenames.keys())
save_model_card(
args=self.args,
repo_id=self.state.repo_id,
videos=video_filenames,
validation_prompts=prompts,
)

# Remove all hooks that might have been added during pipeline initialization to the models
pipeline.remove_all_hooks()
Expand Down
77 changes: 77 additions & 0 deletions finetrainers/utils/hub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from typing import List, Union

import numpy as np
import wandb
from diffusers.utils import export_to_video
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from PIL import Image


def save_model_card(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to keep it minimal yet sufficient for now.

args,
repo_id: str,
videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
validation_prompts: List[str],
fps: int = 30,
) -> None:
widget_dict = []
output_dir = str(args.output_dir)
if videos is not None and len(videos) > 0:
for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)):
if not isinstance(video, str):
export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps)
widget_dict.append(
{
"text": validation_prompt if validation_prompt else " ",
"output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"},
}
)

model_description = f"""
# LoRA Finetune

<Gallery />

## Model description

This is a lora finetune of model: `{args.pretrained_model_name_or_path}`.

The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers).

## Download model

[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.

## Usage

Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.

```py
TODO
```

For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
"""
if wandb.run.url:
model_description += f"""
Find out the wandb run URL and training configurations [here]({wandb.run.url}).
"""

model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
base_model=args.pretrained_model_name_or_path,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"diffusers-training",
"diffusers",
"lora",
"template:sd-lora",
]

model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(args.output_dir, "README.md"))