Skip to content

Commit

Permalink
TKW initial commit
Browse files Browse the repository at this point in the history
- Readme describing TKW approach
- initial op set: read, write, mma, tiled_loop, register
- wave decorator for block-based programming model
- types for yet to be specified Memory and for Registers
- initial gemm test that traces successfully
- unit tests for tracing the initial set of ops
- lit and filecheck dependencies for unit tests
- interface on top of fx.Node to rely less on string matching when navigating the fx.Graph.
- small changes to tracing contexts to use op handlers from node interfaces

Signed-off-by: Martin Lücke <[email protected]>
  • Loading branch information
martin-luecke committed Jun 13, 2024
1 parent c4db712 commit ec5b437
Show file tree
Hide file tree
Showing 15 changed files with 1,019 additions and 3 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Build/test requirements.
Jinja2==3.1.3
filecheck==0.0.24
numpy==1.26.3
parameterized==0.9.0
pytest==8.0.0
pytest-xdist==3.5.0
lit==18.1.7
mypy==1.8.0
setuptools
wheel
Expand Down
15 changes: 15 additions & 0 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..lang.types import (
Index,
)
from ..ops.wave_ops import CustomOp, Placeholder, Reduction, Unknown

from .regions import RegionGraph, SubgraphTracer

Expand Down Expand Up @@ -178,6 +179,20 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
for n in grid_type.symbolic_shape
]

def register_custom_op(self, name: str, op: CustomOp):
def handler(*args, **kwargs):
return op.handle(self.region_graph, *args, **kwargs)

setattr(self, f"handle_{name}", handler)

def node(self, node: fx.Node) -> CustomOp:
if node.op == "placeholder":
return Placeholder.from_fx_node(node)
# If the node was created as a CustomNode it has a corresponding field
if hasattr(node, "tkw_op"):
return node.tkw_op.from_fx_node(node)
return Unknown.from_fx_node(node)

### ========================================================================
### Core Operations
### ========================================================================
Expand Down
3 changes: 3 additions & 0 deletions shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
IntegerType,
Location,
Operation,
OpResult,
MemRefType,
ShapedType,
StringAttr,
SymbolTable,
Type as IrType,
UnitAttr,
Value,
VectorType,
)
Expand All @@ -37,6 +39,7 @@
flow as flow_d,
func as func_d,
math as math_d,
memref as memref_d,
stream as stream_d,
vector as vector_d,
scf as scf_d,
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prims import *
from .types import *
from .kernel_buffer import *
from .functional_types import *
from .grid import *

# Include publics from the _support library.
Expand Down
136 changes: 136 additions & 0 deletions shark_turbine/kernel/lang/functional_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Optional, Type, TypeVar, cast, ClassVar
from enum import Enum

from ..lang.kernel_buffer import KernelBufferUsage
from ..ops.wave_ops import register
from .._support.shaped_type import ShapedDataType
from .._support.dtype import DataType
from ...kernel._support.indexing import IndexExpr
import torch
import torch.fx

__all__ = [
"Memory",
"Register",
"AddressSpace",
]

MemoryTypeT = TypeVar("MemoryTypeT")


class AddressSpace(Enum):
REGISTER = 0
SHARED_MEMORY = 1
GLOBAL_MEMORY = 2


class _MemoryStorage(ShapedDataType):
def new_subtype(
cls: Type[MemoryTypeT],
*,
symbolic_shape: tuple[IndexExpr, ...],
address_space: AddressSpace,
dtype: DataType,
usage: Optional[KernelBufferUsage] = None,
) -> Type[MemoryTypeT]:
init_symbolic_shape = symbolic_shape
init_dtype = dtype
init_address_space = (
address_space if address_space else AddressSpace.REGISTER.value
)
init_usage = usage

class MemoryType(cls):
symbolic_shape = init_symbolic_shape
rank = len(symbolic_shape)
address_space = init_address_space
dtype = init_dtype
usage = init_usage

return cast(Type[MemoryTypeT], MemoryType)


class Memory(metaclass=_MemoryStorage):
"""
Represents storage anywhere in the memory hierarchy except registers.
Parameterized by a shape, address space and element type. The allocated
memory is traversed by an iterator that specifies the offset, stride
and size along each dimension.
"""

symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
address_space: ClassVar[int]
rank: ClassVar[int]
dtype: ClassVar[DataType]
usage: ClassVar[Optional[KernelBufferUsage]]

def __init__(self, tensor: torch.Tensor) -> None:
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}"
self._tensor = tensor
self.symbolic_shape = None

def __class_getitem__(
cls, shape_and_dtype: tuple[IndexExpr | DataType, ...]
) -> Type["Memory"]:
"""Syntax: `Memory[shape1, ...., shapeN, addressSpace, dtype, Optional[usage]]"""
if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 3:
raise TypeError(f"Expected at least 3 arguments, got: {shape_and_dtype}")

shift = 0
usage = None
if isinstance(shape_and_dtype[-1], KernelBufferUsage):
shift = 1
usage = shape_and_dtype[-1]
shape = shape_and_dtype[: -2 - shift]
addressSpace = shape_and_dtype[-2 - shift]
dtype = shape_and_dtype[-1 - shift]

if not all(isinstance(s, IndexExpr) for s in shape):
raise TypeError(f"Expected shape to be a tuple of IndexExpr, got {shape}")
if not isinstance(dtype, DataType):
raise TypeError(f"Expected dtype to be a DataType, got {dtype}")
if not isinstance(addressSpace, IndexExpr):
raise TypeError(
f"Expected addressSpace to be a AddressSpace, got {addressSpace}"
)

shape = cast(tuple[IndexExpr, ...], shape)
dtype = cast(DataType, dtype)
addressSpace = cast(AddressSpace, addressSpace)

return cls.new_subtype(
symbolic_shape=shape, address_space=addressSpace, dtype=dtype, usage=usage
)


class Register(metaclass=_MemoryStorage):
"Represents virtual registers. Parameterized by a shape and element type."
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]

def __new__(cls, value) -> torch.fx.Proxy:
return register(cls.symbolic_shape, cls.dtype, value)

def set(self, value) -> None:
self.value = value

def __class_getitem__(
cls, shape_and_dtype: tuple[IndexExpr | DataType, ...]
) -> Type["Register"]:

if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 2:
raise TypeError(f"Expected at least 2 arguments, got: {shape_and_dtype}")

shape = shape_and_dtype[:-1]
dtype = shape_and_dtype[-1]

shape = cast(tuple[IndexExpr, ...], shape)
dtype = cast(DataType, dtype)
return cls.new_subtype(
symbolic_shape=shape, dtype=dtype, address_space=AddressSpace.REGISTER.value
)


def is_memory_meta_derived(t: type) -> bool:
return isinstance(t, _MemoryStorage)
7 changes: 4 additions & 3 deletions shark_turbine/kernel/ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .._support import context

T = TypeVar("T")
OpDispatcherT = TypeVar("OpDispatcherT", bound="OpDispatcher")


class OpDispatcher:
Expand All @@ -18,9 +19,9 @@ def handle_{idname}(self, operator, *args, **kwargs)

__tk_context_idname__ = "OpDispatcher"

@staticmethod
def current() -> "OpDispatcher":
return context.current(OpDispatcher)
@classmethod
def current(cls: Type[OpDispatcherT]) -> OpDispatcherT:
return context.current(cls)

def __enter__(self) -> "OpDispatcher":
return context.push(OpDispatcher, self)
Expand Down
Loading

0 comments on commit ec5b437

Please sign in to comment.