-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Readme describing TKW approach - initial op set: read, write, mma, tiled_loop, construct_register_from_metadata - wave decorator for block-based programming model - types for yet to be specified Memory and for Registers - initial gemm test that traces successfully - interface on top of fx.Node to rely less on string matching when navigating the fx.Graph. - small changes to tracing contexts to use op handlers from node interfaces
- Loading branch information
1 parent
c4db712
commit bb4f0ba
Showing
12 changed files
with
568 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from abc import ABC | ||
from dataclasses import dataclass | ||
from typing import Any, Optional, Sequence, Type, TypeVar, final | ||
import torch.fx as fx | ||
|
||
from ..lang.functional_types import Memory, Register | ||
from .._support.indexing import IndexExpr | ||
from .._support.dtype import DataType | ||
|
||
|
||
def get_node_name(string: str, skip_first: bool = True): | ||
snakeString = "" | ||
if skip_first: | ||
snakeString += string[0].lower() | ||
string = string[1:] | ||
for i in string: | ||
if i.isupper(): | ||
snakeString += "_" + i.lower() | ||
else: | ||
snakeString += i | ||
# Drop the "_node" suffix | ||
return snakeString[:-5] | ||
|
||
|
||
CustomNodeT = TypeVar("CustomNodeT", bound="CustomNode") | ||
PlaceholderNodeT = TypeVar("PlaceholderNodeT", bound="PlaceholderNode") | ||
|
||
|
||
@dataclass | ||
class CustomNode(ABC): | ||
""" | ||
Base class for all custom fx nodes. | ||
""" | ||
|
||
graph: fx.Graph | ||
op: Any | ||
|
||
@classmethod | ||
def from_fx_node(cls: Type[CustomNodeT], node: fx.Node) -> CustomNodeT: | ||
instance = cls(node.graph, node.op, *node.args) | ||
instance.fx_node = node | ||
return instance | ||
|
||
def __str__(self) -> str: | ||
name = get_node_name(self.__class__.__name__) | ||
# print all variables of the node apart from graph and op | ||
vars_list = [f"{key}={value}" for key, value in vars(self).items()][2:] | ||
vars_str = ", ".join(vars_list) | ||
return f"{name}({vars_str})" | ||
|
||
def custom_string(self, value_map: dict[str, str]) -> str: | ||
# If a subclass does not define custom printing we revert to the default | ||
return str(self) | ||
|
||
def add_to_graph(self, region_graph): | ||
arg_list = tuple([value for _, value in vars(self).items()][2:]) | ||
self.fx_node = region_graph.create_proxy( | ||
"call_function", | ||
target=self.op, | ||
args=arg_list, | ||
kwargs={}, | ||
) | ||
|
||
@classmethod | ||
def handle(cls, graph, *args, **kwargs) -> fx.Node: | ||
node = cls(graph, *args, **kwargs) | ||
node.add_to_graph(graph) | ||
return node.fx_node | ||
|
||
@property | ||
def name(self) -> str: | ||
if hasattr(self, "_name"): | ||
return self._name | ||
return self.fx_node.name | ||
|
||
|
||
@final | ||
@dataclass | ||
class UnknownNode(CustomNode): | ||
""" | ||
Represents an fx.Node that has no corresponding CustomNode class. | ||
""" | ||
|
||
args: Sequence[Any] | ||
kwargs: dict[Any, Any] | ||
|
||
@classmethod | ||
def from_fx_node(cls, node: fx.Node) -> "UnknownNode": | ||
return cls(node.graph, node.op, node.args, node.kwargs) | ||
|
||
|
||
@dataclass | ||
class PlaceholderNode(CustomNode): | ||
""" | ||
Represents a placeholder node in the graph, i.e. an input to a function. | ||
""" | ||
|
||
_name: str | ||
type: Optional[DataType] | ||
|
||
@classmethod | ||
def from_fx_node(cls: Type[PlaceholderNodeT], node: fx.Node) -> PlaceholderNodeT: | ||
return cls(node.graph, node.op, node.name, node.type) | ||
|
||
|
||
# Nodes modeling TKW operations in the kernel language | ||
|
||
|
||
@dataclass | ||
class ConstructRegisterFromMetadataNode(CustomNode): | ||
shape: tuple[IndexExpr, ...] | ||
dtype: DataType | ||
value: float | ||
|
||
|
||
@dataclass | ||
class MmaNode(CustomNode): | ||
lhs: fx.Node | ||
rhs: fx.Node | ||
acc: fx.Node | ||
|
||
|
||
@dataclass | ||
class ReadNode(CustomNode): | ||
memory: fx.Proxy | ||
elements_per_thread: Optional[Any] = None | ||
type: Optional[Type[Register]] = None | ||
|
||
|
||
@dataclass | ||
class ReductionNode(CustomNode): | ||
axis: IndexExpr | ||
init_args: Sequence[Any] | ||
subgraph_name: str | ||
implicit_captures: Sequence[fx.Proxy] | ||
|
||
@classmethod | ||
def handle(cls, graph, *args, **kwargs): | ||
def wrapper(f): | ||
with graph.subtracer() as subtracer: | ||
subgraph_name, implicit_captures = subtracer.trace(f) | ||
node = ReductionNode( | ||
graph, | ||
*args, | ||
**kwargs, | ||
subgraph_name=subgraph_name, | ||
implicit_captures=implicit_captures, | ||
) | ||
node.add_to_graph(graph) | ||
return node.fx_node | ||
|
||
return wrapper | ||
|
||
|
||
@dataclass | ||
class WriteNode(CustomNode): | ||
register_: fx.Proxy | ||
memory: fx.Proxy | ||
elements_per_thread: Optional[Any] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from typing import Optional, Type, TypeVar, cast, ClassVar | ||
from enum import Enum | ||
|
||
|
||
from ..lang.kernel_buffer import KernelBufferUsage | ||
from .._support.shaped_type import ShapedDataType | ||
from .._support.dtype import DataType | ||
from ...kernel._support.indexing import IndexExpr | ||
import torch | ||
|
||
__all__ = [ | ||
"Memory", | ||
"Register", | ||
"AddressSpace", | ||
] | ||
|
||
MemoryTypeT = TypeVar("MemoryTypeT") | ||
|
||
|
||
class AddressSpace(Enum): | ||
REGISTER = 0 | ||
SHARED_MEMORY = 1 | ||
GLOBAL_MEMORY = 2 | ||
|
||
|
||
class _MemoryStorage(ShapedDataType): | ||
def new_subtype( | ||
cls: Type[MemoryTypeT], | ||
*, | ||
symbolic_shape: tuple[IndexExpr, ...], | ||
address_space: AddressSpace, | ||
dtype: DataType, | ||
usage: Optional[KernelBufferUsage] = None, | ||
) -> Type[MemoryTypeT]: | ||
init_symbolic_shape = symbolic_shape | ||
init_dtype = dtype | ||
init_address_space = ( | ||
address_space if address_space else AddressSpace.REGISTER.value | ||
) | ||
init_usage = usage | ||
|
||
class MemoryType(cls): | ||
symbolic_shape = init_symbolic_shape | ||
rank = len(symbolic_shape) | ||
address_space = init_address_space | ||
dtype = init_dtype | ||
usage = init_usage | ||
|
||
return cast(Type[MemoryTypeT], MemoryType) | ||
|
||
|
||
class Memory(metaclass=_MemoryStorage): | ||
""" | ||
Represents storage anywhere in the memory hierarchy except registers. | ||
Parameterized by a shape, address space and element type. The allocated | ||
memory is traversed by an iterator that specifies the offset, stride | ||
and size along each dimension. | ||
""" | ||
|
||
symbolic_shape: ClassVar[tuple[IndexExpr, ...]] | ||
address_space: ClassVar[int] | ||
rank: ClassVar[int] | ||
dtype: ClassVar[DataType] | ||
usage: ClassVar[Optional[KernelBufferUsage]] | ||
|
||
def __init__(self, tensor: torch.Tensor) -> None: | ||
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" | ||
self._tensor = tensor | ||
self.symbolic_shape = None | ||
|
||
def __class_getitem__( | ||
cls, shape_and_dtype: tuple[IndexExpr | DataType, ...] | ||
) -> Type["Memory"]: | ||
"""Syntax: `Memory[shape1, ...., shapeN, addressSpace, dtype, Optional[usage]]""" | ||
if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 3: | ||
raise TypeError(f"Expected at least 3 arguments, got: {shape_and_dtype}") | ||
|
||
shift = 0 | ||
usage = None | ||
if isinstance(shape_and_dtype[-1], KernelBufferUsage): | ||
shift = 1 | ||
usage = shape_and_dtype[-1] | ||
shape = shape_and_dtype[: -2 - shift] | ||
addressSpace = shape_and_dtype[-2 - shift] | ||
dtype = shape_and_dtype[-1 - shift] | ||
|
||
if not all(isinstance(s, IndexExpr) for s in shape): | ||
raise TypeError(f"Expected shape to be a tuple of IndexExpr, got {shape}") | ||
if not isinstance(dtype, DataType): | ||
raise TypeError(f"Expected dtype to be a DataType, got {dtype}") | ||
if not isinstance(addressSpace, IndexExpr): | ||
raise TypeError( | ||
f"Expected addressSpace to be a AddressSpace, got {addressSpace}" | ||
) | ||
|
||
shape = cast(tuple[IndexExpr, ...], shape) | ||
dtype = cast(DataType, dtype) | ||
addressSpace = cast(AddressSpace, addressSpace) | ||
|
||
return cls.new_subtype( | ||
symbolic_shape=shape, address_space=addressSpace, dtype=dtype, usage=usage | ||
) | ||
|
||
|
||
class Register(metaclass=_MemoryStorage): | ||
"Represents virtual registers. Parameterized by a shape and element type." | ||
symbolic_shape: ClassVar[tuple[IndexExpr, ...]] | ||
rank: ClassVar[int] | ||
dtype: ClassVar[DataType] | ||
|
||
def __init__(self, shape, dtype) -> None: | ||
if not isinstance(shape, tuple): | ||
raise TypeError(f"Expected at shape to be a tuple, got: {shape}") | ||
self.symbolic_shape = shape | ||
self.rank = len(self.symbolic_shape) | ||
self.dtype = dtype | ||
self.value = None | ||
|
||
def set(self, value) -> None: | ||
self.value = value | ||
|
||
def __class_getitem__( | ||
cls, shape_and_dtype: tuple[IndexExpr | DataType, ...] | ||
) -> Type["Register"]: | ||
|
||
if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 2: | ||
raise TypeError(f"Expected at least 2 arguments, got: {shape_and_dtype}") | ||
|
||
shape = shape_and_dtype[:-1] | ||
dtype = shape_and_dtype[-1] | ||
|
||
shape = cast(tuple[IndexExpr, ...], shape) | ||
dtype = cast(DataType, dtype) | ||
return cls.new_subtype( | ||
symbolic_shape=shape, dtype=dtype, address_space=AddressSpace.REGISTER.value | ||
) | ||
|
||
|
||
def is_memory_meta_derived(t: type) -> bool: | ||
return isinstance(t, _MemoryStorage) |
Oops, something went wrong.