From 01c900e039b2cbd35842cab99783c29498c8657a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 25 May 2025 10:50:01 -0700 Subject: [PATCH 1/7] Add fp4 support --- .pre-commit-config.yaml | 12 +- MODULE.bazel | 4 +- examples/dynamo/vgg16_ptq.py | 12 +- py/torch_tensorrt/_enums.py | 19 + py/torch_tensorrt/dynamo/_compiler.py | 14 +- py/torch_tensorrt/dynamo/_defaults.py | 9 +- .../dynamo/conversion/_TRTInterpreter.py | 4 +- .../dynamo/conversion/aten_ops_converters.py | 33 ++ .../dynamo/conversion/converter_utils.py | 28 +- .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/addmm.py | 6 +- .../dynamo/conversion/impl/nvfp4_quantize.py | 368 ++++++++++++++++++ .../dynamo/conversion/impl/permutation.py | 6 +- .../lowering/passes/constant_folding.py | 5 + pyproject.toml | 10 +- tests/py/dynamo/models/test_models_export.py | 87 ++++- 16 files changed, 587 insertions(+), 31 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31305568d..a7b91eec34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,14 +21,14 @@ repos: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.4.0 + rev: 8.0.3 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.23 + rev: v0.24.1 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort @@ -37,17 +37,17 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.9.0" + rev: "v1.15.0" hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.11.7 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 25.1.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs @@ -57,7 +57,7 @@ repos: - id: typos - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.5.5 + rev: 0.7.1 hooks: # Update the uv lockfile - id: uv-lock diff --git a/MODULE.bazel b/MODULE.bazel index 008c7f53fc..66a879afcf 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -94,9 +94,9 @@ http_archive( http_archive( name = "tensorrt", build_file = "@//third_party/tensorrt/archive:BUILD", - strip_prefix = "TensorRT-10.9.0.34", + strip_prefix = "TensorRT-10.10.0.31", urls = [ - "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.9.0/tars/TensorRT-10.9.0.34.Linux.x86_64-gnu.cuda-12.8.tar.gz", + "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.10.0/tars/TensorRT-10.10.0.31.Linux.x86_64-gnu.cuda-12.9.tar.gz", ], ) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 7fa943040e..7af490afb4 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -169,7 +169,6 @@ def vgg16(num_classes=1000, init_weights=False): data = iter(training_dataloader) images, _ = next(data) - crit = nn.CrossEntropyLoss() # %% @@ -200,8 +199,11 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG +elif args.quantize_type == "fp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point # %% @@ -233,12 +235,20 @@ def calibrate_loop(model): with export_torch_mode(): # Compile the model with Torch-TensorRT Dynamo backend input_tensor = images.cuda() + torch.onnx.export(model, input_tensor, "mtq_vgg16_model.onnx") exp_program = torch.export.export(model, (input_tensor,), strict=False) if args.quantize_type == "int8": enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} + elif args.quantize_type == "fp4": + enabled_precisions = { + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float16, + torch.float32, + } trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index c706c345d6..e0a78e1a0b 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -80,6 +80,12 @@ class dtype(Enum): :meta hide-value: """ + f4 = auto() + """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4`` + + :meta hide-value: + """ + uint8 = u8 int8 = i8 @@ -91,6 +97,9 @@ class dtype(Enum): float8 = f8 fp8 = f8 + float4 = f4 + fp4 = f4 + half = f16 fp16 = f16 float16 = f16 @@ -162,6 +171,8 @@ def _from( return dtype.i32 elif t == torch.float8_e4m3fn: return dtype.f8 + elif t == torch.float4_e2m1fn_x2: + return dtype.f4 elif t == torch.half: return dtype.f16 elif t == torch.float: @@ -188,6 +199,8 @@ def _from( return dtype.i8 elif t == trt.DataType.FP8: return dtype.f8 + elif t == trt.DataType.FP4: + return dtype.fp4 elif t == trt.DataType.INT32: return dtype.i32 elif t == trt.DataType.INT64: @@ -357,6 +370,8 @@ def to( return torch.long elif self == dtype.f8: return torch.float8_e4m3fn + elif self == dtype.f4: + return torch.float4_e2m1fn_x2 elif self == dtype.f16: return torch.half elif self == dtype.f32: @@ -394,6 +409,8 @@ def to( return trt.DataType.BOOL elif self == dtype.bf16: return trt.DataType.BF16 + elif self == dtype.f4: + return trt.DataType.FP4 elif use_default: return trt.DataType.FLOAT else: @@ -410,6 +427,8 @@ def to( return np.int64 elif self == dtype.f16: return np.float16 + elif self == dtype.f4: + return np.float4_e2m1fn_x2 elif self == dtype.f32: return np.float32 elif self == dtype.f64: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 831ce37305..4db11daa78 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -581,13 +581,13 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - if use_explicit_typing: - if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} - ): - raise AssertionError( - f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - ) + # if use_explicit_typing: + # if len(enabled_precisions) != 1 or not any( + # x in enabled_precisions for x in {torch.float32, dtype.f32} + # ): + # raise AssertionError( + # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index aafd1072f4..921cb37646 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -29,7 +29,14 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} +SUPPORTED_KERNEL_PRECISIONS = { + dtype.f32, + dtype.f16, + dtype.bf16, + dtype.i8, + dtype.f8, + dtype.f4, +} TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 39a1ed957d..ecf08f38c4 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -274,13 +274,13 @@ def _populate_trt_builder_config( self.compilation_settings.dla_global_dram_size, ) - if dtype.float16 in self.compilation_settings.enabled_precisions: + if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP16) if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) - if dtype.fp8 in self.compilation_settings.enabled_precisions: + if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP8) if dtype.bfloat16 in self.compilation_settings.enabled_precisions: diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..08a1bb4ea4 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -617,6 +617,39 @@ def aten_ops_quantize_op( ) +try: + import modelopt.torch.quantization as mtq # noqa: F401 + + assert torch.ops.tensorrt.dynamic_block_quantize_op.default +except Exception as e: + _LOGGER.warning( + "Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models" + ) +else: + + @dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default) + def aten_ops_dynamic_block_quantize_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.nvfp4_quantize.nvfp4_quantize( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 685f40b254..eb18a14eca 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -361,12 +361,37 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: + if torch_value.dtype == torch.float8_e4m3fn: + weights = trt.Weights( + type=trt.DataType.FP8, + ptr=torch_value.data_ptr(), + count=torch_value.numel(), + ) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 + if torch_value.dtype == torch.uint8: + weights = trt.Weights( + type=trt.DataType.FP4, + ptr=torch_value.data_ptr(), + count=torch_value.numel() * 2, + ) + shape[-1] = shape[-1] * 2 + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() else: numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) constant = ctx.net.add_constant( shape, @@ -381,7 +406,6 @@ def create_constant( trt.DataType.BF16, name + "_bf16_cast", ) - return constant.get_output(0) else: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..1f2d9d0de1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -14,6 +14,7 @@ matmul, nccl_ops, normalization, + nvfp4_quantize, pad, permutation, pool, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py index 1a0690852a..73d98acfdf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.types import TRTTensor - +import os def addmm( ctx: ConversionContext, @@ -21,6 +21,10 @@ def addmm( beta: Union[float, int], alpha: Union[float, int], ) -> TRTTensor: + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, skip addmm and returning mat2") + return mat2 + print("lan added disable_gemm is not set, doing addmm") mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) if alpha != 1: mm = impl.elementwise.mul( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py new file mode 100644 index 0000000000..1c2f297764 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -0,0 +1,368 @@ +import os +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt.dynamo.conversion.impl as impl +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_trt_tensor, + to_torch, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def nvfp4_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + block_size: int, + amax: Union[np.ndarray, torch.Tensor], + num_bits: int, + exponent_bits: int, + scale_num_bits: int, + scale_exponent_bits: int, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to FP4 based + on the output_type set and dequantizes them back. + """ + print( + f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" + ) + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + ) + with unset_fake_temporarily(): + axis = -1 + global_scale = _calculate_global_scale(ctx, name, amax) + print(f"lan added input_tensor: {input_tensor.shape=} {input_tensor.dtype=}") + print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=}") + if ".weight_quantizer" in name: + output = _static_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + elif ".input_quantizer" in name: + # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + output = _dynamic_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + + else: + raise ValueError( + f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" + ) + return output + + +def _dynamic_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int = -1, + block_size: int = 16, + output_type: trt.DataType = trt.DataType.FP4, + scale_type: trt.DataType = trt.DataType.FP8, +) -> TRTTensor: + """ + quantize input tensor to fp4 + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : Tensor (On GPU) + The input tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis : int + The axis to quantize. Default is -1 (the last axis). + block_size : int + The block size for quantization. Default is 16. + output_type : trt.DataType + The data type for quantized data. Default is FP4. + scale_type : trt.DataType + The data type for block scale. Default is FP8. + + """ + if os.getenv("DISABLE_DYNAMIC_QUANTIZE", "false").lower() == "true": + print("lan added disable_dynamic_quantize is set, skipping dynamic quantize") + return input_tensor + print("lan added disable_dynamic_quantize is not set, doing dynamic quantize") + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + + if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: + raise ValueError( + f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}" + ) + # dynamic quantize input tensor to fp4 + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis, + block_size, + output_type, + scale_type, + ) + dynamic_quantize_layer.set_input(1, global_scale) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) + quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + + return _double_dequantize( + ctx, + target, + source_ir, + name, + quantized_data_in_fp4, + quantized_scale_in_fp8, + global_scale, + axis, + input_tensor.dtype, + ) + + +def _double_dequantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + quantized_data_in_fp4: TRTTensor, + quantized_scale_in_fp8: TRTTensor, + global_scale: torch.Tensor, + axis: int = -1, + output_type: trt.DataType = trt.DataType.FLOAT, +) -> TRTTensor: + # dequantize scale from fp8 to orignal dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize( + quantized_scale_in_fp8, global_scale, output_type + ) + dequantize_scale_layer.axis = axis + dequantize_scale_layer.to_type = output_type + set_layer_name( + dequantize_scale_layer, target, name + "_dequantize_scale", source_ir + ) + dequantized_scale = dequantize_scale_layer.get_output(0) + + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + quantized_data_in_fp4, dequantized_scale, output_type + ) + dequantize_data_layer.axis = axis + dequantize_data_layer.to_type = output_type + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +def _static_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true": + print("lan added disable_static_quantize is set, skipping static quantize") + return get_trt_tensor(ctx, weights_tensor, name + "_weights") + print( + "lan added static disable_static_quantize is not set, doing static double quantize " + ) + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + if weights_tensor.dtype == torch.float16: + original_dtype = trt.DataType.HALF + elif weights_tensor.dtype == torch.float32: + original_dtype = trt.DataType.FLOAT + else: + raise ValueError( + f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" + ) + block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, + 16, + global_scale, + )[0] + weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale_fp8, + global_scale, + )[0]._quantized_data + + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + global_scale = to_torch(global_scale, None) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") + print( + f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=}" + ) + dequantized_data = _double_dequantize( + ctx, + target, + source_ir, + name, + weights_tensor_fp4, + block_scale_fp8, + global_scale, + axis, + original_dtype, + ) + return dequantized_data + + +def _static_double_quantize_transpose( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + axis = -2 + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + if weights_tensor.dtype == torch.float16: + original_dtype = trt.DataType.HALF + elif weights_tensor.dtype == torch.float32: + original_dtype = trt.DataType.FLOAT + else: + raise ValueError( + f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" + ) + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, + 16, + global_scale, + keep_high_precision=True, + )[0] + weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale, + global_scale, + keep_high_precision=True, + ) + + block_scale = block_scale.transpose(0, 1) + weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1) + + block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled) + weights_tensor_uint8 = ( + weights_tensor_uint4[..., 1::2] << 4 + ) | weights_tensor_uint4[..., 0::2] + + print( + f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}" + ) + print( + f"lan added weights_tensor_uint4: {weights_tensor_uint4.shape=} {weights_tensor_uint4.dtype=} {weights_tensor_uint4=}" + ) + print( + f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=} {weights_tensor_uint8=}" + ) + + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + global_scale = to_torch(global_scale, None) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + weights_tensor_fp4 = get_trt_tensor( + ctx, weights_tensor_uint8, name + "_weights_fp4" + ) + print( + f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=}" + ) + dequantized_data = _double_dequantize( + ctx, + target, + source_ir, + name, + weights_tensor_fp4, + block_scale_fp8, + global_scale, + axis, + original_dtype, + ) + dequantized_data = impl.permutation.permute( + ctx, + target, + source_ir, + name + "_dequantized_data_transposed", + dequantized_data, + (-1, -2), + ) + return dequantized_data + + +def _calculate_global_scale( + ctx: ConversionContext, + name: str, + amax: torch.Tensor, +) -> torch.Tensor: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + if amax is None or amax == 0: + amax = 1.0 + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6 * 448) + if global_scale == 0: + global_scale = 1.0 + return global_scale diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 1537d0fdbe..4408b62809 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -13,7 +13,7 @@ ) from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.types import TRTTensor - +import os def permute( ctx: ConversionContext, @@ -23,6 +23,10 @@ def permute( input: TRTTensor, permutation: Sequence[int], ) -> TRTTensor: + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, skip permute") + return input + print("lan added disable_gemm is not set, doing permute") if not isinstance(input, TRTTensor): raise RuntimeError( f"permute received input {input} that is not a TensorRT ITensor" diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..190b6752b4 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -101,4 +101,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: + if node.target in ( + torch.ops.tensorrt.quantize_op.default, + torch.ops.tensorrt.dynamic_block_quantize_op.default, + ): + return True return False diff --git a/pyproject.toml b/pyproject.toml index 3bb857e3e0..e2878bc7bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt-cu12>=10.9.0,<10.10.0", + "tensorrt-cu12>=10.10.0,<10.11.0", "torch>=2.8.0.dev,<2.9.0", "pybind11==2.6.2", "numpy", @@ -56,10 +56,10 @@ keywords = [ ] dependencies = [ "torch>=2.8.0.dev,<2.9.0", - "tensorrt>=10.9.0,<10.10.0", - "tensorrt-cu12>=10.9.0,<10.10.0", - "tensorrt-cu12-bindings>=10.9.0,<10.10.0", - "tensorrt-cu12-libs>=10.9.0,<10.10.0", + "tensorrt>=10.10.0,<10.11.0", + "tensorrt-cu12>=10.10.0,<10.11.0", + "tensorrt-cu12-bindings>=10.10.0,<10.11.0", + "tensorrt-cu12-libs>=10.10.0,<10.11.0", "packaging>=23", "numpy", "typing-extensions>=4.7.0", diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 0c28b23bba..175d3d79d7 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -3,7 +3,6 @@ import platform import unittest from importlib import metadata - import pytest import timm import torch @@ -14,7 +13,7 @@ from packaging.version import Version assertions = unittest.TestCase() - +import os @pytest.mark.unit def test_resnet18(ir): @@ -199,6 +198,88 @@ def test_resnet18_half(ir): torch._dynamo.reset() +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (10, 0), +# "FP4 quantization requires compute capability 10.0 or later", +# ) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + dtype = torch.float16 + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear( + in_features=64, out_features=32, bias=True, dtype=dtype + ) + + def forward(self, x): + x = self.linear1(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.ones(128, 64, dtype=dtype).cuda() + + + model = SimpleNetwork().eval().cuda() + model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) + model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) + print(f"lan added amax: {input_tensor.abs().amax()=}") + print(f"lan added amax: {model.linear1.weight.abs().amax()=}") + expected_output = model(input_tensor) + print(f"lan added model input: {input_tensor=}") + print(f"lan added model weight: {model.linear1.weight=}") + print(f"lan added model bias: {model.linear1.bias=}") + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has qdq nodes at this point + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + from torch.fx import passes + + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={ + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float32, + torch.float16, + }, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + use_explicit_typing=dtype == torch.float16, + ) + + outputs_trt = trt_model(input_tensor) + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, compring result with weights") + expected_output = model.linear1.weight + else: + print("lan added disable_gemm is not set, compring result with pytorch") + + print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}") + print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}") + + abs_diff = torch.abs(expected_output - outputs_trt) + print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + print(f"lan added abs_diff: {abs_diff=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) + + @unittest.skipIf( torch.cuda.get_device_capability() < (8, 9), "FP8 quantization requires compute capability 8.9 or later", @@ -230,8 +311,8 @@ def calibrate_loop(model): input_tensor = torch.randn(1, 10).cuda() model = SimpleNetwork().eval().cuda() - quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point output_pyt = model(input_tensor) From f989864d01321e20c9f7e536ad324aaffadc009b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 28 May 2025 08:49:37 -0700 Subject: [PATCH 2/7] add dynamic_shape support --- .../dynamo/conversion/aten_ops_converters.py | 5 +- tests/py/dynamo/models/test_models_export.py | 110 +++++++++++++++++- 2 files changed, 108 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 08a1bb4ea4..d8ef865fd3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -627,7 +627,10 @@ def aten_ops_quantize_op( ) else: - @dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default) + @dynamo_tensorrt_converter( + torch.ops.tensorrt.dynamic_block_quantize_op.default, + supports_dynamic_shapes=True, + ) def aten_ops_dynamic_block_quantize_op( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 175d3d79d7..1601d20ed5 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -3,6 +3,7 @@ import platform import unittest from importlib import metadata + import pytest import timm import torch @@ -15,6 +16,7 @@ assertions = unittest.TestCase() import os + @pytest.mark.unit def test_resnet18(ir): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -198,6 +200,96 @@ def test_resnet18_half(ir): torch._dynamo.reset() +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (10, 0), +# "FP4 quantization requires compute capability 10.0 or later", +# ) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4_dynamic_shapes(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + dtype = torch.float16 + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear( + in_features=64, out_features=32, bias=True, dtype=dtype + ) + + def forward(self, x): + x = self.linear1(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(dummy_inputs) + + BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=16, max=128) + batch_size = 64 + dummy_inputs = torch.ones(batch_size, 64, dtype=dtype).cuda() + + model = SimpleNetwork().eval().cuda() + # model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) + # model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) + + print(f"lan added model weight: {model.linear1.weight=}") + print(f"lan added model bias: {model.linear1.bias=}") + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has qdq nodes at this point + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export( + model, (dummy_inputs,), strict=False, dynamic_shapes=({0: BATCH_SIZE},) + ) + + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[dummy_inputs], + enabled_precisions={ + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float32, + torch.float16, + }, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + use_explicit_typing=dtype == torch.float16, + ) + batch_size = 128 + input_tensor = torch.ones(batch_size, 64, dtype=dtype).cuda() + expected_output = model(input_tensor) + outputs_trt = trt_model(input_tensor) + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, compring result with weights") + expected_output = model.linear1.weight + else: + print("lan added disable_gemm is not set, compring result with pytorch") + + print( + f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}" + ) + print( + f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}" + ) + + abs_diff = torch.abs(expected_output - outputs_trt) + print( + f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}" + ) + print(f"lan added abs_diff: {abs_diff=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) + + # @unittest.skipIf( # torch.cuda.get_device_capability() < (10, 0), # "FP4 quantization requires compute capability 10.0 or later", @@ -210,6 +302,7 @@ def test_resnet18_half(ir): def test_base_fp4(ir): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode + dtype = torch.float16 class SimpleNetwork(torch.nn.Module): @@ -229,17 +322,16 @@ def calibrate_loop(model): input_tensor = torch.ones(128, 64, dtype=dtype).cuda() - model = SimpleNetwork().eval().cuda() model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) print(f"lan added amax: {input_tensor.abs().amax()=}") print(f"lan added amax: {model.linear1.weight.abs().amax()=}") expected_output = model(input_tensor) - print(f"lan added model input: {input_tensor=}") + print(f"lan added model input: {input_tensor=}") print(f"lan added model weight: {model.linear1.weight=}") print(f"lan added model bias: {model.linear1.bias=}") - + quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has qdq nodes at this point @@ -271,11 +363,17 @@ def calibrate_loop(model): else: print("lan added disable_gemm is not set, compring result with pytorch") - print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}") - print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}") + print( + f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}" + ) + print( + f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}" + ) abs_diff = torch.abs(expected_output - outputs_trt) - print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + print( + f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}" + ) print(f"lan added abs_diff: {abs_diff=}") assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) From b99a9dbe419f242f744486419da40aa3edf00a7a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 5 Jun 2025 10:50:27 -0700 Subject: [PATCH 3/7] clean up the PR --- examples/dynamo/vgg16_ptq.py | 12 +- py/torch_tensorrt/dynamo/_compiler.py | 22 +-- .../dynamo/conversion/_TRTInterpreter.py | 24 ++- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/converter_utils.py | 35 ++-- .../dynamo/conversion/impl/addmm.py | 6 +- .../dynamo/conversion/impl/nvfp4_quantize.py | 152 ++---------------- .../dynamo/conversion/impl/permutation.py | 6 +- tests/py/dynamo/models/test_models_export.py | 70 ++------ 9 files changed, 75 insertions(+), 254 deletions(-) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 7af490afb4..7fa943040e 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -169,6 +169,7 @@ def vgg16(num_classes=1000, init_weights=False): data = iter(training_dataloader) images, _ = next(data) + crit = nn.CrossEntropyLoss() # %% @@ -199,11 +200,8 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG -elif args.quantize_type == "fp4": - quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has FP8 qdq nodes at this point # %% @@ -235,20 +233,12 @@ def calibrate_loop(model): with export_torch_mode(): # Compile the model with Torch-TensorRT Dynamo backend input_tensor = images.cuda() - torch.onnx.export(model, input_tensor, "mtq_vgg16_model.onnx") exp_program = torch.export.export(model, (input_tensor,), strict=False) if args.quantize_type == "int8": enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} - elif args.quantize_type == "fp4": - enabled_precisions = { - torch.float4_e2m1fn_x2, - torch.float8_e4m3fn, - torch.float16, - torch.float32, - } trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 4db11daa78..46af714c51 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -250,13 +250,10 @@ def cross_compile_for_windows( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - if use_explicit_typing: - if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} - ): - raise AssertionError( - f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - ) + # if use_explicit_typing and len(enabled_precisions) != 1: + # raise AssertionError( + # f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( @@ -581,13 +578,10 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - # if use_explicit_typing: - # if len(enabled_precisions) != 1 or not any( - # x in enabled_precisions for x in {torch.float32, dtype.f32} - # ): - # raise AssertionError( - # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - # ) + # if use_explicit_typing and len(enabled_precisions) != 1: + # raise AssertionError( + # f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index ecf08f38c4..acc05935c6 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -43,6 +43,7 @@ get_node_io, get_node_name, get_trt_tensor, + global_reference_holder, to_torch, ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device @@ -274,16 +275,28 @@ def _populate_trt_builder_config( self.compilation_settings.dla_global_dram_size, ) - if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions: + if ( + not self.compilation_settings.use_explicit_typing + and dtype.float16 in self.compilation_settings.enabled_precisions + ): builder_config.set_flag(trt.BuilderFlag.FP16) - if dtype.int8 in self.compilation_settings.enabled_precisions: + if ( + not self.compilation_settings.use_explicit_typing + and dtype.int8 in self.compilation_settings.enabled_precisions + ): builder_config.set_flag(trt.BuilderFlag.INT8) - if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions: + if ( + not self.compilation_settings.use_explicit_typing + and dtype.fp8 in self.compilation_settings.enabled_precisions + ): builder_config.set_flag(trt.BuilderFlag.FP8) - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + if ( + not self.compilation_settings.use_explicit_typing + and dtype.bfloat16 in self.compilation_settings.enabled_precisions + ): builder_config.set_flag(trt.BuilderFlag.BF16) if self.compilation_settings.sparse_weights: @@ -737,7 +750,8 @@ def run( self.ctx.net, builder_config ) assert serialized_engine - + # clear the global reference holder after engin is built + global_reference_holder.clear() _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d8ef865fd3..9cb81b9b62 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -638,7 +638,7 @@ def aten_ops_dynamic_block_quantize_op( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.nvfp4_quantize.nvfp4_quantize( + return impl.nvfp4_quantize.quantize( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index eb18a14eca..4e47bb1511 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -25,6 +25,9 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +# global reference holder for the objects that should not be garbage collected before the engine is built +global_reference_holder: List[torch.Tensor] = [] + def get_node_name(node: torch.fx.Node) -> str: # nn_module_stack preserves the call stack of pytorch nn.modules @@ -361,32 +364,30 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: - if torch_value.dtype == torch.float8_e4m3fn: + if torch_value.dtype in (torch.float8_e4m3fn, torch.uint8): + # global_reference_holder.append(torch_value) + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 + if torch_value.dtype == torch.uint8: + count = torch_value.numel() * 2 + shape[-1] = shape[-1] * 2 + dtype = trt.DataType.FP4 + else: + count = torch_value.numel() + dtype = trt.DataType.FP8 weights = trt.Weights( - type=trt.DataType.FP8, + type=dtype, ptr=torch_value.data_ptr(), - count=torch_value.numel(), + count=count, ) constant = ctx.net.add_constant( shape, weights, ) constant.name = name + # TODO: confirm with @dheeraj whether it is ok to put the torch tensor here, since the fp8 torch tensor cannot have the equivalent of numpy array + ctx.mapping[name + " CONSTANT"] = torch_value.reshape(-1) return constant.get_output(0) - # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 - if torch_value.dtype == torch.uint8: - weights = trt.Weights( - type=trt.DataType.FP4, - ptr=torch_value.data_ptr(), - count=torch_value.numel() * 2, - ) - shape[-1] = shape[-1] * 2 - constant = ctx.net.add_constant( - shape, - weights, - ) - constant.name = name - return constant.get_output(0) + if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() diff --git a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py index 73d98acfdf..1a0690852a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.types import TRTTensor -import os + def addmm( ctx: ConversionContext, @@ -21,10 +21,6 @@ def addmm( beta: Union[float, int], alpha: Union[float, int], ) -> TRTTensor: - if os.getenv("DISABLE_GEMM", "false").lower() == "true": - print("lan added disable_gemm is set, skip addmm and returning mat2") - return mat2 - print("lan added disable_gemm is not set, doing addmm") mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) if alpha != 1: mm = impl.elementwise.mul( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 1c2f297764..0ac39622cf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -1,23 +1,20 @@ -import os from typing import Optional, Union import numpy as np import tensorrt as trt import torch -import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( get_trt_tensor, - to_torch, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor -def nvfp4_quantize( +def quantize( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -34,9 +31,6 @@ def nvfp4_quantize( Adds quantize and dequantize ops (QDQ) which quantize to FP4 based on the output_type set and dequantizes them back. """ - print( - f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" - ) if len(input_tensor.shape) not in (2, 3): raise ValueError( f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" @@ -44,8 +38,6 @@ def nvfp4_quantize( with unset_fake_temporarily(): axis = -1 global_scale = _calculate_global_scale(ctx, name, amax) - print(f"lan added input_tensor: {input_tensor.shape=} {input_tensor.dtype=}") - print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=}") if ".weight_quantizer" in name: output = _static_double_quantize( ctx, @@ -80,7 +72,7 @@ def _dynamic_double_quantize( target: Target, source_ir: Optional[SourceIR], name: str, - input_tensor: torch.Tensor, + input_tensor: TRTTensor, global_scale: torch.Tensor, axis: int = -1, block_size: int = 16, @@ -94,8 +86,8 @@ def _dynamic_double_quantize( target: Target, source_ir: Optional[SourceIR] name: str - input_tensor : Tensor (On GPU) - The input tensor. + input_tensor : TRTTensor (On GPU) + The input TRTTensor. global_scale : Tensor (On GPU) The global per-tensor scaling factor. It should contain only 1 element. axis : int @@ -108,15 +100,15 @@ def _dynamic_double_quantize( The data type for block scale. Default is FP8. """ - if os.getenv("DISABLE_DYNAMIC_QUANTIZE", "false").lower() == "true": - print("lan added disable_dynamic_quantize is set, skipping dynamic quantize") - return input_tensor - print("lan added disable_dynamic_quantize is not set, doing dynamic quantize") global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: + if input_tensor.dtype not in [ + trt.DataType.HALF, + trt.DataType.FLOAT, + trt.DataType.BF16, + ]: raise ValueError( - f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}" + f"Currently supported input tensor type is float16 | float32 | bfloat16, got Unsupported dtype: {input_tensor.dtype}" ) # dynamic quantize input tensor to fp4 dynamic_quantize_layer = ctx.net.add_dynamic_quantize( @@ -203,21 +195,18 @@ def _static_double_quantize( Returns: quantized data tensor in fp4 """ - if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true": - print("lan added disable_static_quantize is set, skipping static quantize") - return get_trt_tensor(ctx, weights_tensor, name + "_weights") - print( - "lan added static disable_static_quantize is not set, doing static double quantize " - ) + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor if weights_tensor.dtype == torch.float16: original_dtype = trt.DataType.HALF elif weights_tensor.dtype == torch.float32: original_dtype = trt.DataType.FLOAT + elif weights_tensor.dtype == torch.bfloat16: + original_dtype = trt.DataType.BF16 else: raise ValueError( - f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" + f"Currently supported weights tensor type is float16 | float32 | bfloat16, got Unsupported dtype: {weights_tensor.dtype}" ) block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( weights_tensor, @@ -232,103 +221,9 @@ def _static_double_quantize( )[0]._quantized_data block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") - global_scale = to_torch(global_scale, None) global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") - print( - f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=}" - ) - dequantized_data = _double_dequantize( - ctx, - target, - source_ir, - name, - weights_tensor_fp4, - block_scale_fp8, - global_scale, - axis, - original_dtype, - ) - return dequantized_data - -def _static_double_quantize_transpose( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - weights_tensor: torch.Tensor, - global_scale: torch.Tensor, - axis: int, -) -> TRTTensor: - """ - Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - weights_tensor : Tensor (On GPU) - The input tensor for weights. - global_scale : Tensor (On GPU) - The global per-tensor scaling factor. It should contain only 1 element. - axis: int - The axis to quantize. Default is -1 (the last axis). - Returns: - quantized data tensor in fp4 - """ - axis = -2 - import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor - - if weights_tensor.dtype == torch.float16: - original_dtype = trt.DataType.HALF - elif weights_tensor.dtype == torch.float32: - original_dtype = trt.DataType.FLOAT - else: - raise ValueError( - f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" - ) - block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( - weights_tensor, - 16, - global_scale, - keep_high_precision=True, - )[0] - weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( - weights_tensor, - 16, - block_scale, - global_scale, - keep_high_precision=True, - ) - - block_scale = block_scale.transpose(0, 1) - weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1) - - block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) - weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled) - weights_tensor_uint8 = ( - weights_tensor_uint4[..., 1::2] << 4 - ) | weights_tensor_uint4[..., 0::2] - - print( - f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}" - ) - print( - f"lan added weights_tensor_uint4: {weights_tensor_uint4.shape=} {weights_tensor_uint4.dtype=} {weights_tensor_uint4=}" - ) - print( - f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=} {weights_tensor_uint8=}" - ) - - block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") - global_scale = to_torch(global_scale, None) - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - weights_tensor_fp4 = get_trt_tensor( - ctx, weights_tensor_uint8, name + "_weights_fp4" - ) - print( - f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=}" - ) dequantized_data = _double_dequantize( ctx, target, @@ -340,14 +235,6 @@ def _static_double_quantize_transpose( axis, original_dtype, ) - dequantized_data = impl.permutation.permute( - ctx, - target, - source_ir, - name + "_dequantized_data_transposed", - dequantized_data, - (-1, -2), - ) return dequantized_data @@ -357,12 +244,7 @@ def _calculate_global_scale( amax: torch.Tensor, ) -> torch.Tensor: # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) - if amax is None or amax == 0: - amax = 1.0 - amax = to_torch( - amax, None - ) # amax is calculated from input_tensor.abs().amax().float() - global_scale = torch.divide(amax, 6 * 448) - if global_scale == 0: - global_scale = 1.0 + assert len(amax.shape) == 0, "amax should be a scalar" + global_scale = amax / 6 / 448 + global_scale.masked_fill_(global_scale == 0, 1.0) return global_scale diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 4408b62809..1537d0fdbe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -13,7 +13,7 @@ ) from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.types import TRTTensor -import os + def permute( ctx: ConversionContext, @@ -23,10 +23,6 @@ def permute( input: TRTTensor, permutation: Sequence[int], ) -> TRTTensor: - if os.getenv("DISABLE_GEMM", "false").lower() == "true": - print("lan added disable_gemm is set, skip permute") - return input - print("lan added disable_gemm is not set, doing permute") if not isinstance(input, TRTTensor): raise RuntimeError( f"permute received input {input} that is not a TensorRT ITensor" diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 1601d20ed5..cafd206301 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -14,7 +14,6 @@ from packaging.version import Version assertions = unittest.TestCase() -import os @pytest.mark.unit @@ -235,11 +234,6 @@ def calibrate_loop(model): dummy_inputs = torch.ones(batch_size, 64, dtype=dtype).cuda() model = SimpleNetwork().eval().cuda() - # model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) - # model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) - - print(f"lan added model weight: {model.linear1.weight=}") - print(f"lan added model bias: {model.linear1.bias=}") quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) @@ -255,39 +249,20 @@ def calibrate_loop(model): inputs=[dummy_inputs], enabled_precisions={ torch.float4_e2m1fn_x2, - torch.float8_e4m3fn, - torch.float32, - torch.float16, }, min_block_size=1, debug=True, cache_built_engines=False, reuse_cached_engines=False, - use_explicit_typing=dtype == torch.float16, + use_explicit_typing=True, ) batch_size = 128 input_tensor = torch.ones(batch_size, 64, dtype=dtype).cuda() expected_output = model(input_tensor) outputs_trt = trt_model(input_tensor) - if os.getenv("DISABLE_GEMM", "false").lower() == "true": - print("lan added disable_gemm is set, compring result with weights") - expected_output = model.linear1.weight - else: - print("lan added disable_gemm is not set, compring result with pytorch") - - print( - f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}" - ) - print( - f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}" - ) - abs_diff = torch.abs(expected_output - outputs_trt) - print( - f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}" - ) - print(f"lan added abs_diff: {abs_diff=}") - assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) + print(f"max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3) # @unittest.skipIf( @@ -299,11 +274,11 @@ def calibrate_loop(model): "ModelOpt is required to run this test", ) @pytest.mark.unit -def test_base_fp4(ir): +def test_base_fp4_static_shapes(ir): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - dtype = torch.float16 + dtype = torch.bfloat16 class SimpleNetwork(torch.nn.Module): def __init__(self): @@ -320,17 +295,10 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.ones(128, 64, dtype=dtype).cuda() + input_tensor = torch.randn(128, 64, dtype=dtype).cuda() model = SimpleNetwork().eval().cuda() - model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) - model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) - print(f"lan added amax: {input_tensor.abs().amax()=}") - print(f"lan added amax: {model.linear1.weight.abs().amax()=}") expected_output = model(input_tensor) - print(f"lan added model input: {input_tensor=}") - print(f"lan added model weight: {model.linear1.weight=}") - print(f"lan added model bias: {model.linear1.bias=}") quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) @@ -345,37 +313,17 @@ def calibrate_loop(model): inputs=[input_tensor], enabled_precisions={ torch.float4_e2m1fn_x2, - torch.float8_e4m3fn, - torch.float32, - torch.float16, }, min_block_size=1, debug=True, cache_built_engines=False, reuse_cached_engines=False, - use_explicit_typing=dtype == torch.float16, + use_explicit_typing=True, ) - outputs_trt = trt_model(input_tensor) - if os.getenv("DISABLE_GEMM", "false").lower() == "true": - print("lan added disable_gemm is set, compring result with weights") - expected_output = model.linear1.weight - else: - print("lan added disable_gemm is not set, compring result with pytorch") - - print( - f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}" - ) - print( - f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}" - ) - abs_diff = torch.abs(expected_output - outputs_trt) - print( - f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}" - ) - print(f"lan added abs_diff: {abs_diff=}") - assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) + print(f"max/mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3) @unittest.skipIf( From 4e226d0e7298625caffbb51d527b70690ffb6e7c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 5 Jun 2025 11:27:32 -0700 Subject: [PATCH 4/7] test --- .../dynamo/conversion/_TRTInterpreter.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index acc05935c6..05fa292d10 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -39,11 +39,10 @@ CallingConvention, ) from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor -from torch_tensorrt.dynamo.conversion.converter_utils import ( +from torch_tensorrt.dynamo.conversion.converter_utils import ( # global_reference_holder, get_node_io, get_node_name, get_trt_tensor, - global_reference_holder, to_torch, ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device @@ -426,7 +425,14 @@ def find_weight( state_dict: state of the graph module """ with unset_fake_temporarily(): - network_weight = torch.from_numpy(np_map[weight_name]).to(device) + if isinstance(np_map[weight_name], np.ndarray): + network_weight = torch.from_numpy(np_map[weight_name]).to(device) + elif isinstance(np_map[weight_name], torch.Tensor): + network_weight = np_map[weight_name].to(device) + else: + raise ValueError( + f"Unsupported weight type: {type(np_map[weight_name])}, currently only support numpy.ndarray | torch.Tensor" + ) for sd_w_name, sd_weight in state_dict.items(): if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): del state_dict[sd_w_name] @@ -751,7 +757,7 @@ def run( ) assert serialized_engine # clear the global reference holder after engin is built - global_reference_holder.clear() + # global_reference_holder.clear() _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) From 951a920e5f98e7dd2c0bc58bb1feedeb09d04c7b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 5 Jun 2025 11:43:19 -0700 Subject: [PATCH 5/7] test --- py/torch_tensorrt/dynamo/_compiler.py | 16 ++++++++-------- .../dynamo/conversion/converter_utils.py | 6 ++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 46af714c51..a6d3fce7f1 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -250,10 +250,10 @@ def cross_compile_for_windows( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - # if use_explicit_typing and len(enabled_precisions) != 1: - # raise AssertionError( - # f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" - # ) + if use_explicit_typing and len(enabled_precisions) != 1: + raise AssertionError( + f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" + ) if use_fp32_acc: logger.debug( @@ -578,10 +578,10 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - # if use_explicit_typing and len(enabled_precisions) != 1: - # raise AssertionError( - # f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" - # ) + if use_explicit_typing and len(enabled_precisions) != 1: + raise AssertionError( + f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" + ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4e47bb1511..dcc60b26e9 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -365,7 +365,7 @@ def create_constant( if torch_value is not None: if torch_value.dtype in (torch.float8_e4m3fn, torch.uint8): - # global_reference_holder.append(torch_value) + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 if torch_value.dtype == torch.uint8: count = torch_value.numel() * 2 @@ -384,8 +384,10 @@ def create_constant( weights, ) constant.name = name - # TODO: confirm with @dheeraj whether it is ok to put the torch tensor here, since the fp8 torch tensor cannot have the equivalent of numpy array + # TODO: confirm with @dheeraj @naren whether i can use ctx.mapping as the reference holder to prevent the torch tensor being garbage collected. ctx.mapping[name + " CONSTANT"] = torch_value.reshape(-1) + # if yes, then the following global_reference_holder is no longer needed + # global_reference_holder.append(torch_value) return constant.get_output(0) if torch_value.dtype == torch.bfloat16: From 43cdec8c1a041b63833b28a48c423a4821504fd5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 5 Jun 2025 14:14:18 -0700 Subject: [PATCH 6/7] address comments --- .../dynamo/conversion/_ConversionContext.py | 5 +++ .../dynamo/conversion/_TRTInterpreter.py | 14 ++------ .../dynamo/conversion/aten_ops_converters.py | 4 +-- .../dynamo/conversion/converter_utils.py | 35 ++++++++++--------- .../dynamo/conversion/impl/__init__.py | 2 +- ..._quantize.py => dynamic_block_quantize.py} | 2 +- 6 files changed, 30 insertions(+), 32 deletions(-) rename py/torch_tensorrt/dynamo/conversion/impl/{nvfp4_quantize.py => dynamic_block_quantize.py} (98%) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 141b68f3e7..9da6d8cbdc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field +from typing import Union import numpy as np +import torch from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.types import TRTNetwork @@ -21,3 +23,6 @@ class ConversionContext: ) requires_output_allocator: bool = False mapping: dict[str, np.array] = field(default_factory=dict) + weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field( + default_factory=dict + ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 05fa292d10..dc2a8e7fb1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -39,7 +39,7 @@ CallingConvention, ) from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor -from torch_tensorrt.dynamo.conversion.converter_utils import ( # global_reference_holder, +from torch_tensorrt.dynamo.conversion.converter_utils import ( get_node_io, get_node_name, get_trt_tensor, @@ -425,14 +425,7 @@ def find_weight( state_dict: state of the graph module """ with unset_fake_temporarily(): - if isinstance(np_map[weight_name], np.ndarray): - network_weight = torch.from_numpy(np_map[weight_name]).to(device) - elif isinstance(np_map[weight_name], torch.Tensor): - network_weight = np_map[weight_name].to(device) - else: - raise ValueError( - f"Unsupported weight type: {type(np_map[weight_name])}, currently only support numpy.ndarray | torch.Tensor" - ) + network_weight = torch.from_numpy(np_map[weight_name]).to(device) for sd_w_name, sd_weight in state_dict.items(): if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): del state_dict[sd_w_name] @@ -756,8 +749,7 @@ def run( self.ctx.net, builder_config ) assert serialized_engine - # clear the global reference holder after engin is built - # global_reference_holder.clear() + _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9cb81b9b62..543de347d0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -623,7 +623,7 @@ def aten_ops_quantize_op( assert torch.ops.tensorrt.dynamic_block_quantize_op.default except Exception as e: _LOGGER.warning( - "Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models" + "Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models" ) else: @@ -638,7 +638,7 @@ def aten_ops_dynamic_block_quantize_op( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.nvfp4_quantize.quantize( + return impl.dynamic_block_quantize.quantize( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index dcc60b26e9..4d3f070398 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -25,9 +25,6 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -# global reference holder for the objects that should not be garbage collected before the engine is built -global_reference_holder: List[torch.Tensor] = [] - def get_node_name(node: torch.fx.Node) -> str: # nn_module_stack preserves the call stack of pytorch nn.modules @@ -364,30 +361,34 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: - if torch_value.dtype in (torch.float8_e4m3fn, torch.uint8): + if torch_value.dtype == torch.float8_e4m3fn: + weights = trt.Weights( + type=trt.DataType.FP8, + ptr=torch_value.data_ptr(), + count=torch_value.numel(), + ) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + ctx.weights_reference_holder[name + " FP8_CONSTANT"] = torch_value + return constant.get_output(0) + if torch_value.dtype == torch.uint8: # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 - if torch_value.dtype == torch.uint8: - count = torch_value.numel() * 2 - shape[-1] = shape[-1] * 2 - dtype = trt.DataType.FP4 - else: - count = torch_value.numel() - dtype = trt.DataType.FP8 + shape[-1] = shape[-1] * 2 weights = trt.Weights( - type=dtype, + type=trt.DataType.FP4, ptr=torch_value.data_ptr(), - count=count, + count=torch_value.numel() * 2, ) constant = ctx.net.add_constant( shape, weights, ) constant.name = name - # TODO: confirm with @dheeraj @naren whether i can use ctx.mapping as the reference holder to prevent the torch tensor being garbage collected. - ctx.mapping[name + " CONSTANT"] = torch_value.reshape(-1) - # if yes, then the following global_reference_holder is no longer needed - # global_reference_holder.append(torch_value) + ctx.weights_reference_holder[name + " FP4_CONSTANT"] = torch_value return constant.get_output(0) if torch_value.dtype == torch.bfloat16: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 1f2d9d0de1..10af2ad892 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -7,6 +7,7 @@ condition, conv, deconv, + dynamic_block_quantize, elementwise, embedding, full, @@ -14,7 +15,6 @@ matmul, nccl_ops, normalization, - nvfp4_quantize, pad, permutation, pool, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py similarity index 98% rename from py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py rename to py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py index 0ac39622cf..ec9048fca4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py @@ -33,7 +33,7 @@ def quantize( """ if len(input_tensor.shape) not in (2, 3): raise ValueError( - f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" ) with unset_fake_temporarily(): axis = -1 From 416c39668717fa216a3e9f3c72bbbd9d6eb3c59e Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 6 Jun 2025 09:39:41 -0700 Subject: [PATCH 7/7] resolve comments --- py/torch_tensorrt/dynamo/_compiler.py | 22 +++++++++----- .../dynamo/conversion/_ConversionContext.py | 5 +++- .../dynamo/conversion/_TRTInterpreter.py | 22 ++++---------- .../dynamo/conversion/converter_utils.py | 21 ++++++++++--- .../conversion/impl/dynamic_block_quantize.py | 30 ++++++++++++++++--- tests/py/dynamo/models/test_models_export.py | 24 ++++++--------- 6 files changed, 76 insertions(+), 48 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a6d3fce7f1..831ce37305 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -250,10 +250,13 @@ def cross_compile_for_windows( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - if use_explicit_typing and len(enabled_precisions) != 1: - raise AssertionError( - f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" - ) + if use_explicit_typing: + if len(enabled_precisions) != 1 or not any( + x in enabled_precisions for x in {torch.float32, dtype.f32} + ): + raise AssertionError( + f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + ) if use_fp32_acc: logger.debug( @@ -578,10 +581,13 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - if use_explicit_typing and len(enabled_precisions) != 1: - raise AssertionError( - f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}" - ) + if use_explicit_typing: + if len(enabled_precisions) != 1 or not any( + x in enabled_precisions for x in {torch.float32, dtype.f32} + ): + raise AssertionError( + f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 9da6d8cbdc..1c4926bcfa 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -23,6 +23,9 @@ class ConversionContext: ) requires_output_allocator: bool = False mapping: dict[str, np.array] = field(default_factory=dict) - weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field( + cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field( default_factory=dict ) + + def clear_cpu_weights_reference_holder(self) -> None: + self.cpu_weights_reference_holder.clear() diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index dc2a8e7fb1..bb1a77b4eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -274,28 +274,16 @@ def _populate_trt_builder_config( self.compilation_settings.dla_global_dram_size, ) - if ( - not self.compilation_settings.use_explicit_typing - and dtype.float16 in self.compilation_settings.enabled_precisions - ): + if dtype.float16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP16) - if ( - not self.compilation_settings.use_explicit_typing - and dtype.int8 in self.compilation_settings.enabled_precisions - ): + if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) - if ( - not self.compilation_settings.use_explicit_typing - and dtype.fp8 in self.compilation_settings.enabled_precisions - ): + if dtype.fp8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP8) - if ( - not self.compilation_settings.use_explicit_typing - and dtype.bfloat16 in self.compilation_settings.enabled_precisions - ): + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.BF16) if self.compilation_settings.sparse_weights: @@ -755,6 +743,8 @@ def run( ) _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") + self.ctx.clear_cpu_weights_reference_holder() + self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4d3f070398..b5b7cce868 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -326,6 +326,7 @@ def create_constant( name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]], min_rank: Optional[int] = 1, + target_quantized_type: Optional[TRTDataType] = None, ) -> TRTTensor: """ Add a TensorRT constant layer whose value is `value` to `ctx.net`. @@ -338,6 +339,7 @@ def create_constant( dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If a dtype is given, we will convert the type of the given `value` to this dtype. min_rank (int): minimum rank of the constant tensor. + target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype. Returns: A TensorRT ITensor that represents the given value. """ @@ -372,11 +374,18 @@ def create_constant( weights, ) constant.name = name - ctx.weights_reference_holder[name + " FP8_CONSTANT"] = torch_value + ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value return constant.get_output(0) if torch_value.dtype == torch.uint8: - # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 + if ( + target_quantized_type is None + or target_quantized_type != trt.DataType.FP4 + ): + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 + raise ValueError( + "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}" + ) shape[-1] = shape[-1] * 2 weights = trt.Weights( type=trt.DataType.FP4, @@ -388,7 +397,7 @@ def create_constant( weights, ) constant.name = name - ctx.weights_reference_holder[name + " FP4_CONSTANT"] = torch_value + ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value return constant.get_output(0) if torch_value.dtype == torch.bfloat16: @@ -423,6 +432,7 @@ def get_trt_tensor( name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, min_rank: int = 1, + target_quantized_type: Optional[TRTDataType] = None, ) -> TRTTensor: """ Given a value of random type, we try to convert it to a TensorRT ITensor. @@ -436,6 +446,7 @@ def get_trt_tensor( dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If dtype is provided, the given value will be converted to this dtype. min_rank (int): minimum rank of the constant tensor. + target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype. Returns: A TensorRT ITensor that represents the given value. """ @@ -448,7 +459,9 @@ def get_trt_tensor( input_val = input_val.astype(np.float32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)): - return create_constant(ctx, input_val, name, dtype, min_rank) + return create_constant( + ctx, input_val, name, dtype, min_rank, target_quantized_type + ) elif isinstance(input_val, TRTTensor): return input_val else: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py index ec9048fca4..f76a84dea5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py @@ -49,7 +49,6 @@ def quantize( axis, ) elif ".input_quantizer" in name: - # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 output = _dynamic_double_quantize( ctx, target, @@ -59,7 +58,6 @@ def quantize( global_scale, axis, ) - else: raise ValueError( f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" @@ -149,6 +147,20 @@ def _double_dequantize( axis: int = -1, output_type: trt.DataType = trt.DataType.FLOAT, ) -> TRTTensor: + """ + double dequantize will first dequantize scale from fp8 to orignal dtype(default is float32) + and then dequantize data from fp4 to orignal dtype(default is float32) + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + quantized_data_in_fp4: TRTTensor + quantized_scale_in_fp8: TRTTensor + global_scale: torch.Tensor + axis: int + output_type: trt.DataType + """ # dequantize scale from fp8 to orignal dtype(default is float32) dequantize_scale_layer = ctx.net.add_dequantize( quantized_scale_in_fp8, global_scale, output_type @@ -220,9 +232,19 @@ def _static_double_quantize( global_scale, )[0]._quantized_data - block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + block_scale_fp8 = get_trt_tensor( + ctx, + block_scale_fp8, + name + "_block_scale_fp8", + target_quantized_type=trt.DataType.FP8, + ) global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") + weights_tensor_fp4 = get_trt_tensor( + ctx, + weights_tensor_fp4, + name + "_weights_fp4", + target_quantized_type=trt.DataType.FP4, + ) dequantized_data = _double_dequantize( ctx, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index cafd206301..a70755a562 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,10 +199,10 @@ def test_resnet18_half(ir): torch._dynamo.reset() -# @unittest.skipIf( -# torch.cuda.get_device_capability() < (10, 0), -# "FP4 quantization requires compute capability 10.0 or later", -# ) +@unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "FP4 quantization requires compute capability 10.0 or later", +) @unittest.skipIf( not importlib.util.find_spec("modelopt"), "ModelOpt is required to run this test", @@ -247,9 +247,6 @@ def calibrate_loop(model): trt_model = torchtrt.dynamo.compile( exp_program, inputs=[dummy_inputs], - enabled_precisions={ - torch.float4_e2m1fn_x2, - }, min_block_size=1, debug=True, cache_built_engines=False, @@ -261,14 +258,14 @@ def calibrate_loop(model): expected_output = model(input_tensor) outputs_trt = trt_model(input_tensor) abs_diff = torch.abs(expected_output - outputs_trt) - print(f"max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + print(f"max/mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3) -# @unittest.skipIf( -# torch.cuda.get_device_capability() < (10, 0), -# "FP4 quantization requires compute capability 10.0 or later", -# ) +@unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "FP4 quantization requires compute capability 10.0 or later", +) @unittest.skipIf( not importlib.util.find_spec("modelopt"), "ModelOpt is required to run this test", @@ -311,9 +308,6 @@ def calibrate_loop(model): trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], - enabled_precisions={ - torch.float4_e2m1fn_x2, - }, min_block_size=1, debug=True, cache_built_engines=False,