Skip to content

Commit

Permalink
[executorch] Let custom ops registration code only import ATen headers (
Browse files Browse the repository at this point in the history
pytorch#107064)

Summary: Basically we generate `CustomOpsNativeFunctions.h` for registering custom ops into PyTorch JIT runtime. This header needs to hookup with the C++ kernel implementation of all the custom ops. For this reason it should include ATen headers instead of Executorch headers. This PR changes it.

Test Plan: Rely on existing CI jobs

Differential Revision: D48282828

Pull Request resolved: pytorch#107064
Approved by: https://github.com/kirklandsign
  • Loading branch information
larryliu0820 authored and pytorchmergebot committed Aug 13, 2023
1 parent f26aa2d commit ddd2f68
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions torchgen/gen_executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ def gen_headers(
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
"headers": [
"#include <ATen/ATen.h>",
"#include <torch/torch.h>",
],
},
)
aten_headers.append('#include "CustomOpsNativeFunctions.h"')
Expand All @@ -444,16 +448,26 @@ def gen_headers(
),
},
)
headers = {
"headers": [
"#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
"#include <executorch/codegen/macros.h> // TORCH_API",
"#include <executorch/runtime/kernel/kernel_runtime_context.h>",
],
}
if use_aten_lib:
cpu_fm.write(
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
},
lambda: dict(
{
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
},
**headers,
),
)
else:
ns_grouped_kernels = get_ns_grouped_kernels(
Expand All @@ -463,11 +477,14 @@ def gen_headers(
)
cpu_fm.write(
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
ns_grouped_kernels=ns_grouped_kernels,
),
},
lambda: dict(
{
"nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
ns_grouped_kernels=ns_grouped_kernels,
),
},
**headers,
),
)


Expand Down

0 comments on commit ddd2f68

Please sign in to comment.