Skip to content

Commit cf104c6

Browse files
committed
Fixed the lowering pass
1 parent 9f75c42 commit cf104c6

File tree

5 files changed

+76
-157
lines changed

5 files changed

+76
-157
lines changed

examples/apps/flux_demo.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ def compile_model(
4444
torch_dtype=torch.float16,
4545
).to(torch.float16)
4646

47-
if args.debug:
48-
pipe.transformer = FluxTransformer2DModel(
49-
num_layers=1, num_single_layers=1, guidance_embeds=True
50-
).to(torch.float16)
51-
5247
if args.low_vram_mode:
5348
pipe.enable_model_cpu_offload()
5449
else:
@@ -266,10 +261,5 @@ def main(args):
266261
action="store_true",
267262
help="Use dynamic shapes",
268263
)
269-
parser.add_argument(
270-
"--debug",
271-
action="store_true",
272-
help="Use debug mode",
273-
)
274264
args = parser.parse_args()
275265
main(args)

examples/dynamo/register_sdpa.py

Lines changed: 75 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -32,153 +32,88 @@
3232

3333

3434
@_aten_lowering_pass
35-
def lower_scaled_dot_product_attention(
35+
def replace_variants_of_sdpa(
3636
gm: torch.fx.GraphModule, settings: CompilationSettings
3737
) -> torch.fx.GraphModule:
38-
"""Replace specific versions of scaled_dot_product_attention with an equivalent
39-
implementation which can be easily converted to TRT
38+
"""Replace scaled_dot_product_attention with an equivalent
39+
implementation which can be accurately converted to TRT
4040
"""
41-
original_fns, replacement = scaled_dot_product_attention_replacement()
42-
replaced_nodes = []
43-
# For each original function, search for it in the graph and replace
44-
for original in original_fns:
45-
replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
46-
gm,
47-
original,
48-
replacement,
49-
ignore_literals=True,
50-
)
51-
52-
if replaced_nodes:
53-
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
54-
# Also repair instances which specified the is_causal or attn_bias fields
55-
for match in replaced_nodes:
56-
attention_node_replaced = None
57-
# Seek the attention operator being replaced
58-
for node in match.nodes_map:
59-
if node.target in REPLACEABLE_ATEN_OPS:
60-
attention_node_replaced = match.nodes_map[node]
61-
break
62-
63-
assert attention_node_replaced is not None
64-
assert len(match.replacements) == 1
65-
66-
new_attention_node = match.replacements[0]
67-
68-
assert (
69-
new_attention_node.target
70-
== torch.nn.functional.scaled_dot_product_attention
71-
)
72-
73-
# Copy the metadata of the replaced attention node to the new node
74-
# TODO: Investigate why there are multiple FakeTensors in the metadata.
75-
# We only use the first one as it contains the output shape information for this node.
76-
if "val" in attention_node_replaced.meta:
77-
new_attention_node.meta["val"] = copy.copy(
78-
attention_node_replaced.meta["val"][0]
79-
)
80-
81-
# If the attention operator had keyword-args, copy them to the new node
82-
if attention_node_replaced.kwargs:
83-
new_attention_node.kwargs = {**attention_node_replaced.kwargs}
84-
85-
# Set default args in new node:
86-
# Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False
87-
new_attention_node.args = new_attention_node.args + (None, 0.0, False)
88-
89-
# The `is_causal` argument was specified
41+
attn_mask = None
42+
is_causal = True
43+
for node in gm.graph.nodes:
44+
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
9045
if (
91-
(
92-
attention_node_replaced.target
93-
== torch.ops.aten._scaled_dot_product_flash_attention.default
94-
)
95-
and args_bounds_check(attention_node_replaced.args, 4, False)
96-
) or (
97-
(
98-
attention_node_replaced.target
99-
== torch.ops.aten._scaled_dot_product_efficient_attention.default
100-
)
101-
and args_bounds_check(attention_node_replaced.args, 6, False)
46+
node.target
47+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
48+
):
49+
if len(node.args) == 7:
50+
(
51+
query,
52+
key,
53+
value,
54+
attn_bias,
55+
compute_log_sumexp,
56+
dropout_p,
57+
is_causal,
58+
) = node.args
59+
elif len(node.args) == 5:
60+
query, key, value, attn_mask, is_causal = node.args
61+
dropout_p = 0.0
62+
else:
63+
raise ValueError(
64+
f"Unexpected number of arguments for {node.target} in the graph"
65+
)
66+
elif (
67+
node.target
68+
== torch.ops.aten._scaled_dot_product_flash_attention.default
10269
):
103-
new_attention_node.args = (
104-
new_attention_node.args[:5] + (True,) + new_attention_node.args[6:]
70+
if len(node.args) == 6:
71+
query, key, value, dropout_p, is_causal, return_debug_mask = (
72+
node.args
73+
)
74+
elif len(node.args) == 3:
75+
query, key, value = node.args
76+
dropout_p = 0.0
77+
is_causal = True
78+
else:
79+
raise ValueError(
80+
f"Unexpected number of arguments for {node.target} in the graph"
81+
)
82+
if attn_mask is not None:
83+
logger.warning(
84+
f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration."
10585
)
10686

107-
# The `attn_bias` argument was specified
108-
if (
109-
attention_node_replaced.target
110-
== torch.ops.aten._scaled_dot_product_efficient_attention.default
111-
) and args_bounds_check(attention_node_replaced.args, 3) is not None:
112-
new_attention_node.args = (
113-
new_attention_node.args[:3]
114-
+ attention_node_replaced.args[3]
115-
+ new_attention_node.args[4:]
116-
)
87+
modified_input_args = (query, key, value, None, dropout_p, is_causal)
11788

118-
gm = clean_up_graph_after_modifications(gm)
119-
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
89+
# Create a new node with torch.nn.functional.scaled_dot_product_attention
90+
# The input args is (query, key, value, is_causal). kwargs has scale
91+
with gm.graph.inserting_after(node):
92+
new_node = gm.graph.call_function(
93+
torch.nn.functional.scaled_dot_product_attention,
94+
args=modified_input_args,
95+
kwargs={"scale": node.kwargs.get("scale", None)},
96+
)
12097

98+
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
99+
new_node.meta = copy.copy(node.meta)
100+
# Check if there's a getitem node following this attention node
101+
for user in list(node.users):
102+
if user.op == "call_function" and user.target == operator.getitem:
103+
# If the getitem is extracting the first element (the output tensor)
104+
if user.args[1] == 0:
105+
# Replace all uses of the getitem with the new attention node
106+
user.replace_all_uses_with(new_node)
107+
new_node.meta["val"] = new_node.meta["val"][0]
108+
# Replace all uses of the original node with the new node
109+
node.replace_all_uses_with(new_node)
110+
111+
gm.graph.erase_node(node)
112+
113+
# Clean up the graph
114+
clean_up_graph_after_modifications(gm)
115+
116+
logger.info(
117+
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
118+
)
121119
return gm
122-
123-
124-
def scaled_dot_product_attention_replacement() -> Tuple[
125-
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
126-
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
127-
]:
128-
"""Constructs the original and replacement functions for efficient attention"""
129-
130-
# Efficient Attention original graph
131-
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
132-
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
133-
q,
134-
k,
135-
v,
136-
None,
137-
False,
138-
)
139-
out = operator.getitem(outputs, 0)
140-
return out
141-
142-
# Flash Attention original graph
143-
def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
144-
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
145-
q,
146-
k,
147-
v,
148-
)
149-
out = operator.getitem(outputs, 0)
150-
return out
151-
152-
# Efficient Attention w/Scale original graph
153-
def efficient_scale(
154-
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
155-
) -> torch.Tensor:
156-
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
157-
q,
158-
k,
159-
v,
160-
None,
161-
False,
162-
scale=1.0,
163-
)
164-
out = operator.getitem(outputs, 0)
165-
return out
166-
167-
# Flash Attention w/Scale original graph
168-
def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
169-
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
170-
q,
171-
k,
172-
v,
173-
scale=1.0,
174-
)
175-
out = operator.getitem(outputs, 0)
176-
return out
177-
178-
# Replacement graph consists of the functional version of scaled_dot_product_attention
179-
def replacement(
180-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
181-
) -> torch.Tensor:
182-
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
183-
184-
return (efficient, flash, efficient_scale, flash_scale), replacement

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def quantize(
7575
dtype = trt.DataType.FP8
7676
max_bound = 448
7777

78-
7978
axis = None
8079
# int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
8180
if dtype == trt.DataType.INT8 and amax.numel() > 1:

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
100100
super().__init__(*args, **kwargs)
101101
# Set of known quantization ops to be excluded from constant folding.
102102
# Currently, we exclude all quantization ops coming from modelopt library.
103-
self.quantization_ops = set()
103+
self.quantization_ops: Set[torch._ops.OpOverload] = set()
104104
try:
105105
# modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
106106
import modelopt.torch.quantization as mtq

tools/perf/Flux/flux_perf.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@ def main(args):
5656
action="store_true",
5757
help="Use dynamic shapes",
5858
)
59-
parser.add_argument(
60-
"--debug",
61-
action="store_true",
62-
help="Use debug mode",
63-
)
6459
parser.add_argument("--max_batch_size", type=int, default=1)
6560
args = parser.parse_args()
6661
main(args)

0 commit comments

Comments
 (0)