From 6a2f21289289eb8a26314d37bb75c1c90ee767d3 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 22 Apr 2024 21:27:57 -0700 Subject: [PATCH 1/7] Add device bridge support for HIP and CUDA. --- shark_turbine/runtime/device.py | 115 +++++++++++++++++++++++++-- shark_turbine/runtime/op_reg/base.py | 9 +-- tests/runtime/device_test.py | 35 ++++++++ 3 files changed, 149 insertions(+), 10 deletions(-) diff --git a/shark_turbine/runtime/device.py b/shark_turbine/runtime/device.py index a294525df..6208f4b8a 100644 --- a/shark_turbine/runtime/device.py +++ b/shark_turbine/runtime/device.py @@ -7,6 +7,7 @@ from functools import lru_cache from typing import Callable, Optional, Union from threading import local, Lock +import warnings import torch @@ -73,6 +74,7 @@ class DeviceState: "device", "driver", "instance", + "enumerated_info", ] def __init__( @@ -81,10 +83,33 @@ def __init__( driver: Union[str, HalDriver], device: Optional[HalDevice] = None, vm_instance: Optional[VmInstance] = None, + enumerated_info: Optional[dict] = None, ): self.instance = vm_instance or get_vm_instance() self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) self.device = device if device else self.driver.create_default_device() + self.enumerated_info = enumerated_info or {} + + @property + def enumerated_device_id(self) -> int: + try: + return self.enumerated_info["device_id"] + except KeyError as e: + raise RuntimeError("No enumerated device_id for device") from e + + @property + def enumerated_path(self) -> str: + try: + return self.enumerated_device_id["path"] + except KeyError as e: + raise RuntimeError("No enumerated path for device") from e + + @property + def enumerated_name(self) -> str: + try: + return self.enumerated_device_id["name"] + except KeyError as e: + raise RuntimeError("No enumerated name for device") from e @staticmethod @lru_cache(maxsize=None) @@ -243,6 +268,9 @@ def clear(self): ... raise MismatchedDeviceSetClearError() + def dump_device_info(self) -> str: + return self._s.driver.dump_device_info(self._s.enumerated_device_id) + def __repr__(self): return f"" @@ -271,6 +299,27 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe return bv +def _device_import_torch_tensor_hip(device: Device, t: torch.Tensor) -> HalBufferView: + assert hasattr(t, "__dlpack__") + # TODO: The '0' below is the stream (on real-CUDA, it is 1) signifying + # the default stream. It should be the stream upon which we schedule. + bv = device.hal_device.from_dlpack(t) + print("FROM DLPACK:", bv) + return bv + # hal_device = device.hal_device + # element_type = dtype_to_element_type(t.dtype) + # # TODO: In this case, we should be importing the raw buffer, but this is not + # # generically exposed to Python in the IREE runtime. + # bv = device.hal_device.allocator.allocate_buffer_copy( + # memory_type=MemoryType.DEVICE_LOCAL, + # allowed_usage=BufferUsage.DEFAULT, + # device=hal_device, + # buffer=t.detach().numpy(), + # element_type=element_type, + # ) + # return bv + + def _device_export_torch_tensor_cpu( device: Device, bv: HalBufferView, like: torch.Tensor ) -> torch.Tensor: @@ -286,6 +335,8 @@ def _device_export_torch_tensor_cpu( # Mapping of torch tensor importers keyed by driver name. TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { + "cuda": None, + "hip": _device_import_torch_tensor_hip, "local-sync": _device_import_torch_tensor_cpu, "local-task": _device_import_torch_tensor_cpu, } @@ -293,11 +344,15 @@ def _device_export_torch_tensor_cpu( TORCH_TENSOR_EXPORTERS: dict[ str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] ] = { + "cuda": None, + "hip": None, "local-sync": _device_export_torch_tensor_cpu, "local-task": _device_export_torch_tensor_cpu, } DEVICE_TARGET_COMPILE_FLAGS: dict[str, tuple[str, ...]] = { + "cuda": ("--iree-hal-target-backends=cuda",), + "hip": ("--iree-hal-target-backends=rocm",), "local-task": ( "--iree-hal-target-backends=llvm-cpu", "--iree-llvmcpu-target-cpu-features=host", @@ -357,14 +412,64 @@ def get_device_from_torch(torch_device: torch.device) -> Device: def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: torch_type = torch_device.type - uri = None if torch_type == "cpu": - uri = "local-task" + return Device("local-task") + elif torch_type == "cuda": + # Fork based on HIP or real CUDA. + props = torch.cuda.get_device_properties(torch_device) + if not hasattr(props, "gcnArchName"): + # Real CUDA. + return _create_cuda_device(torch_device, props) + else: + # HIP as CUDA. + return _create_hip_device(torch_device, props) + + return None - if uri is None: - return None - return Device(uri) +def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: + device = _create_cuda_like_device(torch_device, props, "hip") + if device: + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}", + ) + return device + + +def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: + device = _create_cuda_like_device(torch_device, props, "hip") + if device: + gcn_arch_name = props.gcnArchName + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-rocm-target-chip={gcn_arch_name}", + ) + return device + + +def _create_cuda_like_device( + torch_device: torch.device, props, driver_name: str +) -> Optional[Device]: + if torch.cuda.device_count() > 1: + warnings.warn( + f"Multiple {driver_name} devices detected: Turbine does not yet " + f"guarantee stable device mapping" + ) + + requested_index = torch_device.index + driver = get_driver(driver_name) + available_infos = driver.query_available_devices() + if requested_index >= len(available_infos): + return None + device_info = available_infos[requested_index] + hal_device = driver.create_device(device_info) + device_state = DeviceState( + driver=driver, + device=hal_device, + vm_instance=get_vm_instance(), + enumerated_info=device_info, + ) + device = Device(device_state=device_state) + return device ############################################################################### diff --git a/shark_turbine/runtime/op_reg/base.py b/shark_turbine/runtime/op_reg/base.py index 2fcf7863f..d74c49a55 100644 --- a/shark_turbine/runtime/op_reg/base.py +++ b/shark_turbine/runtime/op_reg/base.py @@ -74,11 +74,10 @@ def def_library(ns) -> torch.library.Library: def default_dispatch_keys() -> list[str]: - # TODO: Dynamically determine what devices to register against. - # Note that we have to register against specific keys instead of the - # fallback, as fallback is too broad and breaks certain elements of - # fx tracing. - return ["CPU"] + keys = ["CPU"] + if torch.cuda.is_available(): + keys.append("CUDA") + return keys # All such custom kernels are registered in the 'turbine' library/namespace. diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index c37750cca..c0478e2cb 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -7,6 +7,7 @@ import logging import unittest import threading +import warnings import torch @@ -122,6 +123,40 @@ def testCompilerFlags(self): self.assertIn("--iree-hal-target-backends=llvm-cpu", d.compile_target_flags) +# Make CUDA testing conditional. +test_cuda_device = None +if torch.cuda.is_available(): + try: + test_cuda_device = torch.device("cuda:0") + except: + ... +if test_cuda_device is None: + warnings.warn("Not testing CUDA interop (device not available)") + + +@unittest.skipUnless(test_cuda_device, "CUDA not available") +class TorchCUDAInterop(unittest.TestCase): + def setUp(self): + self._cuda_props = torch.cuda.get_device_properties(test_cuda_device) + print(self._cuda_props) + print(dir(self._cuda_props)) + self._is_hip = False + if hasattr(self._cuda_props, "gcnArchName"): + print("Detected HIP device as CUDA") + self._is_hip = True + + def testFromTorchDevice(self): + torch_device = torch.device("cuda:0") + device = get_device_from_torch(torch_device) + print(device.dump_device_info()) + + def testJit(self): + t = torch.tensor([1, 2, 3, 4, 5]).to("cuda:0") + from shark_turbine.ops import iree as iree_ops + + iree_ops.trace_tensor("FOO", t) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() From 0feccb8b45d79d9f7d871fcb150ece7888ff6a25 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 23 Apr 2024 16:20:10 -0700 Subject: [PATCH 2/7] Use real op --- shark_turbine/ops/iree.py | 21 +++++++++++++++++++++ shark_turbine/runtime/op_reg/compiler.py | 4 +++- tests/runtime/device_test.py | 4 ++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/shark_turbine/ops/iree.py b/shark_turbine/ops/iree.py index e28826db8..66167c886 100644 --- a/shark_turbine/ops/iree.py +++ b/shark_turbine/ops/iree.py @@ -8,6 +8,7 @@ from typing import cast from ..support.ir_imports import ( + Operation, RankedTensorType, StringAttr, Value, @@ -60,3 +61,23 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): key = cast(AttrArg, ksel.arg_descs[0]) _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) kb.yield_results(kb.arg_bindings[1]) + + +@CustomOp.register(library=IREE_LIBRARY) +class _test_add(CustomOp): + signature = "_test_add(Tensor t1, Tensor t2) -> (Tensor)" + + def select(self, ksel: KernelSelection): + t1_desc = ksel.arg_tensor(0) + t1_desc.specialize_all_dims() + t2_desc = ksel.arg_tensor(1) + t2_desc.specialize_all_dims() + result_desc = ksel.return_new_tensor(t1_desc.t.shape, t1_desc.t.dtype) + result_desc.specialize_all_dims() + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + t1, t2 = kb.arg_bindings + result = Operation.create( + "tosa.add", results=[t1.type], operands=[t1, t2] + ).result + kb.yield_results(result) diff --git a/shark_turbine/runtime/op_reg/compiler.py b/shark_turbine/runtime/op_reg/compiler.py index 154070e4b..bb34c0172 100644 --- a/shark_turbine/runtime/op_reg/compiler.py +++ b/shark_turbine/runtime/op_reg/compiler.py @@ -100,8 +100,10 @@ def compile_standalone_kernel( with kb.ip, Location.unknown(): ksel.op.generate(ksel, kb) kb.module_op.verify() + # DO NOT SUBMIT: https://github.com/iree-org/iree/issues/17132 + enable_debug_info = False module_asm = kb.module_op.get_asm( - binary=True, enable_debug_info=True, print_generic_op_form=False + binary=True, enable_debug_info=enable_debug_info, print_generic_op_form=False ) generation_time = default_timer() - start diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index c0478e2cb..386ab2849 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -151,10 +151,10 @@ def testFromTorchDevice(self): print(device.dump_device_info()) def testJit(self): - t = torch.tensor([1, 2, 3, 4, 5]).to("cuda:0") from shark_turbine.ops import iree as iree_ops - iree_ops.trace_tensor("FOO", t) + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to("cuda:0") + print(iree_ops._test_add(t, t)) if __name__ == "__main__": From 75a9955506f010a52b1d84207ffdb644bc53b5d9 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 23 Apr 2024 19:25:36 -0700 Subject: [PATCH 3/7] Implement import/export. --- shark_turbine/runtime/device.py | 102 ++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 30 deletions(-) diff --git a/shark_turbine/runtime/device.py b/shark_turbine/runtime/device.py index 6208f4b8a..e93959343 100644 --- a/shark_turbine/runtime/device.py +++ b/shark_turbine/runtime/device.py @@ -75,6 +75,8 @@ class DeviceState: "driver", "instance", "enumerated_info", + "torch_device", + "dlpack_device_type_code", ] def __init__( @@ -84,11 +86,15 @@ def __init__( device: Optional[HalDevice] = None, vm_instance: Optional[VmInstance] = None, enumerated_info: Optional[dict] = None, + torch_device: Optional[torch.device] = None, + dlpack_device_type_code: int = 0, ): self.instance = vm_instance or get_vm_instance() self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) self.device = device if device else self.driver.create_default_device() self.enumerated_info = enumerated_info or {} + self.torch_device = torch_device + self.dlpack_device_type_code = dlpack_device_type_code @property def enumerated_device_id(self) -> int: @@ -164,7 +170,10 @@ class Device: compile_target_flags: tuple[str, ...] def __new__( - cls, uri: Optional[str] = None, *, device_state: Optional[DeviceState] = None + cls, + uri: Optional[str] = None, + *, + device_state: Optional[DeviceState] = None, ): if uri is not None: # Construction by URI is cached on the thread. @@ -284,6 +293,11 @@ def __exit__(self, type, value, traceback): _CURRENT_THREAD.stack.pop() +################################################################################ +# CPU import/export +################################################################################ + + def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBufferView: hal_device = device.hal_device element_type = dtype_to_element_type(t.dtype) @@ -299,27 +313,6 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe return bv -def _device_import_torch_tensor_hip(device: Device, t: torch.Tensor) -> HalBufferView: - assert hasattr(t, "__dlpack__") - # TODO: The '0' below is the stream (on real-CUDA, it is 1) signifying - # the default stream. It should be the stream upon which we schedule. - bv = device.hal_device.from_dlpack(t) - print("FROM DLPACK:", bv) - return bv - # hal_device = device.hal_device - # element_type = dtype_to_element_type(t.dtype) - # # TODO: In this case, we should be importing the raw buffer, but this is not - # # generically exposed to Python in the IREE runtime. - # bv = device.hal_device.allocator.allocate_buffer_copy( - # memory_type=MemoryType.DEVICE_LOCAL, - # allowed_usage=BufferUsage.DEFAULT, - # device=hal_device, - # buffer=t.detach().numpy(), - # element_type=element_type, - # ) - # return bv - - def _device_export_torch_tensor_cpu( device: Device, bv: HalBufferView, like: torch.Tensor ) -> torch.Tensor: @@ -333,10 +326,44 @@ def _device_export_torch_tensor_cpu( return torch.from_numpy(mapped_array) +################################################################################ +# CUDA and HIP import/export +################################################################################ + + +def _device_import_torch_tensor_cuda_hip( + device: Device, t: torch.Tensor +) -> HalBufferView: + # TODO: The 'None' here tells the producer to synchronize on the default + # stream. For async, we should advance our timeline and signal when an + # event is raised on Torch's stream at the current position. + capsule = t.__dlpack__(None) + bv = device.hal_device.from_dlpack_capsule(capsule) + return bv + + +def _device_export_torch_tensor_cuda_hip( + device: Device, bv: HalBufferView, like: torch.Tensor +) -> torch.Tensor: + state = device._s + device_type_code = state.dlpack_device_type_code + assert device_type_code > 0 + device_index = state.torch_device.index + t = torch.from_dlpack( + device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index) + ) + if t.dtype != like.dtype: + t = t.view(like.dtype) + # TODO: For async, we should enqueue an event on Torch's stream which will + # signal when this tensor is produced (i.e. at the current point in our + # timeline). + return t + + # Mapping of torch tensor importers keyed by driver name. TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { - "cuda": None, - "hip": _device_import_torch_tensor_hip, + "cuda": _device_import_torch_tensor_cuda_hip, + "hip": _device_import_torch_tensor_cuda_hip, "local-sync": _device_import_torch_tensor_cpu, "local-task": _device_import_torch_tensor_cpu, } @@ -344,8 +371,8 @@ def _device_export_torch_tensor_cpu( TORCH_TENSOR_EXPORTERS: dict[ str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] ] = { - "cuda": None, - "hip": None, + "cuda": _device_export_torch_tensor_cuda_hip, + "hip": _device_export_torch_tensor_cuda_hip, "local-sync": _device_export_torch_tensor_cpu, "local-task": _device_export_torch_tensor_cpu, } @@ -413,7 +440,18 @@ def get_device_from_torch(torch_device: torch.device) -> Device: def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: torch_type = torch_device.type if torch_type == "cpu": - return Device("local-task") + cpu_driver = get_driver("local-task") + cpu_enumerated = cpu_driver.query_available_devices() + assert len(cpu_enumerated) >= 1 + cpu_default = cpu_enumerated[0] + cpu_device_state = DeviceState( + driver=cpu_driver, + device=cpu_driver.create_default_device(), + enumerated_info=cpu_default, + torch_device=torch_device, + dlpack_device_type_code=1, + ) + return Device(device_state=cpu_device_state) elif torch_type == "cuda": # Fork based on HIP or real CUDA. props = torch.cuda.get_device_properties(torch_device) @@ -428,7 +466,8 @@ def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: - device = _create_cuda_like_device(torch_device, props, "hip") + # Note that the dlpack device type code for real CUDA ROCM is 2. + device = _create_cuda_like_device(torch_device, props, "hip", 2) if device: device.compile_target_flags = device.compile_target_flags + ( f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}", @@ -437,7 +476,8 @@ def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: - device = _create_cuda_like_device(torch_device, props, "hip") + # Note that the dlpack device type code for ROCM is 10. + device = _create_cuda_like_device(torch_device, props, "hip", 10) if device: gcn_arch_name = props.gcnArchName device.compile_target_flags = device.compile_target_flags + ( @@ -447,7 +487,7 @@ def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: def _create_cuda_like_device( - torch_device: torch.device, props, driver_name: str + torch_device: torch.device, props, driver_name: str, dlpack_device_type_code: int ) -> Optional[Device]: if torch.cuda.device_count() > 1: warnings.warn( @@ -467,6 +507,8 @@ def _create_cuda_like_device( device=hal_device, vm_instance=get_vm_instance(), enumerated_info=device_info, + torch_device=torch_device, + dlpack_device_type_code=dlpack_device_type_code, ) device = Device(device_state=device_state) return device From c59b870f238219e0daa30ea7e42d92f47aa77f50 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 24 Apr 2024 19:11:44 -0700 Subject: [PATCH 4/7] Reach working state. --- shark_turbine/runtime/device.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/shark_turbine/runtime/device.py b/shark_turbine/runtime/device.py index e93959343..f4d2ffd4c 100644 --- a/shark_turbine/runtime/device.py +++ b/shark_turbine/runtime/device.py @@ -334,6 +334,9 @@ def _device_export_torch_tensor_cpu( def _device_import_torch_tensor_cuda_hip( device: Device, t: torch.Tensor ) -> HalBufferView: + # We currently only support contiguous, so ensure that. + if not t.is_contiguous(): + t = t.contiguous() # TODO: The 'None' here tells the producer to synchronize on the default # stream. For async, we should advance our timeline and signal when an # event is raised on Torch's stream at the current position. From ae715544506a5a2938c75931da498d3e4931be41 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 24 Apr 2024 19:17:11 -0700 Subject: [PATCH 5/7] Finishes tests. --- tests/runtime/device_test.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index 386ab2849..47567ae66 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -153,8 +153,20 @@ def testFromTorchDevice(self): def testJit(self): from shark_turbine.ops import iree as iree_ops - t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to("cuda:0") - print(iree_ops._test_add(t, t)) + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") + result = iree_ops._test_add(t, t) + expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu") + torch.testing.assert_close(result.cpu(), expected) + + +class TorchCPUInterop(unittest.TestCase): + def testJit(self): + from shark_turbine.ops import iree as iree_ops + + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") + result = iree_ops._test_add(t, t) + expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu") + torch.testing.assert_close(result, expected) if __name__ == "__main__": From 2b8e3017cbace640fdd0461bee1098c606d3ac3c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 24 Apr 2024 19:18:15 -0700 Subject: [PATCH 6/7] Black --- tests/runtime/device_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index 47567ae66..f6c172415 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -155,7 +155,7 @@ def testJit(self): t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") result = iree_ops._test_add(t, t) - expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu") + expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") torch.testing.assert_close(result.cpu(), expected) @@ -165,7 +165,7 @@ def testJit(self): t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = iree_ops._test_add(t, t) - expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu") + expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") torch.testing.assert_close(result, expected) From 1d4afcc2d8ad7b14c55fef190e5771ae5dc33169 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 24 Apr 2024 19:30:26 -0700 Subject: [PATCH 7/7] Mypy --- shark_turbine/ops/iree.py | 5 +++-- shark_turbine/runtime/device.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/shark_turbine/ops/iree.py b/shark_turbine/ops/iree.py index 66167c886..ce6577285 100644 --- a/shark_turbine/ops/iree.py +++ b/shark_turbine/ops/iree.py @@ -72,12 +72,13 @@ def select(self, ksel: KernelSelection): t1_desc.specialize_all_dims() t2_desc = ksel.arg_tensor(1) t2_desc.specialize_all_dims() - result_desc = ksel.return_new_tensor(t1_desc.t.shape, t1_desc.t.dtype) + result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) result_desc.specialize_all_dims() def generate(self, ksel: KernelSelection, kb: KernelBuilder): t1, t2 = kb.arg_bindings + result_type = t1.type # type: ignore result = Operation.create( - "tosa.add", results=[t1.type], operands=[t1, t2] + "tosa.add", results=[result_type], operands=[t1, t2] ).result kb.yield_results(result) diff --git a/shark_turbine/runtime/device.py b/shark_turbine/runtime/device.py index f4d2ffd4c..402366017 100644 --- a/shark_turbine/runtime/device.py +++ b/shark_turbine/runtime/device.py @@ -106,14 +106,14 @@ def enumerated_device_id(self) -> int: @property def enumerated_path(self) -> str: try: - return self.enumerated_device_id["path"] + return self.enumerated_info["path"] except KeyError as e: raise RuntimeError("No enumerated path for device") from e @property def enumerated_name(self) -> str: try: - return self.enumerated_device_id["name"] + return self.enumerated_info["name"] except KeyError as e: raise RuntimeError("No enumerated name for device") from e @@ -351,7 +351,9 @@ def _device_export_torch_tensor_cuda_hip( state = device._s device_type_code = state.dlpack_device_type_code assert device_type_code > 0 - device_index = state.torch_device.index + torch_device = state.torch_device + assert torch_device is not None + device_index = torch_device.index t = torch.from_dlpack( device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index) )