-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Readme describing TKW approach - initial op set: read, write, mma, tiled_loop, register - wave decorator for block-based programming model - types for yet to be specified Memory and for Registers - initial gemm test that traces successfully - unit tests for tracing the initial set of ops - add lit and filecheck dependencies for unit tests - interface on top of fx.Node to rely less on string matching when navigating the fx.Graph. - small changes to tracing contexts to use op handlers from node interfaces Signed-off-by: Martin Lücke <[email protected]>
- Loading branch information
1 parent
c4db712
commit 40f3a08
Showing
15 changed files
with
1,013 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# RUN: python %s | ||
|
||
import pytest | ||
from typing import Callable | ||
import shark_turbine.kernel as tk | ||
import shark_turbine.kernel.lang as tkl | ||
import shark_turbine.kernel.wave as tkw | ||
import torch | ||
|
||
|
||
M = tkl.sym.M | ||
N = tkl.sym.N | ||
K = tkl.sym.K | ||
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
|
||
|
||
def launch(func: Callable[[], None]) -> Callable[[], None]: | ||
""" | ||
Run a function as part of the test suite in a test launch context. | ||
Provides default values for the hyperparameters. | ||
""" | ||
if __name__ == "__main__": | ||
with tk.gen.TestLaunchContext( | ||
{ | ||
M: 16, | ||
N: 16, | ||
K: 16, | ||
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, | ||
} | ||
): | ||
func() | ||
return func | ||
|
||
|
||
@launch | ||
def test_read(): | ||
@tkw.wave() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
tkw.read(a) | ||
|
||
a = torch.randn(16, 16, dtype=torch.float16) | ||
with pytest.raises( | ||
NotImplementedError, match="Read: Currently only stub implementation" | ||
): | ||
test(a) | ||
|
||
|
||
# TODO: Add more tests once we have more than a stub implementation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# RUN: python %s | FileCheck %s | ||
|
||
from typing import Callable | ||
import shark_turbine.kernel.lang as tkl | ||
import shark_turbine.kernel.wave as tkw | ||
from shark_turbine.kernel.ops.wave_ops import get_custom | ||
|
||
|
||
M = tkl.sym.M | ||
N = tkl.sym.N | ||
K = tkl.sym.K | ||
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
|
||
|
||
def run(func: Callable[[], None]) -> Callable[[], None]: | ||
"""Run a function as part of the test suite.""" | ||
if __name__ == "__main__": | ||
func() | ||
|
||
return func | ||
|
||
|
||
@run | ||
def test_trace_empty(): | ||
@tkw.wave_trace_only() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
pass | ||
|
||
trace = test() | ||
print(trace.get_root_graph()) | ||
# CHECK: %a | ||
# CHECK: return None | ||
for node in trace.get_root_graph().nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK-SAME: MemoryType[M, N].of(f16) | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_read(): | ||
@tkw.wave_trace_only() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
tkw.read(a) | ||
|
||
trace = test() | ||
print(trace.get_root_graph()) | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: return None | ||
for node in trace.get_root_graph().nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK: read(memory=a | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_register(): | ||
@tkw.wave_trace_only() | ||
def test( | ||
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
tkw.register([M, N], tkl.f16, 0.0) | ||
|
||
trace = test() | ||
print(trace.get_root_graph()) | ||
# CHECK: %a | ||
# CHECK: %register | ||
# CHECK: return None | ||
for node in trace.get_root_graph().nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK: new_register | ||
# CHECK-SAME: shape=[M, N], dtype=f16 | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_write(): | ||
@tkw.wave_trace_only() | ||
def test( | ||
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
val = tkw.register([M, N], tkl.f16, 0.0) | ||
tkw.write(val, a, elements_per_thread=4) | ||
|
||
trace = test() | ||
print(trace.get_root_graph()) | ||
# CHECK: %a | ||
# CHECK: %register | ||
# CHECK: %write | ||
# CHECK: return None | ||
for node in trace.get_root_graph().nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK: new_register | ||
# CHECK: write | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_mma(): | ||
@tkw.wave_trace_only() | ||
def test( | ||
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
reg_0 = tkw.read(a) | ||
reg_1 = tkw.read(a) | ||
acc = tkl.Register[M, N, tkl.f32](0.0) | ||
mma = tkw.mma(reg_0, reg_1, acc) | ||
|
||
trace = test() | ||
print(trace.get_root_graph()) | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: %read | ||
# CHECK: %register | ||
# CHECK: %mma | ||
# CHECK: return None | ||
for node in trace.get_root_graph().nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK: read | ||
# CHECK: read | ||
# CHECK: new_register | ||
# CHECK: mma | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_gemm(): | ||
@tkw.wave_trace_only() | ||
def gemm( | ||
A: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], | ||
B: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], | ||
C: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
c = tkl.Register[M, N, tkl.f32](0.0) | ||
|
||
@tkw.reduction(K, init_args=[c]) | ||
def repeat(acc) -> tkl.Register[M, N, tkl.f32]: | ||
a = tkw.read(A) | ||
b = tkw.read(B) | ||
acc = tkl.Register[M, N, tkl.f32](0.0) | ||
mma = tkw.mma(a, b, acc) | ||
return acc | ||
|
||
tkw.write(repeat, C, elements_per_thread=4) | ||
|
||
trace = gemm() | ||
print(trace.get_root_graph()) | ||
# CHECK: %a | ||
# CHECK: %b | ||
# CHECK: %c | ||
# CHECK: %register | ||
# CHECK: %reduction | ||
# CHECK: %write | ||
# CHECK: return None | ||
for node in trace.get_root_graph().nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK: placeholder | ||
# CHECK: placeholder | ||
# CHECK: new_register | ||
# CHECK: reduction | ||
# CHECK: write | ||
# CHECK: unknown: output | ||
|
||
# TODO: Handle proper automatic printing of the subgraph in reduction | ||
print(trace.region_graph.subgraphs["region_0"]) | ||
# CHECK: %acc | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: %b | ||
# CHECK: %read_1 | ||
# CHECK: %register | ||
# CHECK: %mma | ||
# CHECK: return register | ||
for node in trace.region_graph.subgraphs["region_0"].nodes: | ||
print(get_custom(node)) | ||
# CHECK: placeholder | ||
# CHECK: placeholder | ||
# CHECK: read | ||
# CHECK: placeholder | ||
# CHECK: read | ||
# CHECK: new_register | ||
# CHECK: mma | ||
# CHECK: unknown: output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import shutil | ||
import sys | ||
|
||
import lit.formats | ||
import lit.util | ||
|
||
import lit.llvm | ||
|
||
# Configuration file for the 'lit' test runner. | ||
lit.llvm.initialize(lit_config, config) | ||
from lit.llvm import llvm_config | ||
|
||
llvm_config.with_system_environment("PYTHONPATH") | ||
|
||
# name: The name of this test suite. | ||
config.name = "IREE_TURBINE" | ||
|
||
config.test_format = lit.formats.ShTest() | ||
|
||
# suffixes: A list of file extensions to treat as test files. | ||
config.suffixes = [".py"] | ||
|
||
# test_source_root: The root path where tests are located. | ||
config.test_source_root = os.path.dirname(__file__) | ||
|
||
# config.use_default_substitutions() | ||
config.excludes = ["__init__.py", "lit.cfg.py", "lit.site.cfg.py"] | ||
|
||
config.substitutions.extend( | ||
[ | ||
("%PYTHON", sys.executable), | ||
] | ||
) | ||
|
||
# Find a suitable filecheck. | ||
filecheck_exe = None | ||
if filecheck_exe is None: | ||
filecheck_exe = shutil.which("FileCheck") | ||
if filecheck_exe: | ||
print(f"Using LLVM FileCheck: {filecheck_exe}") | ||
if filecheck_exe is None: | ||
filecheck_exe = shutil.which("filecheck") | ||
if filecheck_exe: | ||
print(f"Using pure python filecheck: {filecheck_exe}") | ||
|
||
if filecheck_exe is not None: | ||
config.substitutions.extend( | ||
[ | ||
("FileCheck", filecheck_exe), | ||
] | ||
) | ||
else: | ||
print( | ||
"FileCheck not found " | ||
"(install pure python version with 'pip install filecheck')" | ||
) | ||
|
||
project_root = os.path.dirname(os.path.dirname(__file__)) | ||
lit.llvm.llvm_config.with_environment("PYTHONPATH", project_root, append_path=True) | ||
config.environment["FILECHECK_OPTS"] = "--dump-input=fail" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.