Skip to content

Commit 452bdb4

Browse files
committed
AOT plugin: examples with RMSNORM
1 parent 2b867d3 commit 452bdb4

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
"""
2+
.. _auto_generate_converters:
3+
4+
Automatically Generate a Plugin for a Custom Kernel
5+
===================================================================
6+
7+
We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using
8+
the new Python based plugin system in TensorRT 10.7.
9+
10+
Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT
11+
does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.
12+
The easiest way to fix lack of support for ops is by adding a decomposition (see:
13+
`Writing lowering passes for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html>`_) - which defines the operator
14+
in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:
15+
`Writing converters for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/dynamo_converters.html>`_) - which defines the operator in terms of TensorRT operators.
16+
17+
In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or
18+
TensorRT cannot support it natively.
19+
20+
For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding
21+
the performance and resource overhead from a graph break.
22+
23+
Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_).
24+
With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
25+
plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the
26+
operation in PyTorch to TensorRT.
27+
"""
28+
29+
# %%
30+
# Writing Custom Operators in PyTorch
31+
# -----------------------------------------
32+
#
33+
# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.
34+
# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.
35+
# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type
36+
# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it
37+
# is necessary to define.
38+
#
39+
40+
from typing import Tuple, Union
41+
42+
import tensorrt as trt
43+
import tensorrt.plugin as trtp
44+
import torch
45+
import torch_tensorrt
46+
import triton
47+
import triton.language as tl
48+
49+
50+
@triton.jit
51+
def rms_norm_kernel(
52+
x_ptr,
53+
w_ptr,
54+
n,
55+
x_stride,
56+
o_stride,
57+
o_ptr,
58+
EPS: tl.constexpr,
59+
BLOCK_SIZE: tl.constexpr,
60+
) -> None:
61+
i = tl.program_id(axis=0).to(tl.int64)
62+
63+
x_row = x_ptr + i * x_stride
64+
o_row = o_ptr + i * o_stride
65+
66+
# Find the root mean square for the given row.
67+
square_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
68+
for off in range(0, n, BLOCK_SIZE):
69+
offsets = off + tl.arange(0, BLOCK_SIZE)
70+
mask = offsets < n
71+
72+
x = tl.load(x_row + offsets, mask=mask, other=0.0).to(tl.float32)
73+
74+
square_sum += x * x
75+
76+
# Compute the norm.
77+
rms = tl.rsqrt(tl.sum(square_sum) / n + EPS)
78+
79+
# x[i] = r[i] + x[i] / rms * weight[i]
80+
for off in range(0, n, BLOCK_SIZE):
81+
offsets = off + tl.arange(0, BLOCK_SIZE)
82+
mask = offsets < n
83+
84+
x = tl.load(x_row + offsets, mask=mask).to(tl.float32)
85+
w = tl.load(w_ptr + offsets, mask=mask).to(tl.float32)
86+
87+
# Multiply x with RMS on float32, but cast to the narrower type before
88+
# multiplying with the weights to replicate the HF behaviour precisely.
89+
result = w * (x * rms)
90+
91+
tl.store(o_row + offsets, result, mask=mask)
92+
93+
94+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
95+
def flashinfer_rmsnorm(
96+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
97+
) -> torch.Tensor:
98+
# Ensure the tensors are on the GPU
99+
assert input.is_cuda
100+
101+
# Create output tensor
102+
output = torch.empty_like(input)
103+
104+
# Define block size
105+
BLOCK_SIZE = 64
106+
107+
b, n = input.shape
108+
109+
grid = lambda meta: (triton.cdiv(input.numel(), meta["BLOCK_SIZE"]),)
110+
111+
rms_norm_kernel[grid](
112+
input, weight, n, n, n, output, EPS=eps, BLOCK_SIZE=BLOCK_SIZE
113+
)
114+
115+
return output
116+
117+
118+
@trtp.register("flashinfer::rmsnorm")
119+
def add_plugin_desc(
120+
input: trtp.TensorDesc, weight: trtp.TensorDesc, eps: float
121+
) -> Tuple[trtp.TensorDesc]:
122+
return input.like()
123+
124+
125+
@trtp.aot_impl("flashinfer::rmsnorm")
126+
def flashinfer_rmsnorm(
127+
input: trtp.TensorDesc,
128+
weight: trtp.TensorDesc,
129+
eps: float,
130+
outputs: Tuple[trtp.TensorDesc],
131+
tactic: int,
132+
) -> Tuple[
133+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
134+
]:
135+
assert tactic == 0
136+
block_size = 64
137+
# breakpoint()
138+
139+
type_str = "fp32" if input.dtype == trt.float32 else "fp16"
140+
141+
src = triton.compiler.ASTSource(
142+
fn=rms_norm_kernel,
143+
signature={
144+
"x_ptr": f"*{type_str}",
145+
"w_ptr": f"*{type_str}",
146+
"n": "i32",
147+
"x_stride": "i32",
148+
"o_stride": "i32",
149+
"o_ptr": f"*{type_str}",
150+
"EPS": "constexpr",
151+
"BLOCK_SIZE": "constexpr",
152+
},
153+
constants={
154+
"EPS": eps,
155+
"BLOCK_SIZE": block_size,
156+
},
157+
)
158+
159+
compiled_kernel = triton.compile(src)
160+
launch_params = trtp.KernelLaunchParams()
161+
162+
inp_dims = input.shape_expr
163+
out_dims = outputs[0].shape_expr
164+
165+
b = inp_dims[0]
166+
n = inp_dims[1]
167+
# breakpoint()
168+
169+
# grid dims
170+
launch_params.grid_x = trtp.cdiv(out_dims.numel(), block_size)
171+
# block dims
172+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
173+
# shared memory
174+
launch_params.shared_mem = compiled_kernel.metadata.shared
175+
176+
extra_args = trtp.SymIntExprs(3)
177+
extra_args[0] = trtp.SymInt32(n)
178+
extra_args[1] = trtp.SymInt32(n)
179+
extra_args[2] = trtp.SymInt32(n)
180+
181+
return (
182+
compiled_kernel.metadata.name,
183+
compiled_kernel.asm["ptx"],
184+
launch_params,
185+
extra_args,
186+
)
187+
188+
189+
# %%
190+
# The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape
191+
# in the course of the operation.
192+
193+
194+
@torch.library.register_fake("flashinfer::rmsnorm")
195+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
196+
return input
197+
198+
199+
# %%
200+
# Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using
201+
# TensorRT QDP APIs
202+
# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
203+
# "flashinfer::rmsnorm"
204+
# )
205+
206+
207+
# # %%
208+
# # Generating the Converter
209+
# # -------------------------------------------------------------------
210+
# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
211+
# # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
212+
213+
214+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
215+
"flashinfer::rmsnorm",
216+
supports_dynamic_shapes=False,
217+
requires_output_allocator=False,
218+
aot=True,
219+
)
220+
221+
222+
# # %%
223+
# # Above two commands can be replaced with the following single one line:
224+
# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True)
225+
226+
227+
# %%
228+
# Using our converter with a model
229+
# -------------------------------------------------------------------
230+
#
231+
# Now we can use our custom operator in a model and compile it with Torch-TensorRT.
232+
# We can see that the custom operator is used as one of the operations in the forward pass of the model.
233+
# The process of compiling the model at this point is identical to standard Torch-TensorRT usage.
234+
class MyModel(torch.nn.Module): # type: ignore[misc]
235+
def __init__(self):
236+
super().__init__()
237+
238+
def forward(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
239+
# z = torch.add(x, y)
240+
res = torch.ops.flashinfer.rmsnorm.default(input, weight)
241+
242+
return res
243+
244+
245+
# input_tensor = torch.randn(10, 20, device="cuda", dtype=torch.float16) # 10 samples, 20 features
246+
247+
# # Weight tensor (usually learnable parameters)
248+
# weight_tensor = torch.ones(20, device = "cuda", dtype=torch.float16) # Scaling factor for the features
249+
250+
# # Small epsilon for numerical stability
251+
# eps = 1e-5
252+
253+
# Apply RMS Normalization using flashinfer
254+
# output_tensor = flashinfer.norm.rmsnorm(input_tensor, weight_tensor, eps)
255+
256+
# print(output_tensor)
257+
258+
259+
my_model = MyModel().to("cuda")
260+
m = torch.randn((64, 64), device="cuda", dtype=torch.float16)
261+
n = torch.randn((64,), device="cuda", dtype=torch.float16)
262+
263+
264+
with torch_tensorrt.logging.info():
265+
model_trt = torch_tensorrt.compile(
266+
my_model,
267+
inputs=[m, n],
268+
debug=True,
269+
min_block_size=1,
270+
enabled_precisions={torch.float16},
271+
)
272+
res = model_trt(m, n)
273+
274+
print(res)
275+
print(my_model(m, n))
276+
# for i in range(300):
277+
# res = model_trt(m, n)
278+
# assert torch.allclose(res, my_model(m, n))
279+
280+
281+
print("Ran with custom plugin!")

0 commit comments

Comments
 (0)