diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 15961a203dd4..86ffffd7d5df 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -74,6 +74,7 @@ def text_encoder_attn_modules(text_encoder): "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", + "WanLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ @@ -112,6 +113,7 @@ def text_encoder_attn_modules(text_encoder): SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, + WanLoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7802e307c028..9ba125be102a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4110,6 +4110,311 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components) +class WanLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`WanTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index da038b9fdca5..ee7467fdfe35 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -53,6 +53,7 @@ "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, "Lumina2Transformer2DModel": lambda model_cls, weights: weights, + "WanTransformer3DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 33e9daf70fe4..259afa547bc5 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -13,14 +13,15 @@ # limitations under the License. import math -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -109,9 +110,9 @@ class WanImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() - self.norm1 = nn.LayerNorm(in_features) + self.norm1 = FP32LayerNorm(in_features) self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") - self.norm2 = nn.LayerNorm(out_features) + self.norm2 = FP32LayerNorm(out_features) def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: hidden_states = self.norm1(encoder_hidden_states_image) @@ -287,7 +288,7 @@ def forward( return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin): +class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -391,7 +392,23 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -432,6 +449,10 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 062a2c21fd09..fd6135878492 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -13,7 +13,7 @@ # limitations under the License. import html -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import ftfy import regex as re @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -86,7 +87,7 @@ def prompt_clean(text): return text -class WanPipeline(DiffusionPipeline): +class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" Pipeline for text-to-video generation using Wan. @@ -299,10 +300,10 @@ def check_inputs( def prepare_latents( self, batch_size: int, - num_channels_latents: 16, - height: int = 720, - width: int = 1280, - num_latent_frames: int = 21, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -311,6 +312,7 @@ def prepare_latents( if latents is not None: return latents.to(device=device, dtype=dtype) + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 shape = ( batch_size, num_channels_latents, @@ -347,14 +349,18 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def attention_kwargs(self): + return self._attention_kwargs + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, - height: int = 720, - width: int = 1280, + height: int = 480, + width: int = 832, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, @@ -365,6 +371,7 @@ def __call__( negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -378,11 +385,11 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, defaults to `720`): + height (`int`, defaults to `480`): The height in pixels of the generated image. - width (`int`, defaults to `1280`): + width (`int`, defaults to `832`): The width in pixels of the generated image. - num_frames (`int`, defaults to `129`): + num_frames (`int`, defaults to `81`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -409,6 +416,10 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: @@ -445,6 +456,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -481,14 +493,12 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, - num_latent_frames, + num_frames, torch.float32, device, generator, @@ -512,6 +522,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -520,6 +531,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index eff63efe5197..5dd80ce2d6ae 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -13,17 +13,17 @@ # limitations under the License. import html -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ftfy -import numpy as np import PIL import regex as re import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -103,7 +103,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class WanImageToVideoPipeline(DiffusionPipeline): +class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" Pipeline for image-to-video generation using Wan. @@ -137,7 +137,7 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, + image_encoder: CLIPVisionModelWithProjection, image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, @@ -164,7 +164,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, + max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -291,15 +291,18 @@ def encode_prompt( def check_inputs( self, prompt, + negative_prompt, image, - max_area, + height, + width, prompt_embeds=None, + negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}") - if max_area < 0: - raise ValueError(f"`max_area` has to be positive but are {max_area}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -313,80 +316,70 @@ def check_inputs( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, image: PipelineImageInput, batch_size: int, - num_channels_latents: 32, - height: int = 720, - width: int = 1280, - max_area: int = 720 * 1280, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, num_frames: int = 81, - num_latent_frames: int = 21, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - aspect_ratio = height / width - mod_value = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - - if latents is not None: - return latents.to(device=device, dtype=dtype) + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial - shape = ( - batch_size, - num_channels_latents, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - ) + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) - image = self.video_processor.preprocess(image, height=height, width=width)[:, :, None] + image = image.unsqueeze(2) video_condition = torch.cat( - [image, torch.zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 ) video_condition = video_condition.to(device=device, dtype=dtype) + if isinstance(generator, list): latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator] latents = latent_condition = torch.cat(latent_condition) else: latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - mask_lat_size = torch.ones( - batch_size, - 1, - num_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - ) + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view( - batch_size, - -1, - self.vae_scale_factor_temporal, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - ) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(latent_condition.device) @@ -412,6 +405,10 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def attention_kwargs(self): + return self._attention_kwargs + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -419,7 +416,8 @@ def __call__( image: PipelineImageInput, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, - max_area: int = 720 * 1280, + height: int = 480, + width: int = 832, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, @@ -430,6 +428,7 @@ def __call__( negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -445,9 +444,15 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - max_area (`int`, defaults to `1280 * 720`): - The maximum area in pixels of the generated image. - num_frames (`int`, defaults to `129`): + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -474,6 +479,10 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: @@ -504,13 +513,17 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, + negative_prompt, image, - max_area, + height, + width, prompt_embeds, + negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -537,36 +550,29 @@ def __call__( ) # Encode image embedding - image_embeds = self.encode_image(image) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + image_embeds = self.encode_image(image) + image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - if isinstance(image, torch.Tensor): - height, width = image.shape[-2:] - else: - width, height = image.size - # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) latents, condition = self.prepare_latents( image, batch_size * num_videos_per_prompt, num_channels_latents, height, width, - max_area, num_frames, - num_latent_frames, torch.float32, device, generator, @@ -591,6 +597,7 @@ def __call__( timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -600,6 +607,7 @@ def __call__( timestep=timestep, encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py new file mode 100644 index 000000000000..c2498fa68c3d --- /dev/null +++ b/tests/lora/test_lora_layers_wan.py @@ -0,0 +1,143 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanPipeline, + WanTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + require_peft_backend, + skip_mps, +) + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +@skip_mps +class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = WanPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + + transformer_kwargs = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 16, + "out_channels": 16, + "text_dim": 32, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + transformer_cls = WanTransformer3DModel + vae_kwargs = { + "base_dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + vae_cls = AutoencoderKLWan + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + text_encoder_target_modules = ["q", "k", "v", "o"] + + @property + def output_shape(self): + return (1, 9, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_frames": num_frames, + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in Wan.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Wan.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Wan.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in Wan.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index a94198efaa64..17f6c9ccdf98 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1594,11 +1594,17 @@ def test_lora_fuse_nan(self): ].weight += float("inf") else: named_modules = [name for name, _ in pipe.transformer.named_modules()] + tower_name = ( + "transformer_blocks" + if any(name == "transformer_blocks" for name in named_modules) + else "blocks" + ) + transformer_tower = getattr(pipe.transformer, tower_name) has_attn1 = any("attn1" in name for name in named_modules) if has_attn1: - pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") else: - pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") # with `safe_fusing=True` we should see an Error with self.assertRaises(ValueError): diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index b898545c147b..53fa37dfae99 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -125,7 +125,8 @@ def get_dummy_inputs(self, device, seed=0): "image": image, "prompt": "dance monkey", "negative_prompt": "negative", # TODO - "max_area": 1024, + "height": image_height, + "width": image_width, "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, @@ -147,8 +148,8 @@ def test_inference(self): video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) max_diff = np.abs(generated_video - expected_video).max() self.assertLessEqual(max_diff, 1e10)