|
| 1 | +""" |
| 2 | +.. _auto_generate_converters: |
| 3 | +
|
| 4 | +Automatically Generate a Plugin for a Custom Kernel |
| 5 | +=================================================================== |
| 6 | +
|
| 7 | +We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using |
| 8 | +the new Python based plugin system in TensorRT 10.7. |
| 9 | +
|
| 10 | +Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT |
| 11 | +does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model. |
| 12 | +The easiest way to fix lack of support for ops is by adding a decomposition (see: |
| 13 | +`Writing lowering passes for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html>`_) - which defines the operator |
| 14 | +in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see: |
| 15 | +`Writing converters for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/dynamo_converters.html>`_) - which defines the operator in terms of TensorRT operators. |
| 16 | +
|
| 17 | +In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or |
| 18 | +TensorRT cannot support it natively. |
| 19 | +
|
| 20 | +For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding |
| 21 | +the performance and resource overhead from a graph break. |
| 22 | +
|
| 23 | +Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_). |
| 24 | +With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This |
| 25 | +plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the |
| 26 | +operation in PyTorch to TensorRT. |
| 27 | +""" |
| 28 | + |
| 29 | +# %% |
| 30 | +# Writing Custom Operators in PyTorch |
| 31 | +# ----------------------------------------- |
| 32 | +# |
| 33 | +# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT. |
| 34 | +# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch. |
| 35 | +# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type |
| 36 | +# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it |
| 37 | +# is necessary to define. |
| 38 | +# |
| 39 | + |
| 40 | +from typing import Tuple, Union |
| 41 | + |
| 42 | +import tensorrt as trt |
| 43 | +import tensorrt.plugin as trtp |
| 44 | +import torch |
| 45 | +import torch_tensorrt |
| 46 | +import triton |
| 47 | +import triton.language as tl |
| 48 | + |
| 49 | + |
| 50 | +@triton.jit |
| 51 | +def rms_norm_kernel( |
| 52 | + x_ptr, |
| 53 | + w_ptr, |
| 54 | + n, |
| 55 | + x_stride, |
| 56 | + o_stride, |
| 57 | + o_ptr, |
| 58 | + EPS: tl.constexpr, |
| 59 | + BLOCK_SIZE: tl.constexpr, |
| 60 | +) -> None: |
| 61 | + i = tl.program_id(axis=0).to(tl.int64) |
| 62 | + |
| 63 | + x_row = x_ptr + i * x_stride |
| 64 | + o_row = o_ptr + i * o_stride |
| 65 | + |
| 66 | + # Find the root mean square for the given row. |
| 67 | + square_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
| 68 | + for off in range(0, n, BLOCK_SIZE): |
| 69 | + offsets = off + tl.arange(0, BLOCK_SIZE) |
| 70 | + mask = offsets < n |
| 71 | + |
| 72 | + x = tl.load(x_row + offsets, mask=mask, other=0.0).to(tl.float32) |
| 73 | + |
| 74 | + square_sum += x * x |
| 75 | + |
| 76 | + # Compute the norm. |
| 77 | + rms = tl.rsqrt(tl.sum(square_sum) / n + EPS) |
| 78 | + |
| 79 | + # x[i] = r[i] + x[i] / rms * weight[i] |
| 80 | + for off in range(0, n, BLOCK_SIZE): |
| 81 | + offsets = off + tl.arange(0, BLOCK_SIZE) |
| 82 | + mask = offsets < n |
| 83 | + |
| 84 | + x = tl.load(x_row + offsets, mask=mask).to(tl.float32) |
| 85 | + w = tl.load(w_ptr + offsets, mask=mask).to(tl.float32) |
| 86 | + |
| 87 | + # Multiply x with RMS on float32, but cast to the narrower type before |
| 88 | + # multiplying with the weights to replicate the HF behaviour precisely. |
| 89 | + result = w * (x * rms) |
| 90 | + |
| 91 | + tl.store(o_row + offsets, result, mask=mask) |
| 92 | + |
| 93 | + |
| 94 | +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] |
| 95 | +def flashinfer_rmsnorm( |
| 96 | + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 |
| 97 | +) -> torch.Tensor: |
| 98 | + # Ensure the tensors are on the GPU |
| 99 | + assert input.is_cuda |
| 100 | + |
| 101 | + # Create output tensor |
| 102 | + output = torch.empty_like(input) |
| 103 | + |
| 104 | + # Define block size |
| 105 | + BLOCK_SIZE = 64 |
| 106 | + |
| 107 | + b, n = input.shape |
| 108 | + |
| 109 | + grid = lambda meta: (triton.cdiv(input.numel(), meta["BLOCK_SIZE"]),) |
| 110 | + |
| 111 | + rms_norm_kernel[grid]( |
| 112 | + input, weight, n, n, n, output, EPS=eps, BLOCK_SIZE=BLOCK_SIZE |
| 113 | + ) |
| 114 | + |
| 115 | + return output |
| 116 | + |
| 117 | + |
| 118 | +@trtp.register("flashinfer::rmsnorm") |
| 119 | +def add_plugin_desc( |
| 120 | + input: trtp.TensorDesc, weight: trtp.TensorDesc, eps: float |
| 121 | +) -> Tuple[trtp.TensorDesc]: |
| 122 | + return input.like() |
| 123 | + |
| 124 | + |
| 125 | +@trtp.aot_impl("flashinfer::rmsnorm") |
| 126 | +def flashinfer_rmsnorm( |
| 127 | + input: trtp.TensorDesc, |
| 128 | + weight: trtp.TensorDesc, |
| 129 | + eps: float, |
| 130 | + outputs: Tuple[trtp.TensorDesc], |
| 131 | + tactic: int, |
| 132 | +) -> Tuple[ |
| 133 | + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs |
| 134 | +]: |
| 135 | + assert tactic == 0 |
| 136 | + block_size = 64 |
| 137 | + # breakpoint() |
| 138 | + |
| 139 | + type_str = "fp32" if input.dtype == trt.float32 else "fp16" |
| 140 | + |
| 141 | + src = triton.compiler.ASTSource( |
| 142 | + fn=rms_norm_kernel, |
| 143 | + signature={ |
| 144 | + "x_ptr": f"*{type_str}", |
| 145 | + "w_ptr": f"*{type_str}", |
| 146 | + "n": "i32", |
| 147 | + "x_stride": "i32", |
| 148 | + "o_stride": "i32", |
| 149 | + "o_ptr": f"*{type_str}", |
| 150 | + "EPS": "constexpr", |
| 151 | + "BLOCK_SIZE": "constexpr", |
| 152 | + }, |
| 153 | + constants={ |
| 154 | + "EPS": eps, |
| 155 | + "BLOCK_SIZE": block_size, |
| 156 | + }, |
| 157 | + ) |
| 158 | + |
| 159 | + compiled_kernel = triton.compile(src) |
| 160 | + launch_params = trtp.KernelLaunchParams() |
| 161 | + |
| 162 | + inp_dims = input.shape_expr |
| 163 | + out_dims = outputs[0].shape_expr |
| 164 | + |
| 165 | + b = inp_dims[0] |
| 166 | + n = inp_dims[1] |
| 167 | + # breakpoint() |
| 168 | + |
| 169 | + # grid dims |
| 170 | + launch_params.grid_x = trtp.cdiv(out_dims.numel(), block_size) |
| 171 | + # block dims |
| 172 | + launch_params.block_x = compiled_kernel.metadata.num_warps * 32 |
| 173 | + # shared memory |
| 174 | + launch_params.shared_mem = compiled_kernel.metadata.shared |
| 175 | + |
| 176 | + extra_args = trtp.SymIntExprs(3) |
| 177 | + extra_args[0] = trtp.SymInt32(n) |
| 178 | + extra_args[1] = trtp.SymInt32(n) |
| 179 | + extra_args[2] = trtp.SymInt32(n) |
| 180 | + |
| 181 | + return ( |
| 182 | + compiled_kernel.metadata.name, |
| 183 | + compiled_kernel.asm["ptx"], |
| 184 | + launch_params, |
| 185 | + extra_args, |
| 186 | + ) |
| 187 | + |
| 188 | + |
| 189 | +# %% |
| 190 | +# The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape |
| 191 | +# in the course of the operation. |
| 192 | + |
| 193 | + |
| 194 | +@torch.library.register_fake("flashinfer::rmsnorm") |
| 195 | +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: |
| 196 | + return input |
| 197 | + |
| 198 | + |
| 199 | +# %% |
| 200 | +# Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using |
| 201 | +# TensorRT QDP APIs |
| 202 | +# torch_tensorrt.dynamo.conversion.plugins.generate_plugin( |
| 203 | +# "flashinfer::rmsnorm" |
| 204 | +# ) |
| 205 | + |
| 206 | + |
| 207 | +# # %% |
| 208 | +# # Generating the Converter |
| 209 | +# # ------------------------------------------------------------------- |
| 210 | +# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation. |
| 211 | +# # As long as the namespace and names match, the following function will automatically generate the converter for the operation. |
| 212 | + |
| 213 | + |
| 214 | +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( |
| 215 | + "flashinfer::rmsnorm", |
| 216 | + supports_dynamic_shapes=False, |
| 217 | + requires_output_allocator=False, |
| 218 | + aot=True, |
| 219 | +) |
| 220 | + |
| 221 | + |
| 222 | +# # %% |
| 223 | +# # Above two commands can be replaced with the following single one line: |
| 224 | +# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True) |
| 225 | + |
| 226 | + |
| 227 | +# %% |
| 228 | +# Using our converter with a model |
| 229 | +# ------------------------------------------------------------------- |
| 230 | +# |
| 231 | +# Now we can use our custom operator in a model and compile it with Torch-TensorRT. |
| 232 | +# We can see that the custom operator is used as one of the operations in the forward pass of the model. |
| 233 | +# The process of compiling the model at this point is identical to standard Torch-TensorRT usage. |
| 234 | +class MyModel(torch.nn.Module): # type: ignore[misc] |
| 235 | + def __init__(self): |
| 236 | + super().__init__() |
| 237 | + |
| 238 | + def forward(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
| 239 | + # z = torch.add(x, y) |
| 240 | + res = torch.ops.flashinfer.rmsnorm.default(input, weight) |
| 241 | + |
| 242 | + return res |
| 243 | + |
| 244 | + |
| 245 | +# input_tensor = torch.randn(10, 20, device="cuda", dtype=torch.float16) # 10 samples, 20 features |
| 246 | + |
| 247 | +# # Weight tensor (usually learnable parameters) |
| 248 | +# weight_tensor = torch.ones(20, device = "cuda", dtype=torch.float16) # Scaling factor for the features |
| 249 | + |
| 250 | +# # Small epsilon for numerical stability |
| 251 | +# eps = 1e-5 |
| 252 | + |
| 253 | +# Apply RMS Normalization using flashinfer |
| 254 | +# output_tensor = flashinfer.norm.rmsnorm(input_tensor, weight_tensor, eps) |
| 255 | + |
| 256 | +# print(output_tensor) |
| 257 | + |
| 258 | + |
| 259 | +my_model = MyModel().to("cuda") |
| 260 | +m = torch.randn((64, 64), device="cuda", dtype=torch.float16) |
| 261 | +n = torch.randn((64,), device="cuda", dtype=torch.float16) |
| 262 | + |
| 263 | + |
| 264 | +with torch_tensorrt.logging.info(): |
| 265 | + model_trt = torch_tensorrt.compile( |
| 266 | + my_model, |
| 267 | + inputs=[m, n], |
| 268 | + debug=True, |
| 269 | + min_block_size=1, |
| 270 | + enabled_precisions={torch.float16}, |
| 271 | + ) |
| 272 | + res = model_trt(m, n) |
| 273 | + |
| 274 | + print(res) |
| 275 | + print(my_model(m, n)) |
| 276 | + # for i in range(300): |
| 277 | + # res = model_trt(m, n) |
| 278 | + # assert torch.allclose(res, my_model(m, n)) |
| 279 | + |
| 280 | + |
| 281 | +print("Ran with custom plugin!") |
0 commit comments