diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 6f5df46944f0f..8fc7f594b8689 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -426,6 +426,10 @@ def gen_headers( backend_indices=backend_indices, native_function_decl_gen=dest.compute_native_function_declaration, ), + "headers": [ + "#include ", + "#include ", + ], }, ) aten_headers.append('#include "CustomOpsNativeFunctions.h"') @@ -444,16 +448,26 @@ def gen_headers( ), }, ) + headers = { + "headers": [ + "#include // at::Tensor etc.", + "#include // TORCH_API", + "#include ", + ], + } if use_aten_lib: cpu_fm.write( "NativeFunctions.h", - lambda: { - "nativeFunctions_declarations": get_native_function_declarations( - grouped_native_functions=native_functions, - backend_indices=backend_indices, - native_function_decl_gen=dest.compute_native_function_declaration, - ), - }, + lambda: dict( + { + "nativeFunctions_declarations": get_native_function_declarations( + grouped_native_functions=native_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ), + }, + **headers, + ), ) else: ns_grouped_kernels = get_ns_grouped_kernels( @@ -463,11 +477,14 @@ def gen_headers( ) cpu_fm.write( "NativeFunctions.h", - lambda: { - "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( - ns_grouped_kernels=ns_grouped_kernels, - ), - }, + lambda: dict( + { + "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( + ns_grouped_kernels=ns_grouped_kernels, + ), + }, + **headers, + ), )