Skip to content

Commit 0d9a008

Browse files
committed
Added flux demo and quantization
1 parent dd06bd8 commit 0d9a008

File tree

17 files changed

+872
-251
lines changed

17 files changed

+872
-251
lines changed

MODULE.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.
3737
new_local_repository(
3838
name = "cuda",
3939
build_file = "@//third_party/cuda:BUILD",
40-
path = "/usr/local/cuda-12.8/",
40+
path = "/usr/local/cuda-12.9/",
4141
)
4242

4343
# for Jetson

examples/apps/flux_demo.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import argparse
2+
import os
3+
import re
4+
import sys
5+
import time
6+
7+
import gradio as gr
8+
import modelopt.torch.quantization as mtq
9+
import torch
10+
import torch_tensorrt
11+
from accelerate.hooks import remove_hook_from_module
12+
from diffusers import FluxPipeline
13+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
14+
15+
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
16+
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
17+
from register_sdpa import *
18+
19+
DEVICE = "cuda:0"
20+
21+
22+
def compile_model(
23+
args,
24+
) -> tuple[
25+
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
26+
]:
27+
28+
if args.dtype == "fp8":
29+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
30+
ptq_config = mtq.FP8_DEFAULT_CFG
31+
32+
elif args.dtype == "int8":
33+
enabled_precisions = {torch.int8, torch.float16}
34+
ptq_config = mtq.INT8_DEFAULT_CFG
35+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
36+
37+
elif args.dtype == "fp16":
38+
enabled_precisions = {torch.float16}
39+
40+
print(f"\nUsing {args.dtype}")
41+
42+
pipe = FluxPipeline.from_pretrained(
43+
"black-forest-labs/FLUX.1-dev",
44+
torch_dtype=torch.float16,
45+
).to(torch.float16)
46+
47+
if args.debug:
48+
pipe.transformer = FluxTransformer2DModel(
49+
num_layers=1, num_single_layers=1, guidance_embeds=True
50+
).to(torch.float16)
51+
52+
if args.low_vram_mode:
53+
pipe.enable_model_cpu_offload()
54+
else:
55+
pipe.to(DEVICE)
56+
57+
backbone = pipe.transformer
58+
backbone.eval()
59+
60+
def filter_func(name):
61+
pattern = re.compile(
62+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
63+
)
64+
return pattern.match(name) is not None
65+
66+
def do_calibrate(
67+
pipe,
68+
prompt: str,
69+
) -> None:
70+
"""
71+
Run calibration steps on the pipeline using the given prompts.
72+
"""
73+
image = pipe(
74+
prompt,
75+
output_type="pil",
76+
num_inference_steps=20,
77+
generator=torch.Generator("cuda").manual_seed(0),
78+
).images[0]
79+
80+
def forward_loop(mod):
81+
# Switch the pipeline's backbone, run calibration
82+
pipe.transformer = mod
83+
do_calibrate(
84+
pipe=pipe,
85+
prompt="test",
86+
)
87+
88+
if args.dtype != "fp16":
89+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
90+
mtq.disable_quantizer(backbone, filter_func)
91+
92+
batch_size = 2 if args.dynamic_shapes else 1
93+
if args.dynamic_shapes:
94+
BATCH = torch.export.Dim("batch", min=1, max=8)
95+
dynamic_shapes = {
96+
"hidden_states": {0: BATCH},
97+
"encoder_hidden_states": {0: BATCH},
98+
"pooled_projections": {0: BATCH},
99+
"timestep": {0: BATCH},
100+
"txt_ids": {},
101+
"img_ids": {},
102+
"guidance": {0: BATCH},
103+
"joint_attention_kwargs": {},
104+
"return_dict": None,
105+
}
106+
else:
107+
dynamic_shapes = None
108+
109+
settings = {
110+
"strict": False,
111+
"allow_complex_guards_as_runtime_asserts": True,
112+
"enabled_precisions": enabled_precisions,
113+
"truncate_double": True,
114+
"min_block_size": 1,
115+
"debug": False,
116+
"use_python_runtime": True,
117+
"immutable_weights": False,
118+
"offload_module_to_cpu": True,
119+
}
120+
if args.low_vram_mode:
121+
pipe.remove_all_hooks()
122+
pipe.enable_sequential_cpu_offload()
123+
remove_hook_from_module(pipe.transformer, recurse=True)
124+
pipe.transformer.to(DEVICE)
125+
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
126+
if dynamic_shapes:
127+
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
128+
pipe.transformer = trt_gm
129+
130+
image = pipe(
131+
"Test",
132+
output_type="pil",
133+
num_inference_steps=2,
134+
num_images_per_prompt=batch_size,
135+
).images
136+
137+
torch.cuda.empty_cache()
138+
139+
if args.low_vram_mode:
140+
pipe.remove_all_hooks()
141+
pipe.to(DEVICE)
142+
143+
return pipe, backbone, trt_gm
144+
145+
146+
def launch_gradio(pipeline, backbone, trt_gm):
147+
148+
def generate_image(prompt, inference_step, batch_size=2):
149+
start_time = time.time()
150+
image = pipeline(
151+
prompt,
152+
output_type="pil",
153+
num_inference_steps=inference_step,
154+
num_images_per_prompt=batch_size,
155+
).images
156+
end_time = time.time()
157+
return image, end_time - start_time
158+
159+
def model_change(model):
160+
if model == "Torch Model":
161+
pipeline.transformer = backbone
162+
backbone.to(DEVICE)
163+
else:
164+
backbone.to("cpu")
165+
pipeline.transformer = trt_gm
166+
torch.cuda.empty_cache()
167+
168+
def load_lora(path):
169+
pipeline.load_lora_weights(
170+
path,
171+
adapter_name="lora1",
172+
)
173+
pipeline.set_adapters(["lora1"], adapter_weights=[1])
174+
pipeline.fuse_lora()
175+
pipeline.unload_lora_weights()
176+
print("LoRA loaded! Begin refitting")
177+
generate_image(pipeline, ["Test"], 2)
178+
print("Refitting Finished!")
179+
180+
# Create Gradio interface
181+
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
182+
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")
183+
184+
with gr.Row():
185+
with gr.Column():
186+
# Input components
187+
prompt_input = gr.Textbox(
188+
label="Prompt", placeholder="Enter your prompt here...", lines=3
189+
)
190+
model_dropdown = gr.Dropdown(
191+
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
192+
value="Torch-TensorRT Accelerated Model",
193+
label="Model Variant",
194+
)
195+
196+
lora_upload_path = gr.Textbox(
197+
label="LoRA Path",
198+
placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.",
199+
value="gokaygokay/Flux-Engrave-LoRA",
200+
lines=2,
201+
)
202+
num_steps = gr.Slider(
203+
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
204+
)
205+
batch_size = gr.Slider(
206+
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
207+
)
208+
209+
generate_btn = gr.Button("Generate Image")
210+
load_lora_btn = gr.Button("Load LoRA")
211+
212+
with gr.Column():
213+
# Output component
214+
output_image = gr.Gallery(label="Generated Image")
215+
time_taken = gr.Textbox(
216+
label="Generation Time (seconds)", interactive=False
217+
)
218+
219+
# Connect the button to the generation function
220+
model_dropdown.change(model_change, inputs=[model_dropdown])
221+
load_lora_btn.click(
222+
fn=load_lora,
223+
inputs=[
224+
lora_upload_path,
225+
],
226+
)
227+
228+
# Update generate button click to include time output
229+
generate_btn.click(
230+
fn=generate_image,
231+
inputs=[
232+
prompt_input,
233+
num_steps,
234+
batch_size,
235+
],
236+
outputs=[output_image, time_taken],
237+
)
238+
demo.launch()
239+
240+
241+
def main(args):
242+
pipe, backbone, trt_gm = compile_model(args)
243+
launch_gradio(pipe, backbone, trt_gm)
244+
245+
246+
# Launch the interface
247+
if __name__ == "__main__":
248+
parser = argparse.ArgumentParser(
249+
description="Run Flux quantization with different dtypes"
250+
)
251+
252+
parser.add_argument(
253+
"--dtype",
254+
choices=["fp8", "int8", "fp16"],
255+
default="fp16",
256+
help="Select the data type to use (fp8 or int8 or fp16)",
257+
)
258+
parser.add_argument(
259+
"--low_vram_mode",
260+
action="store_true",
261+
help="Use low VRAM mode when you have a small GPU (<=32GB)",
262+
)
263+
parser.add_argument(
264+
"--dynamic_shapes",
265+
"-d",
266+
action="store_true",
267+
help="Use dynamic shapes",
268+
)
269+
parser.add_argument(
270+
"--debug",
271+
action="store_true",
272+
help="Use debug mode",
273+
)
274+
args = parser.parse_args()
275+
main(args)

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch_tensorrt as torch_trt
2424
import torchvision.models as models
25+
from diffusers import DiffusionPipeline
2526

2627
np.random.seed(5)
2728
torch.manual_seed(5)
@@ -31,7 +32,7 @@
3132
# Initialize the Mutable Torch TensorRT Module with settings.
3233
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3334
settings = {
34-
"use_python": False,
35+
"use_python_runtime": False,
3536
"enabled_precisions": {torch.float32},
3637
"immutable_weights": False,
3738
}
@@ -40,7 +41,6 @@
4041
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
4142
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
4243
mutable_module(*inputs)
43-
4444
# %%
4545
# Make modifications to the mutable module.
4646
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -73,13 +73,12 @@
7373
# Stable Diffusion with Huggingface
7474
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7575

76-
from diffusers import DiffusionPipeline
7776

7877
with torch.no_grad():
7978
settings = {
8079
"use_python_runtime": True,
8180
"enabled_precisions": {torch.float16},
82-
"debug": True,
81+
"debug": False,
8382
"immutable_weights": False,
8483
}
8584

@@ -106,7 +105,7 @@
106105
"text_embeds": {0: BATCH},
107106
"time_ids": {0: BATCH},
108107
},
109-
"return_dict": False,
108+
"return_dict": None,
110109
}
111110
pipe.unet.set_expected_dynamic_shape_range(
112111
args_dynamic_shapes, kwargs_dynamic_shapes

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102

103103
# Check the output
104+
model2.to("cuda")
104105
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
105106
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
106107
assert torch.allclose(

0 commit comments

Comments
 (0)