Skip to content

Commit

Permalink
[Specialized Kernel] Propagate Specialized Kernel Support through Com…
Browse files Browse the repository at this point in the history
…puteCodegenUnboxedKernels (pytorch#103113)

Updating ComputeCodegenUnboxedKernels to accept and write out kernel information to RegisterCodegenUnboxedKernels.cpp

Differential Revision: [D46486195](https://our.internmc.facebook.com/intern/diff/D46486195/)
Pull Request resolved: pytorch#103113
Approved by: https://github.com/larryliu0820, https://github.com/kirklandsign
  • Loading branch information
Jack-Khuu authored and pytorchmergebot committed Jun 14, 2023
1 parent e3ee5b0 commit e9674d1
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 75 deletions.
40 changes: 20 additions & 20 deletions test/edge/operator_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,40 @@
namespace torch {
namespace executor {

OperatorRegistry& getOperatorRegistry() {
static OperatorRegistry operator_registry;
return operator_registry;
KernelRegistry& getKernelRegistry() {
static KernelRegistry kernel_registry;
return kernel_registry;
}

bool register_operators(const ArrayRef<Operator>& operators) {
return getOperatorRegistry().register_operators(operators);
bool register_kernels(const ArrayRef<Kernel>& kernels) {
return getKernelRegistry().register_kernels(kernels);
}

bool OperatorRegistry::register_operators(
const ArrayRef<Operator>& operators) {
for (const auto& op : operators) {
this->operators_map_[op.name_] = op.op_;
bool KernelRegistry::register_kernels(
const ArrayRef<Kernel>& kernels) {
for (const auto& kernel : kernels) {
this->kernels_map_[kernel.name_] = kernel.kernel_;
}
return true;
}

bool hasOpsFn(const char* name) {
return getOperatorRegistry().hasOpsFn(name);
bool hasKernelFn(const char* name) {
return getKernelRegistry().hasKernelFn(name);
}

bool OperatorRegistry::hasOpsFn(const char* name) {
auto op = this->operators_map_.find(name);
return op != this->operators_map_.end();
bool KernelRegistry::hasKernelFn(const char* name) {
auto kernel = this->kernels_map_.find(name);
return kernel != this->kernels_map_.end();
}

OpFunction& getOpsFn(const char* name) {
return getOperatorRegistry().getOpsFn(name);
KernelFunction& getKernelFn(const char* name) {
return getKernelRegistry().getKernelFn(name);
}

OpFunction& OperatorRegistry::getOpsFn(const char* name) {
auto op = this->operators_map_.find(name);
TORCH_CHECK_MSG(op != this->operators_map_.end(), "Operator not found!");
return op->second;
KernelFunction& KernelRegistry::getKernelFn(const char* name) {
auto kernel = this->kernels_map_.find(name);
TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!");
return kernel->second;
}


Expand Down
44 changes: 22 additions & 22 deletions test/edge/operator_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,60 +11,60 @@
namespace torch {
namespace executor {

using OpFunction = std::function<void(RuntimeContext&, EValue**)>;
using KernelFunction = std::function<void(RuntimeContext&, EValue**)>;

template<typename T>
using ArrayRef = at::ArrayRef<T>;

#define EXECUTORCH_SCOPE_PROF(x)

struct Operator {
struct Kernel {
const char* name_;
OpFunction op_;
KernelFunction kernel_;

Operator() = default;
Kernel() = default;

/**
* We are doing a copy of the string pointer instead of duplicating the string
* itself, we require the lifetime of the operator name to be at least as long
* as the operator registry.
* itself, we require the lifetime of the kernel name to be at least as long
* as the kernel registry.
*/
explicit Operator(const char* name, OpFunction func)
: name_(name), op_(func) {}
explicit Kernel(const char* name, KernelFunction func)
: name_(name), kernel_(func) {}
};

/**
* See OperatorRegistry::hasOpsFn()
* See KernelRegistry::hasKernelFn()
*/
bool hasOpsFn(const char* name);
bool hasKernelFn(const char* name);

/**
* See OperatorRegistry::getOpsFn()
* See KernelRegistry::getKernelFn()
*/
OpFunction& getOpsFn(const char* name);
KernelFunction& getKernelFn(const char* name);


[[nodiscard]] bool register_operators(const ArrayRef<Operator>&);
[[nodiscard]] bool register_kernels(const ArrayRef<Kernel>&);

struct OperatorRegistry {
struct KernelRegistry {
public:
OperatorRegistry() : operatorRegSize_(0) {}
KernelRegistry() : kernelRegSize_(0) {}

bool register_operators(const ArrayRef<Operator>&);
bool register_kernels(const ArrayRef<Kernel>&);

/**
* Checks whether an operator with a given name is registered
* Checks whether an kernel with a given name is registered
*/
bool hasOpsFn(const char* name);
bool hasKernelFn(const char* name);

/**
* Checks whether an operator with a given name is registered
* Checks whether an kernel with a given name is registered
*/
OpFunction& getOpsFn(const char* name);
KernelFunction& getKernelFn(const char* name);

private:
std::map<const char*, OpFunction> operators_map_;
uint32_t operatorRegSize_;
std::map<const char*, KernelFunction> kernels_map_;
uint32_t kernelRegSize_;
};

} // namespace executor
Expand Down
18 changes: 9 additions & 9 deletions test/edge/templates/RegisterCodegenUnboxedKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
#include <operator_registry.h>
#include "Functions.h"
#include "${fn_header}" // Generated Function import headers

namespace torch {
namespace executor {

namespace {
using OpArrayRef = ::at::ArrayRef<::torch::executor::Operator>;
using KernelArrayRef = ::at::ArrayRef<::torch::executor::Kernel>;

static Operator operators_to_register[] = {
${unboxed_ops} // Generated operators
static Kernel kernels_to_register[] = {
${unboxed_kernels} // Generated operators
};

// Explicitly convert to ArrayRef, so that the API can take an empty C array of
// Operators.
static OpArrayRef op_array_ref(
operators_to_register,
operators_to_register + sizeof(operators_to_register) / sizeof(Operator));
// Kernels.
static KernelArrayRef kernel_array_ref(
kernels_to_register,
kernels_to_register + sizeof(kernels_to_register) / sizeof(Kernel));

// Return value not used. Keep the static variable assignment to register
// operators in static initialization time.
static auto success_with_op_reg = register_operators(op_array_ref);
static auto success_with_kernel_reg = register_kernels(kernel_array_ref);
} // namespace
} // namespace executor
} // namespace torch
8 changes: 4 additions & 4 deletions test/edge/test_operator_registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ TEST(OperatorRegistrationTest, Add) {
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(int64_t(1));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("aten::add.out"));
auto op = getOpsFn("aten::add.out");
ASSERT_TRUE(hasKernelFn("aten::add.out"));
auto op = getKernelFn("aten::add.out");

EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
Expand All @@ -33,8 +33,8 @@ TEST(OperatorRegistrationTest, CustomAdd3) {
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(at::ones({2, 3}));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("custom::add_3.out"));
auto op = getOpsFn("custom::add_3.out");
ASSERT_TRUE(hasKernelFn("custom::add_3.out"));
auto op = getKernelFn("custom::add_3.out");

EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
Expand Down
89 changes: 87 additions & 2 deletions tools/test/test_executorch_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import yaml

from torchgen.executorch.model import ETKernelIndex
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader

from torchgen.gen_executorch import (
ComputeCodegenUnboxedKernels,
gen_functions_declarations,
parse_yaml_files,
translate_native_yaml,
Expand Down Expand Up @@ -397,7 +398,6 @@ def test_aten_lib_has_context_arg(self) -> None:
selector=SelectiveBuilder.get_nop_selector(),
use_aten_lib=True,
)
print(declarations)
self.assertTrue(
"""
namespace custom_1 {
Expand All @@ -411,3 +411,88 @@ def test_aten_lib_has_context_arg(self) -> None:
"""
in declarations
)


class TestComputeCodegenUnboxedKernels(unittest.TestCase):
def setUp(self) -> None:
(
self.native_function_no_kern,
_,
) = NativeFunction.from_yaml(
{
"func": "custom_1::op_1() -> bool",
"dispatch": {"CPU": "unused_kernel_1"},
},
loc=Location(__file__, 1),
valid_tags=set(),
)

self.default_kernel_key = ETKernelKey(default=True)
self.default_backend_metadata = BackendMetadata(
"default_kernel", False, "at::native"
)
self.default_kernel_entry = (
[self.default_kernel_key],
self.default_backend_metadata,
)

def test_codegen_unboxed_specialized(self) -> None:
specialized_kernel_key = ETKernelKey.gen_from_yaml(
{"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")},
{"T0": ["Double"]},
{"D0": [0, 1, 2, 3]},
)
selector = SelectiveBuilder.get_nop_selector()
use_aten_lib = False
entry = (
self.native_function_no_kern,
(specialized_kernel_key, self.default_backend_metadata),
)

result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry)
# Concat used to prevent whitespace stripping
expected_str = (
"""
Kernel(
"custom_1::op_1",
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
[](torch::executor::RuntimeContext & context, EValue** stack) {
"""
+ """
EXECUTORCH_SCOPE_PROF("native_call_op_1");
bool result_ = at::native::default_kernel(context, );
*stack[0] = EValue(result_);
}
),
"""
)

self.assertEqual(expected_str, result)

def test_codegen_unboxed_default(self) -> None:
selector = SelectiveBuilder.get_nop_selector()
use_aten_lib = False
entry = (self.native_function_no_kern, self.default_kernel_entry)

result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry)
# Concat used to prevent whitespace stripping
expected_str = (
"""
Kernel(
"custom_1::op_1",
[](torch::executor::RuntimeContext & context, EValue** stack) {
"""
+ """
EXECUTORCH_SCOPE_PROF("native_call_op_1");
bool result_ = at::native::default_kernel(context, );
*stack[0] = EValue(result_);
}
),
"""
)

self.assertEqual(expected_str, result)
15 changes: 14 additions & 1 deletion torchgen/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib

import functools
from typing import Callable, Dict, Iterator, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union

import torchgen.local as local
from torchgen.model import (
Expand Down Expand Up @@ -33,6 +33,8 @@
str,
)

F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])


@contextlib.contextmanager
def native_function_manager(
Expand Down Expand Up @@ -90,6 +92,17 @@ def wrapper(slf: S, f: F) -> T:
return wrapper


def method_with_nested_native_function(
func: Callable[[S, F3], T]
) -> Callable[[S, F3], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F3) -> T:
with native_function_manager(f[0]):
return func(slf, f)

return wrapper


# Convenience decorator for functions that explicitly take in a BackendIndex,
# instead of indirectly taking one in as a closure
def with_native_function_and_index(
Expand Down
7 changes: 5 additions & 2 deletions torchgen/executorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ETKernelKey:
def gen_from_yaml(
args: Dict[str, Tuple[str, str]],
type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
dim_order_alias_map: Dict[str, List[str]],
dim_order_alias_map: Dict[str, List[int]],
) -> List["ETKernelKey"]:
"""Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
Expand Down Expand Up @@ -194,7 +194,10 @@ def _to_backend_index(self) -> BackendIndex:
assert (
len(kernel_dict.values()) == 1
), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
index[op] = kernel_dict[ETKernelKey(default=True)]
index[op] = kernel_dict.get(
ETKernelKey(default=True),
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
)
return BackendIndex(
dispatch_key=DispatchKey.CPU,
use_out_as_primary=False,
Expand Down
2 changes: 1 addition & 1 deletion torchgen/executorch/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]
kernel_keys = (
[ETKernelKey((), default=True)]
if arg_meta is None
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias)
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
)

for kernel_key in kernel_keys:
Expand Down
Loading

0 comments on commit e9674d1

Please sign in to comment.