Skip to content

Add fp4 support #3532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/keith/pre-commit-buildifier
rev: 6.4.0
rev: 8.0.3
hooks:
- id: buildifier
args:
- --warnings=all
- id: buildifier-lint
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.23
rev: v0.24.1
hooks:
- id: validate-pyproject
- repo: https://github.com/pycqa/isort
Expand All @@ -37,17 +37,17 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.15.0"
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
rev: v0.11.7
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 25.1.0
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand All @@ -57,7 +57,7 @@ repos:
- id: typos
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.5
rev: 0.7.1
hooks:
# Update the uv lockfile
- id: uv-lock
Expand Down
19 changes: 19 additions & 0 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class dtype(Enum):
:meta hide-value:
"""

f4 = auto()
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``

:meta hide-value:
"""

uint8 = u8
int8 = i8

Expand All @@ -91,6 +97,9 @@ class dtype(Enum):
float8 = f8
fp8 = f8

float4 = f4
fp4 = f4

half = f16
fp16 = f16
float16 = f16
Expand Down Expand Up @@ -162,6 +171,8 @@ def _from(
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.float4_e2m1fn_x2:
return dtype.f4
elif t == torch.half:
return dtype.f16
elif t == torch.float:
Expand All @@ -188,6 +199,8 @@ def _from(
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.FP4:
return dtype.fp4
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
Expand Down Expand Up @@ -357,6 +370,8 @@ def to(
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f4:
return torch.float4_e2m1fn_x2
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
Expand Down Expand Up @@ -394,6 +409,8 @@ def to(
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif self == dtype.f4:
return trt.DataType.FP4
elif use_default:
return trt.DataType.FLOAT
else:
Expand All @@ -410,6 +427,8 @@ def to(
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f4:
return np.float4_e2m1fn_x2
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
Expand Down
22 changes: 8 additions & 14 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,10 @@ def cross_compile_for_windows(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
):
raise AssertionError(
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
)
if use_explicit_typing and len(enabled_precisions) != 1:
raise AssertionError(
f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}"
)

if use_fp32_acc:
logger.debug(
Expand Down Expand Up @@ -581,13 +578,10 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
):
raise AssertionError(
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
)
if use_explicit_typing and len(enabled_precisions) != 1:
raise AssertionError(
f"When use_explicit_typing is enabled, allow only 1 precision in the enabled_precisions but found {enabled_precisions}"
)

if use_fp32_acc:
logger.debug(
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
SUPPORTED_KERNEL_PRECISIONS = {
dtype.f32,
dtype.f16,
dtype.bf16,
dtype.i8,
dtype.f8,
dtype.f4,
}
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
Expand Down
34 changes: 27 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
CallingConvention,
)
from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor
from torch_tensorrt.dynamo.conversion.converter_utils import (
from torch_tensorrt.dynamo.conversion.converter_utils import ( # global_reference_holder,
get_node_io,
get_node_name,
get_trt_tensor,
Expand Down Expand Up @@ -274,16 +274,28 @@ def _populate_trt_builder_config(
self.compilation_settings.dla_global_dram_size,
)

if dtype.float16 in self.compilation_settings.enabled_precisions:
if (
not self.compilation_settings.use_explicit_typing
and dtype.float16 in self.compilation_settings.enabled_precisions
):
builder_config.set_flag(trt.BuilderFlag.FP16)

if dtype.int8 in self.compilation_settings.enabled_precisions:
if (
not self.compilation_settings.use_explicit_typing
and dtype.int8 in self.compilation_settings.enabled_precisions
):
builder_config.set_flag(trt.BuilderFlag.INT8)

if dtype.fp8 in self.compilation_settings.enabled_precisions:
if (
not self.compilation_settings.use_explicit_typing
and dtype.fp8 in self.compilation_settings.enabled_precisions
):
builder_config.set_flag(trt.BuilderFlag.FP8)

if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
if (
not self.compilation_settings.use_explicit_typing
and dtype.bfloat16 in self.compilation_settings.enabled_precisions
):
builder_config.set_flag(trt.BuilderFlag.BF16)

if self.compilation_settings.sparse_weights:
Expand Down Expand Up @@ -413,7 +425,14 @@ def find_weight(
state_dict: state of the graph module
"""
with unset_fake_temporarily():
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
if isinstance(np_map[weight_name], np.ndarray):
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
elif isinstance(np_map[weight_name], torch.Tensor):
network_weight = np_map[weight_name].to(device)
else:
raise ValueError(
f"Unsupported weight type: {type(np_map[weight_name])}, currently only support numpy.ndarray | torch.Tensor"
)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
Expand Down Expand Up @@ -737,7 +756,8 @@ def run(
self.ctx.net, builder_config
)
assert serialized_engine

# clear the global reference holder after engin is built
# global_reference_holder.clear()
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
Expand Down
36 changes: 36 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,42 @@ def aten_ops_quantize_op(
)


try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models"
)
else:

@dynamo_tensorrt_converter(
torch.ops.tensorrt.dynamic_block_quantize_op.default,
supports_dynamic_shapes=True,
)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.nvfp4_quantize.quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
def aten_ops_squeeze(
Expand Down
31 changes: 29 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

_LOGGER: logging.Logger = logging.getLogger(__name__)

# global reference holder for the objects that should not be garbage collected before the engine is built
global_reference_holder: List[torch.Tensor] = []


def get_node_name(node: torch.fx.Node) -> str:
# nn_module_stack preserves the call stack of pytorch nn.modules
Expand Down Expand Up @@ -361,12 +364,37 @@ def create_constant(
shape = list(torch_value.shape)

if torch_value is not None:
if torch_value.dtype in (torch.float8_e4m3fn, torch.uint8):

# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
if torch_value.dtype == torch.uint8:
count = torch_value.numel() * 2
shape[-1] = shape[-1] * 2
dtype = trt.DataType.FP4
else:
count = torch_value.numel()
dtype = trt.DataType.FP8
weights = trt.Weights(
type=dtype,
ptr=torch_value.data_ptr(),
count=count,
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
# TODO: confirm with @dheeraj @naren whether i can use ctx.mapping as the reference holder to prevent the torch tensor being garbage collected.
ctx.mapping[name + " CONSTANT"] = torch_value.reshape(-1)
# if yes, then the following global_reference_holder is no longer needed
# global_reference_holder.append(torch_value)
return constant.get_output(0)

if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()

ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
constant = ctx.net.add_constant(
shape,
Expand All @@ -381,7 +409,6 @@ def create_constant(
trt.DataType.BF16,
name + "_bf16_cast",
)

return constant.get_output(0)
else:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
matmul,
nccl_ops,
normalization,
nvfp4_quantize,
pad,
permutation,
pool,
Expand Down
Loading
Loading