Skip to content

Commit

Permalink
[JIT SSA] Allow updating shape functions without recompilation (pytor…
Browse files Browse the repository at this point in the history
…ch#83629)

In order to avoid extra round trips, and avoid confusion in places such as
this to manually pull in the latest copy of the shape_functions.py file

This also fixes the cases where people pull in the wrong version of the file. This can happen in cases such as when developers run `python setup.py install` instead of `python setup.py develop` to generate their current copy of Pytorch.
Pull Request resolved: pytorch#83629
Approved by: https://github.com/davidberard98
  • Loading branch information
Gamrix authored and pytorchmergebot committed Aug 22, 2022
1 parent 53cda90 commit eff28d6
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions torchgen/shape_functions/gen_jit_shape_functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
#!/usr/bin/env python3
import importlib.util
import os
import sys
from itertools import chain
from pathlib import Path

from torch.jit._shape_functions import (
bounded_compute_graph_mapping,
shape_compute_graph_mapping,
)

# Manually importing the shape function module based on current directory
# instead of torch imports to avoid needing to recompile Pytorch before
# running the script

file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
module_name = "torch.jit._shape_functions"

err_msg = """Could not find shape functions file, please make sure
you are in the root directory of the Pytorch git repo"""
if not file_path.exists():
raise Exception(err_msg)

spec = importlib.util.spec_from_file_location(module_name, file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
assert module is not None
spec.loader.exec_module(module)

bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
shape_compute_graph_mapping = module.shape_compute_graph_mapping


SHAPE_HEADER = r"""
/**
Expand Down

0 comments on commit eff28d6

Please sign in to comment.