diff --git a/backends/vulkan/test/op_tests/TARGETS b/backends/vulkan/test/op_tests/TARGETS new file mode 100644 index 0000000000..e84397dc20 --- /dev/null +++ b/backends/vulkan/test/op_tests/TARGETS @@ -0,0 +1,5 @@ +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets(is_fbcode = True) diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py new file mode 100644 index 0000000000..91b36a368b --- /dev/null +++ b/backends/vulkan/test/op_tests/cases.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from executorch.backends.vulkan.test.op_tests.utils.codegen import VkTestSuite + + +# Prime numbers dim sizes for testing +XL = 113 +L = 89 +M2 = 41 +M1 = 37 +M = 29 +S2 = 11 +S1 = 7 +S = 5 +XS = 3 + + +def get_binary_elementwise_inputs(): + return VkTestSuite( + [ + ((M1, M2), (M1, M2)), + ((M1, M2), (M1, 1), 2.0), + ((M1, M2), (1, M2)), + ((S, S1, S2), (S, S1, S2)), + ((S, S1, S2), (S, S1, 1), 2.0), + ((S, S1, S2), (S, 1, S2), 2.0), + ] + ) + + +def get_mm_inputs(): + test_suite = VkTestSuite( + [ + ((M1, L), (L, M2)), + ((S1, S2), (S2, M)), + ], + ) + test_suite.prepacked_args = ["mat2"] + return test_suite + + +def get_pool2d_inputs(): + test_suite = VkTestSuite( + [ + ((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]), + ] + ) + test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"] + return test_suite + + +test_suites = { + "aten.add.Tensor": get_binary_elementwise_inputs(), + "aten.sub.Tensor": get_binary_elementwise_inputs(), + "aten.div.Tensor": get_binary_elementwise_inputs(), + "aten.mul.Tensor": get_binary_elementwise_inputs(), + "aten.mm.default": get_mm_inputs(), + "aten.max_pool2d_with_indices.default": get_pool2d_inputs(), +} + +prepacked_args = {"aten.mm.default": {"mat2"}} + +support_exceptions = { + "aten.max_pool2d_with_indices.default": { + "layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"] + }, +} diff --git a/backends/vulkan/test/op_tests/generate_op_tests.py b/backends/vulkan/test/op_tests/generate_op_tests.py new file mode 100644 index 0000000000..ef4dc0af91 --- /dev/null +++ b/backends/vulkan/test/op_tests/generate_op_tests.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from typing import Dict + +from executorch.backends.vulkan.test.op_tests.cases import test_suites + +from executorch.backends.vulkan.test.op_tests.utils.codegen import VkCppTestFileGen +from executorch.backends.vulkan.test.op_tests.utils.codegen_base import ( + TestSuite, + TestSuiteGen, +) + +from torchgen.gen import parse_native_yaml, ParsedYaml +from torchgen.model import DispatchKey, NativeFunction + + +def registry_name(f: NativeFunction) -> str: + name = str(f.namespace) + "." + str(f.func.name) + if len(f.func.name.overload_name) == 0: + name += ".default" + return name + + +def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]: + f_map: Dict[str, NativeFunction] = {} + for f in parsed_yaml.native_functions: + f_map[registry_name(f)] = f + return f_map + + +def process_test_suites( + cpp_generator: VkCppTestFileGen, + f_map: Dict[str, NativeFunction], + test_suites: Dict[str, TestSuite], +) -> None: + for registry_name, op_test_suite in test_suites.items(): + f = f_map[registry_name] + cpp_generator.add_suite(registry_name, f, op_test_suite) + + +def generate_cpp( + native_functions_yaml_path: str, tags_path: str, output_dir: str +) -> None: + output_file = os.path.join(output_dir, "op_tests.cpp") + cpp_generator = VkCppTestFileGen(output_file) + + parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path) + f_map = construct_f_map(parsed_yaml) + + TestSuiteGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU] + + process_test_suites(cpp_generator, f_map, test_suites) + + with open(output_file, "w") as file: + file.write(cpp_generator.generate_cpp()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a simple Hello World C++ program." + ) + parser.add_argument( + "--aten-yaml-path", + help="path to native_functions.yaml file.", + ) + parser.add_argument( + "--tags-path", + help="Path to tags.yaml. Required by yaml parsing in codegen system.", + ) + parser.add_argument("-o", "--output", help="Output directory", required=True) + args = parser.parse_args() + generate_cpp(args.aten_yaml_path, args.tags_path, args.output) diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl new file mode 100644 index 0000000000..79cf418fc3 --- /dev/null +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -0,0 +1,77 @@ +load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(is_fbcode = False): + if is_fbcode: + return + + runtime.python_library( + name = "generate_op_tests_lib", + srcs = native.glob(["utils/*.py"]) + [ + "generate_op_tests.py", + "cases.py", + ], + base_module = "executorch.backends.vulkan.test.op_tests", + deps = [ + "//caffe2/torchgen:torchgen", + "fbsource//third-party/pypi/expecttest:expecttest", + ], + ) + + runtime.python_binary( + name = "generate_op_tests", + main_module = "executorch.backends.vulkan.test.op_tests.generate_op_tests", + deps = [ + ":generate_op_tests_lib", + ], + ) + + aten_src_path = runtime.external_dep_location("aten-src-path") + genrule_cmd = [ + "$(exe :generate_op_tests)", + "--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path), + "--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path), + "-o $OUT", + ] + + runtime.genrule( + name = "generated_op_tests_cpp", + outs = { + "op_tests.cpp": ["op_tests.cpp"], + }, + cmd = " ".join(genrule_cmd), + default_outs = ["."], + ) + + runtime.cxx_binary( + name = "compute_graph_op_tests_bin", + srcs = [ + ":generated_op_tests_cpp[op_tests.cpp]", + ], + define_static_target = False, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ], + ) + + runtime.cxx_test( + name = "compute_graph_op_tests", + srcs = [ + ":generated_op_tests_cpp[op_tests.cpp]", + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ], + ) diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py new file mode 100644 index 0000000000..9d99374e29 --- /dev/null +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -0,0 +1,450 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from typing import Any, List, Optional, Union + +from executorch.backends.vulkan.test.op_tests.utils.codegen_base import ( + AT_INT_ARRAY_REF, + AT_SCALAR, + AT_TENSOR, + BOOL, + CppTestFileGen, + TENSOR_TUPLE, + TestSuite, + TestSuiteGen, +) +from torchgen.api import cpp +from torchgen.api.types import CppSignatureGroup + +from torchgen.gen import generate_static_dispatch_backend_call +from torchgen.model import NativeFunction + +################################## +## Custom Test Suite Definition ## +################################## + + +@dataclass +class VkTestSuite(TestSuite): + supports = { + "storage_types": ["api::StorageType::TEXTURE_3D"], + "layouts": [ + "api::GPUMemoryLayout::TENSOR_WIDTH_PACKED", + "api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED", + ], + } + + +########################## +## Code Generator Class ## +########################## + + +@dataclass +class ATenArg: + name: str + cpp_type: str + default: Optional[str] + + +@dataclass +class ValueRef: + name: str + src_cpp_name: str + src_cpp_type: str + is_in: bool = False + is_out: bool = False + requires_prepack: bool = False + supports_prepack: bool = False + + +ValueRefList = Union[ValueRef, List[ValueRef]] + + +class ComputeGraphGen: + def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): + self.op_reg_name = op_reg_name + self.f = f + self.suite_def = suite_def + + self.f_sig = CppSignatureGroup.from_native_function( + self.f, method=False, fallback_binding=self.f.manual_cpp_binding + ).most_faithful_signature() + + self.graph = "graph" + self.dot = "->" + + self.args = [] + self.out = None + self.refs = {} + + self.should_prepack = False + + for binding in self.f_sig.arguments(): + arg = binding.argument + ctype = cpp.argumenttype_type( + arg.type, mutable=arg.is_write, binds=arg.name + ) + cpp_type = ctype.cpp_type(strip_ref=True) + + self.args.append( + ATenArg(name=arg.name, cpp_type=cpp_type, default=arg.default) + ) + + requires_prepack = "weight" in arg.name + supports_prepack = False + if arg.name in self.suite_def.prepacked_args: + supports_prepack = True + + self.refs[arg.name] = ValueRef( + name=f"{arg.name}_ref", + src_cpp_name=arg.name, + src_cpp_type=cpp_type, + is_in=(cpp_type == AT_TENSOR), + requires_prepack=requires_prepack, + supports_prepack=supports_prepack, + ) + + ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type() + self.out = ATenArg(name="out", cpp_type=ret_type, default=None) + if ret_type == AT_TENSOR: + self.refs["out"] = ValueRef( + name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True + ) + elif ret_type == TENSOR_TUPLE: + self.refs["out"] = [ + ValueRef( + name="out_ref_first", + src_cpp_name="std::get<0>(out)", + src_cpp_type="at::Tensor", + is_out=True, + ), + ValueRef( + name="out_ref_second", + src_cpp_name="std::get<1>(out)", + src_cpp_type="at::Tensor", + is_out=True, + ), + ValueRef( + name="out_ref", + src_cpp_name="out", + src_cpp_type=ret_type, + is_out=False, + ), + ] + + ## ATen code generation + + def gen_decl(self, fn_name: str, ret_type: str = "void") -> str: + cpp_args = [a.decl() for a in self.f_sig.arguments()] + cpp_args_str = ", ".join(cpp_args) + return f"{ret_type} {fn_name}({cpp_args_str})" + + def create_aten_fn_call(self) -> str: + func_call = generate_static_dispatch_backend_call( + self.f_sig, self.f, TestSuiteGen.backend_key + )[7:].replace("::cpu", "") + + return func_call + + def create_out_src(self) -> str: + return f"{self.out.cpp_type} out = " + self.create_aten_fn_call() + + ## Graph code generation utils + + def prepack_ref(self, ref: ValueRef) -> bool: + if ref.requires_prepack: + return True + else: + return ref.supports_prepack and self.should_prepack + + def create_value_for(self, ref: ValueRefList) -> str: + if isinstance(ref, list): + ret_str = "" + for r in ref: + ret_str += self.create_value_for(r) + return ret_str + + prepack = self.prepack_ref(ref) + + cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef" + ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}" + if ref.src_cpp_type == AT_TENSOR and not prepack: + ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" + ret_str += f"{ref.src_cpp_name}.sizes().vec(), " + ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n" + elif ref.src_cpp_type == AT_TENSOR and prepack: + ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), " + ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), " + ret_str += f"{ref.src_cpp_name}.const_data_ptr()); \n" + elif ref.src_cpp_type == AT_SCALAR: + # TODO(ssjia): generalize this to work with all scalar types + ret_str += f"add_scalar({ref.src_cpp_name}.toDouble()); \n" + elif ref.src_cpp_type == AT_INT_ARRAY_REF: + ret_str += f"add_scalar_list({ref.src_cpp_name}.vec()); \n" + elif ref.src_cpp_type == BOOL: + ret_str += f"add_scalar({ref.src_cpp_name}); \n" + elif ref.src_cpp_type == TENSOR_TUPLE: + ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n" + else: + raise RuntimeError(f"Unsupported cpp type {ref.src_cpp_type}") + + return ret_str + + def create_op_call(self) -> str: + deref = "*" if self.dot == "->" else "" + op_create_code = f'VK_GET_OP_FN("{self.op_reg_name}")({deref}{self.graph}, {{' + + for aten_arg in self.args: + ref = self.refs[aten_arg.name] + op_create_code += ( + f"{ref.name}.value, " + if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out + else f"{ref.name}, " + ) + + op_create_code += "out_ref});\n" + return op_create_code + + def set_output(self, ref: ValueRefList) -> str: + if isinstance(ref, list): + ret_str = "" + for r in ref[:-1]: + ret_str += self.set_output(r) + return ret_str + + assert ref.src_cpp_type == AT_TENSOR and ref.is_out + ret_str = f"ValueRef {ref.name}_staging = {self.graph}{self.dot}" + ret_str += f"set_output_tensor({ref.name});\n" + return ret_str + + def virtual_resize(self, ref: ValueRefList) -> str: + assert ref.src_cpp_type == AT_TENSOR and ref.is_in + if self.prepack_ref(ref): + return "" + ret_str = f"{self.graph}{self.dot}get_val({ref.name}.value).toTensor()" + ret_str += f".virtual_resize({ref.src_cpp_name}.sizes().vec());\n" + return ret_str + + def copy_into_staging(self, ref: ValueRefList) -> str: + assert ref.src_cpp_type == AT_TENSOR and ref.is_in + if self.prepack_ref(ref): + return "" + ret_str = f"{self.graph}{self.dot}copy_into_staging(" + ret_str += f"{ref.name}.staging, " + ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " + ret_str += f"{ref.src_cpp_name}.numel());\n" + return ret_str + + def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str: + if isinstance(ref, list): + ret_str = "" + for r in ref[:-1]: + ret_str += self.declare_vk_out_for(r) + return ret_str + + return f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name});\n" + + def copy_from_staging(self, ref: ValueRefList) -> str: + if isinstance(ref, list): + ret_str = "" + for r in ref[:-1]: + ret_str += self.copy_from_staging(r) + return ret_str + + assert ref.src_cpp_type == AT_TENSOR and ref.is_out + ret_str = f"{self.graph}{self.dot}copy_from_staging({ref.name}_staging, " + ret_str += f"vk_{ref.name}.mutable_data_ptr(), vk_{ref.name}.numel());\n" + + return ret_str + + ## Misc. code generation utilities + + def check_graph_out(self, ref: ValueRefList) -> str: + if isinstance(ref, list): + ret_str = "" + for r in ref[:-1]: + ret_str += self.check_graph_out(r) + return ret_str + + return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}));\n" + + ## Top level code generation + + def gen_graph_build_code(self) -> str: + graph_build = self.create_out_src() + + for aten_arg in self.args: + graph_build += self.create_value_for(self.refs[aten_arg.name]) + + graph_build += self.create_value_for(self.refs["out"]) + graph_build += self.create_op_call() + + graph_build += self.set_output(self.refs["out"]) + + graph_build += f"{self.graph}{self.dot}prepare();\n" + graph_build += f"{self.graph}{self.dot}encode_prepack();\n" + graph_build += f"{self.graph}{self.dot}prepack();\n" + graph_build += f"{self.graph}{self.dot}encode_execute();\n" + + return graph_build + + def gen_graph_exec_code(self) -> str: + graph_exec = "" + for aten_arg in self.args: + ref = self.refs[aten_arg.name] + if ref.is_in: + graph_exec += self.virtual_resize(ref) + graph_exec += self.copy_into_staging(ref) + + graph_exec += f"{self.graph}{self.dot}propagate_resize();\n" + graph_exec += f"{self.graph}{self.dot}execute();\n" + + graph_exec += self.declare_vk_out_for(self.refs["out"]) + graph_exec += self.copy_from_staging(self.refs["out"]) + + return graph_exec + + def gen_op_check_fn(self) -> str: + op_name = self.f.func.name.unambiguous_name() + op_check_fn = self.gen_decl(f"check_{op_name}") + " {" + if self.should_prepack: + op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {" + op_check_fn += self.gen_graph_build_code() + op_check_fn += self.gen_graph_exec_code() + op_check_fn += self.check_graph_out(self.refs["out"]) + op_check_fn += "}\n" + return op_check_fn + + +################################## +## Test Fixture Code Generation ## +################################## + +test_fixture_template = """ +class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple> {{ + protected: + ComputeGraph* graph; + at::ScalarType test_dtype = at::kFloat; + + void SetUp() override {{ + GraphConfig config; + api::StorageType default_storage_type; + api::GPUMemoryLayout default_memory_layout; + std::tie(default_storage_type, default_memory_layout) = GetParam(); + config.setStorageTypeOverride(default_storage_type); + config.setMemoryLayoutOverride(default_memory_layout); + graph = new ComputeGraph(config); + }} + + void TearDown() override {{ + delete graph; + graph = nullptr; + }} + + {check_fn} + + {prepacked_check_fn} + +}}; +""" + + +class VkTestSuiteGen(TestSuiteGen): + def __init__(self, op_reg_name: str, f: NativeFunction, inputs: List[Any]): + super().__init__(f, inputs) + self.op_reg_name = op_reg_name + self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def) + + def generate_fixture_cpp(self) -> str: + check_fn = "" + if not self.suite_def.requires_prepack: + check_fn = self.generator.gen_op_check_fn() + + prepacked_check_fn = "" + if self.suite_def.supports_prepack(): + self.generator.should_prepack = True + prepacked_check_fn = self.generator.gen_op_check_fn() + + return test_fixture_template.format( + op_name=self.op_name, + check_fn=check_fn, + prepacked_check_fn=prepacked_check_fn, + ) + + def gen_parameterization(self) -> str: + storage_types = self.suite_def.supports["storage_types"] + layouts = self.suite_def.supports["layouts"] + + return f""" + INSTANTIATE_TEST_SUITE_P( + StorageLayoutCombos_{self.op_name}, + GeneratedOpsTest_{self.op_name}, + ::testing::Combine( + ::testing::Values({', '.join(storage_types)}), + ::testing::Values({', '.join(layouts)}))); + """ + + +############################## +## Test File Code Generation ## +############################### + +preamble_str = """ +#include +#include +#include + +#include + +using namespace at::native::vulkan; + +api::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + switch(at_scalartype) { + case c10::kFloat: + return api::kFloat; + case c10::kHalf: + return api::kHalf; + case c10::kInt: + return api::kInt; + case c10::kLong: + return api::kInt; + default: + VK_THROW("Unsupported at::ScalarType!"); + } +} + +#ifdef USE_VULKAN_FP16_INFERENCE +bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-3) { +#else +bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-8) { +#endif + // Skip checking index tensors + if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) { + return true; + } + bool is_close = at::allclose(t1, t2, rtol, atol); + if (!is_close) { + std::cout << "t1:" << t1 << std::endl; + std::cout << "t2:" << t2 << std::endl; + } + return is_close; +} +""" + + +class VkCppTestFileGen(CppTestFileGen): + def __init__(self, out_path: str): + super().__init__(out_path) + + def generate_preamble(self) -> str: + return preamble_str + + def add_suite(self, op_reg_name: str, f: NativeFunction, all_input_cases) -> None: + suites_gen = VkTestSuiteGen(op_reg_name, f, all_input_cases) + self.suites_gens.append(suites_gen) diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py new file mode 100644 index 0000000000..2af45030ec --- /dev/null +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, List + +from torchgen.api import cpp +from torchgen.api.types import CppSignatureGroup +from torchgen.model import Argument, NativeFunction + +######################## +## ATen code patterns ## +######################## + +AT_TENSOR = "at::Tensor" +AT_SCALAR = "at::Scalar" +AT_INT_ARRAY_REF = "at::IntArrayRef" +BOOL = "bool" +TENSOR_TUPLE = "::std::tuple" + +########################### +## Test Suite definition ## +########################### + + +@dataclass +class TestSuite: + input_cases: List[Any] + prepacked_args = [] + requires_prepack = False + + def supports_prepack(self): + return len(self.prepacked_args) > 0 + + +########################## +## Test Suite Generation ## +########################## + +test_fixture_template = """ +class GeneratedOpsTest_{op_name} : public ::testing::Test {{ +}}; +""" + +test_suite_template = """ +TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ + {create_ref_data} + {create_and_check_out} +}} +""" + + +def init_list_str(pylist: Any) -> str: + if pylist == "[]": + return "{" + "}" + + if not isinstance(pylist, (list, tuple)): + pylist = [pylist] + + init_list_str = "{" + for s in pylist: + init_list_str += f"{s}, " + init_list_str = init_list_str[:-2] + "}" + return init_list_str + + +def get_or_return_default(arg: Argument, inputs: List[Any], i: int): + if i < len(inputs): + return inputs[i] + else: + assert arg.default is not None + return arg.default + + +class TestSuiteGen: + backend_key = None + + def __init__(self, f: NativeFunction, test_suite: TestSuite): + self.f = f + self.suite_def = test_suite + self.op_name = f.func.name.unambiguous_name() + + self.f_sig = CppSignatureGroup.from_native_function( + self.f, method=False, fallback_binding=self.f.manual_cpp_binding + ).most_faithful_signature() + + def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: + name_str = self.op_name + if prepack: + name_str += "_prepack" + for arg_sizes_or_val in inputs: + name_str += "_" + if isinstance(arg_sizes_or_val, tuple): + for size in arg_sizes_or_val: + name_str += str(size) + "x" + name_str = name_str[:-1] + elif isinstance(arg_sizes_or_val, list): + for size in arg_sizes_or_val: + name_str += str(size) + "c" + name_str = name_str[:-1] + else: + name_str += str(arg_sizes_or_val).replace(".", "p") + return name_str + + def create_input_data(self, arg: Argument, data: Any) -> str: + ctype = cpp.argumenttype_type(arg.type, mutable=arg.is_write, binds=arg.name) + cpp_type = ctype.cpp_type(strip_ref=True) + + if cpp_type == AT_INT_ARRAY_REF: + ret_str = f"std::vector {arg.name} = " + else: + ret_str = f"{cpp_type} {arg.name} = " + + if cpp_type == AT_TENSOR: + ret_str += f"make_rand_tensor({init_list_str(data)}, test_dtype);" + elif cpp_type == AT_SCALAR: + ret_str += f"{data};" + elif cpp_type == AT_INT_ARRAY_REF: + ret_str += f"{init_list_str(data)};" + elif cpp_type == BOOL: + ret_str += f"{str(data).lower()};" + else: + raise RuntimeError(f"Unsupported cpp type {cpp_type}") + return ret_str + "\n" + + def gen_create_ref_data(self, inputs: List[Any]) -> str: + ref_code = "" + + for i, binding in enumerate(self.f_sig.arguments()): + arg = binding.argument + arg_data = get_or_return_default(arg, inputs, i) + ref_code += self.create_input_data(arg, arg_data) + + return ref_code + + def gen_create_and_check_out(self, prepack=False) -> str: + test_str = f"check_{self.op_name}(" + if prepack: + test_str = f"prepacked_check_{self.op_name}(" + for binding in self.f_sig.arguments(): + arg = binding.argument + test_str += f"{arg.name}, " + test_str = test_str[:-2] + ");" + return test_str + + def gen_parameterization(self) -> str: + return "" + + def generate_fixture_cpp(self) -> str: + return test_fixture_template.format(op_name=self.f.func.name) + + def generate_case_cpp(self, inputs, prepack=False) -> str: + return test_suite_template.format( + op_name=f"{self.op_name}", + case_name=self.gen_case_name(inputs, prepack), + create_ref_data=self.gen_create_ref_data(inputs), + create_and_check_out=self.gen_create_and_check_out(prepack), + ) + + def generate_suite_cpp(self) -> str: + suite_cpp = self.generate_fixture_cpp() + "\n" + for inputs in self.suite_def.input_cases: + if not self.suite_def.requires_prepack: + suite_cpp += self.generate_case_cpp(inputs) + if self.suite_def.supports_prepack(): + suite_cpp += self.generate_case_cpp(inputs, prepack=True) + + suite_cpp += self.gen_parameterization() + return suite_cpp + + +########################## +## Test File Generation ## +########################## + +cpp_test_template = """ +#include + +#include + +{preamble} + +at::Tensor make_rand_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float high = 1.0, + float low = 0.0) {{ + if (high == 1.0 && low == 0.0) + return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); + + return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low; +}} + +{test_suites_cpp} +""" + + +class CppTestFileGen: + def __init__(self, out_path): + self.out_path = out_path + self.suites_gens = [] + + def generate_cpp(self) -> str: + return cpp_test_template.format( + preamble=self.generate_preamble(), + test_suites_cpp=self.generate_test_suites_cpp(), + ) + + def generate_preamble(self) -> str: + return "" + + def generate_test_suites_cpp(self) -> str: + return "\n".join([h.generate_suite_cpp() for h in self.suites_gens]) + + def add_suite(self, f: NativeFunction, test_suite: TestSuite) -> None: + suites_gen = TestSuiteGen(f, test_suite) + self.suites_gens.append(suites_gen)