diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4b489ea515..43b78d341c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1655,8 +1655,12 @@ def test_qnn_backend_multi_graphs(self): to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) for i, edge_prog in enumerate(edge_progs) ] - prog_mgr = generate_multi_graph_program( - compiler_specs=compiler_specs[0], exported_programs=exported_programs + prog_mgr, _ = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=[ + prog.graph_module.lowered_module_0.processed_bytes + for prog in exported_programs + ], ) for index, module in enumerate(modules): self.verify_output( @@ -2120,9 +2124,12 @@ def test_qnn_backend_multi_graphs(self): to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) for i, edge_prog in enumerate(edge_progs) ] - prog_mgr = generate_multi_graph_program( + prog_mgr, _ = generate_multi_graph_program( compiler_specs=compiler_specs[0], - exported_programs=exported_programs, + processed_bytes=[ + prog.graph_module.lowered_module_0.processed_bytes + for prog in exported_programs + ], ) for index, module in enumerate(modules): self.verify_output( diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 2e0ee4f7c6..3d2a9f8c85 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -6,6 +6,7 @@ import operator import re +import time import warnings from collections import OrderedDict from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple @@ -740,7 +741,7 @@ def preprocess_binary(ctx_bin, compiler_specs): for k, v in type_map.items(): dtype_map.setdefault(v, k) - qnn_in_order, executorch_in_order, executorch_out_order = [], [], [] + qnn_in_order, executorch_in_order, executorch_out_order = None, None, None if custom_info is not None: # since some context binaries might fail to open on host # if they are compiled with special flags: @@ -748,9 +749,9 @@ def preprocess_binary(ctx_bin, compiler_specs): # use custom information here instead inputs = build_tensor(custom_info["graph_inputs"], dtype_map) outputs = build_tensor(custom_info["graph_outputs"], dtype_map) - qnn_in_order = custom_info["qnn_in_order"] - executorch_in_order = custom_info["executorch_in_order"] - executorch_out_order = custom_info["executorch_out_order"] + qnn_in_order = custom_info.get("qnn_in_order", None) + executorch_in_order = custom_info.get("executorch_in_order", None) + executorch_out_order = custom_info.get("executorch_out_order", None) graph_name = custom_info["graph_name"] else: # get context-binary io tensor info through qnn manager @@ -800,7 +801,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule): def generate_multi_graph_program( compiler_specs: List[CompileSpec], - exported_programs: List[ExportedProgram] = None, + processed_bytes: List[bytes], + input_nodes_dict: List[torch.fx.Node] = None, + output_nodes_dict: List[torch.fx.Node] = None, backend_config: ExecutorchBackendConfig = None, constant_methods: Optional[Dict[str, Any]] = None, ) -> ExecutorchProgramManager: @@ -813,10 +816,6 @@ def generate_multi_graph_program( executorch_in_order, executorch_out_order, ) = ({}, {}, {}, {}, {}) - - processed_bytes = [ - prog.graph_module.lowered_module_0.processed_bytes for prog in exported_programs - ] qnn_mgr = PyQnnManagerAdaptor.QnnManager( generate_qnn_executorch_option(compiler_specs), processed_bytes ) @@ -829,38 +828,36 @@ def generate_multi_graph_program( graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name) # We need to obtain the order of the IOs to correctly map QNN with nn.module - for i, graph_name in enumerate(graph_names): - # input - input_names = [ - node.name - for node in exported_programs[i].graph_module.graph.nodes - if node.op == "placeholder" - ] - qnn_input_names = [wrapper.GetName() for wrapper in graph_inputs[graph_name]] - input_order_list = [] - for input_name in input_names: - # e.g., input_0_tokens_0 - pattern = rf"^input_(\d+)_({input_name})_(\d+)$" - for j in range(len(qnn_input_names)): - if re.match(pattern, qnn_input_names[j]): - input_order_list.append(j) - break - assert ( - len(input_order_list) == len(input_names) == len(qnn_input_names) - ), "Order list length is different from names" - executorch_in_order[graph_name] = input_order_list - qnn_in_order[graph_name] = sorted( - range(len(input_order_list)), key=lambda k: input_order_list[k] - ) - - # output - get_item_list = [ - node - for node in exported_programs[i].graph_module.graph.nodes - if node.op == "output" - ][0].args[0] - output_order_list = [item.args[1] for item in get_item_list] - executorch_out_order[graph_name] = output_order_list + for graph_name in graph_names: + if input_nodes_dict: + # input + input_names = [node.name for node in input_nodes_dict[graph_name]] + qnn_input_names = [ + wrapper.GetName() for wrapper in graph_inputs[graph_name] + ] + # The input of intermideate module including call_function node + # could not be reorder by node name + if len(input_names) == len(qnn_input_names): + input_order_list = [] + for input_name in input_names: + # e.g., input_0_tokens_0 + pattern = rf"^input_(\d+)_({input_name})_(\d+)$" + for j in range(len(qnn_input_names)): + if re.match(pattern, qnn_input_names[j]): + input_order_list.append(j) + break + assert len(input_order_list) == len( + input_names + ), "Order list length is different from names" + executorch_in_order[graph_name] = input_order_list + qnn_in_order[graph_name] = sorted( + range(len(input_order_list)), key=lambda k: input_order_list[k] + ) + if output_nodes_dict: + # output + get_item_list = output_nodes_dict[graph_name][0].args[0] + output_order_list = [item.args[1] for item in get_item_list] + executorch_out_order[graph_name] = output_order_list qnn_mgr.Destroy() @@ -869,15 +866,15 @@ def generate_multi_graph_program( bundle_progs = [ from_context_binary( ctx_path=binary_info, - op_name=f"loader_{graph_name}", + op_name=f"loader_{graph_name}_{int(time.time())}", soc_model=compiler_options.soc_info.soc_model, custom_info={ "graph_inputs": graph_inputs[graph_name], "graph_outputs": graph_outputs[graph_name], "graph_name": graph_name, - "qnn_in_order": qnn_in_order[graph_name], - "executorch_in_order": executorch_in_order[graph_name], - "executorch_out_order": executorch_out_order[graph_name], + "qnn_in_order": qnn_in_order.get(graph_name, None), + "executorch_in_order": executorch_in_order.get(graph_name, None), + "executorch_out_order": executorch_out_order.get(graph_name, None), }, ) for graph_name in graph_names @@ -900,9 +897,101 @@ def generate_multi_graph_program( break edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)) - return edge_prog_mgr.to_executorch( + exec_prog = edge_prog_mgr.to_executorch( + config=backend_config or ExecutorchBackendConfig() + ) + return exec_prog, bundle_progs + + +def generate_composite_llama_program( + graph_names: List[str], + sample_inputs_list: List[Tuple[Any]], + lower_module_dict: Dict[str, List[LoweredBackendModule]], + call_delegate_node_name_dict: Dict[str, List[str]], + call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]], + outputs_dict: Dict[str, List[Tuple[str, int]]], + backend_config: ExecutorchBackendConfig = None, + constant_methods: Optional[Dict[str, Any]] = None, +) -> ExecutorchProgramManager: + class CompositeLlamaModule(torch.nn.Module): + def __init__( + self, + lower_module_list, + call_delegate_node_name_list, + call_delegate_inputs_list, + outputs_list, + ) -> None: + super().__init__() + self.lower_module_list = lower_module_list + self.call_delegate_node_name_list = call_delegate_node_name_list + self.call_delegate_inputs_list = call_delegate_inputs_list + self.outputs_list = outputs_list + + def reorder( + self, + call_delegate_inputs: List[Tuple[str, int | None]], + module_inputs: dict[str, torch.Tensor], + all_ret: dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor]: + ret = [] + for name, index in call_delegate_inputs: + if index is not None: + # Get tensor from previous results + ret.append(all_ret[name][index]) + else: + # Get tensor from the inputs of module + ret.append(module_inputs[name]) + return tuple(ret) + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor]: + all_ret = {} + module_input_dict = { + "tokens": tokens, + "atten_mask": atten_mask, + "input_pos": input_pos, + } + for num, arg in enumerate(args): + module_input_dict[f"args_{num}"] = arg + for lower_module, call_delegate_node_name, call_delegate_inputs in zip( + self.lower_module_list, + self.call_delegate_node_name_list, + self.call_delegate_inputs_list, + ): + inp = self.reorder(call_delegate_inputs, module_input_dict, all_ret) + ret = lower_module(*inp) + all_ret[call_delegate_node_name] = ret + llama_outputs = [] + for output_src_name, index in self.outputs_list: + llama_outputs.append(all_ret[output_src_name][index]) + return tuple(llama_outputs) + + progs_dict = {} + for graph_name, sample_inputs in zip(graph_names, sample_inputs_list): + composite_llama_module = CompositeLlamaModule( + lower_module_dict[graph_name], + call_delegate_node_name_dict[graph_name], + call_delegate_inputs_dict[graph_name], + outputs_dict[graph_name], + ) + prog = torch.export.export(composite_llama_module, sample_inputs) + progs_dict[graph_name] = prog + # leverage ExecutorchProgramManager for generating pte with multi-methods + edge_prog_mgr = to_edge( + progs_dict, + constant_methods=constant_methods, + # do not alter name for custom op + compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False), + ) + exec_prog = edge_prog_mgr.to_executorch( config=backend_config or ExecutorchBackendConfig() ) + return exec_prog def generate_htp_compiler_spec( diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index df54a6196e..a18690e941 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -31,6 +31,7 @@ from executorch.backends.qualcomm.utils.utils import ( capture_program, convert_linear_to_conv2d, + generate_composite_llama_program, generate_htp_compiler_spec, generate_multi_graph_program, generate_qnn_executorch_compiler_spec, @@ -365,7 +366,7 @@ def lowering_modules( if num_sharding > 0: update_spill_fill_size(edge_prog_mgr.exported_program()) exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) - with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: + with open(f"{work_space}/{pte_filename}.pte", "wb") as file: exec_prog_mgr.write_to_file(file) def get_example_inputs(self, use_kv_cache=True): @@ -435,7 +436,7 @@ def compile(args, pte_filename): use_fp16 = True fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} - if args.ptq != None: + if args.ptq: use_fp16 = False fixed_point_type["kv_type"] = torch.uint8 if args.ptq == "8a8w": @@ -464,7 +465,7 @@ def compile(args, pte_filename): llama_instance_list[i].eval(), pte_filename ) - if args.ptq != None: + if args.ptq: start_quantize_ts = time.time() for llama_instance in llama_instance_list: llama_instance.quantize( @@ -481,6 +482,7 @@ def compile(args, pte_filename): logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") start_lowering_ts = time.time() + quant_attrs = None if len(llama_instance_list) == 1: llama_instance_list[0].lowering_modules( @@ -515,7 +517,9 @@ def compile(args, pte_filename): edge_progs[i].exported_program.graph_module, fixed_point_type, ) - backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 0 + ) graph_names = ["prefill_forward", "kv_forward"] compiler_specs = [ generate_qnn_executorch_compiler_spec( @@ -527,10 +531,17 @@ def compile(args, pte_filename): ) for graph_name in graph_names ] + skip_node_op_set = {"llama.fallback.default"} exported_programs = [ - to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) + to_backend( + edge_prog.exported_program, + QnnPartitioner(compiler_specs[i], skip_node_op_set=skip_node_op_set), + ) for i, edge_prog in enumerate(edge_progs) ] + if args.num_sharding > 0: + for exported_program in exported_programs: + update_spill_fill_size(exported_program) executorch_config = ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -544,14 +555,117 @@ def compile(args, pte_filename): extract_delegate_segments=True, ) - prog_mgr = generate_multi_graph_program( - compiler_specs=compiler_specs[0], - exported_programs=exported_programs, - backend_config=executorch_config, - constant_methods=llama_instance_list[1].llama_meta, # kv method meta - ) - with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: - prog_mgr.write_to_file(file) + lower_module_dict = {name: [] for name in graph_names} + call_delegate_inputs_dict = {name: [] for name in graph_names} + call_delegate_node_name_dict = {name: [] for name in graph_names} + outputs_dict = {name: [] for name in graph_names} + input_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(exported_programs, graph_names): + for node in prog.graph_module.graph.nodes: + if ( + node.op == "call_function" + and "executorch_call_delegate" in node.name + ): + call_delegate_node_name_dict[graph_name].append(node.name) + call_delegate_inputs_list = [] + for arg in node.args: + if arg.op == "call_function": + while "getitem" not in arg.name: + arg = arg.args[0] + call_delegate_inputs_list.append( + (arg.args[0].name, arg.args[1]) + ) + elif arg.op == "placeholder": + call_delegate_inputs_list.append((arg.name, None)) + # No extra needs to do for get_attr node + call_delegate_inputs_dict[graph_name].append( + call_delegate_inputs_list + ) + elif node.op == "output": + for arg in node.args[0]: + outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) + + if args.num_sharding > 0: + bundle_progs_list = [] + for num in range(args.num_sharding - 1, -1, -1): + processed_bytes = [] + for prog, graph_name in zip(exported_programs, graph_names): + processed_bytes.append( + getattr( + prog.graph_module, f"lowered_module_{num}" + ).processed_bytes + ) + + call_delegate_node = [ + list(node.users.keys())[0] + for node in prog.graph_module.graph.nodes + if node.op == "get_attr" + and node.name == f"lowered_module_{num}" + ] + input_nodes_dict[graph_name] = [ + node + for node in call_delegate_node[0].args + if node.op == "placeholder" + ] + + prog_mgr, bundle_progs = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[ + 1 + ].llama_meta, # kv method meta + ) + bundle_progs_list.append(bundle_progs) + for graph_name in graph_names: + lower_module_dict[graph_name].append( + prog_mgr.exported_program(graph_name).graph_module._modules.get( + "lowered_module_0" + ) + ) + + exec_prog = generate_composite_llama_program( + graph_names=graph_names, + sample_inputs_list=sample_inputs_list, + lower_module_dict=lower_module_dict, + call_delegate_node_name_dict=call_delegate_node_name_dict, + call_delegate_inputs_dict=call_delegate_inputs_dict, + outputs_dict=outputs_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + exec_prog.write_to_file(file) + else: + processed_bytes = [] + input_nodes_dict = {name: [] for name in graph_names} + output_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(exported_programs, graph_names): + processed_bytes.append( + prog.graph_module.lowered_module_0.processed_bytes + ) + input_nodes_dict[graph_name] = [ + node + for node in prog.graph_module.graph.nodes + if node.op == "placeholder" + ] + output_nodes_dict[graph_name] = [ + node + for node in prog.graph_module.graph.nodes + if node.op == "output" + ] + + prog_mgr, _ = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + output_nodes_dict=output_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + prog_mgr.write_to_file(file) end_lowering_ts = time.time() logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}")