Skip to content

Commit

Permalink
[torchgen] Generate wrapper functions under custom namespaces (pytorc…
Browse files Browse the repository at this point in the history
…h#81744)

Summary:
A follow up of pytorch#81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.

For example, if the entries in yaml file looks like:

```
 - func: op_1(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: at::op_1_kernel # ATen kernel

- func: op_2(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: custom::op_2_kernel # custom kernel
```

We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:

`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at

namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom

```

Notice the difference between `at::cpu` and `custom::cpu`.

Then the definition for these can be found in `RegisterCPU.cpp`.

`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"

namespace at {

namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_1_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}

namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
  return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at

namespace custom {

namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_2_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}

namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
  return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom

```

The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:

1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers

This customized operator will have nothing to do with `at::native`, `at::cpu` etc.

Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs

Differential Revision: D37972772

Pull Request resolved: pytorch#81744
Approved by: https://github.com/bdhirsh
  • Loading branch information
larryliu0820 authored and pytorchmergebot committed Aug 4, 2022
1 parent cda8635 commit 406ce69
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 128 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,7 @@ test_suite(
"aten/src/ATen/templates/LazyIr.h",
"aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml",
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/templates/DispatchKeyFunctions_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,5 @@

${DispatchKeyFunctions_inl_includes}

namespace at {
namespace ${dispatch_namespace} {

${dispatch_namespaced_declarations}

} // namespace ${dispatch_namespace}
} // namespace at
24 changes: 24 additions & 0 deletions aten/src/ATen/templates/RegisterDispatchDefinitions.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
${ns_prologue}

// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
// ambiguity with conflicting identifiers that may have been defined in
// at namespace already.
namespace {

${dispatch_helpers}

${dispatch_anonymous_definitions}

${static_init_dispatch_registrations}

} // anonymous namespace

${deferred_dispatch_registrations}

namespace ${dispatch_namespace} {

${dispatch_namespaced_definitions}

} // namespace ${dispatch_namespace}

${ns_epilogue}
27 changes: 2 additions & 25 deletions aten/src/ATen/templates/RegisterDispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,5 @@
$dispatch_headers
$ops_headers


namespace at {

// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
// ambiguity with conflicting identifiers that may have been defined in
// at namespace already.
namespace {

${dispatch_helpers}

${dispatch_anonymous_definitions}

${static_init_dispatch_registrations}

} // anonymous namespace

${deferred_dispatch_registrations}

namespace ${dispatch_namespace} {

${dispatch_namespaced_definitions}

} // namespace ${dispatch_namespace}

} // namespace at
// See template file RegisterDispatchDefinitions.ini
$dispatch_definitions
1 change: 1 addition & 0 deletions build.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def define_targets(rules):
":LazyIr.h",
":LazyNonNativeIr.h",
":RegisterDispatchKey.cpp",
":RegisterDispatchDefinitions.ini",
":native_functions.yaml",
":shape_inference.h",
":tags.yaml",
Expand Down
250 changes: 181 additions & 69 deletions torchgen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torchgen.api.native as native
import torchgen.api.structured as structured
import torchgen.dest as dest

from torchgen.api import cpp
from torchgen.api.translate import translate
from torchgen.api.types import (
Expand Down Expand Up @@ -1408,6 +1409,168 @@ def get_native_function_declarations(
return declarations


def get_kernel_namespace(
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
) -> str:
backend_metadata = backend_idx.get_kernel(f)
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
f"with dispatch key {backend_idx.dispatch_key}"
f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
)
return (
backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
)


# Return native function definitions grouped by dispatch key and custom namespace.
# Used in RegisterDispatchKey.cpp and etc.
def get_native_function_definitions(
*,
fm: FileManager,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
skip_dispatcher_op_registration: bool,
gen_dispatch_helpers: bool,
) -> List[str]:
definitions: List[str] = []
ns_definitions: Dict[str, List[str]] = defaultdict(list)
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
newline = "\n"
ns_gen = dest.RegisterDispatchKey(
backend_idx,
Target.NAMESPACED_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
anonymous_gen = dest.RegisterDispatchKey(
backend_idx,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
reg_gen = dest.RegisterDispatchKey(
backend_idx,
Target.REGISTRATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
for f in grouped_native_functions:
kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
"::native", ""
)

ns_definitions[kernel_namespace].extend(
ns_gen(f),
)
anonymous_definitions[kernel_namespace].extend(
anonymous_gen(f),
)
namespace = (
f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
)
if namespace not in registrations[kernel_namespace]:
registrations[kernel_namespace] = defaultdict(list)
registrations[kernel_namespace][namespace].extend(
reg_gen(f),
)

for kernel_namespace in ns_definitions:
if len(ns_definitions[kernel_namespace]) == 0:
continue
ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
registration_body = ""
for namespace in registrations[kernel_namespace]:
if not registrations[kernel_namespace][namespace]:
continue
registration_body += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{newline.join(registrations[kernel_namespace][namespace])}
}};"""
definitions.extend(
fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",
lambda: {
"ns_prologue": ns_helper.prologue,
"ns_epilogue": ns_helper.epilogue,
"dispatch_helpers": dest.gen_registration_helpers(backend_idx)
if gen_dispatch_helpers
else [],
"dispatch_anonymous_definitions": anonymous_definitions[
kernel_namespace
],
"static_init_dispatch_registrations": ""
if skip_dispatcher_op_registration
else registration_body,
"deferred_dispatch_registrations": "",
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
},
).split(newline)
)

return definitions


# Return native function declarations grouped by dispatch key and custom namespace.
# Used in CPUFunctions_inl.h and etc.
def get_namespaced_declaration(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
) -> List[str]:
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
func = dest.RegisterDispatchKey(
backend_idx,
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
)
for f in grouped_native_functions:
namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
"native", dispatch_key.lower()
)

ns_grouped_kernels[namespace].extend(
func(f),
)

for namespace, kernels in ns_grouped_kernels.items():
if len(kernels) == 0:
continue
ns_helper = NamespaceHelper(
namespace_str=namespace, entity_name="", max_level=3
)
ordered_kernels = list(OrderedDict.fromkeys(kernels))
declarations.extend(
f"""
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
)
return declarations


# Return native function schema registration code for aten and other namespaces.
def get_native_function_schema_registrations(
*,
Expand Down Expand Up @@ -1550,18 +1713,12 @@ def gen_aggregated_headers(
lambda: {
"DispatchKeyFunctions_inl_includes": [],
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_declarations": list(
concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
"dispatch_namespaced_declarations": get_namespaced_declaration(
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_indices[dispatch_key],
selector=selector,
rocm=rocm,
),
},
)
Expand Down Expand Up @@ -1998,33 +2155,17 @@ def operator_headers() -> List[str]:
)
ns_grouped_native_functions[namespace].append(grouped_native_function)

static_init_dispatch_registrations = ""
for namespace, functions in ns_grouped_native_functions.items():
dispatch_registrations_body = (
""
if skip_dispatcher_op_registration
else "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
functions,
)
)
)
)

static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
dispatch_namespace = str(dispatch_key).lower()
dispatch_definitions = get_native_function_definitions(
fm=fm,
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_index,
selector=selector,
rocm=rocm,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
gen_dispatch_helpers=True,
)
fm.write_with_template(
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
Expand All @@ -2037,37 +2178,8 @@ def operator_headers() -> List[str]:
backend_index, per_operator_headers, rocm
),
"ops_headers": operator_headers(),
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespaced_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.NAMESPACED_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
grouped_native_functions,
)
),
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
grouped_native_functions,
)
),
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": "",
"dispatch_helpers": "",
"dispatch_definitions": dispatch_definitions,
},
)

Expand Down
Loading

0 comments on commit 406ce69

Please sign in to comment.