Skip to content

Add BlackForest Flux Support #815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
16 changes: 8 additions & 8 deletions optimum/exporters/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,36 @@

_import_structure = {
"__main__": [
"infer_stable_diffusion_shapes_from_diffusers",
"infer_shapes_of_diffusers",
"main_export",
"normalize_stable_diffusion_input_shapes",
"normalize_diffusers_input_shapes",
"get_submodels_and_neuron_configs",
"load_models_and_neuron_configs",
],
"base": ["NeuronDefaultConfig"],
"convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"],
"utils": [
"build_stable_diffusion_components_mandatory_shapes",
"get_stable_diffusion_models_for_export",
"get_diffusion_models_for_export",
"replace_stable_diffusion_submodels",
"get_submodels_for_export_stable_diffusion",
"get_submodels_for_export_diffusion",
],
}

if TYPE_CHECKING:
from .__main__ import (
get_submodels_and_neuron_configs,
infer_stable_diffusion_shapes_from_diffusers,
infer_shapes_of_diffusers,
load_models_and_neuron_configs,
main_export,
normalize_stable_diffusion_input_shapes,
normalize_diffusers_input_shapes,
)
from .base import NeuronDefaultConfig
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .utils import (
build_stable_diffusion_components_mandatory_shapes,
get_stable_diffusion_models_for_export,
get_submodels_for_export_stable_diffusion,
get_diffusion_models_for_export,
get_submodels_for_export_diffusion,
replace_stable_diffusion_submodels,
)
else:
Expand Down
52 changes: 25 additions & 27 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,17 @@
NEURON_COMPILER = "Neuronx"


if is_diffusers_available():
from diffusers import StableDiffusionXLPipeline


if TYPE_CHECKING:
from transformers import PreTrainedModel

if is_diffusers_available():
from diffusers import DiffusionPipeline, ModelMixin, StableDiffusionPipeline
from diffusers import (
DiffusionPipeline,
ModelMixin,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
FluxPipeline,
)


logger = logging.get_logger()
Expand Down Expand Up @@ -191,7 +193,7 @@ def parse_optlevel(args: argparse.Namespace) -> Dict[str, bool]:
return optlevel


def normalize_stable_diffusion_input_shapes(
def normalize_diffusers_input_shapes(
args: argparse.Namespace,
) -> Dict[str, Dict[str, int]]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
Expand All @@ -214,20 +216,14 @@ def normalize_stable_diffusion_input_shapes(
input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes)
return input_shapes


def infer_stable_diffusion_shapes_from_diffusers(
def infer_shapes_of_diffusers(
input_shapes: Dict[str, Dict[str, int]],
model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"],
model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline", "FluxPipeline"],
has_controlnets: bool,
):
if model.tokenizer is not None:
max_sequence_length = model.tokenizer.model_max_length
elif hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None:
max_sequence_length = model.tokenizer_2.model_max_length
else:
raise AttributeError(
f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute."
)
max_sequence_length_1 = model.tokenizer.model_max_length if model.tokenizer is not None else None
max_sequence_length_2 = model.tokenizer_2.model_max_length if hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None else None

vae_encoder_num_channels = model.vae.config.in_channels
vae_decoder_num_channels = model.vae.config.latent_channels
vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8
Expand All @@ -238,9 +234,12 @@ def infer_stable_diffusion_shapes_from_diffusers(

# Text encoders
if input_shapes["text_encoder"].get("sequence_length") is None:
input_shapes["text_encoder"].update({"sequence_length": max_sequence_length})
input_shapes["text_encoder"].update({"sequence_length": max_sequence_length_1})
if hasattr(model, "text_encoder_2"):
input_shapes["text_encoder_2"] = input_shapes["text_encoder"]
input_shapes["text_encoder_2"] = {
"batch_size": input_shapes["text_encoder"]["batch_size"],
"sequence_length": max_sequence_length_2,
}

# UNet or Transformer
unet_or_transformer_name = "transformer" if hasattr(model, "transformer") else "unet"
Expand All @@ -252,8 +251,8 @@ def infer_stable_diffusion_shapes_from_diffusers(
"width": scaled_width,
}
)
if input_shapes["unet_or_transformer"].get("sequence_length") is None:
input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length
if input_shapes["unet_or_transformer"].get("sequence_length") is None:
input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length_2 or max_sequence_length_1
input_shapes["unet_or_transformer"]["vae_scale_factor"] = vae_scale_factor
input_shapes[unet_or_transformer_name] = input_shapes.pop("unet_or_transformer")
if unet_or_transformer_name == "transformer":
Expand Down Expand Up @@ -328,7 +327,7 @@ def get_submodels_and_neuron_configs(
# TODO: Enable optional outputs for Stable Diffusion
if output_attentions:
raise ValueError(f"`output_attentions`is not supported by the {task} task yet.")
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion(
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_diffusion(
model=model,
input_shapes=input_shapes,
output=output,
Expand Down Expand Up @@ -380,8 +379,7 @@ def get_submodels_and_neuron_configs(
maybe_save_preprocessors(model_name_or_path, output, src_subfolder=subfolder)
return models_and_neuron_configs, output_model_names


def _get_submodels_and_neuron_configs_for_stable_diffusion(
def _get_submodels_and_neuron_configs_for_diffusion(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
output: Path,
Expand All @@ -397,12 +395,11 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
raise RuntimeError(
"Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(
input_shapes = infer_shapes_of_diffusers(
input_shapes=input_shapes,
model=model,
has_controlnets=controlnet_ids is not None,
)

# Saving the model config and preprocessor as this is needed sometimes.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if getattr(model, "tokenizer", None) is not None:
Expand All @@ -428,6 +425,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
controlnet_ids=controlnet_ids,
controlnet_input_shapes=input_shapes.get("controlnet", None),
image_encoder_input_shapes=input_shapes.get("image_encoder", None),
text_encoder_2_input_shapes=input_shapes.get("text_encoder_2", None),
)
output_model_names = {
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
Expand Down Expand Up @@ -754,7 +752,7 @@ def main():
library_name = TasksManager.infer_library_from_model(args.model, cache_dir=args.cache_dir)

if library_name == "diffusers":
input_shapes = normalize_stable_diffusion_input_shapes(args)
input_shapes = normalize_diffusers_input_shapes(args)
submodels = {"unet": args.unet}
elif library_name == "sentence_transformers":
input_shapes = normalize_sentence_transformers_input_shapes(args)
Expand Down
20 changes: 14 additions & 6 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,24 @@ def unflatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:

return unflatten

def patch_model_for_export(
def patch_model_and_prepare_aliases(
self,
model: "PreTrainedModel",
dummy_inputs: Optional[Dict[str, torch.Tensor]] = None,
input_names: List[str] = None,
forward_with_tuple: bool = False,
eligible_outputs: Optional[List[Union[str, int]]] = None,
device: Optional[str] = None,
):
"""
Checks if inputs order of the model's forward pass correspond to the generated dummy inputs to ensure the dummy inputs tuple used for
tracing are under the correct order.
Patch the model and generate aliased for tracing.

This function performs the following:
1. Verifies that the input order of the model's `forward` method matches the structure
of the generated dummy inputs. This ensures the dummy inputs tuple is correctly ordered
for tracing.
2. Applies model sharding if tensor parallelism is enabled (using `CUSTOM_MODEL_WRAPPER`).
3. Prepares I/O aliases to identify specific input tensors as state tensors.
These state tensors will remain on the device, helping to reduce host-device I/O overhead.
"""
output_hidden_states = self.output_hidden_states

Expand Down Expand Up @@ -430,6 +437,7 @@ def forward(self, *input):
return outputs

if self.CUSTOM_MODEL_WRAPPER is None:
return ModelWrapper(model, list(dummy_inputs.keys()))
# Order dummy input and build empty alias
return ModelWrapper(model, input_names), {}
else:
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))
return self.CUSTOM_MODEL_WRAPPER(model, input_names), {}
Loading
Loading