From 3f09d976143f5755b157739588601756e6a4c4b6 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 25 May 2025 19:53:08 -0700 Subject: [PATCH 1/7] add flux example. --- examples/dynamo/flux_ptq.py | 175 ++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 examples/dynamo/flux_ptq.py diff --git a/examples/dynamo/flux_ptq.py b/examples/dynamo/flux_ptq.py new file mode 100644 index 0000000000..f19e378b5f --- /dev/null +++ b/examples/dynamo/flux_ptq.py @@ -0,0 +1,175 @@ + +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8"], + default="fp8", + help="Quantization data type to use (fp8 or int8)", +) + +args = parser.parse_args() + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + enabled_precisions = {torch.float8_e4m3fn, torch.float16} + ptq_config = mtq.FP8_DEFAULT_CFG +else: # int8 + enabled_precisions = {torch.int8, torch.float16} + ptq_config = mtq.INT8_DEFAULT_CFG + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None +print(f"\nUsing {args.dtype} quantization") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +pipe.transformer = FluxTransformer2DModel( + num_layers=1, num_single_layers=1, guidance_embeds=True +) + +pipe.to(DEVICE).to(torch.float16) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + +backbone = mtq.quantize(backbone, ptq_config, forward_loop) +mtq.disable_quantizer(backbone, filter_func) + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to( + DEVICE + ), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float16 + ).to(DEVICE), + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to( + DEVICE + ), + "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + +trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + truncate_double=True, + min_block_size=1, + debug=False, + use_python_runtime=True, + immutable_weights=True, + offload_module_to_cpu=True, +) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["A golden retriever"], "dog_code2") From 9f803eec7d8b2b4d7233c5bfbf3dbe6dbb1ddcd5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 27 May 2025 11:02:55 -0700 Subject: [PATCH 2/7] add flux --- .../dynamo/conversion/aten_ops_converters.py | 36 +++- .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/attention.py | 165 +++++++++++++++++ .../dynamo/conversion/impl/quantize.py | 52 +++--- .../dynamo/conversion/impl/unsqueeze.py | 4 +- .../lowering/passes/_aten_lowering_pass.py | 2 + .../lower_scaled_dot_product_attention.py | 169 ++++++++++++++++++ .../perf/Flux/flux_quantization_debug.py | 32 +++- 8 files changed, 437 insertions(+), 24 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/attention.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py rename examples/dynamo/flux_ptq.py => tools/perf/Flux/flux_quantization_debug.py (87%) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 08a1bb4ea4..d9afb36bd8 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -597,7 +597,9 @@ def aten_ops_neg( ) else: - @dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default) + @dynamo_tensorrt_converter( + torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True + ) def aten_ops_quantize_op( ctx: ConversionContext, target: Target, @@ -650,6 +652,38 @@ def aten_ops_dynamic_block_quantize_op( ) +def attention_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + # Currently, `attn_mask` is not supported + return args_bounds_check(node.args, 3) is None + + +@dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + capability_validator=attention_validator, + supports_dynamic_shapes=True, +) +def tensorrt_scaled_dot_product_attention( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.attention.scaled_dot_product_attention( + ctx, + target, + SourceIR.TORCHTRT_LOWERED, + name, + args[0], + args[1], + args[2], + args_bounds_check(args, 5, False), + kwargs.get("scale", None), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 1f2d9d0de1..03903c1f07 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,6 +2,7 @@ activation, addmm, arange, + attention, cast, cat, condition, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py new file mode 100644 index 0000000000..9cc4a30ccf --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -0,0 +1,165 @@ +import math +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # the lower triangle of the tensor means the rows greater than and equal to the cols + row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) + col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) + rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 + ) + # get the rows + row_tensor = impl.elementwise.trunc_div( + ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col + ) + # get the cols + col_tensor = impl.elementwise.fmod( + ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col + ) + cond = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_tensor, col_tensor + ) + return impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", cond, [row, col] + ) + + +def scaled_dot_product_attention( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + query: TRTTensor, + key: TRTTensor, + value: TRTTensor, + is_causal: bool, + scale: Optional[float], +) -> TRTTensor: + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, -2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) + + LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) + + # this is to generate a tensor which has shape (L, S), type is int32 + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 + ) + shape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] + ) + + # since we want our attn_bias to be in float32, so cast it to float32 + shape_tensor = cast_trt_tensor( + ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir + ) + + # initialize the attn_bias as the zeros tensor + attn_bias = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 + ) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + inf_tensor = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") + ) + cond = impl.elementwise.eq( + ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) + ) + # mask out the certain part of the attn_bias + attn_bias = impl.condition.select( + ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond + ) + + scaled = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index e472ed3092..16703e8ae4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -3,6 +3,7 @@ import numpy as np import tensorrt as trt import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -28,42 +29,53 @@ def quantize( """ with unset_fake_temporarily(): - if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( - trt.float32, - trt.float16, - ): - raise ValueError( - f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" - ) + if isinstance(input_tensor, (torch.Tensor, TRTTensor)): + input_tensor = get_trt_tensor(ctx, input_tensor, name) + if input_tensor.dtype not in ( + trt.float32, + trt.float16, + ): + raise ValueError( + f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" + ) if num_bits != 8 or exponent_bits not in (0, 4): raise ValueError( f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" ) + else: + raise ValueError( + f"quantize converter received an input of {type(input_tensor)} type. Supported types: torch.Tensor | TRTTensor" + ) + if num_bits == 8 and exponent_bits == 0: max_bound = 127 elif num_bits == 8 and exponent_bits == 4: max_bound = 448 - amax = to_torch(amax, None) - scale = torch.divide(amax, max_bound) - scale = get_trt_tensor(ctx, scale, name + "_scale") - # Add Q node - quantize_layer = ctx.net.add_quantize(input_tensor, scale) + if not isinstance(amax, trt.ITensor): + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, amax, name + "_scale") + else: + scale = impl.elementwise_divide( + ctx, target, source_ir, name + "_scale", amax, max_bound + ) + if num_bits == 8 and exponent_bits == 0: - quantize_layer.set_output_type(0, trt.DataType.INT8) + dtype = trt.DataType.INT8 elif num_bits == 8 and exponent_bits == 4: - quantize_layer.set_output_type(0, trt.DataType.FP8) + dtype = trt.DataType.FP8 + # Add Q node + quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype) set_layer_name(quantize_layer, target, name + "_quantize", source_ir) q_output = quantize_layer.get_output(0) # Add DQ node - dequantize_layer = ctx.net.add_dequantize(q_output, scale) + dequantize_layer = ctx.net.add_dequantize( + q_output, scale, output_type=input_tensor.dtype + ) + dequantize_layer.to_type = input_tensor.dtype set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - if num_bits == 8 and exponent_bits == 0: - dequantize_layer.precision = trt.DataType.INT8 - elif num_bits == 8 and exponent_bits == 4: - # Set DQ layer precision to FP8 - dequantize_layer.precision = trt.DataType.FP8 dq_output = dequantize_layer.get_output(0) return dq_output diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 02ecf98bfe..4aa6559713 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -11,6 +11,8 @@ ) from torch_tensorrt.dynamo.types import TRTTensor +from packaging import version as pkg_version + logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def unsqueeze( ) -> TRTTensor: from importlib.metadata import version - if version("tensorrt") < "10.7.0": + if pkg_version.parse(version("tensorrt")) < pkg_version.parse("10.7.0"): logger.warning( f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}" ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..553151da7a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -9,6 +9,7 @@ from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach @@ -23,6 +24,7 @@ repair_input_as_output, fuse_prims_broadcast, replace_max_pool_with_indices, + lower_scaled_dot_product_attention, remove_assert_nodes, accumulate_fp32_matmul, remove_num_users_is_0_nodes, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py new file mode 100644 index 0000000000..40fd587615 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -0,0 +1,169 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/examples/dynamo/flux_ptq.py b/tools/perf/Flux/flux_quantization_debug.py similarity index 87% rename from examples/dynamo/flux_ptq.py rename to tools/perf/Flux/flux_quantization_debug.py index f19e378b5f..56ed5d92a4 100644 --- a/examples/dynamo/flux_ptq.py +++ b/tools/perf/Flux/flux_quantization_debug.py @@ -1,4 +1,3 @@ - # %% # Import the following libraries # ----------------------------- @@ -56,12 +55,14 @@ backbone = pipe.transformer backbone.eval() + def filter_func(name): pattern = re.compile( r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" ) return pattern.match(name) is not None + def generate_image(pipe, prompt, image_name): seed = 42 image = pipe( @@ -73,6 +74,28 @@ def generate_image(pipe, prompt, image_name): image.save(f"{image_name}.png") print(f"Image generated using {image_name} model saved as {image_name}.png") + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + # %% # Quantization @@ -100,6 +123,7 @@ def forward_loop(mod): prompt="test", ) + backbone = mtq.quantize(backbone, ptq_config, forward_loop) mtq.disable_quantizer(backbone, filter_func) @@ -155,10 +179,11 @@ def forward_loop(mod): ep, inputs=dummy_inputs, enabled_precisions=enabled_precisions, + use_explicit_typing=True, truncate_double=True, min_block_size=1, debug=False, - use_python_runtime=True, + # use_python_runtime=True, immutable_weights=True, offload_module_to_cpu=True, ) @@ -173,3 +198,6 @@ def forward_loop(mod): trt_gm.device = torch.device(DEVICE) # Function which generates images from the flux pipeline generate_image(pipe, ["A golden retriever"], "dog_code2") + + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB From bb7760aff262e50478b9f56c0b055eb71793f4fa Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 9 Jun 2025 09:34:40 -0700 Subject: [PATCH 3/7] test --- tools/perf/Flux/flux_quantization_debug.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/perf/Flux/flux_quantization_debug.py b/tools/perf/Flux/flux_quantization_debug.py index 56ed5d92a4..43c817ebf0 100644 --- a/tools/perf/Flux/flux_quantization_debug.py +++ b/tools/perf/Flux/flux_quantization_debug.py @@ -179,7 +179,7 @@ def forward_loop(mod): ep, inputs=dummy_inputs, enabled_precisions=enabled_precisions, - use_explicit_typing=True, + use_explicit_typing=False, truncate_double=True, min_block_size=1, debug=False, From 376656963c0288b12dc9ba248b4194ddcf685f76 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 9 Jun 2025 09:52:26 -0700 Subject: [PATCH 4/7] test --- tools/perf/Flux/flex_perf.py | 97 ++++++++++ tools/perf/Flux/flux_quantization.py | 268 +++++++++++++++++++++++++++ tools/perf/Flux/register_sdpa.py | 184 ++++++++++++++++++ tools/perf/Flux/sdpa_converter.py | 176 ++++++++++++++++++ 4 files changed, 725 insertions(+) create mode 100644 tools/perf/Flux/flex_perf.py create mode 100644 tools/perf/Flux/flux_quantization.py create mode 100644 tools/perf/Flux/register_sdpa.py create mode 100644 tools/perf/Flux/sdpa_converter.py diff --git a/tools/perf/Flux/flex_perf.py b/tools/perf/Flux/flex_perf.py new file mode 100644 index 0000000000..b7aa608dd3 --- /dev/null +++ b/tools/perf/Flux/flex_perf.py @@ -0,0 +1,97 @@ +from time import time + +import register_sdpa +import torch +import torch_tensorrt +from diffusers import FluxPipeline + +for i in range(torch.cuda.device_count()): + print(torch.cuda.get_device_properties(i).name) + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +) +pipe.to(DEVICE).to(torch.bfloat16) +backbone = pipe.transformer + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) + +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} + +settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + # "enabled_precisions": {torch.float16}, + use_explicit_typing: True, + "truncate_double": True, + "min_block_size": 1, + "debug": False, + # "use_python_runtime": True, + "immutable_weights": False, + "offload_module_to_cpu": True, +} + + +def generate_image(prompt, inference_step, batch_size=1, benchmark=False, iterations=1): + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + if benchmark: + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +pipe.to(torch.bfloat16) +torch.cuda.empty_cache() +# Warmup +generate_image(["Test"], 20) +print("Benchmark Original PyTorch Module Latency (bfloat16)") +for batch_size in range(1, 3): + generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3) + +pipe.to(torch.float16) +print("Benchmark Original PyTorch Module Latency (float16)") +for batch_size in range(1, 3): + generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3) + +trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) +trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) +pipe.transformer = trt_gm + +start = time() +generate_image(["Test"], 2, batch_size=2) +end = time() +print("Time Elapse compilation:", end - start) +print() +print("Benchmark TRT Accelerated Latency") +for batch_size in range(1, 3): + generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3) +torch.cuda.empty_cache() diff --git a/tools/perf/Flux/flux_quantization.py b/tools/perf/Flux/flux_quantization.py new file mode 100644 index 0000000000..9619bd9e92 --- /dev/null +++ b/tools/perf/Flux/flux_quantization.py @@ -0,0 +1,268 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--debug", + action="store_true", + default=False, + help="debug mode", +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8", "fp4", "fp16", "bf16", "fp32"], + default="fp8", + help="Quantization data type to use (fp8 or int8 or fp4 or fp16 or bf16 or fp32)", +) + +parser.add_argument( + "--sdpa", + action="store_true", + default=False, + help="Register SDPA operator", +) + +parser.add_argument( + "--strong-typing", + action="store_true", + help="string type flag", +) + +args = parser.parse_args() +if args.sdpa: + import register_sdpa + +dtype = torch.float16 +ptq_config = None +use_explicit_typing = args.strong_typing +enabled_precisions = [ + torch.float32, +] + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + ( + enabled_precisions.extend([torch.float8_e4m3fn, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.FP8_DEFAULT_CFG +elif args.dtype == "int8": # int8 + ( + enabled_precisions.extend([torch.int8, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.INT8_DEFAULT_CFG +elif args.dtype == "fp4": + ptq_config = mtq.NVFP4_DEFAULT_CFG + use_explicit_typing = True +elif args.dtype == "fp16": + enabled_precisions.append(torch.float16) if not use_explicit_typing else None +elif args.dtype == "bf16": + dtype = torch.bfloat16 + ( + enabled_precisions.extend([torch.bfloat16, torch.float16]) + if not use_explicit_typing + else None + ) +elif args.dtype == "fp32": + dtype = torch.float32 +else: + raise ValueError(f"Invalid dtype: {args.dtype}") +print(f"\nUsing {args.dtype} quantization") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) + +total_params = sum(p.numel() for p in pipe.transformer.parameters()) +print(f"\n Total number of parameters: {total_params/1000/1000/1000}B") +if dtype in (torch.float16, torch.bfloat16): + total_size = total_params * 2 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") +elif dtype == torch.float32: + total_size = total_params * 4 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") + +if args.debug: + pipe.transformer = FluxTransformer2DModel( + num_layers=1, num_single_layers=1, guidance_embeds=True + ) + +pipe.to(DEVICE).to(dtype) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +if ptq_config is not None: + backbone = mtq.quantize(backbone, ptq_config, forward_loop) + mtq.disable_quantizer(backbone, filter_func) +else: + print("No quantization config provided, skipping quantization") + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=dtype).to(DEVICE), + "encoder_hidden_states": torch.randn((batch_size, 512, 4096), dtype=dtype).to( + DEVICE + ), + "pooled_projections": torch.randn((batch_size, 768), dtype=dtype).to(DEVICE), + "timestep": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=dtype).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=dtype).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + truncate_double=True, + min_block_size=1, + debug=args.debug, + immutable_weights=True, + offload_module_to_cpu=True, + ) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["A golden retriever"], "dog_code2") + +if not args.debug: + print(f"Benchmark TRT Module Latency at ({args.dtype}) started") + for batch_size in range(1, 9): + benchmark(["Test"], 20, batch_size=batch_size, iterations=3) + print(f"Benchmark TRT Module Latency at ({args.dtype}) ended") + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/tools/perf/Flux/register_sdpa.py b/tools/perf/Flux/register_sdpa.py new file mode 100644 index 0000000000..9afbbda7d4 --- /dev/null +++ b/tools/perf/Flux/register_sdpa.py @@ -0,0 +1,184 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from sdpa_converter import * +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + +# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention +# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_efficient_attention.default +) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) + +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +@_aten_lowering_pass +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/tools/perf/Flux/sdpa_converter.py b/tools/perf/Flux/sdpa_converter.py new file mode 100644 index 0000000000..903324dff5 --- /dev/null +++ b/tools/perf/Flux/sdpa_converter.py @@ -0,0 +1,176 @@ +import logging +import math +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + +logger = logging.getLogger(__name__) + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + row: TRTTensor, + col: TRTTensor, +) -> TRTTensor: + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + row_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] + ) + + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + col_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + ) + + mask = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + ) + return mask + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + enabled=True, + supports_dynamic_shapes=True, +) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + # TODO: Handle attn_mask and is_causal arguments in the future + query, key, value, attn_mask, dropout_p, is_causal = args + logger.info( + "Ignoring attn_mask and is_causal arguments provided by the original graph. " + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " + "and for generate phase, is_causal=False since we pass only 1 input token at a time" + ) + + # TODO: remove this once we have a better way to handle the causal mask + scale = kwargs.get("scale", None) + source_ir = SourceIR.ATEN + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + # If is_causal is True, we need to generate a causal mask + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ) + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + else: + scaled_add_attn_bias = scaled + + # Create a if condition to check if is_causal is True + if isinstance(is_causal, TRTTensor): + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + scaled_add_attn_bias = output_layer.get_output(0) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out From f757ac426496c1c389c623c6e667f3b8eb79aacf Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 9 Jun 2025 16:53:32 -0700 Subject: [PATCH 5/7] test --- tools/perf/Flux/flux_quantization.py | 34 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tools/perf/Flux/flux_quantization.py b/tools/perf/Flux/flux_quantization.py index 9619bd9e92..fde2ca0de4 100644 --- a/tools/perf/Flux/flux_quantization.py +++ b/tools/perf/Flux/flux_quantization.py @@ -4,6 +4,7 @@ # Load the ModelOpt-modified model architecture and weights using Huggingface APIs # Add argument parsing for dtype selection import argparse +import gc import re import modelopt.torch.opt as mto @@ -88,12 +89,12 @@ dtype = torch.float32 else: raise ValueError(f"Invalid dtype: {args.dtype}") -print(f"\nUsing {args.dtype} quantization") +print(f"\nUsing {args.dtype} quantization with {args=}") # %% DEVICE = "cuda:0" pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", - torch_dtype=torch.float16, + torch_dtype=dtype, ) total_params = sum(p.numel() for p in pipe.transformer.parameters()) @@ -219,11 +220,15 @@ def forward_loop(mod): "timestep": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), "txt_ids": torch.randn((512, 3), dtype=dtype).to(DEVICE), "img_ids": torch.randn((4096, 3), dtype=dtype).to(DEVICE), - "guidance": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), "joint_attention_kwargs": {}, "return_dict": False, } + +torch.cuda.empty_cache() +torch.cuda.reset_peak_memory_stats() +gc.collect() # This will create an exported program which is going to be compiled with Torch-TensorRT with export_torch_mode(): ep = _export( @@ -235,6 +240,14 @@ def forward_loop(mod): allow_complex_guards_as_runtime_asserts=True, ) +peak_memory = torch.cuda.max_memory_allocated() / (1024**3) +peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) +print(f"Peak memory allocated during torch-export: {peak_memory=}GB {peak_reserved=}GB") + +torch.cuda.empty_cache() +torch.cuda.reset_peak_memory_stats() +gc.collect() + with torch_tensorrt.logging.debug(): trt_gm = torch_tensorrt.dynamo.compile( ep, @@ -248,20 +261,33 @@ def forward_loop(mod): offload_module_to_cpu=True, ) +peak_memory = torch.cuda.max_memory_allocated() / (1024**3) +peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) +print( + f"Peak memory allocated during torch dynamo compilation: {peak_memory=}GB {peak_reserved=}GB" +) del ep pipe.transformer = trt_gm pipe.transformer.config = config +torch.cuda.empty_cache() +torch.cuda.reset_peak_memory_stats() +gc.collect() # %% + trt_gm.device = torch.device(DEVICE) # Function which generates images from the flux pipeline generate_image(pipe, ["A golden retriever"], "dog_code2") +peak_memory = torch.cuda.max_memory_allocated() / (1024**3) +peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) +print(f"Peak memory allocated during inference: {peak_memory=}GB {peak_reserved=}GB") + if not args.debug: print(f"Benchmark TRT Module Latency at ({args.dtype}) started") - for batch_size in range(1, 9): + for batch_size in range(1, 3): benchmark(["Test"], 20, batch_size=batch_size, iterations=3) print(f"Benchmark TRT Module Latency at ({args.dtype}) ended") From 43e9abe52421481f325c7ab43f225fce1a9189f6 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 12 Jun 2025 07:44:17 -0700 Subject: [PATCH 6/7] add save load --- tools/perf/Flux/flux_quantization.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tools/perf/Flux/flux_quantization.py b/tools/perf/Flux/flux_quantization.py index fde2ca0de4..212f7823a2 100644 --- a/tools/perf/Flux/flux_quantization.py +++ b/tools/perf/Flux/flux_quantization.py @@ -14,6 +14,7 @@ from diffusers import FluxPipeline from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG from modelopt.torch.quantization.utils import export_torch_mode from torch.export._trace import _export from transformers import AutoModelForCausalLM @@ -27,6 +28,12 @@ default=False, help="debug mode", ) +parser.add_argument( + "--mha", + action="store_true", + default=False, + help="NVFP4_FP8_MHA_CONFIG mode", +) parser.add_argument( "--dtype", choices=["fp8", "int8", "fp4", "fp16", "bf16", "fp32"], @@ -74,7 +81,10 @@ ) ptq_config = mtq.INT8_DEFAULT_CFG elif args.dtype == "fp4": - ptq_config = mtq.NVFP4_DEFAULT_CFG + if args.mha: + ptq_config = NVFP4_FP8_MHA_CONFIG + else: + ptq_config = mtq.NVFP4_DEFAULT_CFG # mtq.NVFP4_DEFAULT_CFG use_explicit_typing = True elif args.dtype == "fp16": enabled_precisions.append(torch.float16) if not use_explicit_typing else None @@ -106,10 +116,10 @@ total_size = total_params * 4 / 1024 / 1024 / 1024 print(f"\n Total size: {total_size}GB") -if args.debug: - pipe.transformer = FluxTransformer2DModel( - num_layers=1, num_single_layers=1, guidance_embeds=True - ) +# if args.debug: +# pipe.transformer = FluxTransformer2DModel( +# num_layers=1, num_single_layers=1, guidance_embeds=True +# ) pipe.to(DEVICE).to(dtype) # Store the config and transformer backbone @@ -141,6 +151,7 @@ def generate_image(pipe, prompt, image_name): def benchmark(prompt, inference_step, batch_size=1, iterations=1): from time import time + print(f"Benchmark TRT Module Latency started with {batch_size=} {iterations=}") start = time() for i in range(iterations): image = pipe( @@ -287,7 +298,7 @@ def forward_loop(mod): if not args.debug: print(f"Benchmark TRT Module Latency at ({args.dtype}) started") - for batch_size in range(1, 3): + for batch_size in range(1, 9): benchmark(["Test"], 20, batch_size=batch_size, iterations=3) print(f"Benchmark TRT Module Latency at ({args.dtype}) ended") From d54149f15e6d41f7cff8a2134416e2588005cd77 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 12 Jun 2025 07:46:02 -0700 Subject: [PATCH 7/7] test --- tools/perf/Flux/benchmark.sh | 41 +++ tools/perf/Flux/flux_quantization_load.py | 87 +++++++ tools/perf/Flux/flux_quantization_save.py | 295 ++++++++++++++++++++++ 3 files changed, 423 insertions(+) create mode 100644 tools/perf/Flux/benchmark.sh create mode 100644 tools/perf/Flux/flux_quantization_load.py create mode 100644 tools/perf/Flux/flux_quantization_save.py diff --git a/tools/perf/Flux/benchmark.sh b/tools/perf/Flux/benchmark.sh new file mode 100644 index 0000000000..0a1c67bd2f --- /dev/null +++ b/tools/perf/Flux/benchmark.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used --format=csv,nounits -lms 1000 >> fp4_nvidia-smi.log & +NVIDIA_SMI_PID=$! +python tools/perf/Flux/flux_quantization.py --dtype fp4 2>&1 | tee fp4.log +python tools/perf/Flux/flux_quantization.py --dtype fp4 --sdpa 2>&1 | tee fp4_sdpa.log +#python tools/perf/Flux/flux_quantization.py --dtype fp4 --sdpa --mha 2>&1 | tee fp4_sdpa_mha.log +kill $NVIDIA_SMI_PID +#scp fp4*.log lanl@d5c2237-lcedt.dyn.nvidia.com:/home/lanl/git/script/flux/gb100/0611/ +sleep 10 + + +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used --format=csv,nounits -lms 1000 >> fp8_nvidia-smi.log & +NVIDIA_SMI_PID=$! +python tools/perf/Flux/flux_quantization.py --dtype fp8 2>&1 | tee fp8.log +python tools/perf/Flux/flux_quantization.py --dtype fp8 --sdpa 2>&1 | tee fp8_sdpa.log +#python tools/perf/Flux/flux_quantization.py --dtype fp8 --sdpa --mha 2>&1 | tee fp8_sdpa_mha.log +kill $NVIDIA_SMI_PID +# scp fp8*.log lanl@d5c2237-lcedt.dyn.nvidia.com:/home/lanl/git/script/flux/gb100/0611/ +sleep 10 + + +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used --format=csv,nounits -lms 1000 >> fp16_nvidia-smi.log & +NVIDIA_SMI_PID=$! +python tools/perf/Flux/flux_quantization.py --dtype fp16 2>&1 | tee fp16.log +python tools/perf/Flux/flux_quantization.py --dtype fp16 --sdpa 2>&1 | tee fp16_sdpa.log +kill $NVIDIA_SMI_PID +# scp fp16*.log lanl@d5c2237-lcedt.dyn.nvidia.com:/home/lanl/git/script/flux/gb100/0611/ +sleep 10 + +# nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used --format=csv,nounits -lms 1000 >> bf16_nvidia-smi.log & +# NVIDIA_SMI_PID=$! +# python tools/perf/Flux/flux_quantization.py --dtype bf16 2>&1 | tee bf16.log +# kill $NVIDIA_SMI_PID +# sleep 10 + +# nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used --format=csv,nounits -lms 1000 >> fp32_nvidia-smi.log & +# NVIDIA_SMI_PID=$! +# python tools/perf/Flux/flux_quantization.py --dtype fp32 2>&1 | tee fp32.log +# kill $NVIDIA_SMI_PID +# sleep 10 \ No newline at end of file diff --git a/tools/perf/Flux/flux_quantization_load.py b/tools/perf/Flux/flux_quantization_load.py new file mode 100644 index 0000000000..583f86a80b --- /dev/null +++ b/tools/perf/Flux/flux_quantization_load.py @@ -0,0 +1,87 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import gc +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--path", + type="string", + required=True, + help="ep path", +) + +args = parser.parse_args() + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + print(f"Benchmark TRT Module Latency started with {batch_size=} {iterations=}") + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +loaded_trt_module = torch.export.load(path) + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) + +config = pipe.transformer.config +pipe.transformer = loaded_trt_module +pipe.transformer.config = config + +generate_image(pipe, ["beach and kids"], "beach_kids") + + +print(f"Benchmark TRT Module Latency at ({args.dtype=}) started") +for batch_size in range(1, 9): + benchmark(["Test"], 20, batch_size=batch_size, iterations=3) +print(f"Benchmark TRT Module Latency at ({args.dtype=}) ended") diff --git a/tools/perf/Flux/flux_quantization_save.py b/tools/perf/Flux/flux_quantization_save.py new file mode 100644 index 0000000000..b6774ca13c --- /dev/null +++ b/tools/perf/Flux/flux_quantization_save.py @@ -0,0 +1,295 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import gc +import os +import re +import tempfile + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--debug", + action="store_true", + default=False, + help="debug mode", +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8", "fp4", "fp16", "bf16", "fp32"], + default="fp8", + help="Quantization data type to use (fp8 or int8 or fp4 or fp16 or bf16 or fp32)", +) + +parser.add_argument( + "--mha", + action="store_true", + default=False, + help="NVFP4_FP8_MHA_CONFIG mode", +) + +parser.add_argument( + "--model", + choices=[ + "all", + "transformer", + "tokenizer", + "tokenizer_2", + "vae", + "text_encoder", + "text_encoder_2", + ], + default="transformer", + help="Model to use (all or transformer or tokenizer or tokenizer_2 or vae or text_encoder or text_encoder_2)", +) + +parser.add_argument( + "--sdpa", + action="store_true", + default=False, + help="Register SDPA operator", +) + +parser.add_argument( + "--strong-typing", + action="store_true", + help="string type flag", +) + +args = parser.parse_args() +if args.sdpa: + import register_sdpa + +dtype = torch.float16 +ptq_config = None +use_explicit_typing = args.strong_typing +enabled_precisions = [ + torch.float32, +] + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + ( + enabled_precisions.extend([torch.float8_e4m3fn, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.FP8_DEFAULT_CFG +elif args.dtype == "int8": # int8 + ( + enabled_precisions.extend([torch.int8, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.INT8_DEFAULT_CFG +elif args.dtype == "fp4": + if args.mha: + ptq_config = NVFP4_FP8_MHA_CONFIG + else: + ptq_config = mtq.NVFP4_DEFAULT_CFG + use_explicit_typing = True +elif args.dtype == "fp16": + enabled_precisions.append(torch.float16) if not use_explicit_typing else None +elif args.dtype == "bf16": + dtype = torch.bfloat16 + ( + enabled_precisions.extend([torch.bfloat16, torch.float16]) + if not use_explicit_typing + else None + ) +elif args.dtype == "fp32": + dtype = torch.float32 +else: + raise ValueError(f"Invalid dtype: {args.dtype}") +print(f"\nUsing {args.dtype} quantization with {args=}") +# %% +DEVICE = "cuda:0" + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=dtype, + return_dict=False, +) + +total_params = sum(p.numel() for p in pipe.transformer.parameters()) +print(f"\n Total number of parameters: {total_params/1000/1000/1000}B") +if dtype in (torch.float16, torch.bfloat16): + total_size = total_params * 2 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") +elif dtype == torch.float32: + total_size = total_params * 4 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") + + +if args.model == "all": + model = pipe.transformer +elif args.model == "transformer": + model = pipe.transformer +elif args.model == "tokenizer": + model = pipe.tokenizer +elif args.model == "tokenizer_2": + model = pipe.tokenizer_2 +elif args.model == "vae": + model = pipe.vae +elif args.model == "text_encoder": + model = pipe.text_encoder +elif args.model == "text_encoder_2": + model = pipe.text_encoder_2 + +pipe.to(DEVICE).to(dtype) +# Store the config and transformer backbone +config = model.config +# global backbone +backbone = model +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=dtype).to(DEVICE), + "encoder_hidden_states": torch.randn((batch_size, 512, 4096), dtype=dtype).to( + DEVICE + ), + "pooled_projections": torch.randn((batch_size, 768), dtype=dtype).to(DEVICE), + "timestep": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=dtype).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=dtype).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +if ptq_config is not None: + backbone = mtq.quantize(backbone, ptq_config, forward_loop) + mtq.disable_quantizer(backbone, filter_func) +else: + print("No quantization config provided, skipping quantization") + + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + truncate_double=True, + min_block_size=1, + debug=args.debug, + immutable_weights=True, + offload_module_to_cpu=True, + ) + trt_ep_path = os.path.join( + tempfile.gettempdir(), + f"{args.model}_{args.dtype}_{args.debug}_{args.sdpa}_{args.mha}_trt.ep", + ) + torch_tensorrt.save(trt_gm, trt_ep_path, inputs=dummy_inputs) + print(f"TRT module saved to {trt_ep_path}")