Skip to content

Commit

Permalink
TKW initial commit
Browse files Browse the repository at this point in the history
- Readme describing TKW approach
- initial op set: read, write, mma, tiled_loop, register
- wave decorator for block-based programming model
- types for yet to be specified Memory and for Registers
- initial gemm test that traces successfully
- unit tests for tracing the initial set of ops
- add lit and filecheck dependencies for unit tests
- interface on top of fx.Node to rely less on string matching when navigating the fx.Graph.
- small changes to tracing contexts to use op handlers from node interfaces

Signed-off-by: Martin Lücke <[email protected]>
  • Loading branch information
martin-luecke committed Jun 14, 2024
1 parent c4db712 commit 40f3a08
Show file tree
Hide file tree
Showing 15 changed files with 1,013 additions and 3 deletions.
48 changes: 48 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# RUN: python %s

import pytest
from typing import Callable
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
import torch


M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE


def launch(func: Callable[[], None]) -> Callable[[], None]:
"""
Run a function as part of the test suite in a test launch context.
Provides default values for the hyperparameters.
"""
if __name__ == "__main__":
with tk.gen.TestLaunchContext(
{
M: 16,
N: 16,
K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}
):
func()
return func


@launch
def test_read():
@tkw.wave()
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
tkw.read(a)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="Read: Currently only stub implementation"
):
test(a)


# TODO: Add more tests once we have more than a stub implementation.
189 changes: 189 additions & 0 deletions lit_tests/kernel/wave/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# RUN: python %s | FileCheck %s

from typing import Callable
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.ops.wave_ops import get_custom


M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE


def run(func: Callable[[], None]) -> Callable[[], None]:
"""Run a function as part of the test suite."""
if __name__ == "__main__":
func()

return func


@run
def test_trace_empty():
@tkw.wave_trace_only()
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
pass

trace = test()
print(trace.get_root_graph())
# CHECK: %a
# CHECK: return None
for node in trace.get_root_graph().nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK-SAME: MemoryType[M, N].of(f16)
# CHECK: unknown: output


@run
def test_trace_read():
@tkw.wave_trace_only()
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
tkw.read(a)

trace = test()
print(trace.get_root_graph())
# CHECK: %a
# CHECK: %read
# CHECK: return None
for node in trace.get_root_graph().nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK: read(memory=a
# CHECK: unknown: output


@run
def test_trace_register():
@tkw.wave_trace_only()
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
tkw.register([M, N], tkl.f16, 0.0)

trace = test()
print(trace.get_root_graph())
# CHECK: %a
# CHECK: %register
# CHECK: return None
for node in trace.get_root_graph().nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK: new_register
# CHECK-SAME: shape=[M, N], dtype=f16
# CHECK: unknown: output


@run
def test_trace_write():
@tkw.wave_trace_only()
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
val = tkw.register([M, N], tkl.f16, 0.0)
tkw.write(val, a, elements_per_thread=4)

trace = test()
print(trace.get_root_graph())
# CHECK: %a
# CHECK: %register
# CHECK: %write
# CHECK: return None
for node in trace.get_root_graph().nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK: new_register
# CHECK: write
# CHECK: unknown: output


@run
def test_trace_mma():
@tkw.wave_trace_only()
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
reg_0 = tkw.read(a)
reg_1 = tkw.read(a)
acc = tkl.Register[M, N, tkl.f32](0.0)
mma = tkw.mma(reg_0, reg_1, acc)

trace = test()
print(trace.get_root_graph())
# CHECK: %a
# CHECK: %read
# CHECK: %read
# CHECK: %register
# CHECK: %mma
# CHECK: return None
for node in trace.get_root_graph().nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK: read
# CHECK: read
# CHECK: new_register
# CHECK: mma
# CHECK: unknown: output


@run
def test_trace_gemm():
@tkw.wave_trace_only()
def gemm(
A: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
B: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
C: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
c = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c])
def repeat(acc) -> tkl.Register[M, N, tkl.f32]:
a = tkw.read(A)
b = tkw.read(B)
acc = tkl.Register[M, N, tkl.f32](0.0)
mma = tkw.mma(a, b, acc)
return acc

tkw.write(repeat, C, elements_per_thread=4)

trace = gemm()
print(trace.get_root_graph())
# CHECK: %a
# CHECK: %b
# CHECK: %c
# CHECK: %register
# CHECK: %reduction
# CHECK: %write
# CHECK: return None
for node in trace.get_root_graph().nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK: placeholder
# CHECK: placeholder
# CHECK: new_register
# CHECK: reduction
# CHECK: write
# CHECK: unknown: output

# TODO: Handle proper automatic printing of the subgraph in reduction
print(trace.region_graph.subgraphs["region_0"])
# CHECK: %acc
# CHECK: %a
# CHECK: %read
# CHECK: %b
# CHECK: %read_1
# CHECK: %register
# CHECK: %mma
# CHECK: return register
for node in trace.region_graph.subgraphs["region_0"].nodes:
print(get_custom(node))
# CHECK: placeholder
# CHECK: placeholder
# CHECK: read
# CHECK: placeholder
# CHECK: read
# CHECK: new_register
# CHECK: mma
# CHECK: unknown: output
61 changes: 61 additions & 0 deletions lit_tests/lit.cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import shutil
import sys

import lit.formats
import lit.util

import lit.llvm

# Configuration file for the 'lit' test runner.
lit.llvm.initialize(lit_config, config)
from lit.llvm import llvm_config

llvm_config.with_system_environment("PYTHONPATH")

# name: The name of this test suite.
config.name = "IREE_TURBINE"

config.test_format = lit.formats.ShTest()

# suffixes: A list of file extensions to treat as test files.
config.suffixes = [".py"]

# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)

# config.use_default_substitutions()
config.excludes = ["__init__.py", "lit.cfg.py", "lit.site.cfg.py"]

config.substitutions.extend(
[
("%PYTHON", sys.executable),
]
)

# Find a suitable filecheck.
filecheck_exe = None
if filecheck_exe is None:
filecheck_exe = shutil.which("FileCheck")
if filecheck_exe:
print(f"Using LLVM FileCheck: {filecheck_exe}")
if filecheck_exe is None:
filecheck_exe = shutil.which("filecheck")
if filecheck_exe:
print(f"Using pure python filecheck: {filecheck_exe}")

if filecheck_exe is not None:
config.substitutions.extend(
[
("FileCheck", filecheck_exe),
]
)
else:
print(
"FileCheck not found "
"(install pure python version with 'pip install filecheck')"
)

project_root = os.path.dirname(os.path.dirname(__file__))
lit.llvm.llvm_config.with_environment("PYTHONPATH", project_root, append_path=True)
config.environment["FILECHECK_OPTS"] = "--dump-input=fail"
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Build/test requirements.
Jinja2==3.1.3
filecheck==0.0.24
numpy==1.26.3
parameterized==0.9.0
pytest==8.0.0
pytest-xdist==3.5.0
lit==18.1.7
mypy==1.8.0
setuptools
wheel
Expand Down
7 changes: 7 additions & 0 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..lang.types import (
Index,
)
from ..ops.wave_ops import CustomOp, Placeholder, Reduction, Unknown

from .regions import RegionGraph, SubgraphTracer

Expand Down Expand Up @@ -178,6 +179,12 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
for n in grid_type.symbolic_shape
]

def register_custom_op(self, name: str, op: CustomOp):
def handler(*args, **kwargs):
return op.handle(self.region_graph, *args, **kwargs)

setattr(self, f"handle_{name}", handler)

### ========================================================================
### Core Operations
### ========================================================================
Expand Down
3 changes: 3 additions & 0 deletions shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
IntegerType,
Location,
Operation,
OpResult,
MemRefType,
ShapedType,
StringAttr,
SymbolTable,
Type as IrType,
UnitAttr,
Value,
VectorType,
)
Expand All @@ -37,6 +39,7 @@
flow as flow_d,
func as func_d,
math as math_d,
memref as memref_d,
stream as stream_d,
vector as vector_d,
scf as scf_d,
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prims import *
from .types import *
from .kernel_buffer import *
from .functional_types import *
from .grid import *

# Include publics from the _support library.
Expand Down
Loading

0 comments on commit 40f3a08

Please sign in to comment.