From 40f090f250e86c7c340610501473977a78570fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lu=CC=88cke?= Date: Wed, 15 May 2024 12:06:42 +0200 Subject: [PATCH] TKW initial commit - Readme describing TKW approach - initial op set: read, write, mma, tiled_loop, construct_register_from_metadata - wave decorator for block-based programming model - types for yet to be specified Memory and for Registers - initial gemm test that traces successfully - interface on top of fx.Node to rely less on string matching when navigating the fx.Graph. - small changes to tracing contexts to use op handlers from node interfaces --- iree-requirements.txt | 4 +- shark_turbine/kernel/_support/nodes.py | 159 ++++++++++++++++++ shark_turbine/kernel/_support/tracing.py | 36 +++- shark_turbine/kernel/compiler/ir.py | 6 + shark_turbine/kernel/lang/__init__.py | 1 + shark_turbine/kernel/lang/functional_types.py | 140 +++++++++++++++ shark_turbine/kernel/ops/base.py | 7 +- shark_turbine/kernel/wave/README.md | 28 +++ shark_turbine/kernel/wave/__init__.py | 2 + shark_turbine/kernel/wave/ops.py | 31 ++++ shark_turbine/kernel/wave/wave.py | 91 ++++++++++ tests/kernel/wave_gemm_test.py | 76 +++++++++ 12 files changed, 568 insertions(+), 13 deletions(-) create mode 100644 shark_turbine/kernel/_support/nodes.py create mode 100644 shark_turbine/kernel/lang/functional_types.py create mode 100644 shark_turbine/kernel/wave/README.md create mode 100644 shark_turbine/kernel/wave/__init__.py create mode 100644 shark_turbine/kernel/wave/ops.py create mode 100644 shark_turbine/kernel/wave/wave.py create mode 100644 tests/kernel/wave_gemm_test.py diff --git a/iree-requirements.txt b/iree-requirements.txt index f1cb04cc3..0bb7b33f1 100644 --- a/iree-requirements.txt +++ b/iree-requirements.txt @@ -1,2 +1,2 @@ -iree-compiler==20240427.876 -iree-runtime==20240427.876 +iree-compiler==20240514.893 +iree-runtime==20240514.893 diff --git a/shark_turbine/kernel/_support/nodes.py b/shark_turbine/kernel/_support/nodes.py new file mode 100644 index 000000000..053f4d030 --- /dev/null +++ b/shark_turbine/kernel/_support/nodes.py @@ -0,0 +1,159 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any, Optional, Sequence, Type, TypeVar, final +import torch.fx as fx + +from ..lang.functional_types import Memory, Register +from .._support.indexing import IndexExpr +from .._support.dtype import DataType + + +def get_node_name(string: str, skip_first: bool = True): + snakeString = "" + if skip_first: + snakeString += string[0].lower() + string = string[1:] + for i in string: + if i.isupper(): + snakeString += "_" + i.lower() + else: + snakeString += i + # Drop the "_node" suffix + return snakeString[:-5] + + +CustomNodeT = TypeVar("CustomNodeT", bound="CustomNode") +PlaceholderNodeT = TypeVar("PlaceholderNodeT", bound="PlaceholderNode") + + +@dataclass +class CustomNode(ABC): + """ + Base class for all custom fx nodes. + """ + + graph: fx.Graph + op: Any + + @classmethod + def from_fx_node(cls: Type[CustomNodeT], node: fx.Node) -> CustomNodeT: + instance = cls(node.graph, node.op, *node.args) + instance.fx_node = node + return instance + + def __str__(self) -> str: + name = get_node_name(self.__class__.__name__) + # print all variables of the node apart from graph and op + vars_list = [f"{key}={value}" for key, value in vars(self).items()][2:] + vars_str = ", ".join(vars_list) + return f"{name}({vars_str})" + + def custom_string(self, value_map: dict[str, str]) -> str: + # If a subclass does not define custom printing we revert to the default + return str(self) + + def add_to_graph(self, region_graph): + arg_list = tuple([value for _, value in vars(self).items()][2:]) + self.fx_node = region_graph.create_proxy( + "call_function", + target=self.op, + args=arg_list, + kwargs={}, + ) + + @classmethod + def handle(cls, graph, *args, **kwargs) -> fx.Node: + node = cls(graph, *args, **kwargs) + node.add_to_graph(graph) + return node.fx_node + + @property + def name(self) -> str: + if hasattr(self, "_name"): + return self._name + return self.fx_node.name + + +@final +@dataclass +class UnknownNode(CustomNode): + """ + Represents an fx.Node that has no corresponding CustomNode class. + """ + + args: Sequence[Any] + kwargs: dict[Any, Any] + + @classmethod + def from_fx_node(cls, node: fx.Node) -> "UnknownNode": + return cls(node.graph, node.op, node.args, node.kwargs) + + +@dataclass +class PlaceholderNode(CustomNode): + """ + Represents a placeholder node in the graph, i.e. an input to a function. + """ + + _name: str + type: Optional[DataType] + + @classmethod + def from_fx_node(cls: Type[PlaceholderNodeT], node: fx.Node) -> PlaceholderNodeT: + return cls(node.graph, node.op, node.name, node.type) + + +# Nodes modeling TKW operations in the kernel language + + +@dataclass +class ConstructRegisterFromMetadataNode(CustomNode): + shape: tuple[IndexExpr, ...] + dtype: DataType + value: float + + +@dataclass +class MmaNode(CustomNode): + lhs: fx.Node + rhs: fx.Node + acc: fx.Node + + +@dataclass +class ReadNode(CustomNode): + memory: fx.Proxy + elements_per_thread: Optional[Any] = None + type: Optional[Type[Register]] = None + + +@dataclass +class ReductionNode(CustomNode): + axis: IndexExpr + init_args: Sequence[Any] + subgraph_name: str + implicit_captures: Sequence[fx.Proxy] + + @classmethod + def handle(cls, graph, *args, **kwargs): + def wrapper(f): + with graph.subtracer() as subtracer: + subgraph_name, implicit_captures = subtracer.trace(f) + node = ReductionNode( + graph, + *args, + **kwargs, + subgraph_name=subgraph_name, + implicit_captures=implicit_captures, + ) + node.add_to_graph(graph) + return node.fx_node + + return wrapper + + +@dataclass +class WriteNode(CustomNode): + register_: fx.Proxy + memory: fx.Proxy + elements_per_thread: Optional[Any] diff --git a/shark_turbine/kernel/_support/tracing.py b/shark_turbine/kernel/_support/tracing.py index b89818829..5a610c206 100644 --- a/shark_turbine/kernel/_support/tracing.py +++ b/shark_turbine/kernel/_support/tracing.py @@ -30,6 +30,7 @@ from ..lang.types import ( Index, ) +from .nodes import CustomNode, PlaceholderNode, ReductionNode, UnknownNode from .regions import RegionGraph, SubgraphTracer @@ -177,6 +178,29 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]): backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False)) for n in grid_type.symbolic_shape ] + self.custom_ops: dict[str, CustomNode] = {} + + def register_custom_op(self, name: str, op: CustomNode): + self.custom_ops[name] = op + + def handler(*args, **kwargs): + return op.handle(self.region_graph, *args, **kwargs) + + setattr(self, f"handle_{name}", handler) + + def node(self, node: fx.Node) -> CustomNode: + for name, nodeT in self.custom_ops.items(): + # The fx_nodes have a suffix depending on the number of occurrences + # of similar nodes. We are only interested in the name prefix. + + # TODO: Instead simply shape the suffix off, as this depends on the + # iteration order for nodes sharing a prefix, + # e.g. read and read_shared + if node.name.startswith(name): + return nodeT.from_fx_node(node) + if node.op == "placeholder": + return PlaceholderNode.from_fx_node(node) + return UnknownNode.from_fx_node(node) ### ======================================================================== ### Core Operations @@ -389,14 +413,11 @@ def __call__(self, *args, **kwargs): return launch_context.launch(self, args, kwargs) @abstractmethod - def eager_execute(self, args, kwargs): - ... + def eager_execute(self, args, kwargs): ... - def aot_execute(self, args, kwargs): - ... + def aot_execute(self, args, kwargs): ... - def test_execute(self, args, kwargs): - ... + def test_execute(self, args, kwargs): ... class LaunchContext(ABC): @@ -435,8 +456,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): context.pop(LaunchContext, self) @abstractmethod - def launch(self, launchable: Launchable, args, kwargs): - ... + def launch(self, launchable: Launchable, args, kwargs): ... class DebugLaunchContext(LaunchContext): diff --git a/shark_turbine/kernel/compiler/ir.py b/shark_turbine/kernel/compiler/ir.py index 560b85cd2..5e669620c 100644 --- a/shark_turbine/kernel/compiler/ir.py +++ b/shark_turbine/kernel/compiler/ir.py @@ -22,22 +22,28 @@ IntegerType, Location, Operation, + OpResult, MemRefType, ShapedType, StringAttr, SymbolTable, Type as IrType, + UnitAttr, Value, VectorType, ) from iree.compiler.dialects import ( + amdgpu as amdgpu_d, arith as arith_d, builtin as builtin_d, flow as flow_d, func as func_d, + gpu as gpu_d, math as math_d, + memref as memref_d, stream as stream_d, + transform as transform_d, vector as vector_d, scf as scf_d, ) diff --git a/shark_turbine/kernel/lang/__init__.py b/shark_turbine/kernel/lang/__init__.py index ad1ed5aac..97423a6a5 100644 --- a/shark_turbine/kernel/lang/__init__.py +++ b/shark_turbine/kernel/lang/__init__.py @@ -1,6 +1,7 @@ from .prims import * from .types import * from .kernel_buffer import * +from .functional_types import * from .grid import * # Include publics from the _support library. diff --git a/shark_turbine/kernel/lang/functional_types.py b/shark_turbine/kernel/lang/functional_types.py new file mode 100644 index 000000000..3366214c0 --- /dev/null +++ b/shark_turbine/kernel/lang/functional_types.py @@ -0,0 +1,140 @@ +from typing import Optional, Type, TypeVar, cast, ClassVar +from enum import Enum + + +from ..lang.kernel_buffer import KernelBufferUsage +from .._support.shaped_type import ShapedDataType +from .._support.dtype import DataType +from ...kernel._support.indexing import IndexExpr +import torch + +__all__ = [ + "Memory", + "Register", + "AddressSpace", +] + +MemoryTypeT = TypeVar("MemoryTypeT") + + +class AddressSpace(Enum): + REGISTER = 0 + SHARED_MEMORY = 1 + GLOBAL_MEMORY = 2 + + +class _MemoryStorage(ShapedDataType): + def new_subtype( + cls: Type[MemoryTypeT], + *, + symbolic_shape: tuple[IndexExpr, ...], + address_space: AddressSpace, + dtype: DataType, + usage: Optional[KernelBufferUsage] = None, + ) -> Type[MemoryTypeT]: + init_symbolic_shape = symbolic_shape + init_dtype = dtype + init_address_space = ( + address_space if address_space else AddressSpace.REGISTER.value + ) + init_usage = usage + + class MemoryType(cls): + symbolic_shape = init_symbolic_shape + rank = len(symbolic_shape) + address_space = init_address_space + dtype = init_dtype + usage = init_usage + + return cast(Type[MemoryTypeT], MemoryType) + + +class Memory(metaclass=_MemoryStorage): + """ + Represents storage anywhere in the memory hierarchy except registers. + Parameterized by a shape, address space and element type. The allocated + memory is traversed by an iterator that specifies the offset, stride + and size along each dimension. + """ + + symbolic_shape: ClassVar[tuple[IndexExpr, ...]] + address_space: ClassVar[int] + rank: ClassVar[int] + dtype: ClassVar[DataType] + usage: ClassVar[Optional[KernelBufferUsage]] + + def __init__(self, tensor: torch.Tensor) -> None: + assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" + self._tensor = tensor + self.symbolic_shape = None + + def __class_getitem__( + cls, shape_and_dtype: tuple[IndexExpr | DataType, ...] + ) -> Type["Memory"]: + """Syntax: `Memory[shape1, ...., shapeN, addressSpace, dtype, Optional[usage]]""" + if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 3: + raise TypeError(f"Expected at least 3 arguments, got: {shape_and_dtype}") + + shift = 0 + usage = None + if isinstance(shape_and_dtype[-1], KernelBufferUsage): + shift = 1 + usage = shape_and_dtype[-1] + shape = shape_and_dtype[: -2 - shift] + addressSpace = shape_and_dtype[-2 - shift] + dtype = shape_and_dtype[-1 - shift] + + if not all(isinstance(s, IndexExpr) for s in shape): + raise TypeError(f"Expected shape to be a tuple of IndexExpr, got {shape}") + if not isinstance(dtype, DataType): + raise TypeError(f"Expected dtype to be a DataType, got {dtype}") + if not isinstance(addressSpace, IndexExpr): + raise TypeError( + f"Expected addressSpace to be a AddressSpace, got {addressSpace}" + ) + + shape = cast(tuple[IndexExpr, ...], shape) + dtype = cast(DataType, dtype) + addressSpace = cast(AddressSpace, addressSpace) + + return cls.new_subtype( + symbolic_shape=shape, address_space=addressSpace, dtype=dtype, usage=usage + ) + + +class Register(metaclass=_MemoryStorage): + "Represents virtual registers. Parameterized by a shape and element type." + symbolic_shape: ClassVar[tuple[IndexExpr, ...]] + rank: ClassVar[int] + dtype: ClassVar[DataType] + + def __init__(self, shape, dtype) -> None: + if not isinstance(shape, tuple): + raise TypeError(f"Expected at shape to be a tuple, got: {shape}") + self.symbolic_shape = shape + self.rank = len(self.symbolic_shape) + self.dtype = dtype + self.value = None + + def set(self, value) -> None: + self.value = value + + def __class_getitem__( + cls, shape_and_dtype: tuple[IndexExpr | DataType, ...] + ) -> Type["Register"]: + + if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 2: + raise TypeError(f"Expected at least 2 arguments, got: {shape_and_dtype}") + + shape = shape_and_dtype[:-1] + dtype = shape_and_dtype[-1] + + shape = cast(tuple[IndexExpr, ...], shape) + dtype = cast(DataType, dtype) + return cls.new_subtype( + symbolic_shape=shape, dtype=dtype, address_space=AddressSpace.REGISTER.value + ) + + +def is_memory_meta_derived(t: type) -> bool: + return isinstance(t, _MemoryStorage) diff --git a/shark_turbine/kernel/ops/base.py b/shark_turbine/kernel/ops/base.py index a726c2546..882919977 100644 --- a/shark_turbine/kernel/ops/base.py +++ b/shark_turbine/kernel/ops/base.py @@ -7,6 +7,7 @@ from .._support import context T = TypeVar("T") +OpDispatcherT = TypeVar("OpDispatcherT", bound="OpDispatcher") class OpDispatcher: @@ -18,9 +19,9 @@ def handle_{idname}(self, operator, *args, **kwargs) __tk_context_idname__ = "OpDispatcher" - @staticmethod - def current() -> "OpDispatcher": - return context.current(OpDispatcher) + @classmethod + def current(cls: Type[OpDispatcherT]) -> OpDispatcherT: + return context.current(cls) def __enter__(self) -> "OpDispatcher": return context.push(OpDispatcher, self) diff --git a/shark_turbine/kernel/wave/README.md b/shark_turbine/kernel/wave/README.md new file mode 100644 index 000000000..1d1c11633 --- /dev/null +++ b/shark_turbine/kernel/wave/README.md @@ -0,0 +1,28 @@ +# TKW: A Wave Kernel Language for GPUs + +TKW is a high-level programming language designed to simplify the development of GPU micro-kernels by abstracting over intricate details of GPU hardware. It allows developers to write efficient micro-kernels, focusing on core computations while inferring the required data transfers and indexing automatically. TKW implements a wave-based programming model to express programs leveraging coalesced memory accesses effortlessly and supports the explicit use of matrix multiplication intrinsics. + +## Design Goals +TKW is designed with several key goals in mind to facilitate efficient GPU programming and maximize performance: + +1. Abstract over hardware details: Simplify the development of GPU micro-kernels by hiding the complex details of synchronization, thread management, and memory transactions. + - Automatically infer efficient data movement strategies across the memory hierarchy, ensuring efficient use of memory bandwidth. + - Leverage hardware details (such as instruction specifications) to determine indexing. +2. Provide users with low-level control + - Expose an interface to customize the instruction scheduling + - Provide low-level control over how the computation is performed by exposing low-level GPU instructions. This empowers developers to directly leverage hardware-specific features to achieve maximum performance. +3. Systematically expressing constraints to leverage solvers / auto-tuning + - Represent specific tiling possibilities around a micro-kernel using symbolic constraints. This forms a searchable space that allows for fine-tuning by exploring various tiling configurations. + +## Wave-based Programming Model +TKW leverages a wave-based programming model that is specifically designed to take advantage of the parallel processing power of GPUs. + +In GPU programming, a wavefront is a group of threads (or work items) that execute the same instruction in lockstep. In particular, coalesced memory accesses by all threads in a wavefront are executed together. This is analogous to the concept of a "warp" in NVIDIA's CUDA programming model. +Typically, on AMD GPUs, a wavefront contains 32 or 64 threads, all of which participate in executing the same instruction at the same time. + +In this representation, memory access patterns are more naturally optimized for coalescing, reducing the complexity and manual effort required to achieve optimized memory transactions. In consequence, programmers can focus more on the core computational logic rather than the intricacies of thread coordination and memory management. +This approach contrasts with traditional models like OpenCL and CUDA, which often require more explicit management of threads and synchronization. + +## Gemm example + +https://github.com/iree-org/iree-turbine/blob/6178ac8eeeb9456f315b6408c9dfc90fc7176e4f/tests/kernel/wave_gemm_test.py#L12-L71 \ No newline at end of file diff --git a/shark_turbine/kernel/wave/__init__.py b/shark_turbine/kernel/wave/__init__.py new file mode 100644 index 000000000..eae496f35 --- /dev/null +++ b/shark_turbine/kernel/wave/__init__.py @@ -0,0 +1,2 @@ +from .ops import * +from .wave import * diff --git a/shark_turbine/kernel/wave/ops.py b/shark_turbine/kernel/wave/ops.py new file mode 100644 index 000000000..f1870b912 --- /dev/null +++ b/shark_turbine/kernel/wave/ops.py @@ -0,0 +1,31 @@ +from ..ops.base import ( + define_op, +) + +__all__ = [ + "construct_register_from_metadata", + "read", + "write", + "mma", + "reduction", +] + + +@define_op +def construct_register_from_metadata(shape, dtype, value) -> None: ... + + +@define_op +def read(memory: "Memory", elements_pre_thread) -> "Register": ... + + +@define_op +def write(register: "Register", memory: "Memory", elements_pre_thread) -> None: ... + + +@define_op +def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": ... + + +@define_op +def reduction(axis: "IndexExpr", init_args): ... diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py new file mode 100644 index 000000000..a695749d1 --- /dev/null +++ b/shark_turbine/kernel/wave/wave.py @@ -0,0 +1,91 @@ +from typing import Callable, Optional +import inspect + +from ..compiler.ir import Context, Operation +from ..lang import Grid +from .._support.tracing import ( + CapturedTrace, + CompiledContext, + KernelRegionGraph, + Launchable, +) +from .._support.nodes import * + + +__all__ = ["wave"] + + +def wave(): + def decorator(f: Callable[[Any], Any]) -> "LaunchableWave": + return LaunchableWave(f.__name__, f) + + return decorator + + +class LaunchableWave(Launchable): + def __init__( + self, + name: str, + eager_function: Callable[[Any], Any], + ): + super().__init__(eager_function) + + self.grid_type = Grid[None, None] + self._name = name + self._f = eager_function + self._sig = inspect.signature(eager_function) + + def _trace(self, dump: bool = False) -> CapturedTrace: + region_graph = KernelRegionGraph() + with CompiledContext(region_graph, grid_type=self.grid_type) as context: + custom_ops: dict[str, Type[CustomNode]] = { + "construct_register_from_metadata": ConstructRegisterFromMetadataNode, + "mma": MmaNode, + "read": ReadNode, + "write": WriteNode, + "reduction": ReductionNode, + "placeholder": PlaceholderNode, + } + + # Register custom ops + for name, op in custom_ops.items(): + context.register_custom_op(name, op) + + with region_graph.subtracer() as subtracer: + root_name, _ = subtracer.trace(self._f) + trace = CapturedTrace(region_graph, root_name) + + if dump: + print(trace.get_root_graph()) + for node in trace.get_root_graph().nodes: + print(context.node(node)) + return trace + + def _trace_and_get_kernel_signature( + self, + args, + kwargs, + dump: bool = False, + context: Optional[Context] = None, + module_op: Optional[Operation] = None, + ) -> CapturedTrace: + # Trace the function. + trace = self._trace(dump=dump) + + # TODO: Get kernel signature from the trace. + # We want to reuse the existing kernel_codegen for this which + # requires making it aware of tkw.Memory + return trace + + def test_execute(self, args, kwargs): + # For now only tracing + self._trace_and_get_kernel_signature(args, kwargs, dump=True) + + def aot_execute(self, args, kwargs): + raise NotImplementedError("AOT execution for wave not implemented yet.") + + def eager_execute(self, args, kwargs): + raise NotImplementedError("Eager execution for wave not implemented yet.") + + def __repr__(self): + return f"tk.wave @{self._name}[{self.grid_type}]" diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py new file mode 100644 index 000000000..45fd440b8 --- /dev/null +++ b/tests/kernel/wave_gemm_test.py @@ -0,0 +1,76 @@ +import logging +import unittest +import torch +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import shark_turbine.kernel.wave as tkw + + +class Test(unittest.TestCase): + def testGemm(self): + + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave() + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + ): + # c_reg: tkw.Register[M, N, tkl.f32] + c_reg = tkw.construct_register_from_metadata((M, N), tkl.f32, 0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(c_reg) -> tkl.Register[M, N, tkl.f32]: + # a_reg: tkw.Register[M, K, tkl.f16] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[N, K, tkl.f16] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # c_reg: tkw.Register[M, N, tkl.f32] + c_reg = tkw.mma(a_reg, b_reg, c_reg) + return c_reg + + # repeat represents the results of the loop + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 1, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 32, + M: 64, + N: 128, + K: 256, + } + with tk.gen.TestLaunchContext(hyperparams): + a = torch.randn(64, 256, dtype=torch.float16) + b = torch.randn(128, 256, dtype=torch.float16) + c = torch.zeros(64, 128, dtype=torch.float32) + gemm(a, b, c) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()