Skip to content

Commit

Permalink
Enable import_symbolic_shape_expressions in the FxImporter. (iree-org…
Browse files Browse the repository at this point in the history
…#179)

* Adds an option to `aot.export(import_symbolic_shape_expressions=True)`
to enable emission of torch-mlir symbolic shape constraints. This is
currently set to False until IREE is ready to ingest these by default.

Rough sequence of work in IREE proper:

* Custom lowering of `torch.symbolic_int` and
`torch.bind_symbolic_shape` ops to IREE util "assume" ops. Note that we
are only planning to lower "terminal" bindings (basically function
arguments and a couple of other such categories).
* Canonicalizations to ensure that assume equalities are == 0 (versus
the native form from torch where they assume a non zero equality).
* Fusion will clone corresponding bindings on dependent dims into
dispatch regions.
* Existing linalg shape analysis extended and queryable by codegen.

---------

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Sep 30, 2024
1 parent 92ad900 commit 621cbe1
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 8 deletions.
26 changes: 21 additions & 5 deletions shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from .support.ir_utils import (
ModuleBuilder,
ModuleBuilderOptions,
)


Expand Down Expand Up @@ -162,11 +163,13 @@ class CompiledModuleClassInfo:
__slots__ = [
"all_exports",
"ir_module_name",
"options",
]

def __init__(self, *, ir_module_name: str):
def __init__(self, *, ir_module_name: str, options: ModuleBuilderOptions):
self.ir_module_name = ir_module_name
self.all_exports: Dict[str, Exportable] = dict()
self.options = options

def add_export(self, key: str, value: Exportable):
if key in self.all_exports:
Expand Down Expand Up @@ -370,13 +373,23 @@ class CompiledModuleMeta(type):
# It is passed the dictionary of declared attributes and any keyword
# arguments from the class declaration:
# class Foo(Bar, kwarg="you probably just learned this is possible"):
def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None):
def __new__(
mcls,
name: str,
bases,
dct,
*,
export_name: Optional[str] = None,
options: Optional[ModuleBuilderOptions] = None,
):
if not _metaclass_setup_complete:
return type.__new__(mcls, name, bases, dct)

ir_module_name = _derive_ir_module_name(name, export_name)
logger.debug("Create new CompiledModule: %s", ir_module_name)
info = CompiledModuleClassInfo(ir_module_name=ir_module_name)
info = CompiledModuleClassInfo(
ir_module_name=ir_module_name, options=options or ModuleBuilderOptions()
)

# Process that attributes that were set as part of class definition.
# Any attributes that we decide are part of the compiled module
Expand Down Expand Up @@ -436,6 +449,7 @@ def create_from_dict(
dct: dict,
*,
export_name: Optional[str] = None,
options: Optional[ModuleBuilderOptions] = None,
) -> CompiledModuleMeta:
"""Creates a CompiledModule subclass with an explicit dictionary of members.
Expand All @@ -446,7 +460,9 @@ class Foo(CompiledModule, export_name="bar"):
def member(): ...
```
"""
return CompiledModuleMeta(name, (cls,), dct, export_name=export_name)
return CompiledModuleMeta(
name, (cls,), dct, export_name=export_name, options=options
)

@staticmethod
def get_class_info(cls: CompiledModuleMeta) -> CompiledModuleClassInfo:
Expand Down Expand Up @@ -596,7 +612,7 @@ def __new__(
module_op.attributes["sym_name"] = StringAttr.get(
class_info.ir_module_name, context=context
)
module_builder = ModuleBuilder(module_op)
module_builder = ModuleBuilder(module_op, options=class_info.options)
info = CompiledModuleInstanceInfo(class_info, module_builder=module_builder)
_all_compiled_module_instance_infos[self] = info

Expand Down
16 changes: 15 additions & 1 deletion shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .builtins import *
from .compiled_module import (
CompiledModule,
ModuleBuilderOptions,
ImportPhase,
)
from .fx_programs import FxPrograms
Expand Down Expand Up @@ -175,6 +176,7 @@ def export(
module_name: Optional[str] = None,
function_name: Optional[str] = None,
strict_export: bool = True,
import_symbolic_shape_expressions: bool = False,
) -> ExportOutput:
"""Exports a torch.nn.Module.
Expand Down Expand Up @@ -223,6 +225,7 @@ def export(
module_name: Optional[str] = None,
function_name: Optional[str] = None,
strict_export: bool = True,
import_symbolic_shape_expressions: bool = False,
) -> ExportOutput:
"""Generic export of supported entities.
Expand Down Expand Up @@ -270,11 +273,19 @@ def export(
"LambdaCompiledModule",
{(function_name or "main"): mdl},
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
),
)

elif isinstance(mdl, FxPrograms):
TransformedModule = CompiledModule.create_from_dict(
"LambdaCompiledModule", mdl.programs, export_name=module_name or "module"
"LambdaCompiledModule",
mdl.programs,
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
),
)
elif isinstance(mdl, torch.nn.Module):
# Normalize arguments for torch.export.
Expand Down Expand Up @@ -302,6 +313,9 @@ def export(
"LambdaCompiledModule",
{(function_name or "main"): exported_program},
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
),
)
elif issubclass(mdl, CompiledModule):
TransformedModule = mdl
Expand Down
13 changes: 12 additions & 1 deletion shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from typing import Callable, Dict, Optional, Sequence, Tuple

from dataclasses import dataclass
from pathlib import Path
import tempfile

Expand Down Expand Up @@ -148,6 +149,12 @@ def infer_external_from_tensor(
###############################################################################


@dataclass
class ModuleBuilderOptions:
# Whether to import torch symbolic shape expressions for ExportedPrograms.
import_symbolic_shape_expressions: bool = False


class ModuleBuilder:
"""Wrapper around module and IR accounting for a module being built."""

Expand All @@ -159,14 +166,18 @@ class ModuleBuilder:
"last_global_op",
"ip",
"module_op",
"options",
"symbol_table",
"global_ref_tracker",
"native_type_converter",
"_auto_symbol_counts",
]

def __init__(self, module_op: Operation):
def __init__(
self, module_op: Operation, *, options: Optional[ModuleBuilderOptions] = None
):
self.module_op = module_op
self.options = options or ModuleBuilderOptions()
self.context = module_op.context
self.body = module_op.regions[0].blocks[0]
self.symbol_table = SymbolTable(module_op)
Expand Down
5 changes: 4 additions & 1 deletion shark_turbine/aot/support/procedural/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def import_exported_program(
) -> ExportedProgramIntrinsic:
fx_importer = _create_fx_importer(module_builder)
entry_func_op = fx_importer.import_program(
exported_program, func_name=symbol_name, func_visibility=symbol_visibility
exported_program,
func_name=symbol_name,
func_visibility=symbol_visibility,
import_symbolic_shape_expressions=module_builder.options.import_symbolic_shape_expressions,
)

module_call_graph = exported_program.module_call_graph
Expand Down
50 changes: 50 additions & 0 deletions tests/aot/dynamic_shape_export_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

import pytest

from shark_turbine.aot import *


@pytest.mark.parametrize(
"import_symbolic_shape_expressions",
[
True,
False,
],
)
def test_exported_program_dynamic_shapes(import_symbolic_shape_expressions):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU())
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)

def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

output = export(
M(),
args=example_args,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
)
output.print_readable()
asm = str(output.mlir_module)

if import_symbolic_shape_expressions:
assert "bind_symbolic_shape" in asm
else:
assert "bind_symbolic_shape" not in asm

0 comments on commit 621cbe1

Please sign in to comment.