|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import re |
| 4 | +import sys |
| 5 | +import time |
| 6 | + |
| 7 | +import gradio as gr |
| 8 | +import modelopt.torch.quantization as mtq |
| 9 | +import torch |
| 10 | +import torch_tensorrt |
| 11 | +from accelerate.hooks import remove_hook_from_module |
| 12 | +from diffusers import FluxPipeline |
| 13 | +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel |
| 14 | + |
| 15 | +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py |
| 16 | +sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) |
| 17 | +from register_sdpa import * |
| 18 | + |
| 19 | +DEVICE = "cuda:0" |
| 20 | + |
| 21 | + |
| 22 | +def compile_model( |
| 23 | + args, |
| 24 | +) -> tuple[ |
| 25 | + FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule |
| 26 | +]: |
| 27 | + |
| 28 | + if args.dtype == "fp8": |
| 29 | + enabled_precisions = {torch.float8_e4m3fn, torch.float16} |
| 30 | + ptq_config = mtq.FP8_DEFAULT_CFG |
| 31 | + |
| 32 | + elif args.dtype == "int8": |
| 33 | + enabled_precisions = {torch.int8, torch.float16} |
| 34 | + ptq_config = mtq.INT8_DEFAULT_CFG |
| 35 | + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None |
| 36 | + |
| 37 | + elif args.dtype == "fp16": |
| 38 | + enabled_precisions = {torch.float16} |
| 39 | + |
| 40 | + print(f"\nUsing {args.dtype}") |
| 41 | + |
| 42 | + pipe = FluxPipeline.from_pretrained( |
| 43 | + "black-forest-labs/FLUX.1-dev", |
| 44 | + torch_dtype=torch.float16, |
| 45 | + ).to(torch.float16) |
| 46 | + |
| 47 | + if args.debug: |
| 48 | + pipe.transformer = FluxTransformer2DModel( |
| 49 | + num_layers=1, num_single_layers=1, guidance_embeds=True |
| 50 | + ).to(torch.float16) |
| 51 | + |
| 52 | + if args.low_vram_mode: |
| 53 | + pipe.enable_model_cpu_offload() |
| 54 | + else: |
| 55 | + pipe.to(DEVICE) |
| 56 | + |
| 57 | + backbone = pipe.transformer |
| 58 | + backbone.eval() |
| 59 | + |
| 60 | + def filter_func(name): |
| 61 | + pattern = re.compile( |
| 62 | + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" |
| 63 | + ) |
| 64 | + return pattern.match(name) is not None |
| 65 | + |
| 66 | + def do_calibrate( |
| 67 | + pipe, |
| 68 | + prompt: str, |
| 69 | + ) -> None: |
| 70 | + """ |
| 71 | + Run calibration steps on the pipeline using the given prompts. |
| 72 | + """ |
| 73 | + image = pipe( |
| 74 | + prompt, |
| 75 | + output_type="pil", |
| 76 | + num_inference_steps=20, |
| 77 | + generator=torch.Generator("cuda").manual_seed(0), |
| 78 | + ).images[0] |
| 79 | + |
| 80 | + def forward_loop(mod): |
| 81 | + # Switch the pipeline's backbone, run calibration |
| 82 | + pipe.transformer = mod |
| 83 | + do_calibrate( |
| 84 | + pipe=pipe, |
| 85 | + prompt="test", |
| 86 | + ) |
| 87 | + |
| 88 | + if args.dtype != "fp16": |
| 89 | + backbone = mtq.quantize(backbone, ptq_config, forward_loop) |
| 90 | + mtq.disable_quantizer(backbone, filter_func) |
| 91 | + |
| 92 | + batch_size = 2 if args.dynamic_shapes else 1 |
| 93 | + if args.dynamic_shapes: |
| 94 | + BATCH = torch.export.Dim("batch", min=1, max=8) |
| 95 | + dynamic_shapes = { |
| 96 | + "hidden_states": {0: BATCH}, |
| 97 | + "encoder_hidden_states": {0: BATCH}, |
| 98 | + "pooled_projections": {0: BATCH}, |
| 99 | + "timestep": {0: BATCH}, |
| 100 | + "txt_ids": {}, |
| 101 | + "img_ids": {}, |
| 102 | + "guidance": {0: BATCH}, |
| 103 | + "joint_attention_kwargs": {}, |
| 104 | + "return_dict": None, |
| 105 | + } |
| 106 | + else: |
| 107 | + dynamic_shapes = None |
| 108 | + |
| 109 | + settings = { |
| 110 | + "strict": False, |
| 111 | + "allow_complex_guards_as_runtime_asserts": True, |
| 112 | + "enabled_precisions": enabled_precisions, |
| 113 | + "truncate_double": True, |
| 114 | + "min_block_size": 1, |
| 115 | + "debug": False, |
| 116 | + "use_python_runtime": True, |
| 117 | + "immutable_weights": False, |
| 118 | + "offload_module_to_cpu": True, |
| 119 | + } |
| 120 | + if args.low_vram_mode: |
| 121 | + pipe.remove_all_hooks() |
| 122 | + pipe.enable_sequential_cpu_offload() |
| 123 | + remove_hook_from_module(pipe.transformer, recurse=True) |
| 124 | + pipe.transformer.to(DEVICE) |
| 125 | + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) |
| 126 | + if dynamic_shapes: |
| 127 | + trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) |
| 128 | + pipe.transformer = trt_gm |
| 129 | + |
| 130 | + image = pipe( |
| 131 | + "Test", |
| 132 | + output_type="pil", |
| 133 | + num_inference_steps=2, |
| 134 | + num_images_per_prompt=batch_size, |
| 135 | + ).images |
| 136 | + |
| 137 | + torch.cuda.empty_cache() |
| 138 | + |
| 139 | + if args.low_vram_mode: |
| 140 | + pipe.remove_all_hooks() |
| 141 | + pipe.to(DEVICE) |
| 142 | + |
| 143 | + return pipe, backbone, trt_gm |
| 144 | + |
| 145 | + |
| 146 | +def launch_gradio(pipeline, backbone, trt_gm): |
| 147 | + |
| 148 | + def generate_image(prompt, inference_step, batch_size=2): |
| 149 | + start_time = time.time() |
| 150 | + image = pipeline( |
| 151 | + prompt, |
| 152 | + output_type="pil", |
| 153 | + num_inference_steps=inference_step, |
| 154 | + num_images_per_prompt=batch_size, |
| 155 | + ).images |
| 156 | + end_time = time.time() |
| 157 | + return image, end_time - start_time |
| 158 | + |
| 159 | + def model_change(model): |
| 160 | + if model == "Torch Model": |
| 161 | + pipeline.transformer = backbone |
| 162 | + backbone.to(DEVICE) |
| 163 | + else: |
| 164 | + backbone.to("cpu") |
| 165 | + pipeline.transformer = trt_gm |
| 166 | + torch.cuda.empty_cache() |
| 167 | + |
| 168 | + def load_lora(path): |
| 169 | + pipeline.load_lora_weights( |
| 170 | + path, |
| 171 | + adapter_name="lora1", |
| 172 | + ) |
| 173 | + pipeline.set_adapters(["lora1"], adapter_weights=[1]) |
| 174 | + pipeline.fuse_lora() |
| 175 | + pipeline.unload_lora_weights() |
| 176 | + print("LoRA loaded! Begin refitting") |
| 177 | + generate_image(pipeline, ["Test"], 2) |
| 178 | + print("Refitting Finished!") |
| 179 | + |
| 180 | + # Create Gradio interface |
| 181 | + with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: |
| 182 | + gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") |
| 183 | + |
| 184 | + with gr.Row(): |
| 185 | + with gr.Column(): |
| 186 | + # Input components |
| 187 | + prompt_input = gr.Textbox( |
| 188 | + label="Prompt", placeholder="Enter your prompt here...", lines=3 |
| 189 | + ) |
| 190 | + model_dropdown = gr.Dropdown( |
| 191 | + choices=["Torch Model", "Torch-TensorRT Accelerated Model"], |
| 192 | + value="Torch-TensorRT Accelerated Model", |
| 193 | + label="Model Variant", |
| 194 | + ) |
| 195 | + |
| 196 | + lora_upload_path = gr.Textbox( |
| 197 | + label="LoRA Path", |
| 198 | + placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.", |
| 199 | + value="gokaygokay/Flux-Engrave-LoRA", |
| 200 | + lines=2, |
| 201 | + ) |
| 202 | + num_steps = gr.Slider( |
| 203 | + minimum=20, maximum=100, value=20, step=1, label="Inference Steps" |
| 204 | + ) |
| 205 | + batch_size = gr.Slider( |
| 206 | + minimum=1, maximum=8, value=1, step=1, label="Batch Size" |
| 207 | + ) |
| 208 | + |
| 209 | + generate_btn = gr.Button("Generate Image") |
| 210 | + load_lora_btn = gr.Button("Load LoRA") |
| 211 | + |
| 212 | + with gr.Column(): |
| 213 | + # Output component |
| 214 | + output_image = gr.Gallery(label="Generated Image") |
| 215 | + time_taken = gr.Textbox( |
| 216 | + label="Generation Time (seconds)", interactive=False |
| 217 | + ) |
| 218 | + |
| 219 | + # Connect the button to the generation function |
| 220 | + model_dropdown.change(model_change, inputs=[model_dropdown]) |
| 221 | + load_lora_btn.click( |
| 222 | + fn=load_lora, |
| 223 | + inputs=[ |
| 224 | + lora_upload_path, |
| 225 | + ], |
| 226 | + ) |
| 227 | + |
| 228 | + # Update generate button click to include time output |
| 229 | + generate_btn.click( |
| 230 | + fn=generate_image, |
| 231 | + inputs=[ |
| 232 | + prompt_input, |
| 233 | + num_steps, |
| 234 | + batch_size, |
| 235 | + ], |
| 236 | + outputs=[output_image, time_taken], |
| 237 | + ) |
| 238 | + demo.launch() |
| 239 | + |
| 240 | + |
| 241 | +def main(args): |
| 242 | + pipe, backbone, trt_gm = compile_model(args) |
| 243 | + launch_gradio(pipe, backbone, trt_gm) |
| 244 | + |
| 245 | + |
| 246 | +# Launch the interface |
| 247 | +if __name__ == "__main__": |
| 248 | + parser = argparse.ArgumentParser( |
| 249 | + description="Run Flux quantization with different dtypes" |
| 250 | + ) |
| 251 | + |
| 252 | + parser.add_argument( |
| 253 | + "--dtype", |
| 254 | + choices=["fp8", "int8", "fp16"], |
| 255 | + default="fp16", |
| 256 | + help="Select the data type to use (fp8 or int8 or fp16)", |
| 257 | + ) |
| 258 | + parser.add_argument( |
| 259 | + "--low_vram_mode", |
| 260 | + action="store_true", |
| 261 | + help="Use low VRAM mode when you have a small GPU (<=32GB)", |
| 262 | + ) |
| 263 | + parser.add_argument( |
| 264 | + "--dynamic_shapes", |
| 265 | + "-d", |
| 266 | + action="store_true", |
| 267 | + help="Use dynamic shapes", |
| 268 | + ) |
| 269 | + parser.add_argument( |
| 270 | + "--debug", |
| 271 | + action="store_true", |
| 272 | + help="Use debug mode", |
| 273 | + ) |
| 274 | + args = parser.parse_args() |
| 275 | + main(args) |
0 commit comments