From eff28d61c9c961e6c2724b78f57441ee2e3e40cb Mon Sep 17 00:00:00 2001 From: John Clow Date: Wed, 17 Aug 2022 15:07:57 -0700 Subject: [PATCH] [JIT SSA] Allow updating shape functions without recompilation (#83629) In order to avoid extra round trips, and avoid confusion in places such as this to manually pull in the latest copy of the shape_functions.py file This also fixes the cases where people pull in the wrong version of the file. This can happen in cases such as when developers run `python setup.py install` instead of `python setup.py develop` to generate their current copy of Pytorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83629 Approved by: https://github.com/davidberard98 --- .../gen_jit_shape_functions.py | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index 5cffce838f89d..c6336a69518b2 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -1,12 +1,34 @@ #!/usr/bin/env python3 +import importlib.util import os +import sys from itertools import chain from pathlib import Path -from torch.jit._shape_functions import ( - bounded_compute_graph_mapping, - shape_compute_graph_mapping, -) + +# Manually importing the shape function module based on current directory +# instead of torch imports to avoid needing to recompile Pytorch before +# running the script + +file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py" +module_name = "torch.jit._shape_functions" + +err_msg = """Could not find shape functions file, please make sure +you are in the root directory of the Pytorch git repo""" +if not file_path.exists(): + raise Exception(err_msg) + +spec = importlib.util.spec_from_file_location(module_name, file_path) +assert spec is not None +module = importlib.util.module_from_spec(spec) +sys.modules[module_name] = module +assert spec.loader is not None +assert module is not None +spec.loader.exec_module(module) + +bounded_compute_graph_mapping = module.bounded_compute_graph_mapping +shape_compute_graph_mapping = module.shape_compute_graph_mapping + SHAPE_HEADER = r""" /**