-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Readme describing TKW approach - initial op set: read, write, mma, tiled_loop, register - wave decorator for block-based programming model - types for yet to be specified Memory and for Registers - unit tests for new types - initial gemm test that traces successfully - unit tests for tracing the initial set of ops - add lit and filecheck dependencies for unit tests - interface on top of fx.Node to rely less on string matching when navigating the fx.Graph. - small changes to tracing contexts to use op handlers from node interfaces Signed-off-by: Martin Lücke <[email protected]>
- Loading branch information
1 parent
c4db712
commit 0c2dc53
Showing
18 changed files
with
1,155 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# RUN: python %s | ||
|
||
import pytest | ||
from typing import Callable | ||
import shark_turbine.kernel as tk | ||
import shark_turbine.kernel.lang as tkl | ||
import shark_turbine.kernel.wave as tkw | ||
import torch | ||
|
||
|
||
M = tkl.sym.M | ||
N = tkl.sym.N | ||
K = tkl.sym.K | ||
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
|
||
|
||
def launch(func: Callable[[], None]) -> Callable[[], None]: | ||
""" | ||
Run a function as part of the test suite in a test launch context. | ||
Provides default values for the hyperparameters. | ||
""" | ||
if __name__ == "__main__": | ||
with tk.gen.TestLaunchContext( | ||
{ | ||
M: 16, | ||
N: 16, | ||
K: 16, | ||
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, | ||
} | ||
): | ||
func() | ||
return func | ||
|
||
|
||
@launch | ||
def test_read(): | ||
@tkw.wave() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
tkw.read(a) | ||
|
||
a = torch.randn(16, 16, dtype=torch.float16) | ||
with pytest.raises( | ||
NotImplementedError, match="Read: Currently only stub implementation" | ||
): | ||
test(a) | ||
|
||
|
||
# TODO: Add more tests once we have more than a stub implementation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# RUN: python %s | FileCheck %s | ||
|
||
from typing import Callable | ||
from shark_turbine.kernel._support.tracing import CapturedTrace | ||
import shark_turbine.kernel.lang as tkl | ||
import shark_turbine.kernel.wave as tkw | ||
from shark_turbine.kernel.ops.wave_ops import get_custom, Read, Write | ||
|
||
M = tkl.sym.M | ||
N = tkl.sym.N | ||
K = tkl.sym.K | ||
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
|
||
|
||
def run(func: Callable[[], None]) -> Callable[[], None]: | ||
"""Run a function as part of the test suite.""" | ||
if __name__ == "__main__": | ||
func() | ||
|
||
return func | ||
|
||
|
||
def print_trace(trace: CapturedTrace): | ||
""" | ||
Prints all subgraphs of a trace starting with the root graph. | ||
The graphs are printed first in the torch printing format and then using | ||
our custom node format. | ||
""" | ||
# The root graph is at the back so we print the subgraphs in reverse order | ||
for subgraph in reversed(list(trace.region_graph.subgraphs.values())): | ||
print(subgraph) | ||
for node in subgraph.nodes: | ||
print(get_custom(node)) | ||
|
||
|
||
@run | ||
def test_trace_empty(): | ||
@tkw.wave_trace_only() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
pass | ||
|
||
trace = test() | ||
print_trace(trace) | ||
# CHECK: %a | ||
# CHECK: return None | ||
|
||
# Custom format: | ||
# CHECK: placeholder | ||
# CHECK-SAME: MemoryType[M, N].of(f16) | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_empty_then_add_nodes(): | ||
""" | ||
This tests the modification of a graph after the trace has been created. | ||
""" | ||
|
||
@tkw.wave_trace_only() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
pass | ||
|
||
trace = test() | ||
|
||
graph = trace.get_root_graph() | ||
a = list(graph.nodes)[0] | ||
# Insert at the end of the graph | ||
with graph.inserting_before(list(graph.nodes)[-1]): | ||
read = Read(a).add_to_graph(graph) | ||
write = Write(read, a, 4).add_to_graph(graph) | ||
|
||
print_trace(trace) | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: %write | ||
# CHECK: return None | ||
|
||
# Custom format: | ||
# CHECK: placeholder | ||
# CHECK-SAME: MemoryType[M, N].of(f16) | ||
# CHECK: read(memory=a | ||
# CHECK: write(register_=read, memory=a | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_read(): | ||
@tkw.wave_trace_only() | ||
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): | ||
tkw.read(a) | ||
|
||
trace = test() | ||
print_trace(trace) | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: return None | ||
|
||
# Custom format: | ||
# CHECK: placeholder | ||
# CHECK: read(memory=a | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_register(): | ||
@tkw.wave_trace_only() | ||
def test( | ||
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
tkw.register([M, N], tkl.f16, 0.0) | ||
|
||
trace = test() | ||
print_trace(trace) | ||
# CHECK: %a | ||
# CHECK: %register | ||
# CHECK: return None | ||
|
||
# Custom format: | ||
# CHECK: placeholder | ||
# CHECK: register | ||
# CHECK-SAME: shape=[M, N], dtype=f16 | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_write(): | ||
@tkw.wave_trace_only() | ||
def test( | ||
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
val = tkw.register([M, N], tkl.f16, 0.0) | ||
tkw.write(val, a, elements_per_thread=4) | ||
|
||
trace = test() | ||
print_trace(trace) | ||
# CHECK: %a | ||
# CHECK: %register | ||
# CHECK: %write | ||
# CHECK: return None | ||
|
||
# Custom format: | ||
# CHECK: placeholder | ||
# CHECK: register | ||
# CHECK: write | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_mma(): | ||
@tkw.wave_trace_only() | ||
def test( | ||
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
reg_0 = tkw.read(a) | ||
reg_1 = tkw.read(a) | ||
acc = tkl.Register[M, N, tkl.f32](0.0) | ||
mma = tkw.mma(reg_0, reg_1, acc) | ||
|
||
trace = test() | ||
print_trace(trace) | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: %read | ||
# CHECK: %register | ||
# CHECK: %mma | ||
# CHECK: return None | ||
|
||
# Custom format: | ||
# CHECK: placeholder | ||
# CHECK: read | ||
# CHECK: read | ||
# CHECK: register | ||
# CHECK: mma | ||
# CHECK: unknown: output | ||
|
||
|
||
@run | ||
def test_trace_gemm(): | ||
@tkw.wave_trace_only() | ||
def gemm( | ||
A: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], | ||
B: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], | ||
C: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], | ||
): | ||
c = tkl.Register[M, N, tkl.f32](0.0) | ||
|
||
@tkw.reduction(K, init_args=[c]) | ||
def repeat(acc) -> tkl.Register[M, N, tkl.f32]: | ||
a = tkw.read(A) | ||
b = tkw.read(B) | ||
acc = tkl.Register[M, N, tkl.f32](0.0) | ||
mma = tkw.mma(a, b, acc) | ||
return acc | ||
|
||
tkw.write(repeat, C, elements_per_thread=4) | ||
|
||
trace = gemm() | ||
print_trace(trace) | ||
# Root graph: | ||
# CHECK: %a | ||
# CHECK: %b | ||
# CHECK: %c | ||
# CHECK: %register | ||
# CHECK: %reduction | ||
# CHECK: %write | ||
# CHECK: return None | ||
|
||
# Root graph in custom format: | ||
# CHECK: placeholder | ||
# CHECK: placeholder | ||
# CHECK: placeholder | ||
# CHECK: register | ||
# CHECK: reduction | ||
# CHECK: write | ||
# CHECK: unknown: output | ||
|
||
# Subgraph: | ||
# CHECK: %acc | ||
# CHECK: %a | ||
# CHECK: %read | ||
# CHECK: %b | ||
# CHECK: %read_1 | ||
# CHECK: %register | ||
# CHECK: %mma | ||
# CHECK: return register | ||
|
||
# Subgraph in custom format: | ||
# CHECK: placeholder | ||
# CHECK: placeholder | ||
# CHECK: read | ||
# CHECK: placeholder | ||
# CHECK: read | ||
# CHECK: register | ||
# CHECK: mma | ||
# CHECK: unknown: output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import shutil | ||
import sys | ||
|
||
import lit.formats | ||
import lit.util | ||
|
||
import lit.llvm | ||
|
||
from shark_turbine.support.logging import get_logger | ||
|
||
logger = get_logger("turbine.lit_tests") | ||
|
||
# Configuration file for the 'lit' test runner. | ||
lit.llvm.initialize(lit_config, config) | ||
from lit.llvm import llvm_config | ||
|
||
llvm_config.with_system_environment("PYTHONPATH") | ||
|
||
# name: The name of this test suite. | ||
config.name = "IREE_TURBINE" | ||
|
||
config.test_format = lit.formats.ShTest() | ||
|
||
# suffixes: A list of file extensions to treat as test files. | ||
config.suffixes = [".py"] | ||
|
||
# test_source_root: The root path where tests are located. | ||
config.test_source_root = os.path.dirname(__file__) | ||
|
||
# config.use_default_substitutions() | ||
config.excludes = ["__init__.py", "lit.cfg.py", "lit.site.cfg.py"] | ||
|
||
config.substitutions.extend( | ||
[ | ||
("%PYTHON", sys.executable), | ||
] | ||
) | ||
|
||
# Find a suitable filecheck. | ||
filecheck_exe = None | ||
if filecheck_exe is None: | ||
filecheck_exe = shutil.which("FileCheck") | ||
if filecheck_exe: | ||
logger.debug(f"Using LLVM FileCheck: {filecheck_exe}") | ||
if filecheck_exe is None: | ||
filecheck_exe = shutil.which("filecheck") | ||
if filecheck_exe: | ||
logger.debug(f"Using pure python filecheck: {filecheck_exe}") | ||
|
||
if filecheck_exe is not None: | ||
config.substitutions.extend( | ||
[ | ||
("FileCheck", filecheck_exe), | ||
] | ||
) | ||
else: | ||
logger.error( | ||
"FileCheck not found " | ||
"(install pure python version with 'pip install filecheck')" | ||
) | ||
|
||
project_root = os.path.dirname(os.path.dirname(__file__)) | ||
lit.llvm.llvm_config.with_environment("PYTHONPATH", project_root, append_path=True) | ||
config.environment["FILECHECK_OPTS"] = "--dump-input=fail" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.