Skip to content

Qualcomm AI Engine Direct - multi-method support in to_edge_transform_and_lower_to_qnn #11436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 25 additions & 41 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
PyQnnManagerAdaptor,
QnnPartitioner,
rewrite_prepared_observer,
skip_annotation,
to_edge_transform_and_lower_to_qnn,
Expand Down Expand Up @@ -89,12 +88,8 @@
from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel

from executorch.examples.models.wav2letter import Wav2LetterModel
from executorch.exir import EdgeProgramManager, to_edge
from executorch.exir.backend.backend_api import (
disable_validation,
MethodProgramsPartitionerSpec,
to_backend,
)
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import disable_validation


class TestQNNFloatingPointOperator(TestQNN):
Expand Down Expand Up @@ -2701,22 +2696,18 @@ def test_qnn_backend_multi_graphs(self):
)
for graph_name in graph_names
]
# TODO: retire capture_program once we figure out how to extract
# intermediate graph from official lowering API
edge_progs = {
graph_name: capture_program(module, sample_input).exported_program
for graph_name, module, sample_input in zip(
graph_names, modules, sample_inputs
)
}
partitioners = {
graph_name: QnnPartitioner(compiler_spec)
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
}
lowered_ep_dict = to_backend(
MethodProgramsPartitionerSpec(edge_progs, partitioners)

modules_dict = {}
sample_inputs_dict = {}
compiler_specs_dict = {}
for i, graph_name in enumerate(graph_names):
modules_dict[graph_name] = modules[i]
sample_inputs_dict[graph_name] = sample_inputs[i]
compiler_specs_dict[graph_name] = compiler_specs[i]
delegated_program = to_edge_transform_and_lower_to_qnn(
modules_dict, sample_inputs_dict, compiler_specs_dict
)
executorch_prog = EdgeProgramManager(lowered_ep_dict).to_executorch()
executorch_prog = delegated_program.to_executorch()
for index, module in enumerate(modules):
self.verify_output(
module=module,
Expand Down Expand Up @@ -3375,28 +3366,21 @@ def test_qnn_backend_multi_graphs(self):
)
for graph_name in graph_names
]
# TODO: retire capture_program once we figure out how to extract
# intermediate graph from official lowering API
for i, module in enumerate(modules):
module_exported = torch.export.export(module, sample_inputs[i]).module()
modules_dict = {}
sample_inputs_dict = {}
compiler_specs_dict = {}
for i, graph_name in enumerate(graph_names):
module_exported = torch.export.export(modules[i], sample_inputs[i]).module()
module_prepared = prepare_pt2e(module_exported, make_quantizer())
module_prepared(*sample_inputs[i])
modules[i] = convert_pt2e(module_prepared)

edge_progs = {
graph_name: capture_program(module, sample_input).exported_program
for graph_name, module, sample_input in zip(
graph_names, modules, sample_inputs
)
}
partitioners = {
graph_name: QnnPartitioner(compiler_spec)
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
}
lowered_ep_dict = to_backend(
MethodProgramsPartitionerSpec(edge_progs, partitioners)
modules_dict[graph_name] = convert_pt2e(module_prepared)
sample_inputs_dict[graph_name] = sample_inputs[i]
compiler_specs_dict[graph_name] = compiler_specs[i]
delegated_program = to_edge_transform_and_lower_to_qnn(
modules_dict, sample_inputs_dict, compiler_specs_dict
)
executorch_prog = EdgeProgramManager(lowered_ep_dict).to_executorch()

executorch_prog = delegated_program.to_executorch()
for index, module in enumerate(modules):
self.verify_output(
module=module,
Expand Down
142 changes: 104 additions & 38 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,60 +317,126 @@ def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]:


def to_edge_transform_and_lower_to_qnn(
module: Union[torch.nn.Module, torch.fx.GraphModule],
inputs: Tuple[torch.Tensor],
compiler_specs: List[CompileSpec],
module: Union[
torch.nn.Module,
torch.fx.GraphModule,
Dict[str, torch.nn.Module],
Dict[str, torch.fx.GraphModule],
],
inputs: Union[Tuple[torch.Tensor], Dict[str, Tuple[torch.Tensor]]],
compiler_specs: Union[List[Any], Dict[str, List[Any]]],
constant_methods: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Dict] = None,
dep_table: Optional[Dict] = None,
passes_job: Optional[OrderedDict] = None,
passes_job: Optional[Union[OrderedDict, Dict[str, OrderedDict]]] = None,
skip_node_id_set: Optional[set] = None,
skip_node_op_set: Optional[set] = None,
) -> EdgeProgramManager:
"""
Transforms and lowers a given PyTorch module to QNN backend.
Transforms and lowers a given PyTorch module to the QNN backend.

Args:
module (Union[torch.nn.Module, torch.fx.GraphModule]): The PyTorch module or fx.GraphModule to be transformed.
inputs (Tuple[torch.Tensor]): The input tensors for the module.
compiler_specs (List[CompileSpec]): Compiler specs for Qualcomm AI Engine Direct.
constant_methods (Optional[Dict[str, Any]]): An optional dictionary of method name to the constant value
returned by that method in eager mode. Often used to store config information on
Edge models.
dynamic_shapes (Optional[Dict]): Information about dynamic shapes.
dep_table (Optional[Dict]): Dependency table for the transformation passes.
passes_job (Optional[OrderedDict]): Ordered dictionary of transformation passes.
skip_node_id_set (Optional[set]): Set of node IDs to skip during partitioning.
skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning.
module (Union[torch.nn.Module, torch.fx.GraphModule,Dict[str, torch.nn.Module], Dict[str, torch.fx.GraphModule]]):
The PyTorch module or fx.GraphModule to be transformed.
inputs (Union[Tuple[torch.Tensor], Dict[str, Tuple[torch.Tensor]]]):
The input tensors for the module.
compiler_specs (Union[List[Any], Dict[str, List[Any]]]):
Compiler specifications for Qualcomm AI Engine Direct.
constant_methods (Optional[Dict[str, Any]]):
An optional dictionary mapping method names to constant values returned by those methods in eager mode.
Often used to store configuration information on Edge models.
dynamic_shapes (Optional[Dict]):
Information about dynamic shapes.
dep_table (Optional[Dict]):
Dependency table for the transformation passes.
passes_job (Optional[Union[OrderedDict, Dict[str, OrderedDict]]]):
Ordered dictionary of transformation passes.
skip_node_id_set (Optional[set]):
Set of node IDs to skip during partitioning.
skip_node_op_set (Optional[set]):
Set of node operations to skip during partitioning.

Returns:
EdgeProgramManager: The manager for the edge program after transformation and lowering.
EdgeProgramManager:
The manager for the edge program after transformation and lowering.
"""
ep = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes, strict=True)
# This transformation is primarily intended for the LiftConstantScalarOperands pass
# to avoid creating temporary tensors in the operation builder.
# However, this pass will create a get_attr node, which should be converted
# into a lifted tensor constant by the lift_constant_tensor_pass.
# If placed in the to_edge_transform_passes, it will be executed
# after the lift_constant_tensor_pass, causing the operation builder
# to fail to correctly retrieve the parameter by the get_parameter.
ep = QnnPassManager().transform_for_export_pipeline(ep)
transform_passes = QnnPassManager().get_to_edge_transform_passes(
ep, passes_job=passes_job, dep_table=dep_table
)
qnn_partitioner = QnnPartitioner(
compiler_specs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
)
edge_program_manager = to_edge_transform_and_lower(
ep,

def ensure_graph_specific_dict(value, graph_names):
"""
Ensures the input value is a dictionary with keys matching the provided graph names.
If the input is not a dictionary or its keys do not match the graph names, a new dictionary
is created with the graph names as keys and the input value assigned to each key.

Examples:
1. Input is None:
>>> ensure_graph_specific_dict(None, ["forward1", "forward2"])
{'forward1': None, 'forward2': None}

2. Input is a single value:
>>> ensure_graph_specific_dict(input, ["forward1", "forward2"])
{'forward1': input, 'forward2': input}

3. Input is a non-graph specific dict:
>>> ensure_graph_specific_dict({Any: input}, ["forward1", "forward2"])
{'forward1': {Any: input}, 'forward2': {Any: input}}
"""
if value is None:
return {graph_name: None for graph_name in graph_names}
if isinstance(value, dict) and graph_names == value.keys():
return value
return {graph_name: value for graph_name in graph_names}

if not isinstance(module, dict):
module = {"forward": module}

# Ensure attributes are graph-specific dictionaries
graph_names = module.keys()
inputs = ensure_graph_specific_dict(inputs, graph_names)
compiler_specs = ensure_graph_specific_dict(compiler_specs, graph_names)
dynamic_shapes = ensure_graph_specific_dict(dynamic_shapes, graph_names)
dep_table = ensure_graph_specific_dict(dep_table, graph_names)
passes_job = ensure_graph_specific_dict(passes_job, graph_names)

# Prepare programs and partitioners
aten_programs = {}
transform_passes = {}
qnn_partitioners = {
graph_name: [
QnnPartitioner(
compiler_specs[graph_name],
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
)
]
for graph_name in graph_names
}

for graph_name, m in module.items():
ep = torch.export.export(
m,
inputs[graph_name],
dynamic_shapes=dynamic_shapes[graph_name],
strict=True,
)
# This transformation is primarily intended for the LiftConstantScalarOperands pass
# to avoid creating temporary tensors in the operation builder.
# However, this pass will create a get_attr node, which should be converted
# into a lifted tensor constant by the lift_constant_tensor_pass.
# If placed in the to_edge_transform_passes, it will be executed
# after the lift_constant_tensor_pass, causing the operation builder
# to fail to correctly retrieve the parameter by the get_parameter.
aten_programs[graph_name] = QnnPassManager().transform_for_export_pipeline(ep)
transform_passes[graph_name] = QnnPassManager().get_to_edge_transform_passes(
ep, passes_job=passes_job[graph_name], dep_table=dep_table[graph_name]
)

return to_edge_transform_and_lower(
aten_programs,
transform_passes=transform_passes,
partitioner=[qnn_partitioner],
partitioner=qnn_partitioners,
constant_methods=constant_methods,
compile_config=qnn_edge_config(),
)
return edge_program_manager


def capture_program(
Expand Down
58 changes: 27 additions & 31 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,44 +692,43 @@ def permute(w, heads):
)
for graph_name in graph_names
]

# TODO: retire capture_program once we figure out how to extract
# intermediate graph from official lowering API
edge_progs = {
graph_name: capture_program(
module=llama_instance.llama_graph_module,
inputs=sample_input,
dep_table=llama_instance.dep_table,
passes_job=llama_instance.passes_job,
).exported_program
for graph_name, llama_instance, sample_input in zip(
graph_names, llama_instance_list, sample_inputs_list
)
}
for n in edge_progs[graph_names[0]].graph.nodes:
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
{
graph_name: instance.llama_graph_module
for graph_name, instance in zip(graph_names, llama_instance_list)
},
{
graph_name: inputs
for graph_name, inputs in zip(graph_names, sample_inputs_list)
},
{
graph_name: compiler_spec
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
},
llama_instance_list[1].llama_meta,
dep_table={
graph_name: instance.dep_table
for graph_name, instance in zip(graph_names, llama_instance_list)
},
passes_job={
graph_name: instance.passes_job
for graph_name, instance in zip(graph_names, llama_instance_list)
},
skip_node_op_set={"llama.fallback.default"},
)
for n in list(edge_prog_mgr._edge_programs.values())[0].graph.nodes:
if n.op == "output":
for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items():
if node.meta["val"].size() in llama_instance_list[0].io_shape:
quant_attrs = output_encoding

partitioners = {
graph_name: QnnPartitioner(
compiler_spec, skip_node_op_set={"llama.fallback.default"}
)
for graph_name, compiler_spec in zip(graph_names, compiler_specs)
}

lowered_ep_dict = to_backend(
MethodProgramsPartitionerSpec(edge_progs, partitioners)
)

if args.num_sharding > 1:
# TODO: add arg parser of spill_fill_size since weight-sharing based
# context binaries cannot be opened in x86 host
pass

if args.verbose:
for ep in lowered_ep_dict.values():
for ep in edge_prog_mgr._edge_programs.values():
print_delegation_info(ep.graph_module)

executorch_config = ExecutorchBackendConfig(
Expand All @@ -743,10 +742,7 @@ def permute(w, heads):
),
extract_delegate_segments=True,
)
exec_prog_mgr = EdgeProgramManager(
edge_programs=lowered_ep_dict,
constant_methods=llama_instance_list[1].llama_meta,
).to_executorch(executorch_config)
exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config)

with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
exec_prog_mgr.write_to_file(file)
Expand Down
Loading