Skip to content

Commit 4928f8b

Browse files
committed
update
1 parent ea01a78 commit 4928f8b

File tree

2 files changed

+49
-84
lines changed

2 files changed

+49
-84
lines changed

docsrc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Tutorials
6767
* :ref:`custom_kernel_plugins`
6868
* :ref:`auto_generate_converters`
6969
* :ref:`auto_generate_plugins`
70+
* :ref:`aot_rmsnorm_plugins`
7071
* :ref:`mutable_torchtrt_module_example`
7172
* :ref:`weight_streaming_example`
7273
* :ref:`pre_allocated_output_example`

examples/dynamo/aot_flashinfer_plugin.py renamed to examples/dynamo/aot_rmsnorm_plugins.py

Lines changed: 48 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
.. _auto_generate_converters:
2+
.. aot_rmsnorm_plugins:
33
4-
Automatically Generate a Plugin for a Custom Kernel
4+
Automatically Generate a Plugin but Use self defined kernels using TensorRT AOT Plugin
55
===================================================================
66
77
We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using
@@ -21,20 +21,22 @@
2121
the performance and resource overhead from a graph break.
2222
2323
Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_).
24-
With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
24+
As of TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
2525
plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the
2626
operation in PyTorch to TensorRT.
27+
28+
In addition, Torch-TensorRT provides automatic generation of TensorRT plugin feature (see: `Automatically Generate a Plugin for a Custom Kernel <https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/auto_generate_plugins.html>`_).
29+
However, the above methods generates a JIT plugin that might not satisfy user's performance requirements.
30+
To support that, Torch-TensorRT provides auto generation of TensorRT AOT Plugin which raps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered.
31+
This provides a performance boost comparing to JIT plugin.
32+
2733
"""
2834

2935
# %%
3036
# Writing Custom Operators in PyTorch
3137
# -----------------------------------------
3238
#
33-
# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.
34-
# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.
35-
# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type
36-
# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it
37-
# is necessary to define.
39+
# Here we define a Triton kernel which will later be compiled ahead of time for TensorRT Plugin.
3840
#
3941

4042
from typing import Tuple, Union
@@ -101,27 +103,41 @@ def flashinfer_rmsnorm(
101103
# Create output tensor
102104
output = torch.empty_like(input)
103105

104-
# Define block size
105-
BLOCK_SIZE = 64
106-
107106
b, n = input.shape
108107

108+
BLOCK_SIZE = 256
109+
109110
grid = lambda meta: (triton.cdiv(input.numel(), meta["BLOCK_SIZE"]),)
110111

111-
rms_norm_kernel[grid](
112-
input, weight, n, n, n, output, EPS=eps, BLOCK_SIZE=BLOCK_SIZE
112+
num_warps = max(8, min(32, BLOCK_SIZE // 256))
113+
114+
rms_norm_kernel[(b,)](
115+
input,
116+
weight,
117+
n,
118+
n,
119+
n,
120+
output,
121+
EPS=eps,
122+
BLOCK_SIZE=BLOCK_SIZE,
123+
num_warps=num_warps,
113124
)
114125

115126
return output
116127

117128

118-
@trtp.register("flashinfer::rmsnorm")
119-
def add_plugin_desc(
120-
input: trtp.TensorDesc, weight: trtp.TensorDesc, eps: float
121-
) -> Tuple[trtp.TensorDesc]:
122-
return input.like()
129+
# Use Torch-TensorRT automatic plugin generation feature that generates:
130+
# 1. A TensorRT JIT Plugin
131+
# 2. A custom op converter
132+
#
133+
# Torch-TensorRT will try to use AOT Plugin if present. If there is no registered AOT Plugin, Torch-TensorRT will utilize the
134+
# JIT Plugin that is generated in this line.
135+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
136+
"flashinfer::rmsnorm", supports_dynamic_shapes=True, requires_output_allocator=False
137+
)
123138

124139

140+
# TensorRT AOT Plugin implmentation, if this function is not provided, Torch-TensorRT will fallback to use JIT Plugin.
125141
@trtp.aot_impl("flashinfer::rmsnorm")
126142
def flashinfer_rmsnorm(
127143
input: trtp.TensorDesc,
@@ -133,8 +149,12 @@ def flashinfer_rmsnorm(
133149
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
134150
]:
135151
assert tactic == 0
136-
block_size = 64
137-
# breakpoint()
152+
153+
inp_dims = input.shape_expr
154+
155+
b = inp_dims[0]
156+
n = inp_dims[1]
157+
block_size = 256
138158

139159
type_str = "fp32" if input.dtype == trt.float32 else "fp16"
140160

@@ -150,26 +170,16 @@ def flashinfer_rmsnorm(
150170
"EPS": "constexpr",
151171
"BLOCK_SIZE": "constexpr",
152172
},
153-
constants={
154-
"EPS": eps,
155-
"BLOCK_SIZE": block_size,
156-
},
173+
constants={"EPS": eps, "BLOCK_SIZE": block_size},
157174
)
158175

159176
compiled_kernel = triton.compile(src)
160177
launch_params = trtp.KernelLaunchParams()
161178

162-
inp_dims = input.shape_expr
163-
out_dims = outputs[0].shape_expr
164-
165-
b = inp_dims[0]
166-
n = inp_dims[1]
167-
# breakpoint()
168-
169179
# grid dims
170-
launch_params.grid_x = trtp.cdiv(out_dims.numel(), block_size)
171-
# block dims
180+
launch_params.grid_x = b
172181
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
182+
173183
# shared memory
174184
launch_params.shared_mem = compiled_kernel.metadata.shared
175185

@@ -197,35 +207,7 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso
197207

198208

199209
# %%
200-
# Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using
201-
# TensorRT QDP APIs
202-
# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
203-
# "flashinfer::rmsnorm"
204-
# )
205-
206-
207-
# # %%
208-
# # Generating the Converter
209-
# # -------------------------------------------------------------------
210-
# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
211-
# # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
212-
213-
214-
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
215-
"flashinfer::rmsnorm",
216-
supports_dynamic_shapes=False,
217-
requires_output_allocator=False,
218-
aot=True,
219-
)
220-
221-
222-
# # %%
223-
# # Above two commands can be replaced with the following single one line:
224-
# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True)
225-
226-
227-
# %%
228-
# Using our converter with a model
210+
# Using AOT Plugin within a model
229211
# -------------------------------------------------------------------
230212
#
231213
# Now we can use our custom operator in a model and compile it with Torch-TensorRT.
@@ -236,31 +218,15 @@ def __init__(self):
236218
super().__init__()
237219

238220
def forward(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
239-
# z = torch.add(x, y)
240221
res = torch.ops.flashinfer.rmsnorm.default(input, weight)
241222

242223
return res
243224

244225

245-
# input_tensor = torch.randn(10, 20, device="cuda", dtype=torch.float16) # 10 samples, 20 features
246-
247-
# # Weight tensor (usually learnable parameters)
248-
# weight_tensor = torch.ones(20, device = "cuda", dtype=torch.float16) # Scaling factor for the features
249-
250-
# # Small epsilon for numerical stability
251-
# eps = 1e-5
252-
253-
# Apply RMS Normalization using flashinfer
254-
# output_tensor = flashinfer.norm.rmsnorm(input_tensor, weight_tensor, eps)
255-
256-
# print(output_tensor)
257-
258-
259226
my_model = MyModel().to("cuda")
260227
m = torch.randn((64, 64), device="cuda", dtype=torch.float16)
261228
n = torch.randn((64,), device="cuda", dtype=torch.float16)
262229

263-
264230
with torch_tensorrt.logging.info():
265231
model_trt = torch_tensorrt.compile(
266232
my_model,
@@ -271,11 +237,9 @@ def forward(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
271237
)
272238
res = model_trt(m, n)
273239

274-
print(res)
275-
print(my_model(m, n))
276-
# for i in range(300):
277-
# res = model_trt(m, n)
278-
# assert torch.allclose(res, my_model(m, n))
240+
for i in range(300):
241+
res = model_trt(m, n)
242+
assert torch.allclose(res, my_model(m, n))
279243

280244

281-
print("Ran with custom plugin!")
245+
print("Ran with AOT plugin!")

0 commit comments

Comments
 (0)