diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index ca1aa78ef17..bf99c34f1bf 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -8,7 +8,6 @@ from .annotate_quant_attrs import AnnotateQuantAttrs from .annotate_stack import AnnotateStack from .annotate_unbind import AnnotateUnbind -from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny @@ -45,7 +44,6 @@ AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, - ConvertBmmToMatmul, ConvertConv1dToConv2d, ConvertSquareToPow, DecomposeAny, diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py deleted file mode 100644 index 84e1ff26aa1..00000000000 --- a/backends/qualcomm/_passes/convert_bmm_to_matmul.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import operator -from collections import Counter -from typing import List - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -class ConvertBmmToMatmul(ExportPass): - """ - Replace bmm to matmul, because bmm is eqaul to matmul in QNN. - Handle missing quantization tag for bmm op. - """ - - view_copy = exir_ops.edge.aten.view_copy.default - expand_copy = exir_ops.edge.aten.expand_copy.default - clone = exir_ops.edge.aten.clone.default - bmm = exir_ops.edge.aten.bmm.default - matmul = exir_ops.edge.aten.matmul.default - patterns = [ - {expand_copy: 2, view_copy: 3, bmm: 1}, - {expand_copy: 2, view_copy: 3, bmm: 1, clone: 1}, - {bmm: 1}, - ] - - def __init__(self): - super(ConvertBmmToMatmul, self).__init__() - - def _get_ordered_inputs( - self, inputs: List[torch.fx.Node], output: torch.fx.Node - ) -> List[torch.fx.Node]: - bmm_inputs = [] - for arg in output.args: - while arg not in inputs: - arg = arg.args[0] - bmm_inputs.append(arg) - return bmm_inputs - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - partitions = get_source_partitions( - graph, - [operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default], - ) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - op_cnt = Counter([n.target for n in src_partition.nodes]) - if op_cnt not in self.patterns: - raise AssertionError( - "Found a new pattern needed be converted to linear op" - ) - - inputs = src_partition.input_nodes - bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0] - output = src_partition.output_nodes[0] - # the order of src_partition.inputs is not guaranteed. - lhs, rhs = self._get_ordered_inputs(inputs, bmm_node) - with graph_module.graph.inserting_before(output): - # replace bmm to matmul, because bmm is eqaul to matmul in qnn. - matmul_node = graph.create_node( - "call_function", self.matmul, (lhs, rhs) - ) - matmul_node.meta = output.meta - for user in output.users.copy(): - user.replace_input_with(output, matmul_node) - - graph.eliminate_dead_code() - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index bb6a4dd0a67..a46f2ac53ce 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -13,7 +13,6 @@ AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, - ConvertBmmToMatmul, ConvertConv1dToConv2d, ConvertSquareToPow, DecomposeAny, @@ -80,7 +79,6 @@ def get_capture_program_passes(): (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), - (ConvertBmmToMatmul, True), (ConvertConv1dToConv2d, True), (DecomposeAny, True), (DecomposeColIm, True), diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 898e2d5b1f6..ef52d2c190a 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -64,7 +64,6 @@ def get_passes_dependency_for_capture_program(): AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, - ConvertBmmToMatmul, ConvertConv1dToConv2d, DecomposeAny, DecomposeColIm, @@ -85,12 +84,10 @@ def get_passes_dependency_for_capture_program(): AnnotateAdaptiveAvgPool1D: [RemoveRedundancy], AnnotateQuantAttrs: [ RecomposePixelUnshuffle, - ConvertBmmToMatmul, RemoveRedundancy, ], AnnotateStack: [RemoveRedundancy], AnnotateUnbind: [RemoveRedundancy], - ConvertBmmToMatmul: [RecomposePixelUnshuffle], DecomposeAny: [RemoveRedundancy], DecomposeColIm: [FoldQDQ], DecomposeLinalgVectorNorm: [RemoveRedundancy], diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 02318debfa6..4eee818efe5 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -50,6 +50,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.instance_norm.default, torch.ops.aten.leaky_relu.default, torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, torch.ops.aten.pixel_shuffle.default, torch.ops.aten.pixel_unshuffle.default, torch.ops.aten.prelu.default,