diff --git a/requirements.txt b/requirements.txt index 28eb6965c..bc4eaf0fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # Build/test requirements. +Jinja2==3.1.3 numpy==1.26.3 pytest==8.0.0 pytest-xdist==3.5.0 diff --git a/setup.py b/setup.py index 5071c5c48..71db252f5 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,7 @@ def initialize_options(self): f"iree-compiler{get_version_spec('iree-compiler')}", f"iree-runtime{get_version_spec('iree-runtime')}", "torch>=2.3.0", + f"Jinja2{get_version_spec('Jinja2')}", ], extras_require={ "testing": [ diff --git a/shark_turbine/kernel/compiler/vector_codegen.py b/shark_turbine/kernel/compiler/vector_codegen.py index 5ecb8ef93..3a8aa6f98 100644 --- a/shark_turbine/kernel/compiler/vector_codegen.py +++ b/shark_turbine/kernel/compiler/vector_codegen.py @@ -4,6 +4,7 @@ actual loads/stores/computes to local vectors using PyTorch tensor level operations executed as threads over a grid. """ + from typing import Any, Callable, Type, Optional, Sequence, Union, List import types diff --git a/shark_turbine/ops/_jinja_test_ops.py b/shark_turbine/ops/_jinja_test_ops.py new file mode 100644 index 000000000..6b3c5e6c8 --- /dev/null +++ b/shark_turbine/ops/_jinja_test_ops.py @@ -0,0 +1,51 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..support.ir_imports import ( + RankedTensorType, +) + +from ..runtime.op_reg import ( + CustomOp, + KernelBuilder, + KernelSelection, + def_library, + impl_helper, +) + +__all__ = [ + "trace", +] + +LIBRARY = def_library("_turbine_jinja_test") +_templates = impl_helper.JinjaTemplateLoader(__name__) + + +@CustomOp.register(library=LIBRARY) +class test_add(CustomOp): + signature = "test_add(Tensor t1, Tensor t2) -> (Tensor)" + + def select(self, ksel: KernelSelection): + t1_desc = ksel.arg_tensor(0) + t1_desc.specialize_all_dims() + t2_desc = ksel.arg_tensor(1) + t2_desc.specialize_all_dims() + result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) + result_desc.specialize_all_dims() + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + result_type = kb.arg_bindings[0].type # type: ignore + rtt = RankedTensorType(result_type) + function_name = f"turbine_test_add_jinja_{rtt.rank}d_{str(rtt.element_type)}" + func_op = _templates.inline_template_function( + kb, + "test_add_jinja", + function_name, + rank=rtt.rank, + element_type=str(rtt.element_type), + tensor_type=str(rtt), + ) + kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings)) diff --git a/shark_turbine/ops/_str_format_test_ops.py b/shark_turbine/ops/_str_format_test_ops.py new file mode 100644 index 000000000..6988a4aef --- /dev/null +++ b/shark_turbine/ops/_str_format_test_ops.py @@ -0,0 +1,73 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..support.ir_imports import ( + RankedTensorType, +) + +from ..runtime.op_reg import ( + CustomOp, + KernelBuilder, + KernelSelection, + def_library, + impl_helper, +) + +__all__ = [ + "trace", +] + +LIBRARY = def_library("_turbine_str_format_test") +_templates = impl_helper.StrFormatTemplateLoader(__name__) + + +@CustomOp.register(library=LIBRARY) +class test_add(CustomOp): + signature = "test_add(Tensor t1, Tensor t2) -> (Tensor)" + + def select(self, ksel: KernelSelection): + t1_desc = ksel.arg_tensor(0) + t1_desc.specialize_all_dims() + t2_desc = ksel.arg_tensor(1) + t2_desc.specialize_all_dims() + result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) + result_desc.specialize_all_dims() + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + result_type = kb.arg_bindings[0].type # type: ignore + rtt = RankedTensorType(result_type) + function_name = ( + f"turbine_test_add_strformat_{rtt.rank}d_{str(rtt.element_type)}" + ) + func_op = _templates.inline_template_function( + kb, + "test_add_strformat", + function_name, + rank=rtt.rank, + element_type=str(rtt.element_type), + tensor_type=str(rtt), + ) + kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings)) + + +@CustomOp.register(library=LIBRARY) +class syntax_error(CustomOp): + signature = "syntax_error(Tensor t1) -> (Tensor)" + + def select(self, ksel: KernelSelection): + t1_desc = ksel.arg_tensor(0) + t1_desc.specialize_all_dims() + result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) + result_desc.specialize_all_dims() + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + function_name = "syntax_error" + func_op = _templates.inline_template_function( + kb, + "test_syntax_error", + function_name, + ) + kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings)) diff --git a/shark_turbine/ops/iree.py b/shark_turbine/ops/iree.py index ce6577285..e28826db8 100644 --- a/shark_turbine/ops/iree.py +++ b/shark_turbine/ops/iree.py @@ -8,7 +8,6 @@ from typing import cast from ..support.ir_imports import ( - Operation, RankedTensorType, StringAttr, Value, @@ -61,24 +60,3 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): key = cast(AttrArg, ksel.arg_descs[0]) _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) kb.yield_results(kb.arg_bindings[1]) - - -@CustomOp.register(library=IREE_LIBRARY) -class _test_add(CustomOp): - signature = "_test_add(Tensor t1, Tensor t2) -> (Tensor)" - - def select(self, ksel: KernelSelection): - t1_desc = ksel.arg_tensor(0) - t1_desc.specialize_all_dims() - t2_desc = ksel.arg_tensor(1) - t2_desc.specialize_all_dims() - result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) - result_desc.specialize_all_dims() - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - t1, t2 = kb.arg_bindings - result_type = t1.type # type: ignore - result = Operation.create( - "tosa.add", results=[result_type], operands=[t1, t2] - ).result - kb.yield_results(result) diff --git a/shark_turbine/ops/templates/test_add_jinja.mlir b/shark_turbine/ops/templates/test_add_jinja.mlir new file mode 100644 index 000000000..9fb723389 --- /dev/null +++ b/shark_turbine/ops/templates/test_add_jinja.mlir @@ -0,0 +1,12 @@ +!tensor_type = {{tensor_type}} + +module { + +util.func private @turbine_test_add_jinja_{{rank}}d_{{element_type}}( + %a: !tensor_type, %b: !tensor_type +) -> !tensor_type { + %out = tensor.empty() : !tensor_type + %0 = linalg.add ins(%a, %b : !tensor_type, !tensor_type) outs(%out : !tensor_type) -> !tensor_type + util.return %0 : !tensor_type +} +} diff --git a/shark_turbine/ops/templates/test_add_strformat.mlir b/shark_turbine/ops/templates/test_add_strformat.mlir new file mode 100644 index 000000000..5b4e00f3d --- /dev/null +++ b/shark_turbine/ops/templates/test_add_strformat.mlir @@ -0,0 +1,12 @@ +!tensor_type = {tensor_type} + +module {{ + +util.func private @turbine_test_add_strformat_{rank}d_{element_type}( + %a: !tensor_type, %b: !tensor_type +) -> !tensor_type {{ + %out = tensor.empty() : !tensor_type + %0 = linalg.add ins(%a, %b : !tensor_type, !tensor_type) outs(%out : !tensor_type) -> !tensor_type + util.return %0 : !tensor_type +}} +}} diff --git a/shark_turbine/ops/templates/test_syntax_error.mlir b/shark_turbine/ops/templates/test_syntax_error.mlir new file mode 100644 index 000000000..928339561 --- /dev/null +++ b/shark_turbine/ops/templates/test_syntax_error.mlir @@ -0,0 +1 @@ +THIS IS A SYNTAX ERROR diff --git a/shark_turbine/runtime/op_reg/__init__.py b/shark_turbine/runtime/op_reg/__init__.py index c18790d33..32b0733bd 100644 --- a/shark_turbine/runtime/op_reg/__init__.py +++ b/shark_turbine/runtime/op_reg/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .base import * +from . import impl_helper diff --git a/shark_turbine/runtime/op_reg/impl_helper.py b/shark_turbine/runtime/op_reg/impl_helper.py new file mode 100644 index 000000000..471e6ccfa --- /dev/null +++ b/shark_turbine/runtime/op_reg/impl_helper.py @@ -0,0 +1,182 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Helpers for implementing ops. + +Typical usage: + +``` + _templates = JinjaTemplateLoader(__name__) + + def generate(kb: KernelBuilder): + func_op = _templates.inline_template_function( + kb, "my_template", "function_name", **kwargs) + return call_function(func_op, *values) +``` +""" + +from typing import Sequence + +from abc import ABC, abstractmethod +import logging +import textwrap + +from ...support.logging import runtime_logger as logger + +from ...support.ir_imports import ( + FlatSymbolRefAttr, + FunctionType, + MLIRError, + Operation, + StringAttr, + TypeAttr, + Value, +) + +from ...transforms.merger import Merger + +from .base import ( + KernelBuilder, +) + + +__all__ = [ + "TemplateLoader", + "StrFormatTemplateLoader", + "call_function", +] + + +class TemplateLoader(ABC): + """Base class for templates that can be loaded by name.""" + + @abstractmethod + def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation: + """Loads a template by name and kwargs, returning the module operation.""" + ... + + def _parse_module_asm(self, kb: KernelBuilder, asm: str) -> Operation: + try: + module_op = Operation.parse(asm, context=kb.context) + except MLIRError as e: + lines = asm.splitlines() + lines_numbered = "\n".join( + [f" {str(i+1):>5}: {l}" for i, l in enumerate(lines)] + ) + raise RuntimeError( + f"Error parsing generated op template:" + f"\n{textwrap.indent(str(e), ' ')}" + f"\n{lines_numbered}" + ) + return module_op.operation + + def inline_template_function( + self, + kb: KernelBuilder, + template_file: str, + function_name: str, + **kwargs, + ) -> Operation: + """Inlines a template module by first expanding its ASM via **kwargs. + + Returns the inlined symbol `function_name`, which is expected to have been + in the template. + """ + try: + return kb.symbol_table[function_name] + except KeyError: + pass + source_module_op = self.load_template(kb, template_file, **kwargs) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Generated kernel IR %s:\n%s", function_name, str(source_module_op) + ) + merger = Merger( + source_module_op, kb.module_body.owner, target_symbol_table=kb.symbol_table + ) + merger.merge() + return kb.symbol_table[function_name] + + +class StrFormatTemplateLoader(TemplateLoader): + """Template loader that uses str.format. + + Usage: + _templates = StrFromatTemplateLoader(__name__) + + By default, this will resolve a template like "foo" from templates/foo.mlir + in the package directory. + """ + + def __init__( + self, + package_name: str, + package_path: str = "templates", + *, + suffix: str = ".mlir", + ): + self.parent_package_name = ".".join(package_name.split(".")[0:-1]) + self.package_path = package_path + self.suffix = suffix + + def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation: + from importlib import resources + + res = ( + resources.files(self.parent_package_name) + / self.package_path + / f"{name}{self.suffix}" + ) + contents = res.read_text().format(**kwargs) + return self._parse_module_asm(kb, contents) + + +class JinjaTemplateLoader(TemplateLoader): + """Template loader based on jinja templates. + + Usage: + _templates = JinjaTemplateLoader(__name__) + + By default, this will resolve a template like "foo" from templates/foo.mlir + in the package directory. + """ + + def __init__( + self, + package_name: str, + package_path: str = "templates", + *, + suffix: str = ".mlir", + ): + try: + from jinja2 import Environment, PackageLoader, select_autoescape + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Cannot use JinjaTemplateLoader if jinja2 is not installed" + ) from e + self.env = Environment(loader=PackageLoader(package_name, package_path)) + self.suffix = suffix + + def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation: + template_file = f"{name}{self.suffix}" + contents = self.env.get_template(template_file).render(**kwargs) + return self._parse_module_asm(kb, contents) + + +def call_function(target_function: Operation, *operands: Value) -> Sequence[Value]: + """Emits a util.call for a util.func target function operation.""" + target_symbol = FlatSymbolRefAttr.get( + StringAttr(target_function.attributes["sym_name"]).value_bytes + ) + ftype = FunctionType(TypeAttr(target_function.attributes["function_type"]).value) + return Operation.create( + "util.call", + results=ftype.results, + operands=operands, + attributes={ + "callee": target_symbol, + }, + ).results diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index f6c172415..e78aff8e2 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -151,20 +151,28 @@ def testFromTorchDevice(self): print(device.dump_device_info()) def testJit(self): - from shark_turbine.ops import iree as iree_ops + from shark_turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") - result = iree_ops._test_add(t, t) + result = test_ops.test_add(t, t) expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") torch.testing.assert_close(result.cpu(), expected) class TorchCPUInterop(unittest.TestCase): - def testJit(self): - from shark_turbine.ops import iree as iree_ops + def testJitStrFormat(self): + from shark_turbine.ops import _str_format_test_ops as test_ops + + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") + result = test_ops.test_add(t, t) + expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") + torch.testing.assert_close(result, expected) + + def testJitJinja(self): + from shark_turbine.ops import _jinja_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") - result = iree_ops._test_add(t, t) + result = test_ops.test_add(t, t) expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") torch.testing.assert_close(result, expected) diff --git a/tests/runtime/op_reg/impl_helper_test.py b/tests/runtime/op_reg/impl_helper_test.py new file mode 100644 index 000000000..b0797c2d6 --- /dev/null +++ b/tests/runtime/op_reg/impl_helper_test.py @@ -0,0 +1,29 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch + +from shark_turbine.ops import _str_format_test_ops + + +class KernelRegTest(unittest.TestCase): + def testError(self): + t = torch.randn(3, 4) + try: + _str_format_test_ops.syntax_error(t) + self.fail("Expected RuntimeError") + except RuntimeError as e: + message = str(e) + self.assertIn("error:", message) + self.assertIn("1: THIS IS A SYNTAX ERROR", message) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/transforms/general/rename_parameters_test.py b/tests/transforms/general/rename_parameters_test.py index 203e6b455..74fc67538 100644 --- a/tests/transforms/general/rename_parameters_test.py +++ b/tests/transforms/general/rename_parameters_test.py @@ -38,9 +38,9 @@ def testBasic(self): "WEIGHT": "weight", ("foo", "params.classifier.bias"): ("bar", "BIAS"), }, - rename_callback=lambda scope, name: ("XXX", "YYY") - if name == "default" - else None, + rename_callback=lambda scope, name: ( + ("XXX", "YYY") if name == "default" else None + ), ).run() module_asm = str(module_op) print(module_asm)