-
Notifications
You must be signed in to change notification settings - Fork 425
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Automatically generate operator tests (#2754)
Summary: Pull Request resolved: #2754 ## Context One of the most time consuming parts of adding new operators is writing tests to verify that the implementation is correct. This changeset introduces a codegen solution for automatically generating tests. The goal is to introduce a simple interface to specify what inputs an operator should be checked with, and have a 1 button solution for generating the code and executing operator tests. ## Usage Overview From the developer's perspective, they only need to interact with `op_tests/cases.py`. The file is very simple: ``` # Prime numbers dim sizes for testing XL = 113 L = 89 M2 = 41 M1 = 37 M = 29 S2 = 11 S1 = 7 S = 5 XS = 3 ... def get_mm_inputs(): return [ ((M1, L), (L, M2)), ((S1, S2), (S2, M)), ] test_cases = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), "aten.div.Tensor": get_binary_elementwise_inputs(), "aten.mul.Tensor": get_binary_elementwise_inputs(), "aten.mm.default": get_mm_inputs(), } ``` It just contains a mapping from the name an operator is registered under in the operator registry to a list of inputs for which tests should be generated. To generate and run tests: ``` buck run //xplat/executorch/backends/vulkan/test/op_tests:compute_graph_op_tests_bin ``` ## Design Overview The code generation is mostly built on top of [torchgen](https://github.com/pytorch/pytorch/tree/main/torchgen), which is PyTorch's codegen system for parsing [native_function.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) and generating C++ ATen functions from it. The basic idea is: 1. Using the operator registry name, find the corresponding native function for native_function.yaml 2. Use the function schema from the parsed native function to generate test fixtures that can build a Vulkan compute graph for the operator 3. Individual test cases can be generated by creating ATen tensors and calling the ATen operator to get a reference output, then using the test fixture to get a Vulkan output and compare it to the reference output. 4. GTest [test parameterization](https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc) is used to test each test case under a combination of dtypes, storage types, and memory layout [Example generated cpp](https://www.internalfb.com/phabricator/paste/view/P1202279441) Reviewed By: copyrightly Differential Revision: D55446638 fbshipit-source-id: 93ca8e7cd43cee1e2678c489d6f2227507ef256f
- Loading branch information
1 parent
1c98d78
commit d4b3e5c
Showing
6 changed files
with
903 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
load(":targets.bzl", "define_common_targets") | ||
|
||
oncall("executorch") | ||
|
||
define_common_targets(is_fbcode = True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from executorch.backends.vulkan.test.op_tests.utils.codegen import VkTestSuite | ||
|
||
|
||
# Prime numbers dim sizes for testing | ||
XL = 113 | ||
L = 89 | ||
M2 = 41 | ||
M1 = 37 | ||
M = 29 | ||
S2 = 11 | ||
S1 = 7 | ||
S = 5 | ||
XS = 3 | ||
|
||
|
||
def get_binary_elementwise_inputs(): | ||
return VkTestSuite( | ||
[ | ||
((M1, M2), (M1, M2)), | ||
((M1, M2), (M1, 1), 2.0), | ||
((M1, M2), (1, M2)), | ||
((S, S1, S2), (S, S1, S2)), | ||
((S, S1, S2), (S, S1, 1), 2.0), | ||
((S, S1, S2), (S, 1, S2), 2.0), | ||
] | ||
) | ||
|
||
|
||
def get_mm_inputs(): | ||
test_suite = VkTestSuite( | ||
[ | ||
((M1, L), (L, M2)), | ||
((S1, S2), (S2, M)), | ||
], | ||
) | ||
test_suite.prepacked_args = ["mat2"] | ||
return test_suite | ||
|
||
|
||
def get_pool2d_inputs(): | ||
test_suite = VkTestSuite( | ||
[ | ||
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]), | ||
] | ||
) | ||
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"] | ||
return test_suite | ||
|
||
|
||
test_suites = { | ||
"aten.add.Tensor": get_binary_elementwise_inputs(), | ||
"aten.sub.Tensor": get_binary_elementwise_inputs(), | ||
"aten.div.Tensor": get_binary_elementwise_inputs(), | ||
"aten.mul.Tensor": get_binary_elementwise_inputs(), | ||
"aten.mm.default": get_mm_inputs(), | ||
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(), | ||
} | ||
|
||
prepacked_args = {"aten.mm.default": {"mat2"}} | ||
|
||
support_exceptions = { | ||
"aten.max_pool2d_with_indices.default": { | ||
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"] | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import os | ||
|
||
from typing import Dict | ||
|
||
from executorch.backends.vulkan.test.op_tests.cases import test_suites | ||
|
||
from executorch.backends.vulkan.test.op_tests.utils.codegen import VkCppTestFileGen | ||
from executorch.backends.vulkan.test.op_tests.utils.codegen_base import ( | ||
TestSuite, | ||
TestSuiteGen, | ||
) | ||
|
||
from torchgen.gen import parse_native_yaml, ParsedYaml | ||
from torchgen.model import DispatchKey, NativeFunction | ||
|
||
|
||
def registry_name(f: NativeFunction) -> str: | ||
name = str(f.namespace) + "." + str(f.func.name) | ||
if len(f.func.name.overload_name) == 0: | ||
name += ".default" | ||
return name | ||
|
||
|
||
def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]: | ||
f_map: Dict[str, NativeFunction] = {} | ||
for f in parsed_yaml.native_functions: | ||
f_map[registry_name(f)] = f | ||
return f_map | ||
|
||
|
||
def process_test_suites( | ||
cpp_generator: VkCppTestFileGen, | ||
f_map: Dict[str, NativeFunction], | ||
test_suites: Dict[str, TestSuite], | ||
) -> None: | ||
for registry_name, op_test_suite in test_suites.items(): | ||
f = f_map[registry_name] | ||
cpp_generator.add_suite(registry_name, f, op_test_suite) | ||
|
||
|
||
def generate_cpp( | ||
native_functions_yaml_path: str, tags_path: str, output_dir: str | ||
) -> None: | ||
output_file = os.path.join(output_dir, "op_tests.cpp") | ||
cpp_generator = VkCppTestFileGen(output_file) | ||
|
||
parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path) | ||
f_map = construct_f_map(parsed_yaml) | ||
|
||
TestSuiteGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU] | ||
|
||
process_test_suites(cpp_generator, f_map, test_suites) | ||
|
||
with open(output_file, "w") as file: | ||
file.write(cpp_generator.generate_cpp()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Generate a simple Hello World C++ program." | ||
) | ||
parser.add_argument( | ||
"--aten-yaml-path", | ||
help="path to native_functions.yaml file.", | ||
) | ||
parser.add_argument( | ||
"--tags-path", | ||
help="Path to tags.yaml. Required by yaml parsing in codegen system.", | ||
) | ||
parser.add_argument("-o", "--output", help="Output directory", required=True) | ||
args = parser.parse_args() | ||
generate_cpp(args.aten_yaml_path, args.tags_path, args.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID") | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|
||
def define_common_targets(is_fbcode = False): | ||
if is_fbcode: | ||
return | ||
|
||
runtime.python_library( | ||
name = "generate_op_tests_lib", | ||
srcs = native.glob(["utils/*.py"]) + [ | ||
"generate_op_tests.py", | ||
"cases.py", | ||
], | ||
base_module = "executorch.backends.vulkan.test.op_tests", | ||
deps = [ | ||
"//caffe2/torchgen:torchgen", | ||
"fbsource//third-party/pypi/expecttest:expecttest", | ||
], | ||
) | ||
|
||
runtime.python_binary( | ||
name = "generate_op_tests", | ||
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_tests", | ||
deps = [ | ||
":generate_op_tests_lib", | ||
], | ||
) | ||
|
||
aten_src_path = runtime.external_dep_location("aten-src-path") | ||
genrule_cmd = [ | ||
"$(exe :generate_op_tests)", | ||
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path), | ||
"--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path), | ||
"-o $OUT", | ||
] | ||
|
||
runtime.genrule( | ||
name = "generated_op_tests_cpp", | ||
outs = { | ||
"op_tests.cpp": ["op_tests.cpp"], | ||
}, | ||
cmd = " ".join(genrule_cmd), | ||
default_outs = ["."], | ||
) | ||
|
||
runtime.cxx_binary( | ||
name = "compute_graph_op_tests_bin", | ||
srcs = [ | ||
":generated_op_tests_cpp[op_tests.cpp]", | ||
], | ||
define_static_target = False, | ||
deps = [ | ||
"//third-party/googletest:gtest_main", | ||
"//executorch/backends/vulkan:vulkan_graph_runtime", | ||
runtime.external_dep_location("libtorch"), | ||
], | ||
) | ||
|
||
runtime.cxx_test( | ||
name = "compute_graph_op_tests", | ||
srcs = [ | ||
":generated_op_tests_cpp[op_tests.cpp]", | ||
], | ||
contacts = ["[email protected]"], | ||
fbandroid_additional_loaded_sonames = [ | ||
"torch-code-gen", | ||
"vulkan_graph_runtime", | ||
"vulkan_graph_runtime_shaderlib", | ||
], | ||
platforms = [ANDROID], | ||
use_instrumentation_test = True, | ||
deps = [ | ||
"//third-party/googletest:gtest_main", | ||
"//executorch/backends/vulkan:vulkan_graph_runtime", | ||
runtime.external_dep_location("libtorch"), | ||
], | ||
) |
Oops, something went wrong.