Skip to content

Commit

Permalink
Inductor Freezing (pytorch#100652)
Browse files Browse the repository at this point in the history
Adds a freezing pass that will constant fold parameters in inductor `config.freezing`. This occurs post functionalization in aot autograd to capture both dispatching and allow passes to occur post functionalization. A few notes:

- There is an option to discard parameters `config.freezing_discard_parameters` which will take the current eager modules and wrap parameters to a Tensor subclass which will error if used.
- I needed to expose flat_params in aot_autograd in order to discard old references when we constant fold away parameters, like with amp. I also exposed `fw_metadata` to avoid constant folding mutated paraemters.
- Caching parameter transformations/constant folding across different inferences nyi
- Checking version_counter of constant folded params nyi

I'm not really sure what the actual naming should be. In jit there was both "freezing", which was platform agnostic, and "optimize for inference", which made device specific optimizations. We're doing the latter here but maybe freezing is a better name.

Differential Revision: [D46244033](https://our.internmc.facebook.com/intern/diff/D46244033)
Pull Request resolved: pytorch#100652
Approved by: https://github.com/jansel
  • Loading branch information
eellison authored and pytorchmergebot committed Jun 12, 2023
1 parent 54daf87 commit d083d44
Show file tree
Hide file tree
Showing 9 changed files with 591 additions and 2 deletions.
258 changes: 258 additions & 0 deletions test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Owner(s): ["module: inductor"]
import contextlib
import functools
import importlib
import os
import sys
import unittest
import weakref

import torch

import torch._dynamo
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
)

if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")

from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests

importlib.import_module("functorch")
importlib.import_module("filelock")

from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA

HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
aten = torch.ops.aten
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")


class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
"freezing": True,
"freezing_discard_parameters": True,
}
)
)

@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()

def setUp(self):
torch._dynamo.reset()
super().setUp()

def tearDown(self):
super().tearDown()
torch._dynamo.reset()


class ConvBN(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float)

def forward(self, x):
return self.bn(self.conv(x))


class OptimizeForInferenceTemplate(TestCase):
def test_mutation(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.mutated_param = torch.nn.Parameter(torch.zeros([10, 10]))

def forward(self):
self.mutated_param.add_(10)
return self.mutated_param

with torch.no_grad():
mod = Mod().to(self.device)
out_eager = mod()
out_eager2 = mod()

mod = Mod().to(self.device)

@torch.compile
def foo(mod):
return mod()

out_comp = foo(mod)
out_comp2 = foo(mod)

self.assertEqual(out_eager, out_comp)
self.assertEqual(out_eager2, out_comp2)

def test_aliased_param_return(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.aliased_param = torch.nn.Parameter(torch.zeros([10, 10]))

def forward(self):
return self.aliased_param[1:], self.aliased_param

mod = Mod().to(self.device).eval()

@torch.compile()
def foo(mod):
return mod()

with torch.no_grad():
mod_eager = mod()
self.assertEqual(foo(mod), mod_eager)

def test_autocast(self):
if self.device == "cpu":
raise unittest.SkipTest("MLKDNN Bug")

mod = torch.nn.Linear(10, 10).to(self.device).eval()
inp = torch.rand([10, 10]).to(self.device).to(torch.half)

@torch.compile()
def foo(mod, inp):
return mod(inp)

with torch.no_grad():
with self.autocast():
out_eager = mod(inp)
out_compiled, code = run_and_get_code(foo, mod, inp)

FileCheck().check_not("@triton.jit").run(code[0])
self.assertEqual(out_eager, out_compiled)

def test_error_on_eager(self):
mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device)

x = torch.rand(3, 3, 32, 32).to(self.device)

@torch.compile()
def foo(mod, x):
return mod(x)

with torch.no_grad():
foo(mod, x)

with self.assertRaisesRegex(
RuntimeError, "Trying to Run Pytorch Eager Module After Dynamo Freezing"
):
mod(x)

def test_rng_op(self):
@torch.compile()
def foo():
return torch.rand([4, 4], device=self.device) + 1

with torch.no_grad():
o1 = foo()
o2 = foo()
self.assertNotEqual(o1, o2)

def test_symint_not_folded(self):
def fn(a):
return a.cos(), torch.zeros(a.shape[0], a.shape[1])

fn_opt = torch._dynamo.optimize("inductor", dynamic=True)(fn)
inp = torch.randn(2, 4, 6).to(self.device)
torch._dynamo.mark_dynamic(inp, 0)
torch._dynamo.mark_dynamic(inp, 1)

with torch.no_grad():
self.assertEqual(fn(inp), fn_opt(inp))
inp2 = torch.randn(3, 5, 6).to(self.device)
torch._dynamo.mark_dynamic(inp2, 0)
torch._dynamo.mark_dynamic(inp2, 1)
self.assertEqual(fn(inp2), fn_opt(inp2))

def test_param_deallocated(self):
# TODO: cpu path keeps an extra copy of graph around somewhere,
# memory not as important for cpu
if self.device == "cpu":
raise unittest.SkipTest("NYI CPU")

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([10, 10]))

def forward(self, x):
return (self.param + 10) + x

mod = Mod().eval().to(self.device)
inp = torch.rand([10], device=self.device)

with torch.no_grad():
eager = mod(inp)

weight_ref = weakref.ref(mod.param)

@torch.compile()
def foo(mod, inp):
return mod(inp)

with torch.no_grad():
compiled = foo(mod, inp)

self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)


if HAS_CPU and not torch.backends.mps.is_available():

class CpuTests(TestCase):
common = check_model
device = "cpu"
autocast = torch.cpu.amp.autocast

copy_tests(OptimizeForInferenceTemplate, CpuTests, "cpu")

if HAS_CUDA and not TEST_WITH_ASAN:

class CudaTests(TestCase):
common = check_model_cuda
device = "cuda"
autocast = torch.cuda.amp.autocast

copy_tests(OptimizeForInferenceTemplate, CudaTests, "cuda")


del OptimizeForInferenceTemplate

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
run_tests(needs="filelock")
13 changes: 12 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,9 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
fake_mode = detect_fake_mode()
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
flat_args.extend([seed, offset])

if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().fw_metadata = fw_metadata
compiled_fw = compiler(fw_module, flat_args)

# This boxed_call handling happens inside create_runtime_wrapper as well.
Expand Down Expand Up @@ -2776,13 +2779,18 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
# 1) There is a check in the the debug compiler at the end
# 2) It does not matter as these are fake tensors

if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().fw_metadata = fw_metadata


# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down Expand Up @@ -3663,7 +3671,7 @@ def aot_module_simplified(
**dict(mod.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = tuple(params_flat)
params_flat = list(params_flat)
params_len = len(params_flat)

functional_call = create_functional_call(mod, params_spec, params_len)
Expand All @@ -3679,6 +3687,9 @@ def aot_module_simplified(
# First, the params
full_args.extend(params_flat)

if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().params_flat = params_flat

aot_autograd_arg_pos_to_source = None
# Then, the params 1:1 mapped sources, if relevant.
if hasattr(mod, "_param_name_to_source"):
Expand Down
3 changes: 3 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ def __init__(self, fake_mode):
self.fake_mode = fake_mode
self.frame_summary_stack = []
self.loc_in_frame = None
# this is only set after aot_autograd
self.fw_metadata = None
self.params_flat = None

@staticmethod
def extract_stack():
Expand Down
Loading

0 comments on commit d083d44

Please sign in to comment.