Skip to content

Commit

Permalink
TKW initial commit
Browse files Browse the repository at this point in the history
- Readme describing TKW approach
- initial op set: read, write, mma, tiled_loop, construct_register_from_metadata
- wave decorator for block-based programming model
- types for yet to be specified Memory and for Registers
- initial gemm test that traces successfully
- interface on top of fx.Node to rely less on string matching when navigating the fx.Graph.
- small changes to tracing contexts to use op handlers from node interfaces

Signed-off-by: Martin Lücke <[email protected]>
  • Loading branch information
martin-luecke committed Jun 10, 2024
1 parent c4db712 commit 7dfabaa
Show file tree
Hide file tree
Showing 11 changed files with 564 additions and 3 deletions.
159 changes: 159 additions & 0 deletions shark_turbine/kernel/_support/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any, Optional, Sequence, Type, TypeVar, final
import torch.fx as fx

from ..lang.functional_types import Memory, Register
from .._support.indexing import IndexExpr
from .._support.dtype import DataType


def get_node_name(string: str, skip_first: bool = True):
snakeString = ""
if skip_first:
snakeString += string[0].lower()
string = string[1:]
for i in string:
if i.isupper():
snakeString += "_" + i.lower()
else:
snakeString += i
# Drop the "_node" suffix
return snakeString[:-5]


CustomNodeT = TypeVar("CustomNodeT", bound="CustomNode")
PlaceholderNodeT = TypeVar("PlaceholderNodeT", bound="PlaceholderNode")


@dataclass
class CustomNode(ABC):
"""
Base class for all custom fx nodes.
"""

graph: fx.Graph
op: Any

@classmethod
def from_fx_node(cls: Type[CustomNodeT], node: fx.Node) -> CustomNodeT:
instance = cls(node.graph, node.op, *node.args)
instance.fx_node = node
return instance

def __str__(self) -> str:
name = get_node_name(self.__class__.__name__)
# print all variables of the node apart from graph and op
vars_list = [f"{key}={value}" for key, value in vars(self).items()][2:]
vars_str = ", ".join(vars_list)
return f"{name}({vars_str})"

def custom_string(self, value_map: dict[str, str]) -> str:
# If a subclass does not define custom printing we revert to the default
return str(self)

def add_to_graph(self, region_graph):
arg_list = tuple([value for _, value in vars(self).items()][2:])
self.fx_node = region_graph.create_proxy(
"call_function",
target=self.op,
args=arg_list,
kwargs={},
)

@classmethod
def handle(cls, graph, *args, **kwargs) -> fx.Node:
node = cls(graph, *args, **kwargs)
node.add_to_graph(graph)
return node.fx_node

@property
def name(self) -> str:
if hasattr(self, "_name"):
return self._name
return self.fx_node.name


@final
@dataclass
class UnknownNode(CustomNode):
"""
Represents an fx.Node that has no corresponding CustomNode class.
"""

args: Sequence[Any]
kwargs: dict[Any, Any]

@classmethod
def from_fx_node(cls, node: fx.Node) -> "UnknownNode":
return cls(node.graph, node.op, node.args, node.kwargs)


@dataclass
class PlaceholderNode(CustomNode):
"""
Represents a placeholder node in the graph, i.e. an input to a function.
"""

_name: str
type: Optional[DataType]

@classmethod
def from_fx_node(cls: Type[PlaceholderNodeT], node: fx.Node) -> PlaceholderNodeT:
return cls(node.graph, node.op, node.name, node.type)


# Nodes modeling TKW operations in the kernel language


@dataclass
class ConstructRegisterFromMetadataNode(CustomNode):
shape: tuple[IndexExpr, ...]
dtype: DataType
value: float


@dataclass
class MmaNode(CustomNode):
lhs: fx.Node
rhs: fx.Node
acc: fx.Node


@dataclass
class ReadNode(CustomNode):
memory: fx.Proxy
elements_per_thread: Optional[Any] = None
type: Optional[Type[Register]] = None


@dataclass
class ReductionNode(CustomNode):
axis: IndexExpr
init_args: Sequence[Any]
subgraph_name: str
implicit_captures: Sequence[fx.Proxy]

@classmethod
def handle(cls, graph, *args, **kwargs):
def wrapper(f):
with graph.subtracer() as subtracer:
subgraph_name, implicit_captures = subtracer.trace(f)
node = ReductionNode(
graph,
*args,
**kwargs,
subgraph_name=subgraph_name,
implicit_captures=implicit_captures,
)
node.add_to_graph(graph)
return node.fx_node

return wrapper


@dataclass
class WriteNode(CustomNode):
register_: fx.Proxy
memory: fx.Proxy
elements_per_thread: Optional[Any]
24 changes: 24 additions & 0 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..lang.types import (
Index,
)
from .nodes import CustomNode, PlaceholderNode, ReductionNode, UnknownNode

from .regions import RegionGraph, SubgraphTracer

Expand Down Expand Up @@ -177,6 +178,29 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False))
for n in grid_type.symbolic_shape
]
self.custom_ops: dict[str, CustomNode] = {}

def register_custom_op(self, name: str, op: CustomNode):
self.custom_ops[name] = op

def handler(*args, **kwargs):
return op.handle(self.region_graph, *args, **kwargs)

setattr(self, f"handle_{name}", handler)

def node(self, node: fx.Node) -> CustomNode:
for name, nodeT in self.custom_ops.items():
# The fx_nodes have a suffix depending on the number of occurrences
# of similar nodes. We are only interested in the name prefix.

# TODO: Instead simply shape the suffix off, as this depends on the
# iteration order for nodes sharing a prefix,
# e.g. read and read_shared
if node.name.startswith(name):
return nodeT.from_fx_node(node)
if node.op == "placeholder":
return PlaceholderNode.from_fx_node(node)
return UnknownNode.from_fx_node(node)

### ========================================================================
### Core Operations
Expand Down
3 changes: 3 additions & 0 deletions shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
IntegerType,
Location,
Operation,
OpResult,
MemRefType,
ShapedType,
StringAttr,
SymbolTable,
Type as IrType,
UnitAttr,
Value,
VectorType,
)
Expand All @@ -37,6 +39,7 @@
flow as flow_d,
func as func_d,
math as math_d,
memref as memref_d,
stream as stream_d,
vector as vector_d,
scf as scf_d,
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prims import *
from .types import *
from .kernel_buffer import *
from .functional_types import *
from .grid import *

# Include publics from the _support library.
Expand Down
140 changes: 140 additions & 0 deletions shark_turbine/kernel/lang/functional_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Optional, Type, TypeVar, cast, ClassVar
from enum import Enum


from ..lang.kernel_buffer import KernelBufferUsage
from .._support.shaped_type import ShapedDataType
from .._support.dtype import DataType
from ...kernel._support.indexing import IndexExpr
import torch

__all__ = [
"Memory",
"Register",
"AddressSpace",
]

MemoryTypeT = TypeVar("MemoryTypeT")


class AddressSpace(Enum):
REGISTER = 0
SHARED_MEMORY = 1
GLOBAL_MEMORY = 2


class _MemoryStorage(ShapedDataType):
def new_subtype(
cls: Type[MemoryTypeT],
*,
symbolic_shape: tuple[IndexExpr, ...],
address_space: AddressSpace,
dtype: DataType,
usage: Optional[KernelBufferUsage] = None,
) -> Type[MemoryTypeT]:
init_symbolic_shape = symbolic_shape
init_dtype = dtype
init_address_space = (
address_space if address_space else AddressSpace.REGISTER.value
)
init_usage = usage

class MemoryType(cls):
symbolic_shape = init_symbolic_shape
rank = len(symbolic_shape)
address_space = init_address_space
dtype = init_dtype
usage = init_usage

return cast(Type[MemoryTypeT], MemoryType)


class Memory(metaclass=_MemoryStorage):
"""
Represents storage anywhere in the memory hierarchy except registers.
Parameterized by a shape, address space and element type. The allocated
memory is traversed by an iterator that specifies the offset, stride
and size along each dimension.
"""

symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
address_space: ClassVar[int]
rank: ClassVar[int]
dtype: ClassVar[DataType]
usage: ClassVar[Optional[KernelBufferUsage]]

def __init__(self, tensor: torch.Tensor) -> None:
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}"
self._tensor = tensor
self.symbolic_shape = None

def __class_getitem__(
cls, shape_and_dtype: tuple[IndexExpr | DataType, ...]
) -> Type["Memory"]:
"""Syntax: `Memory[shape1, ...., shapeN, addressSpace, dtype, Optional[usage]]"""
if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 3:
raise TypeError(f"Expected at least 3 arguments, got: {shape_and_dtype}")

shift = 0
usage = None
if isinstance(shape_and_dtype[-1], KernelBufferUsage):
shift = 1
usage = shape_and_dtype[-1]
shape = shape_and_dtype[: -2 - shift]
addressSpace = shape_and_dtype[-2 - shift]
dtype = shape_and_dtype[-1 - shift]

if not all(isinstance(s, IndexExpr) for s in shape):
raise TypeError(f"Expected shape to be a tuple of IndexExpr, got {shape}")
if not isinstance(dtype, DataType):
raise TypeError(f"Expected dtype to be a DataType, got {dtype}")
if not isinstance(addressSpace, IndexExpr):
raise TypeError(
f"Expected addressSpace to be a AddressSpace, got {addressSpace}"
)

shape = cast(tuple[IndexExpr, ...], shape)
dtype = cast(DataType, dtype)
addressSpace = cast(AddressSpace, addressSpace)

return cls.new_subtype(
symbolic_shape=shape, address_space=addressSpace, dtype=dtype, usage=usage
)


class Register(metaclass=_MemoryStorage):
"Represents virtual registers. Parameterized by a shape and element type."
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]

def __init__(self, shape, dtype) -> None:
if not isinstance(shape, tuple):
raise TypeError(f"Expected at shape to be a tuple, got: {shape}")
self.symbolic_shape = shape
self.rank = len(self.symbolic_shape)
self.dtype = dtype
self.value = None

def set(self, value) -> None:
self.value = value

def __class_getitem__(
cls, shape_and_dtype: tuple[IndexExpr | DataType, ...]
) -> Type["Register"]:

if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 2:
raise TypeError(f"Expected at least 2 arguments, got: {shape_and_dtype}")

shape = shape_and_dtype[:-1]
dtype = shape_and_dtype[-1]

shape = cast(tuple[IndexExpr, ...], shape)
dtype = cast(DataType, dtype)
return cls.new_subtype(
symbolic_shape=shape, dtype=dtype, address_space=AddressSpace.REGISTER.value
)


def is_memory_meta_derived(t: type) -> bool:
return isinstance(t, _MemoryStorage)
7 changes: 4 additions & 3 deletions shark_turbine/kernel/ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .._support import context

T = TypeVar("T")
OpDispatcherT = TypeVar("OpDispatcherT", bound="OpDispatcher")


class OpDispatcher:
Expand All @@ -18,9 +19,9 @@ def handle_{idname}(self, operator, *args, **kwargs)

__tk_context_idname__ = "OpDispatcher"

@staticmethod
def current() -> "OpDispatcher":
return context.current(OpDispatcher)
@classmethod
def current(cls: Type[OpDispatcherT]) -> OpDispatcherT:
return context.current(cls)

def __enter__(self) -> "OpDispatcher":
return context.push(OpDispatcher, self)
Expand Down
Loading

0 comments on commit 7dfabaa

Please sign in to comment.