From ddd2f682b974fa274771965266c2bc0786f1e747 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sun, 13 Aug 2023 00:34:34 +0000 Subject: [PATCH] [executorch] Let custom ops registration code only import ATen headers (#107064) Summary: Basically we generate `CustomOpsNativeFunctions.h` for registering custom ops into PyTorch JIT runtime. This header needs to hookup with the C++ kernel implementation of all the custom ops. For this reason it should include ATen headers instead of Executorch headers. This PR changes it. Test Plan: Rely on existing CI jobs Differential Revision: D48282828 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107064 Approved by: https://github.com/kirklandsign --- torchgen/gen_executorch.py | 41 +++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 6f5df46944f0f6..8fc7f594b8689f 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -426,6 +426,10 @@ def gen_headers( backend_indices=backend_indices, native_function_decl_gen=dest.compute_native_function_declaration, ), + "headers": [ + "#include ", + "#include ", + ], }, ) aten_headers.append('#include "CustomOpsNativeFunctions.h"') @@ -444,16 +448,26 @@ def gen_headers( ), }, ) + headers = { + "headers": [ + "#include // at::Tensor etc.", + "#include // TORCH_API", + "#include ", + ], + } if use_aten_lib: cpu_fm.write( "NativeFunctions.h", - lambda: { - "nativeFunctions_declarations": get_native_function_declarations( - grouped_native_functions=native_functions, - backend_indices=backend_indices, - native_function_decl_gen=dest.compute_native_function_declaration, - ), - }, + lambda: dict( + { + "nativeFunctions_declarations": get_native_function_declarations( + grouped_native_functions=native_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ), + }, + **headers, + ), ) else: ns_grouped_kernels = get_ns_grouped_kernels( @@ -463,11 +477,14 @@ def gen_headers( ) cpu_fm.write( "NativeFunctions.h", - lambda: { - "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( - ns_grouped_kernels=ns_grouped_kernels, - ), - }, + lambda: dict( + { + "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( + ns_grouped_kernels=ns_grouped_kernels, + ), + }, + **headers, + ), )