Skip to content

Commit

Permalink
TKW initial commit
Browse files Browse the repository at this point in the history
- Readme describing TKW approach
- initial op set: read, write, mma, tiled_loop, construct_register_from_metadata
- wave decorator for block-based programming model
- types for yet to be specified Memory and for Registers
- initial gemm test that traces successfully
- interface on top of fx.Node to rely less on string matching when navigating the fx.Graph.
- small changes to tracing contexts to use op handlers from node interfaces
  • Loading branch information
martin-luecke committed May 15, 2024
1 parent 3520958 commit e3b0abc
Show file tree
Hide file tree
Showing 12 changed files with 568 additions and 13 deletions.
4 changes: 2 additions & 2 deletions iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240427.876
iree-runtime==20240427.876
iree-compiler==20240514.893
iree-runtime==20240514.893
159 changes: 159 additions & 0 deletions shark_turbine/kernel/_support/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any, Optional, Sequence, Type, TypeVar, final
import torch.fx as fx

from ..lang.functional_types import Memory, Register
from .._support.indexing import IndexExpr
from .._support.dtype import DataType


def get_node_name(string: str, skip_first: bool = True):
snakeString = ""
if skip_first:
snakeString += string[0].lower()
string = string[1:]
for i in string:
if i.isupper():
snakeString += "_" + i.lower()
else:
snakeString += i
# Drop the "_node" suffix
return snakeString[:-5]


CustomNodeT = TypeVar("CustomNodeT", bound="CustomNode")
PlaceholderNodeT = TypeVar("PlaceholderNodeT", bound="PlaceholderNode")


@dataclass
class CustomNode(ABC):
"""
Base class for all custom fx nodes.
"""

graph: fx.Graph
op: Any

@classmethod
def from_fx_node(cls: Type[CustomNodeT], node: fx.Node) -> CustomNodeT:
instance = cls(node.graph, node.op, *node.args)
instance.fx_node = node
return instance

def __str__(self) -> str:
name = get_node_name(self.__class__.__name__)
# print all variables of the node apart from graph and op
vars_list = [f"{key}={value}" for key, value in vars(self).items()][2:]
vars_str = ", ".join(vars_list)
return f"{name}({vars_str})"

def custom_string(self, value_map: dict[str, str]) -> str:
# If a subclass does not define custom printing we revert to the default
return str(self)

def add_to_graph(self, region_graph):
arg_list = tuple([value for _, value in vars(self).items()][2:])
self.fx_node = region_graph.create_proxy(
"call_function",
target=self.op,
args=arg_list,
kwargs={},
)

@classmethod
def handle(cls, graph, *args, **kwargs) -> fx.Node:
node = cls(graph, *args, **kwargs)
node.add_to_graph(graph)
return node.fx_node

@property
def name(self) -> str:
if hasattr(self, "_name"):
return self._name
return self.fx_node.name


@final
@dataclass
class UnknownNode(CustomNode):
"""
Represents an fx.Node that has no corresponding CustomNode class.
"""

args: Sequence[Any]
kwargs: dict[Any, Any]

@classmethod
def from_fx_node(cls, node: fx.Node) -> "UnknownNode":
return cls(node.graph, node.op, node.args, node.kwargs)


@dataclass
class PlaceholderNode(CustomNode):
"""
Represents a placeholder node in the graph, i.e. an input to a function.
"""

_name: str
type: Optional[DataType]

@classmethod
def from_fx_node(cls: Type[PlaceholderNodeT], node: fx.Node) -> PlaceholderNodeT:
return cls(node.graph, node.op, node.name, node.type)


# Nodes modeling TKW operations in the kernel language


@dataclass
class ConstructRegisterFromMetadataNode(CustomNode):
shape: tuple[IndexExpr, ...]
dtype: DataType
value: float


@dataclass
class MmaNode(CustomNode):
lhs: fx.Node
rhs: fx.Node
acc: fx.Node


@dataclass
class ReadNode(CustomNode):
memory: fx.Proxy
elements_per_thread: Optional[Any] = None
type: Optional[Type[Register]] = None


@dataclass
class TiledLoop(CustomNode):
axis: IndexExpr
init_args: Sequence[Any]
subgraph_name: str
implicit_captures: Sequence[fx.Proxy]

@classmethod
def handle(cls, graph, *args, **kwargs):
def wrapper(f):
with graph.subtracer() as subtracer:
subgraph_name, implicit_captures = subtracer.trace(f)
node = TiledLoop(
graph,
*args,
**kwargs,
subgraph_name=subgraph_name,
implicit_captures=implicit_captures,
)
node.add_to_graph(graph)
return node.fx_node

return wrapper


@dataclass
class WriteNode(CustomNode):
register_: fx.Proxy
memory: fx.Proxy
elements_per_thread: Optional[Any]
36 changes: 28 additions & 8 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..lang.types import (
Index,
)
from .nodes import CustomNode, PlaceholderNode, UnknownNode, TiledLoop

from .regions import RegionGraph, SubgraphTracer

Expand Down Expand Up @@ -177,6 +178,29 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False))
for n in grid_type.symbolic_shape
]
self.custom_ops: dict[str, CustomNode] = {}

def register_custom_op(self, name: str, op: CustomNode):
self.custom_ops[name] = op

def handler(*args, **kwargs):
return op.handle(self.region_graph, *args, **kwargs)

setattr(self, f"handle_{name}", handler)

def node(self, node: fx.Node) -> CustomNode:
for name, nodeT in self.custom_ops.items():
# The fx_nodes have a suffix depending on the number of occurrences
# of similar nodes. We are only interested in the name prefix.

# TODO: Instead simply shape the suffix off, as this depends on the
# iteration order for nodes sharing a prefix,
# e.g. read and read_shared
if node.name.startswith(name):
return nodeT.from_fx_node(node)
if node.op == "placeholder":
return PlaceholderNode.from_fx_node(node)
return UnknownNode.from_fx_node(node)

### ========================================================================
### Core Operations
Expand Down Expand Up @@ -389,14 +413,11 @@ def __call__(self, *args, **kwargs):
return launch_context.launch(self, args, kwargs)

@abstractmethod
def eager_execute(self, args, kwargs):
...
def eager_execute(self, args, kwargs): ...

def aot_execute(self, args, kwargs):
...
def aot_execute(self, args, kwargs): ...

def test_execute(self, args, kwargs):
...
def test_execute(self, args, kwargs): ...


class LaunchContext(ABC):
Expand Down Expand Up @@ -435,8 +456,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
context.pop(LaunchContext, self)

@abstractmethod
def launch(self, launchable: Launchable, args, kwargs):
...
def launch(self, launchable: Launchable, args, kwargs): ...


class DebugLaunchContext(LaunchContext):
Expand Down
6 changes: 6 additions & 0 deletions shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,28 @@
IntegerType,
Location,
Operation,
OpResult,
MemRefType,
ShapedType,
StringAttr,
SymbolTable,
Type as IrType,
UnitAttr,
Value,
VectorType,
)

from iree.compiler.dialects import (
amdgpu as amdgpu_d,
arith as arith_d,
builtin as builtin_d,
flow as flow_d,
func as func_d,
gpu as gpu_d,
math as math_d,
memref as memref_d,
stream as stream_d,
transform as transform_d,
vector as vector_d,
scf as scf_d,
)
1 change: 1 addition & 0 deletions shark_turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prims import *
from .types import *
from .kernel_buffer import *
from .functional_types import *
from .grid import *

# Include publics from the _support library.
Expand Down
Loading

0 comments on commit e3b0abc

Please sign in to comment.