Skip to content

Commit

Permalink
Automatically generate operator tests (#2754)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2754

## Context

One of the most time consuming parts of adding new operators is writing tests to verify that the implementation is correct. This changeset introduces a codegen solution for automatically generating tests. The goal is to introduce a simple interface to specify what inputs an operator should be checked with, and have a 1 button solution for generating the code and executing operator tests.

## Usage Overview

From the developer's perspective, they only need to interact with `op_tests/cases.py`. The file is very simple:

```
# Prime numbers dim sizes for testing
XL = 113
L = 89
M2 = 41
M1 = 37
M = 29
S2 = 11
S1 = 7
S = 5
XS = 3

...

def get_mm_inputs():
    return [
        ((M1, L), (L, M2)),
        ((S1, S2), (S2, M)),
    ]

test_cases = {
    "aten.add.Tensor": get_binary_elementwise_inputs(),
    "aten.sub.Tensor": get_binary_elementwise_inputs(),
    "aten.div.Tensor": get_binary_elementwise_inputs(),
    "aten.mul.Tensor": get_binary_elementwise_inputs(),
    "aten.mm.default": get_mm_inputs(),
}
```

It just contains a mapping from the name an operator is registered under in the operator registry to a list of inputs for which tests should be generated.

To generate and run tests:

```
buck run //xplat/executorch/backends/vulkan/test/op_tests:compute_graph_op_tests_bin
```

## Design Overview

The code generation is mostly built on top of [torchgen](https://github.com/pytorch/pytorch/tree/main/torchgen), which is PyTorch's codegen system for parsing [native_function.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) and generating C++ ATen functions from it. The basic idea is:

1. Using the operator registry name, find the corresponding native function for native_function.yaml
2. Use the function schema from the parsed native function to generate test fixtures that can build a Vulkan compute graph for the operator
3. Individual test cases can be generated by creating ATen tensors and calling the ATen operator to get a reference output, then using the test fixture to get a Vulkan output and compare it to the reference output.
4. GTest [test parameterization](https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc) is used to test each test case under a combination of dtypes, storage types, and memory layout

[Example generated cpp](https://www.internalfb.com/phabricator/paste/view/P1202279441)

Reviewed By: copyrightly

Differential Revision: D55446638

fbshipit-source-id: 93ca8e7cd43cee1e2678c489d6f2227507ef256f
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Mar 29, 2024
1 parent 1c98d78 commit d4b3e5c
Show file tree
Hide file tree
Showing 6 changed files with 903 additions and 0 deletions.
5 changes: 5 additions & 0 deletions backends/vulkan/test/op_tests/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets(is_fbcode = True)
72 changes: 72 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from executorch.backends.vulkan.test.op_tests.utils.codegen import VkTestSuite


# Prime numbers dim sizes for testing
XL = 113
L = 89
M2 = 41
M1 = 37
M = 29
S2 = 11
S1 = 7
S = 5
XS = 3


def get_binary_elementwise_inputs():
return VkTestSuite(
[
((M1, M2), (M1, M2)),
((M1, M2), (M1, 1), 2.0),
((M1, M2), (1, M2)),
((S, S1, S2), (S, S1, S2)),
((S, S1, S2), (S, S1, 1), 2.0),
((S, S1, S2), (S, 1, S2), 2.0),
]
)


def get_mm_inputs():
test_suite = VkTestSuite(
[
((M1, L), (L, M2)),
((S1, S2), (S2, M)),
],
)
test_suite.prepacked_args = ["mat2"]
return test_suite


def get_pool2d_inputs():
test_suite = VkTestSuite(
[
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
"aten.div.Tensor": get_binary_elementwise_inputs(),
"aten.mul.Tensor": get_binary_elementwise_inputs(),
"aten.mm.default": get_mm_inputs(),
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
}

prepacked_args = {"aten.mm.default": {"mat2"}}

support_exceptions = {
"aten.max_pool2d_with_indices.default": {
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
},
}
79 changes: 79 additions & 0 deletions backends/vulkan/test/op_tests/generate_op_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os

from typing import Dict

from executorch.backends.vulkan.test.op_tests.cases import test_suites

from executorch.backends.vulkan.test.op_tests.utils.codegen import VkCppTestFileGen
from executorch.backends.vulkan.test.op_tests.utils.codegen_base import (
TestSuite,
TestSuiteGen,
)

from torchgen.gen import parse_native_yaml, ParsedYaml
from torchgen.model import DispatchKey, NativeFunction


def registry_name(f: NativeFunction) -> str:
name = str(f.namespace) + "." + str(f.func.name)
if len(f.func.name.overload_name) == 0:
name += ".default"
return name


def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]:
f_map: Dict[str, NativeFunction] = {}
for f in parsed_yaml.native_functions:
f_map[registry_name(f)] = f
return f_map


def process_test_suites(
cpp_generator: VkCppTestFileGen,
f_map: Dict[str, NativeFunction],
test_suites: Dict[str, TestSuite],
) -> None:
for registry_name, op_test_suite in test_suites.items():
f = f_map[registry_name]
cpp_generator.add_suite(registry_name, f, op_test_suite)


def generate_cpp(
native_functions_yaml_path: str, tags_path: str, output_dir: str
) -> None:
output_file = os.path.join(output_dir, "op_tests.cpp")
cpp_generator = VkCppTestFileGen(output_file)

parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
f_map = construct_f_map(parsed_yaml)

TestSuiteGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]

process_test_suites(cpp_generator, f_map, test_suites)

with open(output_file, "w") as file:
file.write(cpp_generator.generate_cpp())


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate a simple Hello World C++ program."
)
parser.add_argument(
"--aten-yaml-path",
help="path to native_functions.yaml file.",
)
parser.add_argument(
"--tags-path",
help="Path to tags.yaml. Required by yaml parsing in codegen system.",
)
parser.add_argument("-o", "--output", help="Output directory", required=True)
args = parser.parse_args()
generate_cpp(args.aten_yaml_path, args.tags_path, args.output)
77 changes: 77 additions & 0 deletions backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets(is_fbcode = False):
if is_fbcode:
return

runtime.python_library(
name = "generate_op_tests_lib",
srcs = native.glob(["utils/*.py"]) + [
"generate_op_tests.py",
"cases.py",
],
base_module = "executorch.backends.vulkan.test.op_tests",
deps = [
"//caffe2/torchgen:torchgen",
"fbsource//third-party/pypi/expecttest:expecttest",
],
)

runtime.python_binary(
name = "generate_op_tests",
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_tests",
deps = [
":generate_op_tests_lib",
],
)

aten_src_path = runtime.external_dep_location("aten-src-path")
genrule_cmd = [
"$(exe :generate_op_tests)",
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
"--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
"-o $OUT",
]

runtime.genrule(
name = "generated_op_tests_cpp",
outs = {
"op_tests.cpp": ["op_tests.cpp"],
},
cmd = " ".join(genrule_cmd),
default_outs = ["."],
)

runtime.cxx_binary(
name = "compute_graph_op_tests_bin",
srcs = [
":generated_op_tests_cpp[op_tests.cpp]",
],
define_static_target = False,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
],
)

runtime.cxx_test(
name = "compute_graph_op_tests",
srcs = [
":generated_op_tests_cpp[op_tests.cpp]",
],
contacts = ["[email protected]"],
fbandroid_additional_loaded_sonames = [
"torch-code-gen",
"vulkan_graph_runtime",
"vulkan_graph_runtime_shaderlib",
],
platforms = [ANDROID],
use_instrumentation_test = True,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
],
)
Loading

0 comments on commit d4b3e5c

Please sign in to comment.