Skip to content

Commit

Permalink
enable sharding in hybrid mode
Browse files Browse the repository at this point in the history
  • Loading branch information
shewu-quic committed Jan 8, 2025
1 parent 97ab146 commit 9def8ff
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 63 deletions.
15 changes: 11 additions & 4 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,8 +1655,12 @@ def test_qnn_backend_multi_graphs(self):
to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
for i, edge_prog in enumerate(edge_progs)
]
prog_mgr = generate_multi_graph_program(
compiler_specs=compiler_specs[0], exported_programs=exported_programs
prog_mgr, _ = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
processed_bytes=[
prog.graph_module.lowered_module_0.processed_bytes
for prog in exported_programs
],
)
for index, module in enumerate(modules):
self.verify_output(
Expand Down Expand Up @@ -2120,9 +2124,12 @@ def test_qnn_backend_multi_graphs(self):
to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
for i, edge_prog in enumerate(edge_progs)
]
prog_mgr = generate_multi_graph_program(
prog_mgr, _ = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
exported_programs=exported_programs,
processed_bytes=[
prog.graph_module.lowered_module_0.processed_bytes
for prog in exported_programs
],
)
for index, module in enumerate(modules):
self.verify_output(
Expand Down
181 changes: 135 additions & 46 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import operator
import re
import time
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple
Expand Down Expand Up @@ -740,17 +741,17 @@ def preprocess_binary(ctx_bin, compiler_specs):
for k, v in type_map.items():
dtype_map.setdefault(v, k)

qnn_in_order, executorch_in_order, executorch_out_order = [], [], []
qnn_in_order, executorch_in_order, executorch_out_order = None, None, None
if custom_info is not None:
# since some context binaries might fail to open on host
# if they are compiled with special flags:
# e.g. weight sharing
# use custom information here instead
inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
qnn_in_order = custom_info["qnn_in_order"]
executorch_in_order = custom_info["executorch_in_order"]
executorch_out_order = custom_info["executorch_out_order"]
qnn_in_order = custom_info.get("qnn_in_order", None)
executorch_in_order = custom_info.get("executorch_in_order", None)
executorch_out_order = custom_info.get("executorch_out_order", None)
graph_name = custom_info["graph_name"]
else:
# get context-binary io tensor info through qnn manager
Expand Down Expand Up @@ -800,7 +801,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule):

def generate_multi_graph_program(
compiler_specs: List[CompileSpec],
exported_programs: List[ExportedProgram] = None,
processed_bytes: List[bytes],
input_nodes_dict: List[torch.fx.Node] = None,
output_nodes_dict: List[torch.fx.Node] = None,
backend_config: ExecutorchBackendConfig = None,
constant_methods: Optional[Dict[str, Any]] = None,
) -> ExecutorchProgramManager:
Expand All @@ -813,10 +816,6 @@ def generate_multi_graph_program(
executorch_in_order,
executorch_out_order,
) = ({}, {}, {}, {}, {})

processed_bytes = [
prog.graph_module.lowered_module_0.processed_bytes for prog in exported_programs
]
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
generate_qnn_executorch_option(compiler_specs), processed_bytes
)
Expand All @@ -829,38 +828,36 @@ def generate_multi_graph_program(
graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)

# We need to obtain the order of the IOs to correctly map QNN with nn.module
for i, graph_name in enumerate(graph_names):
# input
input_names = [
node.name
for node in exported_programs[i].graph_module.graph.nodes
if node.op == "placeholder"
]
qnn_input_names = [wrapper.GetName() for wrapper in graph_inputs[graph_name]]
input_order_list = []
for input_name in input_names:
# e.g., input_0_tokens_0
pattern = rf"^input_(\d+)_({input_name})_(\d+)$"
for j in range(len(qnn_input_names)):
if re.match(pattern, qnn_input_names[j]):
input_order_list.append(j)
break
assert (
len(input_order_list) == len(input_names) == len(qnn_input_names)
), "Order list length is different from names"
executorch_in_order[graph_name] = input_order_list
qnn_in_order[graph_name] = sorted(
range(len(input_order_list)), key=lambda k: input_order_list[k]
)

# output
get_item_list = [
node
for node in exported_programs[i].graph_module.graph.nodes
if node.op == "output"
][0].args[0]
output_order_list = [item.args[1] for item in get_item_list]
executorch_out_order[graph_name] = output_order_list
for graph_name in graph_names:
if input_nodes_dict:
# input
input_names = [node.name for node in input_nodes_dict[graph_name]]
qnn_input_names = [
wrapper.GetName() for wrapper in graph_inputs[graph_name]
]
# The input of intermideate module including call_function node
# could not be reorder by node name
if len(input_names) == len(qnn_input_names):
input_order_list = []
for input_name in input_names:
# e.g., input_0_tokens_0
pattern = rf"^input_(\d+)_({input_name})_(\d+)$"
for j in range(len(qnn_input_names)):
if re.match(pattern, qnn_input_names[j]):
input_order_list.append(j)
break
assert len(input_order_list) == len(
input_names
), "Order list length is different from names"
executorch_in_order[graph_name] = input_order_list
qnn_in_order[graph_name] = sorted(
range(len(input_order_list)), key=lambda k: input_order_list[k]
)
if output_nodes_dict:
# output
get_item_list = output_nodes_dict[graph_name][0].args[0]
output_order_list = [item.args[1] for item in get_item_list]
executorch_out_order[graph_name] = output_order_list

qnn_mgr.Destroy()

Expand All @@ -869,15 +866,15 @@ def generate_multi_graph_program(
bundle_progs = [
from_context_binary(
ctx_path=binary_info,
op_name=f"loader_{graph_name}",
op_name=f"loader_{graph_name}_{int(time.time())}",
soc_model=compiler_options.soc_info.soc_model,
custom_info={
"graph_inputs": graph_inputs[graph_name],
"graph_outputs": graph_outputs[graph_name],
"graph_name": graph_name,
"qnn_in_order": qnn_in_order[graph_name],
"executorch_in_order": executorch_in_order[graph_name],
"executorch_out_order": executorch_out_order[graph_name],
"qnn_in_order": qnn_in_order.get(graph_name, None),
"executorch_in_order": executorch_in_order.get(graph_name, None),
"executorch_out_order": executorch_out_order.get(graph_name, None),
},
)
for graph_name in graph_names
Expand All @@ -900,9 +897,101 @@ def generate_multi_graph_program(
break

edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs))
return edge_prog_mgr.to_executorch(
exec_prog = edge_prog_mgr.to_executorch(
config=backend_config or ExecutorchBackendConfig()
)
return exec_prog, bundle_progs


def generate_composite_llama_program(
graph_names: List[str],
sample_inputs_list: List[Tuple[Any]],
lower_module_dict: Dict[str, List[LoweredBackendModule]],
call_delegate_node_name_dict: Dict[str, List[str]],
call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]],
outputs_dict: Dict[str, List[Tuple[str, int]]],
backend_config: ExecutorchBackendConfig = None,
constant_methods: Optional[Dict[str, Any]] = None,
) -> ExecutorchProgramManager:
class CompositeLlamaModule(torch.nn.Module):
def __init__(
self,
lower_module_list,
call_delegate_node_name_list,
call_delegate_inputs_list,
outputs_list,
) -> None:
super().__init__()
self.lower_module_list = lower_module_list
self.call_delegate_node_name_list = call_delegate_node_name_list
self.call_delegate_inputs_list = call_delegate_inputs_list
self.outputs_list = outputs_list

def reorder(
self,
call_delegate_inputs: List[Tuple[str, int | None]],
module_inputs: dict[str, torch.Tensor],
all_ret: dict[str, torch.Tensor],
) -> Tuple[torch.Tensor]:
ret = []
for name, index in call_delegate_inputs:
if index is not None:
# Get tensor from previous results
ret.append(all_ret[name][index])
else:
# Get tensor from the inputs of module
ret.append(module_inputs[name])
return tuple(ret)

def forward(
self,
tokens: torch.Tensor,
atten_mask: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
*args,
) -> Tuple[torch.Tensor]:
all_ret = {}
module_input_dict = {
"tokens": tokens,
"atten_mask": atten_mask,
"input_pos": input_pos,
}
for num, arg in enumerate(args):
module_input_dict[f"args_{num}"] = arg
for lower_module, call_delegate_node_name, call_delegate_inputs in zip(
self.lower_module_list,
self.call_delegate_node_name_list,
self.call_delegate_inputs_list,
):
inp = self.reorder(call_delegate_inputs, module_input_dict, all_ret)
ret = lower_module(*inp)
all_ret[call_delegate_node_name] = ret
llama_outputs = []
for output_src_name, index in self.outputs_list:
llama_outputs.append(all_ret[output_src_name][index])
return tuple(llama_outputs)

progs_dict = {}
for graph_name, sample_inputs in zip(graph_names, sample_inputs_list):
composite_llama_module = CompositeLlamaModule(
lower_module_dict[graph_name],
call_delegate_node_name_dict[graph_name],
call_delegate_inputs_dict[graph_name],
outputs_dict[graph_name],
)
prog = torch.export.export(composite_llama_module, sample_inputs)
progs_dict[graph_name] = prog
# leverage ExecutorchProgramManager for generating pte with multi-methods
edge_prog_mgr = to_edge(
progs_dict,
constant_methods=constant_methods,
# do not alter name for custom op
compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False),
)
exec_prog = edge_prog_mgr.to_executorch(
config=backend_config or ExecutorchBackendConfig()
)
return exec_prog


def generate_htp_compiler_spec(
Expand Down
Loading

0 comments on commit 9def8ff

Please sign in to comment.