diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index 5cffce838f89d8..c6336a69518b2b 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -1,12 +1,34 @@ #!/usr/bin/env python3 +import importlib.util import os +import sys from itertools import chain from pathlib import Path -from torch.jit._shape_functions import ( - bounded_compute_graph_mapping, - shape_compute_graph_mapping, -) + +# Manually importing the shape function module based on current directory +# instead of torch imports to avoid needing to recompile Pytorch before +# running the script + +file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py" +module_name = "torch.jit._shape_functions" + +err_msg = """Could not find shape functions file, please make sure +you are in the root directory of the Pytorch git repo""" +if not file_path.exists(): + raise Exception(err_msg) + +spec = importlib.util.spec_from_file_location(module_name, file_path) +assert spec is not None +module = importlib.util.module_from_spec(spec) +sys.modules[module_name] = module +assert spec.loader is not None +assert module is not None +spec.loader.exec_module(module) + +bounded_compute_graph_mapping = module.bounded_compute_graph_mapping +shape_compute_graph_mapping = module.shape_compute_graph_mapping + SHAPE_HEADER = r""" /**