Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support model cards #176

Merged
merged 13 commits into from
Jan 5, 2025
8 changes: 8 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Args:
validation_every_n_epochs: Optional[int] = None
validation_every_n_steps: Optional[int] = None
enable_model_cpu_offload: bool = False
fps: int = None

# Miscellaneous arguments
tracker_name: str = "finetrainers"
Expand Down Expand Up @@ -190,6 +191,7 @@ def to_dict(self) -> Dict[str, Any]:
"validation_every_n_epochs": self.validation_every_n_epochs,
"validation_every_n_steps": self.validation_every_n_steps,
"enable_model_cpu_offload": self.enable_model_cpu_offload,
"fps": self.fps,
},
"miscellaneous_arguments": {
"tracker_name": self.tracker_name,
Expand Down Expand Up @@ -626,6 +628,12 @@ def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
default=False,
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
)
parser.add_argument(
"--fps",
type=int,
default=8,
help="FPS to use to run inference and serialize the videos with.",
)


def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down
1 change: 1 addition & 0 deletions finetrainers/cogvideox/cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def validation(
"generator": generator,
"return_dict": True,
"output_type": "pil",
"fps": kwargs.get("fps", None),
}
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
output = pipeline(**generation_kwargs).frames[0]
Expand Down
1 change: 1 addition & 0 deletions finetrainers/hunyuan_video/hunyuan_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def validation(
"generator": generator,
"return_dict": True,
"output_type": "pil",
"fps": kwargs.get("fps", None),
}
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
output = pipeline(**generation_kwargs).frames[0]
Expand Down
1 change: 1 addition & 0 deletions finetrainers/ltx_video/ltx_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def validation(
"frame_rate": frame_rate,
"num_videos_per_prompt": num_videos_per_prompt,
"generator": generator,
"fps": kwargs.get("fps", None),
"return_dict": True,
"output_type": "pil",
}
Expand Down
108 changes: 24 additions & 84 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import math
import os
import random
from datetime import datetime, timedelta
from pathlib import Path
Expand All @@ -24,13 +23,10 @@
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import export_to_video, load_image, load_video
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from tqdm import tqdm

import wandb

from .args import _INVERSE_DTYPE_MAP, Args, validate_args
from .constants import (
FINETRAINERS_LOG_LEVEL,
Expand All @@ -41,6 +37,7 @@
from .dataset import BucketSampler, PrecomputedDataset, VideoDatasetWithResizing
from .models import get_config_from_model_name
from .state import State
from .utils.artifact_utils import generate_artifacts, log_artifacts
from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
from .utils.data_utils import should_perform_precomputation
from .utils.diffusion_utils import (
Expand All @@ -50,7 +47,7 @@
prepare_sigmas,
prepare_target,
)
from .utils.file_utils import string_to_filename
from .utils.hub_utils import save_model_card
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
from .utils.optimizer_utils import get_optimizer
Expand Down Expand Up @@ -912,90 +909,33 @@ def validate(self, step: int, final_validation: bool = False) -> None:
)
pipeline.load_lora_weights(self.args.output_dir)

all_processes_artifacts = []
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue

prompt = self.args.validation_prompts[i]
image = self.args.validation_images[i]
video = self.args.validation_videos[i]
height = self.args.validation_heights[i]
width = self.args.validation_widths[i]
num_frames = self.args.validation_num_frames[i]

if image is not None:
image = load_image(image)
if video is not None:
video = load_video(video)

logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.model_config["validation"](
pipeline=pipeline,
prompt=prompt,
image=image,
video=video,
height=height,
width=width,
num_frames=num_frames,
num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
generator=self.state.generator,
# todo support passing `fps` for supported pipelines.
)

prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False,
)

for key, value in list(artifacts.items()):
artifact_type = value["type"]
artifact_value = value["value"]
if artifact_type not in ["image", "video"] or artifact_value is None:
continue

extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
filename = os.path.join(self.args.output_dir, filename)

if artifact_type == "image":
logger.debug(f"Saving image to {filename}")
artifact_value.save(filename)
artifact_value = wandb.Image(filename)
elif artifact_type == "video":
logger.debug(f"Saving video to {filename}")
# TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
export_to_video(artifact_value, filename, fps=15)
artifact_value = wandb.Video(filename, caption=prompt)

all_processes_artifacts.append(artifact_value)
all_processes_artifacts, prompts_to_filenames = generate_artifacts(
model_config=self.model_config,
pipeline=pipeline,
args=self.args,
generator=self.state.generator,
step=step,
num_processes=accelerator.num_processes,
process_index=accelerator.process_index,
trackers=accelerator.trackers,
final_validation=final_validation,
)

all_artifacts = gather_object(all_processes_artifacts)

if accelerator.is_main_process:
tracker_key = "final" if final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "wandb":
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
tracker.log(
{
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
},
step=step,
)

log_artifacts(artifacts=all_artifacts, trackers=accelerator.trackers, tracker_key=tracker_key, step=step)
if final_validation:
video_filenames = list(prompts_to_filenames.values())
prompts = list(prompts_to_filenames.keys())
save_model_card(
args=self.args,
repo_id=self.state.repo_id,
videos=video_filenames,
validation_prompts=prompts,
fps=self.args.fps,
)
# Remove all hooks that might have been added during pipeline initialization to the models
pipeline.remove_all_hooks()
del pipeline
Expand Down
156 changes: 156 additions & 0 deletions finetrainers/utils/artifact_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import argparse
import inspect
import os
import re

import torch
from accelerate.logging import get_logger
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video, load_image, load_video

from ..constants import FINETRAINERS_LOG_LEVEL
from ..utils.file_utils import string_to_filename


logger = get_logger("finetrainers")
logger.setLevel(FINETRAINERS_LOG_LEVEL)


def generate_artifacts(
model_config: dict,
pipeline: DiffusionPipeline,
args: argparse.Namespace,
generator: torch.Generator,
step: int,
num_processes: int,
process_index: int,
trackers: list,
final_validation: bool = False,
) -> list:
wandb_tracking = any("wandb" in tracker.name for tracker in trackers)
if wandb_tracking:
import wandb

all_processes_artifacts = []
num_validation_samples = len(args.validation_prompts)
prompts_to_filenames = {}
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % num_processes != process_index:
continue

prompt = args.validation_prompts[i]
image = args.validation_images[i]
video = args.validation_videos[i]
height = args.validation_heights[i]
width = args.validation_widths[i]
num_frames = args.validation_num_frames[i]

if image is not None:
image = load_image(image)
if video is not None:
video = load_video(video)

logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {process_index}. Prompt: {prompt}",
main_process_only=False,
)
has_fps_in_call = any("fps" in p for p in inspect.signature(model_config["pipeline_cls"].__call__).parameters)
validation_kwargs = {
"pipeline": pipeline,
"prompt": prompt,
"image": image,
"video": video,
"height": height,
"width": width,
"num_frames": num_frames,
"num_videos_per_prompt": args.num_validation_videos_per_prompt,
"generator": generator,
}
if has_fps_in_call:
validation_kwargs.update({"fps": args.fps})
validation_artifacts = model_config["validation"](**validation_kwargs)

prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
logger.debug(
f"Validation artifacts on process {process_index}: {list(artifacts.keys())}",
main_process_only=False,
)

for key, value in list(artifacts.items()):
artifact_type = value["type"]
artifact_value = value["value"]
if artifact_type not in ["image", "video"] or artifact_value is None:
continue

extension = "png" if artifact_type == "image" else "mp4"
filename = "validation-" if not final_validation else "final-"
filename += f"{step}-{process_index}-{prompt_filename}.{extension}"
filename = os.path.join(args.output_dir, filename)

if artifact_type == "image":
logger.debug(f"Saving image to {filename}")
artifact_value.save(filename)
if wandb_tracking:
artifact_value = wandb.Image(filename)
elif artifact_type == "video":
logger.debug(f"Saving video to {filename}")
export_to_video(artifact_value, filename, fps=args.fps)
if wandb_tracking:
artifact_value = wandb.Video(filename, caption=prompt, fps=args.fps)

all_processes_artifacts.append(artifact_value)
# limit to first process only as this will go into the model card.
if process_index == 0 and final_validation:
prompts_to_filenames[prompt] = filename

return all_processes_artifacts, prompts_to_filenames


def log_artifacts(artifacts: list, trackers: list, tracker_key: str, step: int) -> None:
wandb_tracking = any("wandb" in tracker.name for tracker in trackers)
if wandb_tracking:
import wandb

for tracker in trackers:
if tracker.name == "wandb":
image_artifacts = [artifact for artifact in artifacts if isinstance(artifact, wandb.Image)]
video_artifacts = [artifact for artifact in artifacts if isinstance(artifact, wandb.Video)]
tracker.log(
{
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
},
step=step,
)
else:
logger.warning("No supported tracker found for which logging is available.")

return


def get_latest_step_files(files):
# Regex to extract step
pattern = re.compile(r"validation-(\d+)-\d+-.+?\.mp4")

latest_step = -1
latest_files = []

for file in files:
match = pattern.match(file)
if match:
step = int(match.group(1))

# Update the latest step and reset the list if a higher step is found
if step > latest_step:
latest_step = step
latest_files = [file]
elif step == latest_step:
latest_files.append(file)

return latest_files
Loading