From a327b73f2d43b384154c73a52fe012e32fcbcf06 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Mon, 6 Jan 2025 16:12:51 +0800 Subject: [PATCH] Add unit test to validate the size of the Spill-Fill buffer. --- backends/qualcomm/tests/models.py | 12 +++++++ backends/qualcomm/tests/test_qnn_delegate.py | 37 ++++++++++++++++++++ backends/qualcomm/utils/utils.py | 8 +++-- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 3faa1dfbe9..96aab87826 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -596,6 +596,18 @@ def forward(self, input_pos, k_val): return k_out +class LargeTensorLinear(torch.nn.Module): + def __init__(self): + super().__init__() + hidden_dim = 4096 + self.linear1 = torch.nn.Linear(512, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, 512) + + def forward(self, x): + x1 = self.linear1(x) + self.linear1(x) + return self.linear2(x1) + + class LayerNorm(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4b489ea515..388177a38e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1581,6 +1581,24 @@ def test_qnn_backend_skip_node_op(self): skip_node_op_set={"aten.add.Tensor"}, ) + def test_qnn_backend_spill_fill_buffer_size(self): + module = LargeTensorLinear() # noqa: F405 + sample_input = (torch.randn(1, 256, 512),) + edge_prog = capture_program(module, sample_input) + + backend_options = generate_htp_compiler_spec( + use_fp16=True, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) + max_sf_size = update_spill_fill_size(edge_prog.exported_program) + self.assertNotEqual(0, max_sf_size) + def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -2007,6 +2025,25 @@ def calibrator(gm): ).to_executorch() self.verify_output(module, sample_input, exec_prog) + def test_qnn_backend_spill_fill_buffer_size(self): + module = LargeTensorLinear() # noqa: F405 + sample_input = (torch.randn(1, 256, 512),) + module = self.get_qdq_module(module, sample_input) + edge_prog = capture_program(module, sample_input) + + backend_options = generate_htp_compiler_spec( + use_fp16=False, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) + max_sf_size = update_spill_fill_size(edge_prog.exported_program) + self.assertNotEqual(0, max_sf_size) + def test_qnn_backend_graph_level_mixed_precision(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 2e0ee4f7c6..a647453b09 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -268,15 +268,17 @@ def set_spec(module, options): options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size set_spec(module, options) + max_sf_size, modules_map = 0, {} if isinstance(exported_program, list): - max_sf_size, modules_map = 0, {} for prog in exported_program: max_sf_buf_size, module_map = get_program_info(prog) max_sf_size = max(max_sf_size, max_sf_buf_size) modules_map.update(module_map) - update_program(max_sf_size, modules_map) else: - update_program(*get_program_info(exported_program)) + max_sf_size, module_map = get_program_info(exported_program) + update_program(max_sf_size, module_map) + + return max_sf_size def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: