forked from iree-org/iree-turbine
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable import_symbolic_shape_expressions in the FxImporter. (iree-org…
…#179) * Adds an option to `aot.export(import_symbolic_shape_expressions=True)` to enable emission of torch-mlir symbolic shape constraints. This is currently set to False until IREE is ready to ingest these by default. Rough sequence of work in IREE proper: * Custom lowering of `torch.symbolic_int` and `torch.bind_symbolic_shape` ops to IREE util "assume" ops. Note that we are only planning to lower "terminal" bindings (basically function arguments and a couple of other such categories). * Canonicalizations to ensure that assume equalities are == 0 (versus the native form from torch where they assume a non zero equality). * Fusion will clone corresponding bindings on dependent dims into dispatch regions. * Existing linalg shape analysis extended and queryable by codegen. --------- Signed-off-by: Stella Laurenzo <[email protected]>
- Loading branch information
1 parent
92ad900
commit 621cbe1
Showing
5 changed files
with
102 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
|
||
import pytest | ||
|
||
from shark_turbine.aot import * | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"import_symbolic_shape_expressions", | ||
[ | ||
True, | ||
False, | ||
], | ||
) | ||
def test_exported_program_dynamic_shapes(import_symbolic_shape_expressions): | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) | ||
self.branch2 = torch.nn.Sequential( | ||
torch.nn.Linear(128, 64), torch.nn.ReLU() | ||
) | ||
self.buffer = torch.ones(32) | ||
|
||
def forward(self, x1, x2): | ||
out1 = self.branch1(x1) | ||
out2 = self.branch2(x2) | ||
return (out1 + self.buffer, out2) | ||
|
||
example_args = (torch.randn(32, 64), torch.randn(32, 128)) | ||
|
||
# Create a dynamic batch size | ||
batch = torch.export.Dim("batch") | ||
# Specify that the first dimension of each input is that batch size | ||
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} | ||
|
||
output = export( | ||
M(), | ||
args=example_args, | ||
dynamic_shapes=dynamic_shapes, | ||
import_symbolic_shape_expressions=import_symbolic_shape_expressions, | ||
) | ||
output.print_readable() | ||
asm = str(output.mlir_module) | ||
|
||
if import_symbolic_shape_expressions: | ||
assert "bind_symbolic_shape" in asm | ||
else: | ||
assert "bind_symbolic_shape" not in asm |