diff --git a/shark_turbine/ops/iree.py b/shark_turbine/ops/iree.py index e28826db8..ce6577285 100644 --- a/shark_turbine/ops/iree.py +++ b/shark_turbine/ops/iree.py @@ -8,6 +8,7 @@ from typing import cast from ..support.ir_imports import ( + Operation, RankedTensorType, StringAttr, Value, @@ -60,3 +61,24 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): key = cast(AttrArg, ksel.arg_descs[0]) _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) kb.yield_results(kb.arg_bindings[1]) + + +@CustomOp.register(library=IREE_LIBRARY) +class _test_add(CustomOp): + signature = "_test_add(Tensor t1, Tensor t2) -> (Tensor)" + + def select(self, ksel: KernelSelection): + t1_desc = ksel.arg_tensor(0) + t1_desc.specialize_all_dims() + t2_desc = ksel.arg_tensor(1) + t2_desc.specialize_all_dims() + result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) + result_desc.specialize_all_dims() + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + t1, t2 = kb.arg_bindings + result_type = t1.type # type: ignore + result = Operation.create( + "tosa.add", results=[result_type], operands=[t1, t2] + ).result + kb.yield_results(result) diff --git a/shark_turbine/runtime/device.py b/shark_turbine/runtime/device.py index a294525df..402366017 100644 --- a/shark_turbine/runtime/device.py +++ b/shark_turbine/runtime/device.py @@ -7,6 +7,7 @@ from functools import lru_cache from typing import Callable, Optional, Union from threading import local, Lock +import warnings import torch @@ -73,6 +74,9 @@ class DeviceState: "device", "driver", "instance", + "enumerated_info", + "torch_device", + "dlpack_device_type_code", ] def __init__( @@ -81,10 +85,37 @@ def __init__( driver: Union[str, HalDriver], device: Optional[HalDevice] = None, vm_instance: Optional[VmInstance] = None, + enumerated_info: Optional[dict] = None, + torch_device: Optional[torch.device] = None, + dlpack_device_type_code: int = 0, ): self.instance = vm_instance or get_vm_instance() self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) self.device = device if device else self.driver.create_default_device() + self.enumerated_info = enumerated_info or {} + self.torch_device = torch_device + self.dlpack_device_type_code = dlpack_device_type_code + + @property + def enumerated_device_id(self) -> int: + try: + return self.enumerated_info["device_id"] + except KeyError as e: + raise RuntimeError("No enumerated device_id for device") from e + + @property + def enumerated_path(self) -> str: + try: + return self.enumerated_info["path"] + except KeyError as e: + raise RuntimeError("No enumerated path for device") from e + + @property + def enumerated_name(self) -> str: + try: + return self.enumerated_info["name"] + except KeyError as e: + raise RuntimeError("No enumerated name for device") from e @staticmethod @lru_cache(maxsize=None) @@ -139,7 +170,10 @@ class Device: compile_target_flags: tuple[str, ...] def __new__( - cls, uri: Optional[str] = None, *, device_state: Optional[DeviceState] = None + cls, + uri: Optional[str] = None, + *, + device_state: Optional[DeviceState] = None, ): if uri is not None: # Construction by URI is cached on the thread. @@ -243,6 +277,9 @@ def clear(self): ... raise MismatchedDeviceSetClearError() + def dump_device_info(self) -> str: + return self._s.driver.dump_device_info(self._s.enumerated_device_id) + def __repr__(self): return f"" @@ -256,6 +293,11 @@ def __exit__(self, type, value, traceback): _CURRENT_THREAD.stack.pop() +################################################################################ +# CPU import/export +################################################################################ + + def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBufferView: hal_device = device.hal_device element_type = dtype_to_element_type(t.dtype) @@ -284,8 +326,49 @@ def _device_export_torch_tensor_cpu( return torch.from_numpy(mapped_array) +################################################################################ +# CUDA and HIP import/export +################################################################################ + + +def _device_import_torch_tensor_cuda_hip( + device: Device, t: torch.Tensor +) -> HalBufferView: + # We currently only support contiguous, so ensure that. + if not t.is_contiguous(): + t = t.contiguous() + # TODO: The 'None' here tells the producer to synchronize on the default + # stream. For async, we should advance our timeline and signal when an + # event is raised on Torch's stream at the current position. + capsule = t.__dlpack__(None) + bv = device.hal_device.from_dlpack_capsule(capsule) + return bv + + +def _device_export_torch_tensor_cuda_hip( + device: Device, bv: HalBufferView, like: torch.Tensor +) -> torch.Tensor: + state = device._s + device_type_code = state.dlpack_device_type_code + assert device_type_code > 0 + torch_device = state.torch_device + assert torch_device is not None + device_index = torch_device.index + t = torch.from_dlpack( + device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index) + ) + if t.dtype != like.dtype: + t = t.view(like.dtype) + # TODO: For async, we should enqueue an event on Torch's stream which will + # signal when this tensor is produced (i.e. at the current point in our + # timeline). + return t + + # Mapping of torch tensor importers keyed by driver name. TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { + "cuda": _device_import_torch_tensor_cuda_hip, + "hip": _device_import_torch_tensor_cuda_hip, "local-sync": _device_import_torch_tensor_cpu, "local-task": _device_import_torch_tensor_cpu, } @@ -293,11 +376,15 @@ def _device_export_torch_tensor_cpu( TORCH_TENSOR_EXPORTERS: dict[ str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] ] = { + "cuda": _device_export_torch_tensor_cuda_hip, + "hip": _device_export_torch_tensor_cuda_hip, "local-sync": _device_export_torch_tensor_cpu, "local-task": _device_export_torch_tensor_cpu, } DEVICE_TARGET_COMPILE_FLAGS: dict[str, tuple[str, ...]] = { + "cuda": ("--iree-hal-target-backends=cuda",), + "hip": ("--iree-hal-target-backends=rocm",), "local-task": ( "--iree-hal-target-backends=llvm-cpu", "--iree-llvmcpu-target-cpu-features=host", @@ -357,14 +444,79 @@ def get_device_from_torch(torch_device: torch.device) -> Device: def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: torch_type = torch_device.type - uri = None if torch_type == "cpu": - uri = "local-task" + cpu_driver = get_driver("local-task") + cpu_enumerated = cpu_driver.query_available_devices() + assert len(cpu_enumerated) >= 1 + cpu_default = cpu_enumerated[0] + cpu_device_state = DeviceState( + driver=cpu_driver, + device=cpu_driver.create_default_device(), + enumerated_info=cpu_default, + torch_device=torch_device, + dlpack_device_type_code=1, + ) + return Device(device_state=cpu_device_state) + elif torch_type == "cuda": + # Fork based on HIP or real CUDA. + props = torch.cuda.get_device_properties(torch_device) + if not hasattr(props, "gcnArchName"): + # Real CUDA. + return _create_cuda_device(torch_device, props) + else: + # HIP as CUDA. + return _create_hip_device(torch_device, props) + + return None + + +def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: + # Note that the dlpack device type code for real CUDA ROCM is 2. + device = _create_cuda_like_device(torch_device, props, "hip", 2) + if device: + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}", + ) + return device + + +def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: + # Note that the dlpack device type code for ROCM is 10. + device = _create_cuda_like_device(torch_device, props, "hip", 10) + if device: + gcn_arch_name = props.gcnArchName + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-rocm-target-chip={gcn_arch_name}", + ) + return device - if uri is None: - return None - return Device(uri) +def _create_cuda_like_device( + torch_device: torch.device, props, driver_name: str, dlpack_device_type_code: int +) -> Optional[Device]: + if torch.cuda.device_count() > 1: + warnings.warn( + f"Multiple {driver_name} devices detected: Turbine does not yet " + f"guarantee stable device mapping" + ) + + requested_index = torch_device.index + driver = get_driver(driver_name) + available_infos = driver.query_available_devices() + if requested_index >= len(available_infos): + return None + device_info = available_infos[requested_index] + hal_device = driver.create_device(device_info) + device_state = DeviceState( + driver=driver, + device=hal_device, + vm_instance=get_vm_instance(), + enumerated_info=device_info, + torch_device=torch_device, + dlpack_device_type_code=dlpack_device_type_code, + ) + device = Device(device_state=device_state) + return device ############################################################################### diff --git a/shark_turbine/runtime/op_reg/base.py b/shark_turbine/runtime/op_reg/base.py index 2fcf7863f..d74c49a55 100644 --- a/shark_turbine/runtime/op_reg/base.py +++ b/shark_turbine/runtime/op_reg/base.py @@ -74,11 +74,10 @@ def def_library(ns) -> torch.library.Library: def default_dispatch_keys() -> list[str]: - # TODO: Dynamically determine what devices to register against. - # Note that we have to register against specific keys instead of the - # fallback, as fallback is too broad and breaks certain elements of - # fx tracing. - return ["CPU"] + keys = ["CPU"] + if torch.cuda.is_available(): + keys.append("CUDA") + return keys # All such custom kernels are registered in the 'turbine' library/namespace. diff --git a/shark_turbine/runtime/op_reg/compiler.py b/shark_turbine/runtime/op_reg/compiler.py index 154070e4b..bb34c0172 100644 --- a/shark_turbine/runtime/op_reg/compiler.py +++ b/shark_turbine/runtime/op_reg/compiler.py @@ -100,8 +100,10 @@ def compile_standalone_kernel( with kb.ip, Location.unknown(): ksel.op.generate(ksel, kb) kb.module_op.verify() + # DO NOT SUBMIT: https://github.com/iree-org/iree/issues/17132 + enable_debug_info = False module_asm = kb.module_op.get_asm( - binary=True, enable_debug_info=True, print_generic_op_form=False + binary=True, enable_debug_info=enable_debug_info, print_generic_op_form=False ) generation_time = default_timer() - start diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index c37750cca..f6c172415 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -7,6 +7,7 @@ import logging import unittest import threading +import warnings import torch @@ -122,6 +123,52 @@ def testCompilerFlags(self): self.assertIn("--iree-hal-target-backends=llvm-cpu", d.compile_target_flags) +# Make CUDA testing conditional. +test_cuda_device = None +if torch.cuda.is_available(): + try: + test_cuda_device = torch.device("cuda:0") + except: + ... +if test_cuda_device is None: + warnings.warn("Not testing CUDA interop (device not available)") + + +@unittest.skipUnless(test_cuda_device, "CUDA not available") +class TorchCUDAInterop(unittest.TestCase): + def setUp(self): + self._cuda_props = torch.cuda.get_device_properties(test_cuda_device) + print(self._cuda_props) + print(dir(self._cuda_props)) + self._is_hip = False + if hasattr(self._cuda_props, "gcnArchName"): + print("Detected HIP device as CUDA") + self._is_hip = True + + def testFromTorchDevice(self): + torch_device = torch.device("cuda:0") + device = get_device_from_torch(torch_device) + print(device.dump_device_info()) + + def testJit(self): + from shark_turbine.ops import iree as iree_ops + + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") + result = iree_ops._test_add(t, t) + expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") + torch.testing.assert_close(result.cpu(), expected) + + +class TorchCPUInterop(unittest.TestCase): + def testJit(self): + from shark_turbine.ops import iree as iree_ops + + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") + result = iree_ops._test_add(t, t) + expected = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], device="cpu") + torch.testing.assert_close(result, expected) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main()